diff options
| -rw-r--r-- | talpid-core/src/routing/windows.rs | 69 | ||||
| -rw-r--r-- | talpid-core/src/tunnel_state_machine/connected_state.rs | 2 | ||||
| -rw-r--r-- | talpid-core/src/tunnel_state_machine/connecting_state.rs | 2 | ||||
| -rw-r--r-- | talpid-core/src/winnet.rs | 20 | ||||
| -rw-r--r-- | windows/winnet/src/winnet/routing/routemanager.cpp | 49 | ||||
| -rw-r--r-- | windows/winnet/src/winnet/routing/routemanager.h | 1 | ||||
| -rw-r--r-- | windows/winnet/src/winnet/winnet.cpp | 29 | ||||
| -rw-r--r-- | windows/winnet/src/winnet/winnet.h | 7 |
8 files changed, 129 insertions, 50 deletions
diff --git a/talpid-core/src/routing/windows.rs b/talpid-core/src/routing/windows.rs index 0e75cab953..412384574b 100644 --- a/talpid-core/src/routing/windows.rs +++ b/talpid-core/src/routing/windows.rs @@ -5,9 +5,15 @@ use std::collections::HashSet; /// Windows routing errors. #[derive(err_derive::Error, Debug)] pub enum Error { - /// Failure to apply a route + /// Failure to initialize route manager #[error(display = "Failed to start route manager")] FailedToStartManager, + /// Failure to add routes + #[error(display = "Failed to add routes")] + AddRoutesFailed, + /// Failure to clear routes + #[error(display = "Failed to clear applied routes")] + ClearRoutesFailed, } pub type Result<T> = std::result::Result<T, Error>; @@ -22,27 +28,15 @@ impl RouteManager { /// Creates a new route manager that will apply the provided routes and ensure they exist until /// it's stopped. pub fn new(required_routes: HashSet<RequiredRoute>) -> Result<Self> { - let routes: Vec<_> = required_routes - .iter() - .map(|route| { - let destination = winnet::WinNetIpNetwork::from(route.prefix); - match &route.node { - NetNode::DefaultNode => winnet::WinNetRoute::through_default_node(destination), - NetNode::RealNode(node) => { - winnet::WinNetRoute::new(winnet::WinNetNode::from(node), destination) - } - } - }) - .collect(); - - if !winnet::activate_routing_manager(&routes) { + if !winnet::activate_routing_manager() { return Err(Error::FailedToStartManager); } - - Ok(Self { + let manager = Self { callback_handles: vec![], is_stopped: false, - }) + }; + manager.add_routes(required_routes)?; + Ok(manager) } /// Sets a callback that is called whenever the default route changes. @@ -67,6 +61,13 @@ impl RouteManager { } } + /// Removes all routes previously applied in [`RouteManager::new`] or + /// [`RouteManager::add_routes`]. + pub fn clear_default_route_callbacks(&mut self) { + // `WinNetCallbackHandle::drop` removes these callbacks. + self.callback_handles.clear(); + } + /// Stops the routing manager and invalidates the route manager - no new default route callbacks /// can be added pub fn stop(&mut self) { @@ -76,6 +77,38 @@ impl RouteManager { self.is_stopped = true; } } + + /// Applies the given routes until [`RouteManager::stop`] is called. + pub fn add_routes(&self, routes: HashSet<RequiredRoute>) -> Result<()> { + let routes: Vec<_> = routes + .iter() + .map(|route| { + let destination = winnet::WinNetIpNetwork::from(route.prefix); + match &route.node { + NetNode::DefaultNode => winnet::WinNetRoute::through_default_node(destination), + NetNode::RealNode(node) => { + winnet::WinNetRoute::new(winnet::WinNetNode::from(node), destination) + } + } + }) + .collect(); + + if winnet::routing_manager_add_routes(&routes) { + Ok(()) + } else { + Err(Error::AddRoutesFailed) + } + } + + /// Removes all routes previously applied in [`RouteManager::new`] or + /// [`RouteManager::add_routes`]. + pub fn clear_routes(&self) -> Result<()> { + if winnet::routing_manager_delete_applied_routes() { + Ok(()) + } else { + Err(Error::ClearRoutesFailed) + } + } } impl Drop for RouteManager { diff --git a/talpid-core/src/tunnel_state_machine/connected_state.rs b/talpid-core/src/tunnel_state_machine/connected_state.rs index dab6e37619..03cbc41cf6 100644 --- a/talpid-core/src/tunnel_state_machine/connected_state.rs +++ b/talpid-core/src/tunnel_state_machine/connected_state.rs @@ -90,6 +90,8 @@ impl ConnectedState { } fn reset_routes(shared_values: &mut SharedTunnelStateValues) { + #[cfg(windows)] + shared_values.route_manager.clear_default_route_callbacks(); if let Err(error) = shared_values.route_manager.clear_routes() { log::error!( "Failed to clear routes: {:?}", diff --git a/talpid-core/src/tunnel_state_machine/connecting_state.rs b/talpid-core/src/tunnel_state_machine/connecting_state.rs index 4bd1345339..5d70568239 100644 --- a/talpid-core/src/tunnel_state_machine/connecting_state.rs +++ b/talpid-core/src/tunnel_state_machine/connecting_state.rs @@ -170,6 +170,8 @@ impl ConnectingState { } fn reset_routes(shared_values: &mut SharedTunnelStateValues) { + #[cfg(windows)] + shared_values.route_manager.clear_default_route_callbacks(); if let Err(error) = shared_values.route_manager.clear_routes() { log::error!( "Failed to clear routes: {:?}", diff --git a/talpid-core/src/winnet.rs b/talpid-core/src/winnet.rs index a0b183d779..ab9dff5d06 100644 --- a/talpid-core/src/winnet.rs +++ b/talpid-core/src/winnet.rs @@ -289,17 +289,8 @@ impl Drop for WinNetRoute { } } -pub fn activate_routing_manager(routes: &[WinNetRoute]) -> bool { - if unsafe { WinNet_ActivateRouteManager(Some(log_sink), logging_context()) } { - if routing_manager_add_routes(routes) { - true - } else { - deactivate_routing_manager(); - false - } - } else { - false - } +pub fn activate_routing_manager() -> bool { + unsafe { WinNet_ActivateRouteManager(Some(log_sink), logging_context()) } } pub struct WinNetCallbackHandle { @@ -360,6 +351,10 @@ pub fn routing_manager_add_routes(routes: &[WinNetRoute]) -> bool { unsafe { WinNet_AddRoutes(ptr, length) } } +pub fn routing_manager_delete_applied_routes() -> bool { + unsafe { WinNet_DeleteAppliedRoutes() } +} + pub fn deactivate_routing_manager() { unsafe { WinNet_DeactivateRouteManager() } } @@ -400,6 +395,9 @@ mod api { // #[link_name = "WinNet_DeleteRoute"] // pub fn WinNet_DeleteRoute(route: *const super::WinNetRoute) -> bool; + #[link_name = "WinNet_DeleteAppliedRoutes"] + pub fn WinNet_DeleteAppliedRoutes() -> bool; + #[link_name = "WinNet_DeactivateRouteManager"] pub fn WinNet_DeactivateRouteManager(); diff --git a/windows/winnet/src/winnet/routing/routemanager.cpp b/windows/winnet/src/winnet/routing/routemanager.cpp index 81ee4e3d96..9d5b36ff98 100644 --- a/windows/winnet/src/winnet/routing/routemanager.cpp +++ b/windows/winnet/src/winnet/routing/routemanager.cpp @@ -209,27 +209,7 @@ RouteManager::~RouteManager() m_routeMonitorV4.reset(); m_routeMonitorV6.reset(); - // - // Delete all routes owned by us. - // - - for (const auto &record : m_routes) - { - try - { - deleteFromRoutingTable(record.registeredRoute); - } - catch (const std::exception &ex) - { - std::wstringstream ss; - - ss << L"Failed to delete route as part of cleaning up, Route: " - << FormatRegisteredRoute(record.registeredRoute); - - m_logSink->error(common::string::ToAnsi(ss.str()).c_str()); - m_logSink->error(ex.what()); - } - } + deleteAppliedRoutes(); } void RouteManager::addRoutes(const std::vector<Route> &routes) @@ -302,6 +282,33 @@ void RouteManager::deleteRoutes(const std::vector<Route> &routes) } } +void RouteManager::deleteAppliedRoutes() +{ + // + // Delete all routes owned by us. + // + + for (const auto &record : m_routes) + { + try + { + deleteFromRoutingTable(record.registeredRoute); + } + catch (const std::exception & ex) + { + std::wstringstream ss; + + ss << L"Failed to delete route while clearing applied routes, Route: " + << FormatRegisteredRoute(record.registeredRoute); + + m_logSink->error(common::string::ToAnsi(ss.str()).c_str()); + m_logSink->error(ex.what()); + } + } + + m_routes.clear(); +} + RouteManager::CallbackHandle RouteManager::registerDefaultRouteChangedCallback(DefaultRouteChangedCallback callback) { AutoRecursiveLockType lock(m_defaultRouteCallbacksLock); diff --git a/windows/winnet/src/winnet/routing/routemanager.h b/windows/winnet/src/winnet/routing/routemanager.h index 92c712e25d..07cc7dbf40 100644 --- a/windows/winnet/src/winnet/routing/routemanager.h +++ b/windows/winnet/src/winnet/routing/routemanager.h @@ -32,6 +32,7 @@ public: void addRoutes(const std::vector<Route> &routes); void deleteRoutes(const std::vector<Route> &routes); + void deleteAppliedRoutes(); using DefaultRouteChangedEventType = DefaultRouteMonitor::EventType; diff --git a/windows/winnet/src/winnet/winnet.cpp b/windows/winnet/src/winnet/winnet.cpp index ac56e94ff6..c52fc57d60 100644 --- a/windows/winnet/src/winnet/winnet.cpp +++ b/windows/winnet/src/winnet/winnet.cpp @@ -385,6 +385,35 @@ extern "C" WINNET_LINKAGE
bool
WINNET_API
+WinNet_DeleteAppliedRoutes()
+{
+ AutoLockType lock(g_RouteManagerLock);
+
+ if (nullptr == g_RouteManager)
+ {
+ return false;
+ }
+
+ try
+ {
+ g_RouteManager->deleteAppliedRoutes();
+ return true;
+ }
+ catch (const std::exception & err)
+ {
+ common::error::UnwindException(err, g_RouteManagerLogSink);
+ return false;
+ }
+ catch (...)
+ {
+ return false;
+ }
+}
+
+extern "C"
+WINNET_LINKAGE
+bool
+WINNET_API
WinNet_DeleteRoute(
const WINNET_ROUTE *route
)
diff --git a/windows/winnet/src/winnet/winnet.h b/windows/winnet/src/winnet/winnet.h index 98b0083f03..3597b01a26 100644 --- a/windows/winnet/src/winnet/winnet.h +++ b/windows/winnet/src/winnet/winnet.h @@ -173,6 +173,13 @@ WinNet_DeleteRoute( const WINNET_ROUTE *route ); +extern "C" +WINNET_LINKAGE +bool +WINNET_API +WinNet_DeleteAppliedRoutes( +); + enum WINNET_DEFAULT_ROUTE_CHANGED_EVENT_TYPE { // Best default route changed. |
