summaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
-rw-r--r--talpid-core/src/routing/windows.rs69
-rw-r--r--talpid-core/src/tunnel_state_machine/connected_state.rs2
-rw-r--r--talpid-core/src/tunnel_state_machine/connecting_state.rs2
-rw-r--r--talpid-core/src/winnet.rs20
-rw-r--r--windows/winnet/src/winnet/routing/routemanager.cpp49
-rw-r--r--windows/winnet/src/winnet/routing/routemanager.h1
-rw-r--r--windows/winnet/src/winnet/winnet.cpp29
-rw-r--r--windows/winnet/src/winnet/winnet.h7
8 files changed, 129 insertions, 50 deletions
diff --git a/talpid-core/src/routing/windows.rs b/talpid-core/src/routing/windows.rs
index 0e75cab953..412384574b 100644
--- a/talpid-core/src/routing/windows.rs
+++ b/talpid-core/src/routing/windows.rs
@@ -5,9 +5,15 @@ use std::collections::HashSet;
/// Windows routing errors.
#[derive(err_derive::Error, Debug)]
pub enum Error {
- /// Failure to apply a route
+ /// Failure to initialize route manager
#[error(display = "Failed to start route manager")]
FailedToStartManager,
+ /// Failure to add routes
+ #[error(display = "Failed to add routes")]
+ AddRoutesFailed,
+ /// Failure to clear routes
+ #[error(display = "Failed to clear applied routes")]
+ ClearRoutesFailed,
}
pub type Result<T> = std::result::Result<T, Error>;
@@ -22,27 +28,15 @@ impl RouteManager {
/// Creates a new route manager that will apply the provided routes and ensure they exist until
/// it's stopped.
pub fn new(required_routes: HashSet<RequiredRoute>) -> Result<Self> {
- let routes: Vec<_> = required_routes
- .iter()
- .map(|route| {
- let destination = winnet::WinNetIpNetwork::from(route.prefix);
- match &route.node {
- NetNode::DefaultNode => winnet::WinNetRoute::through_default_node(destination),
- NetNode::RealNode(node) => {
- winnet::WinNetRoute::new(winnet::WinNetNode::from(node), destination)
- }
- }
- })
- .collect();
-
- if !winnet::activate_routing_manager(&routes) {
+ if !winnet::activate_routing_manager() {
return Err(Error::FailedToStartManager);
}
-
- Ok(Self {
+ let manager = Self {
callback_handles: vec![],
is_stopped: false,
- })
+ };
+ manager.add_routes(required_routes)?;
+ Ok(manager)
}
/// Sets a callback that is called whenever the default route changes.
@@ -67,6 +61,13 @@ impl RouteManager {
}
}
+ /// Removes all routes previously applied in [`RouteManager::new`] or
+ /// [`RouteManager::add_routes`].
+ pub fn clear_default_route_callbacks(&mut self) {
+ // `WinNetCallbackHandle::drop` removes these callbacks.
+ self.callback_handles.clear();
+ }
+
/// Stops the routing manager and invalidates the route manager - no new default route callbacks
/// can be added
pub fn stop(&mut self) {
@@ -76,6 +77,38 @@ impl RouteManager {
self.is_stopped = true;
}
}
+
+ /// Applies the given routes until [`RouteManager::stop`] is called.
+ pub fn add_routes(&self, routes: HashSet<RequiredRoute>) -> Result<()> {
+ let routes: Vec<_> = routes
+ .iter()
+ .map(|route| {
+ let destination = winnet::WinNetIpNetwork::from(route.prefix);
+ match &route.node {
+ NetNode::DefaultNode => winnet::WinNetRoute::through_default_node(destination),
+ NetNode::RealNode(node) => {
+ winnet::WinNetRoute::new(winnet::WinNetNode::from(node), destination)
+ }
+ }
+ })
+ .collect();
+
+ if winnet::routing_manager_add_routes(&routes) {
+ Ok(())
+ } else {
+ Err(Error::AddRoutesFailed)
+ }
+ }
+
+ /// Removes all routes previously applied in [`RouteManager::new`] or
+ /// [`RouteManager::add_routes`].
+ pub fn clear_routes(&self) -> Result<()> {
+ if winnet::routing_manager_delete_applied_routes() {
+ Ok(())
+ } else {
+ Err(Error::ClearRoutesFailed)
+ }
+ }
}
impl Drop for RouteManager {
diff --git a/talpid-core/src/tunnel_state_machine/connected_state.rs b/talpid-core/src/tunnel_state_machine/connected_state.rs
index dab6e37619..03cbc41cf6 100644
--- a/talpid-core/src/tunnel_state_machine/connected_state.rs
+++ b/talpid-core/src/tunnel_state_machine/connected_state.rs
@@ -90,6 +90,8 @@ impl ConnectedState {
}
fn reset_routes(shared_values: &mut SharedTunnelStateValues) {
+ #[cfg(windows)]
+ shared_values.route_manager.clear_default_route_callbacks();
if let Err(error) = shared_values.route_manager.clear_routes() {
log::error!(
"Failed to clear routes: {:?}",
diff --git a/talpid-core/src/tunnel_state_machine/connecting_state.rs b/talpid-core/src/tunnel_state_machine/connecting_state.rs
index 4bd1345339..5d70568239 100644
--- a/talpid-core/src/tunnel_state_machine/connecting_state.rs
+++ b/talpid-core/src/tunnel_state_machine/connecting_state.rs
@@ -170,6 +170,8 @@ impl ConnectingState {
}
fn reset_routes(shared_values: &mut SharedTunnelStateValues) {
+ #[cfg(windows)]
+ shared_values.route_manager.clear_default_route_callbacks();
if let Err(error) = shared_values.route_manager.clear_routes() {
log::error!(
"Failed to clear routes: {:?}",
diff --git a/talpid-core/src/winnet.rs b/talpid-core/src/winnet.rs
index a0b183d779..ab9dff5d06 100644
--- a/talpid-core/src/winnet.rs
+++ b/talpid-core/src/winnet.rs
@@ -289,17 +289,8 @@ impl Drop for WinNetRoute {
}
}
-pub fn activate_routing_manager(routes: &[WinNetRoute]) -> bool {
- if unsafe { WinNet_ActivateRouteManager(Some(log_sink), logging_context()) } {
- if routing_manager_add_routes(routes) {
- true
- } else {
- deactivate_routing_manager();
- false
- }
- } else {
- false
- }
+pub fn activate_routing_manager() -> bool {
+ unsafe { WinNet_ActivateRouteManager(Some(log_sink), logging_context()) }
}
pub struct WinNetCallbackHandle {
@@ -360,6 +351,10 @@ pub fn routing_manager_add_routes(routes: &[WinNetRoute]) -> bool {
unsafe { WinNet_AddRoutes(ptr, length) }
}
+pub fn routing_manager_delete_applied_routes() -> bool {
+ unsafe { WinNet_DeleteAppliedRoutes() }
+}
+
pub fn deactivate_routing_manager() {
unsafe { WinNet_DeactivateRouteManager() }
}
@@ -400,6 +395,9 @@ mod api {
// #[link_name = "WinNet_DeleteRoute"]
// pub fn WinNet_DeleteRoute(route: *const super::WinNetRoute) -> bool;
+ #[link_name = "WinNet_DeleteAppliedRoutes"]
+ pub fn WinNet_DeleteAppliedRoutes() -> bool;
+
#[link_name = "WinNet_DeactivateRouteManager"]
pub fn WinNet_DeactivateRouteManager();
diff --git a/windows/winnet/src/winnet/routing/routemanager.cpp b/windows/winnet/src/winnet/routing/routemanager.cpp
index 81ee4e3d96..9d5b36ff98 100644
--- a/windows/winnet/src/winnet/routing/routemanager.cpp
+++ b/windows/winnet/src/winnet/routing/routemanager.cpp
@@ -209,27 +209,7 @@ RouteManager::~RouteManager()
m_routeMonitorV4.reset();
m_routeMonitorV6.reset();
- //
- // Delete all routes owned by us.
- //
-
- for (const auto &record : m_routes)
- {
- try
- {
- deleteFromRoutingTable(record.registeredRoute);
- }
- catch (const std::exception &ex)
- {
- std::wstringstream ss;
-
- ss << L"Failed to delete route as part of cleaning up, Route: "
- << FormatRegisteredRoute(record.registeredRoute);
-
- m_logSink->error(common::string::ToAnsi(ss.str()).c_str());
- m_logSink->error(ex.what());
- }
- }
+ deleteAppliedRoutes();
}
void RouteManager::addRoutes(const std::vector<Route> &routes)
@@ -302,6 +282,33 @@ void RouteManager::deleteRoutes(const std::vector<Route> &routes)
}
}
+void RouteManager::deleteAppliedRoutes()
+{
+ //
+ // Delete all routes owned by us.
+ //
+
+ for (const auto &record : m_routes)
+ {
+ try
+ {
+ deleteFromRoutingTable(record.registeredRoute);
+ }
+ catch (const std::exception & ex)
+ {
+ std::wstringstream ss;
+
+ ss << L"Failed to delete route while clearing applied routes, Route: "
+ << FormatRegisteredRoute(record.registeredRoute);
+
+ m_logSink->error(common::string::ToAnsi(ss.str()).c_str());
+ m_logSink->error(ex.what());
+ }
+ }
+
+ m_routes.clear();
+}
+
RouteManager::CallbackHandle RouteManager::registerDefaultRouteChangedCallback(DefaultRouteChangedCallback callback)
{
AutoRecursiveLockType lock(m_defaultRouteCallbacksLock);
diff --git a/windows/winnet/src/winnet/routing/routemanager.h b/windows/winnet/src/winnet/routing/routemanager.h
index 92c712e25d..07cc7dbf40 100644
--- a/windows/winnet/src/winnet/routing/routemanager.h
+++ b/windows/winnet/src/winnet/routing/routemanager.h
@@ -32,6 +32,7 @@ public:
void addRoutes(const std::vector<Route> &routes);
void deleteRoutes(const std::vector<Route> &routes);
+ void deleteAppliedRoutes();
using DefaultRouteChangedEventType = DefaultRouteMonitor::EventType;
diff --git a/windows/winnet/src/winnet/winnet.cpp b/windows/winnet/src/winnet/winnet.cpp
index ac56e94ff6..c52fc57d60 100644
--- a/windows/winnet/src/winnet/winnet.cpp
+++ b/windows/winnet/src/winnet/winnet.cpp
@@ -385,6 +385,35 @@ extern "C"
WINNET_LINKAGE
bool
WINNET_API
+WinNet_DeleteAppliedRoutes()
+{
+ AutoLockType lock(g_RouteManagerLock);
+
+ if (nullptr == g_RouteManager)
+ {
+ return false;
+ }
+
+ try
+ {
+ g_RouteManager->deleteAppliedRoutes();
+ return true;
+ }
+ catch (const std::exception & err)
+ {
+ common::error::UnwindException(err, g_RouteManagerLogSink);
+ return false;
+ }
+ catch (...)
+ {
+ return false;
+ }
+}
+
+extern "C"
+WINNET_LINKAGE
+bool
+WINNET_API
WinNet_DeleteRoute(
const WINNET_ROUTE *route
)
diff --git a/windows/winnet/src/winnet/winnet.h b/windows/winnet/src/winnet/winnet.h
index 98b0083f03..3597b01a26 100644
--- a/windows/winnet/src/winnet/winnet.h
+++ b/windows/winnet/src/winnet/winnet.h
@@ -173,6 +173,13 @@ WinNet_DeleteRoute(
const WINNET_ROUTE *route
);
+extern "C"
+WINNET_LINKAGE
+bool
+WINNET_API
+WinNet_DeleteAppliedRoutes(
+);
+
enum WINNET_DEFAULT_ROUTE_CHANGED_EVENT_TYPE
{
// Best default route changed.