summaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
-rw-r--r--talpid-core/src/routing/android.rs30
-rw-r--r--talpid-core/src/routing/linux.rs170
-rw-r--r--talpid-core/src/routing/macos.rs146
-rw-r--r--talpid-core/src/routing/unix.rs87
-rw-r--r--talpid-core/src/routing/windows.rs69
-rw-r--r--talpid-core/src/tunnel/mod.rs15
-rw-r--r--talpid-core/src/tunnel/wireguard/mod.rs15
-rw-r--r--talpid-core/src/tunnel_state_machine/connected_state.rs15
-rw-r--r--talpid-core/src/tunnel_state_machine/connecting_state.rs20
-rw-r--r--talpid-core/src/tunnel_state_machine/mod.rs10
-rw-r--r--talpid-core/src/winnet.rs20
-rw-r--r--windows/winnet/src/winnet/routing/routemanager.cpp49
-rw-r--r--windows/winnet/src/winnet/routing/routemanager.h1
-rw-r--r--windows/winnet/src/winnet/winnet.cpp29
-rw-r--r--windows/winnet/src/winnet/winnet.h7
15 files changed, 474 insertions, 209 deletions
diff --git a/talpid-core/src/routing/android.rs b/talpid-core/src/routing/android.rs
index 1f47027343..d364f0769e 100644
--- a/talpid-core/src/routing/android.rs
+++ b/talpid-core/src/routing/android.rs
@@ -1,5 +1,5 @@
-use crate::routing::RequiredRoute;
-use futures01::{sync::oneshot, Async, Future};
+use crate::routing::{imp::RouteManagerCommand, RequiredRoute};
+use futures01::{stream::Stream, sync::mpsc};
use std::collections::HashSet;
/// Stub error type for routing errors on Android.
@@ -9,30 +9,26 @@ pub struct Error;
/// Stub route manager for Android
pub struct RouteManagerImpl {
- shutdown_rx: oneshot::Receiver<oneshot::Sender<()>>,
+ manage_rx: mpsc::UnboundedReceiver<RouteManagerCommand>,
}
impl RouteManagerImpl {
pub fn new(
_required_routes: HashSet<RequiredRoute>,
- shutdown_rx: oneshot::Receiver<oneshot::Sender<()>>,
+ manage_rx: mpsc::UnboundedReceiver<RouteManagerCommand>,
) -> Result<Self, Error> {
- Ok(RouteManagerImpl { shutdown_rx })
+ Ok(RouteManagerImpl { manage_rx })
}
-}
-
-impl Future for RouteManagerImpl {
- type Item = ();
- type Error = Error;
- fn poll(&mut self) -> Result<Async<()>, Error> {
- match self.shutdown_rx.poll() {
- Ok(Async::Ready(result_tx)) => {
- result_tx.send(()).map_err(|()| Error)?;
- Ok(Async::Ready(()))
+ pub fn wait(self) -> Result<(), Error> {
+ for msg in self.manage_rx.wait() {
+ if let Ok(command) = msg {
+ if let RouteManagerCommand::Shutdown(tx) = command {
+ tx.send(()).map_err(|()| Error)?;
+ break;
+ }
}
- Ok(Async::NotReady) => Ok(Async::NotReady),
- Err(_) => Ok(Async::Ready(())),
}
+ Ok(())
}
}
diff --git a/talpid-core/src/routing/linux.rs b/talpid-core/src/routing/linux.rs
index 124c25e319..7a7bfa4865 100644
--- a/talpid-core/src/routing/linux.rs
+++ b/talpid-core/src/routing/linux.rs
@@ -1,17 +1,21 @@
-use crate::routing::{NetNode, Node, RequiredRoute, Route};
+use crate::routing::{imp::RouteManagerCommand, NetNode, Node, RequiredRoute, Route};
+
+use talpid_types::ErrorExt;
use ipnetwork::IpNetwork;
use std::{
collections::{BTreeMap, HashSet},
io,
net::IpAddr,
+ thread,
};
-use futures01::sync::oneshot as old_oneshot;
+use futures01::{stream::Stream as old_stream, sync::mpsc as old_mpsc};
use futures::{
- channel::mpsc::UnboundedReceiver, compat::Future01CompatExt, future::FutureExt, StreamExt,
- TryStreamExt,
+ channel::mpsc::{self, UnboundedReceiver},
+ future::FutureExt,
+ StreamExt, TryStreamExt,
};
@@ -62,10 +66,13 @@ pub enum Error {
#[error(display = "Unknown device index - {}", _0)]
UnknownDeviceIndex(u32),
+
+ #[error(display = "Shutting down route manager")]
+ Shutdown,
}
pub struct RouteManagerImpl {
- shutdown_rx: old_oneshot::Receiver<old_oneshot::Sender<()>>,
+ manage_rx: old_mpsc::UnboundedReceiver<RouteManagerCommand>,
manager: RouteManagerImplInner,
runtime: tokio02::runtime::Runtime,
}
@@ -74,7 +81,7 @@ impl RouteManagerImpl {
/// Creates a new RouteManagerImplInner.
pub fn new(
required_routes: HashSet<RequiredRoute>,
- shutdown_rx: old_oneshot::Receiver<old_oneshot::Sender<()>>,
+ manage_rx: old_mpsc::UnboundedReceiver<RouteManagerCommand>,
) -> Result<Self> {
let mut runtime = tokio02::runtime::Builder::new()
.basic_scheduler()
@@ -87,7 +94,7 @@ impl RouteManagerImpl {
let manager = runtime.block_on(RouteManagerImplInner::new(required_routes))?;
Ok(Self {
- shutdown_rx,
+ manage_rx,
runtime,
manager,
})
@@ -95,11 +102,28 @@ impl RouteManagerImpl {
pub fn wait(self) -> Result<()> {
let Self {
- shutdown_rx,
+ manage_rx,
mut runtime,
manager,
} = self;
- runtime.block_on(manager.into_future(shutdown_rx))
+
+ let (new_manage_tx, new_manage_rx) = mpsc::unbounded();
+
+ thread::spawn(move || {
+ for msg in manage_rx.wait() {
+ match msg {
+ Ok(msg) => {
+ if new_manage_tx.unbounded_send(msg).is_err() {
+ log::error!("RouteManager receiver unexpectedly dropped");
+ break;
+ }
+ }
+ Err(_) => break,
+ }
+ }
+ });
+
+ runtime.block_on(manager.into_future(new_manage_rx))
}
}
@@ -140,32 +164,12 @@ impl RouteManagerImplInner {
let iface_map = Self::initialize_link_map(&handle).await?;
-
- let mut required_normal_routes = HashSet::new();
- let mut required_default_routes = HashSet::new();
-
- for route in required_routes {
- match route.node {
- NetNode::RealNode(node) => {
- required_normal_routes
- .insert(Route::new(node, route.prefix).table(route.table_id));
- }
- NetNode::DefaultNode => {
- required_default_routes.insert(RequiredDefaultRoute {
- table_id: route.table_id,
- destination: route.prefix,
- });
- }
- }
- }
-
-
let mut monitor = Self {
iface_map,
handle,
messages,
- required_default_routes,
+ required_default_routes: HashSet::new(),
added_routes: HashSet::new(),
default_routes: HashSet::new(),
@@ -179,24 +183,72 @@ impl RouteManagerImplInner {
monitor.best_default_node_v6 =
Self::pick_best_default_node(&monitor.default_routes, IpVersion::V6);
+ monitor.add_required_routes(required_routes).await?;
- for normal_route in required_normal_routes.into_iter() {
- monitor.add_route(normal_route).await?;
- }
+ Ok(monitor)
+ }
- for route in monitor.required_default_routes.clone().into_iter() {
+ async fn add_required_default_routes(
+ &mut self,
+ required_default_routes: HashSet<RequiredDefaultRoute>,
+ ) -> Result<()> {
+ for route in required_default_routes.into_iter() {
if let (false, _, Some(default_node)) | (true, Some(default_node), _) = (
route.destination.is_ipv4(),
- &monitor.best_default_node_v4,
- &monitor.best_default_node_v6,
+ &self.best_default_node_v4,
+ &self.best_default_node_v6,
) {
// best to pick a single node identifier rather than device + ip
let new_route =
Route::new(default_node.clone(), route.destination).table(route.table_id);
- monitor.add_route(new_route).await?;
+ self.add_route(new_route).await?;
}
+ self.required_default_routes.insert(route);
}
- Ok(monitor)
+ Ok(())
+ }
+
+ async fn add_required_routes(&mut self, required_routes: HashSet<RequiredRoute>) -> Result<()> {
+ let mut required_normal_routes = HashSet::new();
+ let mut required_default_routes = HashSet::new();
+
+ for route in required_routes {
+ match route.node {
+ NetNode::RealNode(node) => {
+ required_normal_routes
+ .insert(Route::new(node, route.prefix).table(route.table_id));
+ }
+ NetNode::DefaultNode => {
+ required_default_routes.insert(RequiredDefaultRoute {
+ table_id: route.table_id,
+ destination: route.prefix,
+ });
+ }
+ }
+ }
+
+ for normal_route in required_normal_routes.into_iter() {
+ self.add_route(normal_route).await?;
+ }
+
+ if self
+ .add_required_default_routes(required_default_routes.clone())
+ .await
+ .is_err()
+ {
+ log::trace!("Refreshing default routes which may be stale");
+
+ self.default_routes = self.get_default_routes().await?;
+ self.best_default_node_v4 =
+ Self::pick_best_default_node(&self.default_routes, IpVersion::V4);
+ self.best_default_node_v6 =
+ Self::pick_best_default_node(&self.default_routes, IpVersion::V6);
+
+ self.add_required_default_routes(required_default_routes)
+ .await?;
+ }
+
+ Ok(())
}
async fn get_default_routes(&self) -> Result<HashSet<Route>> {
@@ -398,27 +450,47 @@ impl RouteManagerImplInner {
pub async fn into_future(
mut self,
- shutdown_rx: futures01::sync::oneshot::Receiver<futures01::sync::oneshot::Sender<()>>,
+ mut manage_rx: UnboundedReceiver<RouteManagerCommand>,
) -> Result<()> {
- let mut shutdown = shutdown_rx.compat().fuse();
loop {
futures::select! {
- shutdown_signal = shutdown => {
- log::trace!("Shutting down route manager");
- self.cleanup_routes().await;
- log::trace!("Route manager done");
- if let Ok(shutdown_signal) = shutdown_signal {
- let _ = shutdown_signal.send(());
- }
- return Ok(());
+ command = manage_rx.select_next_some().fuse() => {
+ self.process_command(command).await?;
},
(route_change, socket) = self.messages.select_next_some().fuse() => {
- self.process_netlink_message(route_change).await?;
+ if let Err(error) = self.process_netlink_message(route_change).await {
+ log::error!("{}", error.display_chain_with_msg("Failed to process netlink message"));
+ }
}
};
}
}
+ async fn process_command(&mut self, command: RouteManagerCommand) -> Result<()> {
+ match command {
+ RouteManagerCommand::Shutdown(shutdown_signal) => {
+ log::trace!("Shutting down route manager");
+ self.cleanup_routes().await;
+ log::trace!("Route manager done");
+ let _ = shutdown_signal.send(());
+ return Err(Error::Shutdown);
+ }
+ RouteManagerCommand::AddRoutes(routes, result_rx) => {
+ log::debug!("Adding routes: {:?}", routes);
+ if let Err(error) = self.add_required_routes(routes.clone()).await {
+ let _ = result_rx.send(Err(error));
+ } else {
+ let _ = result_rx.send(Ok(()));
+ }
+ }
+ RouteManagerCommand::ClearRoutes => {
+ log::debug!("Clearing routes");
+ self.cleanup_routes().await;
+ }
+ }
+ Ok(())
+ }
+
async fn process_netlink_message(&mut self, msg: NetlinkMessage<RtnlMessage>) -> Result<()> {
match msg.payload {
NetlinkPayload::InnerMessage(RtnlMessage::NewLink(new_link)) => {
diff --git a/talpid-core/src/routing/macos.rs b/talpid-core/src/routing/macos.rs
index 4f05927824..56f7257024 100644
--- a/talpid-core/src/routing/macos.rs
+++ b/talpid-core/src/routing/macos.rs
@@ -1,4 +1,4 @@
-use crate::routing::{NetNode, Node, RequiredRoute, Route};
+use crate::routing::{imp::RouteManagerCommand, NetNode, Node, RequiredRoute, Route};
use ipnetwork::IpNetwork;
use std::{
@@ -8,7 +8,11 @@ use std::{
process::{Command, ExitStatus, Stdio},
};
-use futures01::{stream, sync::oneshot, Async, Future, IntoFuture, Stream};
+use futures01::{
+ stream,
+ sync::{mpsc, oneshot},
+ Async, Future, IntoFuture, Stream,
+};
use tokio_process::{Child, CommandExt};
@@ -69,18 +73,16 @@ pub struct RouteManagerImpl {
current_state: RouteManagerState,
v4_gateway: Option<Node>,
v6_gateway: Option<Node>,
- shutdown_rx: Option<oneshot::Receiver<oneshot::Sender<()>>>,
+ manage_rx: Option<mpsc::UnboundedReceiver<RouteManagerCommand>>,
}
impl RouteManagerImpl {
pub fn new(
required_routes: HashSet<RequiredRoute>,
- shutdown_rx: oneshot::Receiver<oneshot::Sender<()>>,
+ manage_rx: mpsc::UnboundedReceiver<RouteManagerCommand>,
) -> Result<Self> {
- let mut applied_routes = HashSet::new();
- let mut routes_to_apply = vec![];
- let mut default_destinations = HashSet::new();
+ let change_listener = ChangeListener::new().map_err(Error::FailedToMonitorRoutes)?;
let v4_gateway = Self::get_default_node_cmd("-inet").wait()?;
let v6_gateway = Self::get_default_node_cmd("-inet6").wait()?;
@@ -89,6 +91,24 @@ impl RouteManagerImpl {
return Err(Error::NoDefaultRoute);
}
+ let mut manager = Self {
+ default_destinations: HashSet::new(),
+ applied_routes: HashSet::new(),
+ current_state: RouteManagerState::Listening(change_listener),
+ manage_rx: Some(manage_rx),
+ v4_gateway,
+ v6_gateway,
+ };
+
+ manager.add_required_routes(required_routes)?;
+
+ Ok(manager)
+ }
+
+ fn add_required_routes(&mut self, required_routes: HashSet<RequiredRoute>) -> Result<()> {
+ let mut routes_to_apply = vec![];
+ let mut default_destinations = HashSet::new();
+
for route in required_routes {
match route.node {
NetNode::DefaultNode => {
@@ -99,47 +119,22 @@ impl RouteManagerImpl {
}
}
- let apply_routes_fn = || -> Result<()> {
- for route in routes_to_apply {
- Self::add_route(&route).wait()?;
- applied_routes.insert(route);
- }
- for destination in default_destinations.iter() {
- match (&v4_gateway, &v6_gateway, destination.is_ipv4()) {
- (Some(gateway), _, true) | (_, Some(gateway), false) => {
- let route = Route::new(gateway.clone(), *destination);
- Self::add_route(&route).wait()?;
- applied_routes.insert(route);
- }
- _ => (),
- };
- }
-
- Ok(())
- };
-
- if let Err(e) = apply_routes_fn() {
- log::error!("Failed to apply routes - {}", e);
- for applied_route in applied_routes.iter() {
- if let Err(removal_err) = Self::delete_route(applied_route.prefix).wait() {
- log::error!(
- "Failed to clean up routes after failing to set them up - {}",
- removal_err
- );
+ for route in routes_to_apply {
+ Self::add_route(&route).wait()?;
+ self.applied_routes.insert(route);
+ }
+ for destination in default_destinations.iter() {
+ match (&self.v4_gateway, &self.v6_gateway, destination.is_ipv4()) {
+ (Some(gateway), _, true) | (_, Some(gateway), false) => {
+ let route = Route::new(gateway.clone(), *destination);
+ Self::add_route(&route).wait()?;
+ self.applied_routes.insert(route);
}
- }
- return Err(e);
+ _ => (),
+ };
}
- let change_listener = ChangeListener::new().map_err(Error::FailedToMonitorRoutes)?;
- Ok(Self {
- default_destinations,
- applied_routes,
- current_state: RouteManagerState::Listening(change_listener),
- shutdown_rx: Some(shutdown_rx),
- v4_gateway,
- v6_gateway,
- })
+ Ok(())
}
// Retrieves the node that's currently used to reach 0.0.0.0/0
@@ -230,10 +225,7 @@ impl RouteManagerImpl {
.map_err(Error::FailedToAddRoute)
}
- fn shutdown_future(
- &self,
- shutdown_done_tx: Option<oneshot::Sender<()>>,
- ) -> impl Future<Item = (), Error = ()> + Send {
+ fn cleanup_routes(&self) -> impl Future<Item = (), Error = ()> + Send {
let remove_route_future = |route: &Route| {
Self::delete_route(route.prefix).then(|removal| {
match removal {
@@ -261,16 +253,21 @@ impl RouteManagerImpl {
_ => None,
}
}));
- stream::futures_ordered(routes_to_remove)
- .for_each(|_| Ok(()))
- .and_then(|_| {
- if let Some(tx) = shutdown_done_tx {
- if tx.send(()).is_err() {
- log::debug!("RouteManager already dropped")
- }
+ stream::futures_ordered(routes_to_remove).for_each(|_| Ok(()))
+ }
+
+ fn shutdown_future(
+ &self,
+ shutdown_done_tx: Option<oneshot::Sender<()>>,
+ ) -> impl Future<Item = (), Error = ()> + Send {
+ self.cleanup_routes().and_then(|_| {
+ if let Some(tx) = shutdown_done_tx {
+ if tx.send(()).is_err() {
+ log::debug!("RouteManager already dropped")
}
- Ok(())
- })
+ }
+ Ok(())
+ })
}
fn apply_new_default_routes(
@@ -323,20 +320,35 @@ impl Future for RouteManagerImpl {
type Item = ();
type Error = Error;
fn poll(&mut self) -> Result<Async<()>> {
- if let Some(mut shutdown_rx) = self.shutdown_rx.take() {
- match shutdown_rx.poll() {
- Ok(Async::Ready(shutdown_tx)) => {
- self.current_state = RouteManagerState::Shutdown(Box::new(
- self.shutdown_future(Some(shutdown_tx)),
- ));
- }
+ if let Some(mut manage_rx) = self.manage_rx.take() {
+ match manage_rx.poll() {
+ Ok(Async::Ready(Some(command))) => match command {
+ RouteManagerCommand::Shutdown(tx) => {
+ self.current_state =
+ RouteManagerState::Shutdown(Box::new(self.shutdown_future(Some(tx))));
+ }
+ RouteManagerCommand::AddRoutes(routes, result_tx) => {
+ self.manage_rx = Some(manage_rx);
+ log::debug!("Adding routes: {:?}", routes);
+ if let Err(error) = self.add_required_routes(routes) {
+ let _ = result_tx.send(Err(error));
+ } else {
+ let _ = result_tx.send(Ok(()));
+ }
+ }
+ RouteManagerCommand::ClearRoutes => {
+ self.manage_rx = Some(manage_rx);
+ log::debug!("Clearing routes");
+ let _ = self.cleanup_routes().wait();
+ }
+ },
// handle is already dropped
- Err(_) => {
+ Ok(Async::Ready(None)) | Err(_) => {
self.current_state =
RouteManagerState::Shutdown(Box::new(self.shutdown_future(None)));
}
Ok(Async::NotReady) => {
- self.shutdown_rx = Some(shutdown_rx);
+ self.manage_rx = Some(manage_rx);
}
};
}
diff --git a/talpid-core/src/routing/unix.rs b/talpid-core/src/routing/unix.rs
index 8dbfdf4d56..3f50ef3db5 100644
--- a/talpid-core/src/routing/unix.rs
+++ b/talpid-core/src/routing/unix.rs
@@ -2,8 +2,15 @@
#![cfg_attr(target_os = "windows", allow(dead_code))]
// TODO: remove the allow(dead_code) for android once it's up to scratch.
use super::RequiredRoute;
-use futures01::{sync::oneshot, Future};
+use futures01::{
+ sync::{
+ mpsc::{unbounded, UnboundedSender},
+ oneshot,
+ },
+ Future,
+};
use std::{collections::HashSet, sync::mpsc::sync_channel};
+use talpid_types::ErrorExt;
#[cfg(target_os = "macos")]
#[path = "macos.rs"]
@@ -25,19 +32,32 @@ pub enum Error {
/// Routing manager thread panicked before starting routing manager
#[error(display = "Routing manager thread panicked before starting routing manager")]
RoutingManagerThreadPanic,
- /// Platform sepcific error occured
- #[error(display = "Failed to create route manager")]
- FailedToInitializeManager(#[error(source)] imp::Error),
+ /// Platform specific error occured
+ #[error(display = "Internal route manager error")]
+ PlatformError(#[error(source)] imp::Error),
/// Failed to spawn route manager future
#[error(display = "Failed to spawn route manager on the provided executor")]
FailedToSpawnManager,
+ /// Attempt to use route manager that has been dropped
+ #[error(display = "Cannot send message to route manager since it is down")]
+ RouteManagerDown,
+}
+
+#[derive(Debug)]
+pub enum RouteManagerCommand {
+ AddRoutes(
+ HashSet<RequiredRoute>,
+ oneshot::Sender<Result<(), PlatformError>>,
+ ),
+ ClearRoutes,
+ Shutdown(oneshot::Sender<()>),
}
/// RouteManager applies a set of routes to the route table.
/// If a destination has to be routed through the default node,
/// the route will be adjusted dynamically when the default route changes.
pub struct RouteManager {
- tx: Option<oneshot::Sender<oneshot::Sender<()>>>,
+ manage_tx: Option<UnboundedSender<RouteManagerCommand>>,
}
impl RouteManager {
@@ -45,11 +65,11 @@ impl RouteManager {
/// Takes a set of network destinations and network nodes as an argument, and applies said
/// routes.
pub fn new(required_routes: HashSet<RequiredRoute>) -> Result<Self, Error> {
- let (tx, rx) = oneshot::channel();
+ let (manage_tx, manage_rx) = unbounded();
let (start_tx, start_rx) = sync_channel(1);
std::thread::spawn(
- move || match imp::RouteManagerImpl::new(required_routes, rx) {
+ move || match imp::RouteManagerImpl::new(required_routes, manage_rx) {
Ok(route_manager) => {
let _ = start_tx.send(Ok(()));
if let Err(e) = route_manager.wait() {
@@ -57,12 +77,14 @@ impl RouteManager {
}
}
Err(e) => {
- let _ = start_tx.send(Err(Error::FailedToInitializeManager(e)));
+ let _ = start_tx.send(Err(Error::PlatformError(e)));
}
},
);
match start_rx.recv() {
- Ok(Ok(())) => Ok(Self { tx: Some(tx) }),
+ Ok(Ok(())) => Ok(Self {
+ manage_tx: Some(manage_tx),
+ }),
Ok(Err(e)) => Err(e),
Err(_) => Err(Error::RoutingManagerThreadPanic),
}
@@ -70,9 +92,13 @@ impl RouteManager {
/// Stops RouteManager and removes all of the applied routes.
pub fn stop(&mut self) {
- if let Some(tx) = self.tx.take() {
+ if let Some(tx) = self.manage_tx.take() {
let (wait_tx, wait_rx) = oneshot::channel();
- if tx.send(wait_tx).is_err() {
+
+ if tx
+ .unbounded_send(RouteManagerCommand::Shutdown(wait_tx))
+ .is_err()
+ {
log::error!("RouteManager already down!");
return;
}
@@ -82,6 +108,45 @@ impl RouteManager {
}
}
}
+
+ /// Applies the given routes until [`RouteManager::stop`] is called.
+ pub fn add_routes(&mut self, routes: HashSet<RequiredRoute>) -> Result<(), Error> {
+ if let Some(tx) = &self.manage_tx {
+ let (result_tx, result_rx) = oneshot::channel();
+ if tx
+ .unbounded_send(RouteManagerCommand::AddRoutes(routes, result_tx))
+ .is_err()
+ {
+ return Err(Error::RouteManagerDown);
+ }
+
+ match result_rx.wait() {
+ Ok(result) => result.map_err(Error::PlatformError),
+ Err(error) => {
+ log::trace!(
+ "{}",
+ error.display_chain_with_msg("oneshot channel is closed")
+ );
+ Ok(())
+ }
+ }
+ } else {
+ Err(Error::RouteManagerDown)
+ }
+ }
+
+ /// Removes all routes previously applied in [`RouteManager::new`] or
+ /// [`RouteManager::add_routes`].
+ pub fn clear_routes(&mut self) -> Result<(), Error> {
+ if let Some(tx) = &self.manage_tx {
+ if tx.unbounded_send(RouteManagerCommand::ClearRoutes).is_err() {
+ return Err(Error::RouteManagerDown);
+ }
+ Ok(())
+ } else {
+ Err(Error::RouteManagerDown)
+ }
+ }
}
impl Drop for RouteManager {
diff --git a/talpid-core/src/routing/windows.rs b/talpid-core/src/routing/windows.rs
index 0e75cab953..412384574b 100644
--- a/talpid-core/src/routing/windows.rs
+++ b/talpid-core/src/routing/windows.rs
@@ -5,9 +5,15 @@ use std::collections::HashSet;
/// Windows routing errors.
#[derive(err_derive::Error, Debug)]
pub enum Error {
- /// Failure to apply a route
+ /// Failure to initialize route manager
#[error(display = "Failed to start route manager")]
FailedToStartManager,
+ /// Failure to add routes
+ #[error(display = "Failed to add routes")]
+ AddRoutesFailed,
+ /// Failure to clear routes
+ #[error(display = "Failed to clear applied routes")]
+ ClearRoutesFailed,
}
pub type Result<T> = std::result::Result<T, Error>;
@@ -22,27 +28,15 @@ impl RouteManager {
/// Creates a new route manager that will apply the provided routes and ensure they exist until
/// it's stopped.
pub fn new(required_routes: HashSet<RequiredRoute>) -> Result<Self> {
- let routes: Vec<_> = required_routes
- .iter()
- .map(|route| {
- let destination = winnet::WinNetIpNetwork::from(route.prefix);
- match &route.node {
- NetNode::DefaultNode => winnet::WinNetRoute::through_default_node(destination),
- NetNode::RealNode(node) => {
- winnet::WinNetRoute::new(winnet::WinNetNode::from(node), destination)
- }
- }
- })
- .collect();
-
- if !winnet::activate_routing_manager(&routes) {
+ if !winnet::activate_routing_manager() {
return Err(Error::FailedToStartManager);
}
-
- Ok(Self {
+ let manager = Self {
callback_handles: vec![],
is_stopped: false,
- })
+ };
+ manager.add_routes(required_routes)?;
+ Ok(manager)
}
/// Sets a callback that is called whenever the default route changes.
@@ -67,6 +61,13 @@ impl RouteManager {
}
}
+ /// Removes all routes previously applied in [`RouteManager::new`] or
+ /// [`RouteManager::add_routes`].
+ pub fn clear_default_route_callbacks(&mut self) {
+ // `WinNetCallbackHandle::drop` removes these callbacks.
+ self.callback_handles.clear();
+ }
+
/// Stops the routing manager and invalidates the route manager - no new default route callbacks
/// can be added
pub fn stop(&mut self) {
@@ -76,6 +77,38 @@ impl RouteManager {
self.is_stopped = true;
}
}
+
+ /// Applies the given routes until [`RouteManager::stop`] is called.
+ pub fn add_routes(&self, routes: HashSet<RequiredRoute>) -> Result<()> {
+ let routes: Vec<_> = routes
+ .iter()
+ .map(|route| {
+ let destination = winnet::WinNetIpNetwork::from(route.prefix);
+ match &route.node {
+ NetNode::DefaultNode => winnet::WinNetRoute::through_default_node(destination),
+ NetNode::RealNode(node) => {
+ winnet::WinNetRoute::new(winnet::WinNetNode::from(node), destination)
+ }
+ }
+ })
+ .collect();
+
+ if winnet::routing_manager_add_routes(&routes) {
+ Ok(())
+ } else {
+ Err(Error::AddRoutesFailed)
+ }
+ }
+
+ /// Removes all routes previously applied in [`RouteManager::new`] or
+ /// [`RouteManager::add_routes`].
+ pub fn clear_routes(&self) -> Result<()> {
+ if winnet::routing_manager_delete_applied_routes() {
+ Ok(())
+ } else {
+ Err(Error::ClearRoutesFailed)
+ }
+ }
}
impl Drop for RouteManager {
diff --git a/talpid-core/src/tunnel/mod.rs b/talpid-core/src/tunnel/mod.rs
index 6881b02311..8e19fe3813 100644
--- a/talpid-core/src/tunnel/mod.rs
+++ b/talpid-core/src/tunnel/mod.rs
@@ -1,5 +1,5 @@
use self::tun_provider::TunProvider;
-use crate::logging;
+use crate::{logging, routing::RouteManager};
#[cfg(not(target_os = "android"))]
use std::collections::HashMap;
use std::{
@@ -149,6 +149,7 @@ impl TunnelMonitor {
resource_dir: &Path,
on_event: L,
tun_provider: &mut TunProvider,
+ route_manager: &mut RouteManager,
) -> Result<Self>
where
L: Fn(TunnelEvent) + Send + Clone + Sync + 'static,
@@ -164,9 +165,13 @@ impl TunnelMonitor {
#[cfg(target_os = "android")]
TunnelParameters::OpenVpn(_) => Err(Error::UnsupportedPlatform),
- TunnelParameters::Wireguard(config) => {
- Self::start_wireguard_tunnel(&config, log_file, on_event, tun_provider)
- }
+ TunnelParameters::Wireguard(config) => Self::start_wireguard_tunnel(
+ &config,
+ log_file,
+ on_event,
+ tun_provider,
+ route_manager,
+ ),
}
}
@@ -175,6 +180,7 @@ impl TunnelMonitor {
log: Option<PathBuf>,
on_event: L,
tun_provider: &mut TunProvider,
+ route_manager: &mut RouteManager,
) -> Result<Self>
where
L: Fn(TunnelEvent) + Send + Sync + Clone + 'static,
@@ -185,6 +191,7 @@ impl TunnelMonitor {
log.as_ref().map(|p| p.as_path()),
on_event,
tun_provider,
+ route_manager,
)?;
Ok(TunnelMonitor {
monitor: InternalTunnelMonitor::Wireguard(monitor),
diff --git a/talpid-core/src/tunnel/wireguard/mod.rs b/talpid-core/src/tunnel/wireguard/mod.rs
index d19eb54550..6324d13e80 100644
--- a/talpid-core/src/tunnel/wireguard/mod.rs
+++ b/talpid-core/src/tunnel/wireguard/mod.rs
@@ -46,8 +46,6 @@ pub enum Error {
pub struct WireguardMonitor {
/// Tunnel implementation
tunnel: Arc<Mutex<Option<Box<dyn Tunnel>>>>,
- /// Route manager
- route_handle: routing::RouteManager,
/// Callback to signal tunnel events
event_callback: Box<dyn Fn(TunnelEvent) + Send + Sync + 'static>,
close_msg_sender: mpsc::Sender<CloseMsg>,
@@ -62,6 +60,7 @@ impl WireguardMonitor {
log_path: Option<&Path>,
on_event: F,
tun_provider: &mut TunProvider,
+ route_manager: &mut routing::RouteManager,
) -> Result<WireguardMonitor> {
let tunnel = Box::new(WgGoTunnel::start_tunnel(
&config,
@@ -70,12 +69,12 @@ impl WireguardMonitor {
Self::get_tunnel_routes(config),
)?);
let iface_name = tunnel.get_interface_name().to_string();
- #[cfg_attr(not(windows), allow(unused_mut))]
- let mut route_handle = routing::RouteManager::new(Self::get_routes(&iface_name, &config))
+ route_manager
+ .add_routes(Self::get_routes(&iface_name, &config))
.map_err(Error::SetupRoutingError)?;
#[cfg(target_os = "windows")]
- route_handle
+ route_manager
.add_default_route_callback(Some(WgGoTunnel::default_route_changed_callback), ());
let event_callback = Box::new(on_event.clone());
@@ -83,7 +82,6 @@ impl WireguardMonitor {
let (pinger_tx, pinger_rx) = mpsc::channel();
let monitor = WireguardMonitor {
tunnel: Arc::new(Mutex::new(Some(tunnel))),
- route_handle,
event_callback,
close_msg_sender,
close_msg_receiver,
@@ -144,11 +142,6 @@ impl WireguardMonitor {
let _ = self.pinger_stop_sender.send(());
- // Clear routes manually - otherwise there will be some log spam since the tunnel device
- // can be removed before the routes are cleared, which automatically clears some of the
- // routes that were set.
- self.route_handle.stop();
-
self.stop_tunnel();
(self.event_callback)(TunnelEvent::Down);
diff --git a/talpid-core/src/tunnel_state_machine/connected_state.rs b/talpid-core/src/tunnel_state_machine/connected_state.rs
index 47bc8d0f36..03cbc41cf6 100644
--- a/talpid-core/src/tunnel_state_machine/connected_state.rs
+++ b/talpid-core/src/tunnel_state_machine/connected_state.rs
@@ -89,12 +89,25 @@ impl ConnectedState {
}
}
+ fn reset_routes(shared_values: &mut SharedTunnelStateValues) {
+ #[cfg(windows)]
+ shared_values.route_manager.clear_default_route_callbacks();
+ if let Err(error) = shared_values.route_manager.clear_routes() {
+ log::error!(
+ "Failed to clear routes: {:?}",
+ error.display_chain_with_msg("Failed to clear routes")
+ );
+ }
+ }
+
fn disconnect(
self,
shared_values: &mut SharedTunnelStateValues,
after_disconnect: AfterDisconnect,
) -> EventConsequence<Self> {
Self::reset_dns(shared_values);
+ Self::reset_routes(shared_values);
+
EventConsequence::NewState(DisconnectingState::enter(
shared_values,
(self.close_handle, self.tunnel_close_event, after_disconnect),
@@ -185,6 +198,7 @@ impl ConnectedState {
match poll_result {
Ok(Async::Ready(block_reason)) => {
if let Some(reason) = block_reason {
+ Self::reset_routes(shared_values);
return NewState(ErrorState::enter(shared_values, reason));
}
}
@@ -194,6 +208,7 @@ impl ConnectedState {
log::info!("Tunnel closed. Reconnecting.");
Self::reset_dns(shared_values);
+ Self::reset_routes(shared_values);
NewState(ConnectingState::enter(shared_values, 0))
}
}
diff --git a/talpid-core/src/tunnel_state_machine/connecting_state.rs b/talpid-core/src/tunnel_state_machine/connecting_state.rs
index 14cd2cc51d..5d70568239 100644
--- a/talpid-core/src/tunnel_state_machine/connecting_state.rs
+++ b/talpid-core/src/tunnel_state_machine/connecting_state.rs
@@ -5,6 +5,7 @@ use super::{
};
use crate::{
firewall::FirewallPolicy,
+ routing::RouteManager,
tunnel::{
self, tun_provider::TunProvider, CloseHandle, TunnelEvent, TunnelMetadata, TunnelMonitor,
},
@@ -68,18 +69,21 @@ impl ConnectingState {
log_dir: &Option<PathBuf>,
resource_dir: &Path,
tun_provider: &mut TunProvider,
+ route_manager: &mut RouteManager,
retry_attempt: u32,
) -> crate::tunnel::Result<Self> {
let (event_tx, event_rx) = mpsc::unbounded();
let on_tunnel_event = move |event| {
let _ = event_tx.unbounded_send(event);
};
+
let monitor = TunnelMonitor::start(
&parameters,
log_dir,
resource_dir,
on_tunnel_event,
tun_provider,
+ route_manager,
)?;
let close_handle = Some(monitor.close_handle());
let tunnel_close_event = Self::spawn_tunnel_monitor_wait_thread(monitor);
@@ -165,11 +169,24 @@ impl ConnectingState {
}
}
+ fn reset_routes(shared_values: &mut SharedTunnelStateValues) {
+ #[cfg(windows)]
+ shared_values.route_manager.clear_default_route_callbacks();
+ if let Err(error) = shared_values.route_manager.clear_routes() {
+ log::error!(
+ "Failed to clear routes: {:?}",
+ error.display_chain_with_msg("Failed to clear routes")
+ );
+ }
+ }
+
fn disconnect(
self,
shared_values: &mut SharedTunnelStateValues,
after_disconnect: AfterDisconnect,
) -> EventConsequence<Self> {
+ Self::reset_routes(shared_values);
+
EventConsequence::NewState(DisconnectingState::enter(
shared_values,
(self.close_handle, self.tunnel_close_event, after_disconnect),
@@ -270,6 +287,7 @@ impl ConnectingState {
match poll_result {
Ok(Async::Ready(block_reason)) => {
if let Some(reason) = block_reason {
+ Self::reset_routes(shared_values);
return EventConsequence::NewState(ErrorState::enter(shared_values, reason));
}
}
@@ -281,6 +299,7 @@ impl ConnectingState {
"Tunnel closed. Reconnecting, attempt {}.",
self.retry_attempt + 1
);
+ Self::reset_routes(shared_values);
EventConsequence::NewState(ConnectingState::enter(
shared_values,
self.retry_attempt + 1,
@@ -359,6 +378,7 @@ impl TunnelState for ConnectingState {
&shared_values.log_dir,
&shared_values.resource_dir,
&mut shared_values.tun_provider,
+ &mut shared_values.route_manager,
retry_attempt,
) {
Ok(connecting_state) => {
diff --git a/talpid-core/src/tunnel_state_machine/mod.rs b/talpid-core/src/tunnel_state_machine/mod.rs
index baf52c4b2b..2ffce2bec9 100644
--- a/talpid-core/src/tunnel_state_machine/mod.rs
+++ b/talpid-core/src/tunnel_state_machine/mod.rs
@@ -19,6 +19,7 @@ use crate::{
firewall::{Firewall, FirewallArguments},
mpsc::Sender,
offline,
+ routing::RouteManager,
tunnel::tun_provider::TunProvider,
};
@@ -27,6 +28,7 @@ use futures01::{
Async, Future, Poll, Stream,
};
use std::{
+ collections::HashSet,
io,
path::{Path, PathBuf},
sync::{mpsc as sync_mpsc, Arc},
@@ -56,6 +58,10 @@ pub enum Error {
#[error(display = "Failed to initialize the system DNS manager and monitor")]
InitDnsMonitorError(#[error(source)] crate::dns::Error),
+ /// Failed to initialize the route manager.
+ #[error(display = "Failed to initialize the route manager")]
+ InitRouteManagerError(#[error(source)] crate::routing::Error),
+
/// Failed to initialize tunnel state machine event loop executor
#[error(display = "Failed to initialize tunnel state machine event loop executor")]
ReactorError(#[error(source)] io::Error),
@@ -231,9 +237,12 @@ impl TunnelStateMachine {
};
let firewall = Firewall::new(args).map_err(Error::InitFirewallError)?;
let dns_monitor = DnsMonitor::new(cache_dir).map_err(Error::InitDnsMonitorError)?;
+ let route_manager =
+ RouteManager::new(HashSet::new()).map_err(Error::InitRouteManagerError)?;
let mut shared_values = SharedTunnelStateValues {
firewall,
dns_monitor,
+ route_manager,
allow_lan,
block_when_disconnected,
is_offline,
@@ -317,6 +326,7 @@ pub trait TunnelParametersGenerator: Send + 'static {
struct SharedTunnelStateValues {
firewall: Firewall,
dns_monitor: DnsMonitor,
+ route_manager: RouteManager,
/// Should LAN access be allowed outside the tunnel.
allow_lan: bool,
/// Should network access be allowed when in the disconnected state.
diff --git a/talpid-core/src/winnet.rs b/talpid-core/src/winnet.rs
index a0b183d779..ab9dff5d06 100644
--- a/talpid-core/src/winnet.rs
+++ b/talpid-core/src/winnet.rs
@@ -289,17 +289,8 @@ impl Drop for WinNetRoute {
}
}
-pub fn activate_routing_manager(routes: &[WinNetRoute]) -> bool {
- if unsafe { WinNet_ActivateRouteManager(Some(log_sink), logging_context()) } {
- if routing_manager_add_routes(routes) {
- true
- } else {
- deactivate_routing_manager();
- false
- }
- } else {
- false
- }
+pub fn activate_routing_manager() -> bool {
+ unsafe { WinNet_ActivateRouteManager(Some(log_sink), logging_context()) }
}
pub struct WinNetCallbackHandle {
@@ -360,6 +351,10 @@ pub fn routing_manager_add_routes(routes: &[WinNetRoute]) -> bool {
unsafe { WinNet_AddRoutes(ptr, length) }
}
+pub fn routing_manager_delete_applied_routes() -> bool {
+ unsafe { WinNet_DeleteAppliedRoutes() }
+}
+
pub fn deactivate_routing_manager() {
unsafe { WinNet_DeactivateRouteManager() }
}
@@ -400,6 +395,9 @@ mod api {
// #[link_name = "WinNet_DeleteRoute"]
// pub fn WinNet_DeleteRoute(route: *const super::WinNetRoute) -> bool;
+ #[link_name = "WinNet_DeleteAppliedRoutes"]
+ pub fn WinNet_DeleteAppliedRoutes() -> bool;
+
#[link_name = "WinNet_DeactivateRouteManager"]
pub fn WinNet_DeactivateRouteManager();
diff --git a/windows/winnet/src/winnet/routing/routemanager.cpp b/windows/winnet/src/winnet/routing/routemanager.cpp
index 81ee4e3d96..9d5b36ff98 100644
--- a/windows/winnet/src/winnet/routing/routemanager.cpp
+++ b/windows/winnet/src/winnet/routing/routemanager.cpp
@@ -209,27 +209,7 @@ RouteManager::~RouteManager()
m_routeMonitorV4.reset();
m_routeMonitorV6.reset();
- //
- // Delete all routes owned by us.
- //
-
- for (const auto &record : m_routes)
- {
- try
- {
- deleteFromRoutingTable(record.registeredRoute);
- }
- catch (const std::exception &ex)
- {
- std::wstringstream ss;
-
- ss << L"Failed to delete route as part of cleaning up, Route: "
- << FormatRegisteredRoute(record.registeredRoute);
-
- m_logSink->error(common::string::ToAnsi(ss.str()).c_str());
- m_logSink->error(ex.what());
- }
- }
+ deleteAppliedRoutes();
}
void RouteManager::addRoutes(const std::vector<Route> &routes)
@@ -302,6 +282,33 @@ void RouteManager::deleteRoutes(const std::vector<Route> &routes)
}
}
+void RouteManager::deleteAppliedRoutes()
+{
+ //
+ // Delete all routes owned by us.
+ //
+
+ for (const auto &record : m_routes)
+ {
+ try
+ {
+ deleteFromRoutingTable(record.registeredRoute);
+ }
+ catch (const std::exception & ex)
+ {
+ std::wstringstream ss;
+
+ ss << L"Failed to delete route while clearing applied routes, Route: "
+ << FormatRegisteredRoute(record.registeredRoute);
+
+ m_logSink->error(common::string::ToAnsi(ss.str()).c_str());
+ m_logSink->error(ex.what());
+ }
+ }
+
+ m_routes.clear();
+}
+
RouteManager::CallbackHandle RouteManager::registerDefaultRouteChangedCallback(DefaultRouteChangedCallback callback)
{
AutoRecursiveLockType lock(m_defaultRouteCallbacksLock);
diff --git a/windows/winnet/src/winnet/routing/routemanager.h b/windows/winnet/src/winnet/routing/routemanager.h
index 92c712e25d..07cc7dbf40 100644
--- a/windows/winnet/src/winnet/routing/routemanager.h
+++ b/windows/winnet/src/winnet/routing/routemanager.h
@@ -32,6 +32,7 @@ public:
void addRoutes(const std::vector<Route> &routes);
void deleteRoutes(const std::vector<Route> &routes);
+ void deleteAppliedRoutes();
using DefaultRouteChangedEventType = DefaultRouteMonitor::EventType;
diff --git a/windows/winnet/src/winnet/winnet.cpp b/windows/winnet/src/winnet/winnet.cpp
index ac56e94ff6..c52fc57d60 100644
--- a/windows/winnet/src/winnet/winnet.cpp
+++ b/windows/winnet/src/winnet/winnet.cpp
@@ -385,6 +385,35 @@ extern "C"
WINNET_LINKAGE
bool
WINNET_API
+WinNet_DeleteAppliedRoutes()
+{
+ AutoLockType lock(g_RouteManagerLock);
+
+ if (nullptr == g_RouteManager)
+ {
+ return false;
+ }
+
+ try
+ {
+ g_RouteManager->deleteAppliedRoutes();
+ return true;
+ }
+ catch (const std::exception & err)
+ {
+ common::error::UnwindException(err, g_RouteManagerLogSink);
+ return false;
+ }
+ catch (...)
+ {
+ return false;
+ }
+}
+
+extern "C"
+WINNET_LINKAGE
+bool
+WINNET_API
WinNet_DeleteRoute(
const WINNET_ROUTE *route
)
diff --git a/windows/winnet/src/winnet/winnet.h b/windows/winnet/src/winnet/winnet.h
index 98b0083f03..3597b01a26 100644
--- a/windows/winnet/src/winnet/winnet.h
+++ b/windows/winnet/src/winnet/winnet.h
@@ -173,6 +173,13 @@ WinNet_DeleteRoute(
const WINNET_ROUTE *route
);
+extern "C"
+WINNET_LINKAGE
+bool
+WINNET_API
+WinNet_DeleteAppliedRoutes(
+);
+
enum WINNET_DEFAULT_ROUTE_CHANGED_EVENT_TYPE
{
// Best default route changed.