summaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorOdd Stranne <odd@mullvad.net>2019-10-24 12:29:37 +0200
committerOdd Stranne <odd@mullvad.net>2019-11-25 13:46:13 +0100
commitc82f4dbf8c1d128009f23e635a6b02b0cb124b3b (patch)
treeadebd8112bdda69ea361347d22f1ee17bc780b8d
parent7a595e110b978e4e3f12f3d724e1644a621e1c63 (diff)
downloadmullvadvpn-c82f4dbf8c1d128009f23e635a6b02b0cb124b3b.tar.xz
mullvadvpn-c82f4dbf8c1d128009f23e635a6b02b0cb124b3b.zip
Add notification for when the default route changes
-rw-r--r--windows/winnet/src/winnet/routemanager.cpp45
-rw-r--r--windows/winnet/src/winnet/routemanager.h33
-rw-r--r--windows/winnet/src/winnet/winnet.cpp105
-rw-r--r--windows/winnet/src/winnet/winnet.h46
4 files changed, 228 insertions, 1 deletions
diff --git a/windows/winnet/src/winnet/routemanager.cpp b/windows/winnet/src/winnet/routemanager.cpp
index 35aef395ba..c1e897b578 100644
--- a/windows/winnet/src/winnet/routemanager.cpp
+++ b/windows/winnet/src/winnet/routemanager.cpp
@@ -625,6 +625,31 @@ void RouteManager::deleteRoute(const Route &route)
m_routes.erase(record);
}
+RouteManager::CallbackHandle RouteManager::registerDefaultRouteChangedCallback(DefaultRouteChangedCallback callback)
+{
+ LockType lock(m_defaultRouteCallbacksLock);
+
+ m_defaultRouteCallbacks.emplace_back(callback);
+
+ // Return raw address of record in list.
+ return &m_defaultRouteCallbacks.back();
+}
+
+void RouteManager::unregisterDefaultRouteChangedCallback(CallbackHandle handle)
+{
+ LockType lock(m_defaultRouteCallbacksLock);
+
+ for (auto it = m_defaultRouteCallbacks.begin(); it != m_defaultRouteCallbacks.end(); ++it)
+ {
+ // Match on raw address of record.
+ if (&*it == handle)
+ {
+ m_defaultRouteCallbacks.erase(it);
+ return;
+ }
+ }
+}
+
std::list<RouteManager::RouteRecord>::iterator RouteManager::findRouteRecord(const Network &network)
{
return std::find_if(m_routes.begin(), m_routes.end(), [&network](const auto &candidate)
@@ -930,4 +955,24 @@ std::wstring RouteManager::FormatRegisteredRoute(const RegisteredRoute &route)
return ss.str();
}
+void RouteManager::notifyNewBestDefaultRoute(ADDRESS_FAMILY addressFamily, NET_LUID iface)
+{
+ LockType lock(m_defaultRouteCallbacksLock);
+
+ for (const auto &callback : m_defaultRouteCallbacks)
+ {
+ callback(DefaultRouteChangedEvent::Updated, addressFamily, iface);
+ }
+}
+
+void RouteManager::notifyNoDefaultRoutesExist(ADDRESS_FAMILY addressFamily)
+{
+ LockType lock(m_defaultRouteCallbacksLock);
+
+ for (const auto &callback : m_defaultRouteCallbacks)
+ {
+ callback(DefaultRouteChangedEvent::Removed, addressFamily, NET_LUID{ 0 });
+ }
+}
+
}
diff --git a/windows/winnet/src/winnet/routemanager.h b/windows/winnet/src/winnet/routemanager.h
index 6161dcd62a..9ae007b757 100644
--- a/windows/winnet/src/winnet/routemanager.h
+++ b/windows/winnet/src/winnet/routemanager.h
@@ -13,6 +13,7 @@
#include <ws2ipdef.h>
#include <iphlpapi.h>
#include <netioapi.h>
+#include <functional>
// Custom header files below here.
// So broken networking headers don't get confused and break the compilation.
@@ -149,6 +150,31 @@ public:
void deleteRoutes(const std::vector<Route> &routes);
void deleteRoute(const Route &route);
+ enum class DefaultRouteChangedEvent
+ {
+ // The best default route changed.
+ Updated,
+
+ // No default routes exist.
+ Removed,
+ };
+
+ using DefaultRouteChangedCallback = std::function<void
+ (
+ DefaultRouteChangedEvent eventType,
+
+ // Signals which IP family the event relates to.
+ ADDRESS_FAMILY addressFamily,
+
+ // For update events, signals the interface associated with the new best default route.
+ NET_LUID iface
+ )>;
+
+ using CallbackHandle = void*;
+
+ CallbackHandle registerDefaultRouteChangedCallback(DefaultRouteChangedCallback callback);
+ void unregisterDefaultRouteChangedCallback(CallbackHandle handle);
+
private:
std::shared_ptr<common::logging::ILogSink> m_logSink;
@@ -169,9 +195,11 @@ private:
};
std::list<RouteRecord> m_routes;
-
std::recursive_mutex m_routesLock;
+ std::list<DefaultRouteChangedCallback> m_defaultRouteCallbacks;
+ std::recursive_mutex m_defaultRouteCallbacksLock;
+
// Find record based on destination and mask.
std::list<RouteRecord>::iterator findRouteRecord(const Network &network);
@@ -202,6 +230,9 @@ private:
void routeChangeCallback(MIB_IPFORWARD_ROW2 *row, MIB_NOTIFICATION_TYPE notificationType);
static std::wstring FormatRegisteredRoute(const RegisteredRoute &route);
+
+ void notifyNewBestDefaultRoute(ADDRESS_FAMILY addressFamily, NET_LUID iface);
+ void notifyNoDefaultRoutesExist(ADDRESS_FAMILY addressFamily);
};
}
diff --git a/windows/winnet/src/winnet/winnet.cpp b/windows/winnet/src/winnet/winnet.cpp
index 4240f6a706..d36a60fd90 100644
--- a/windows/winnet/src/winnet/winnet.cpp
+++ b/windows/winnet/src/winnet/winnet.cpp
@@ -520,6 +520,111 @@ WinNet_DeleteRoute(
extern "C"
WINNET_LINKAGE
+bool
+WINNET_API
+WinNet_RegisterDefaultRouteChangedCallback(
+ WinNetDefaultRouteChangedCallback callback,
+ void *context,
+ void **registrationHandle
+)
+{
+ if (nullptr == g_RouteManager)
+ {
+ return false;
+ }
+
+ try
+ {
+ auto forwarder = [callback, context]
+ (RouteManager::DefaultRouteChangedEvent eventType, ADDRESS_FAMILY addressFamily, NET_LUID iface)
+ {
+ WINNET_DEFAULT_ROUTE_CHANGED_EVENT_TYPE translatedType;
+
+ switch (eventType)
+ {
+ case RouteManager::DefaultRouteChangedEvent::Updated:
+ {
+ translatedType = WINNET_DEFAULT_ROUTE_CHANGED_EVENT_TYPE_UPDATED;
+ break;
+ }
+ case RouteManager::DefaultRouteChangedEvent::Removed:
+ {
+ translatedType = WINNET_DEFAULT_ROUTE_CHANGED_EVENT_TYPE_REMOVED;
+ break;
+ }
+ default:
+ {
+ throw std::runtime_error("Unexpected default-route-changed event type");
+ }
+ }
+
+ WINNET_IP_FAMILY translatedFamily;
+
+ switch (addressFamily)
+ {
+ case AF_INET:
+ {
+ translatedFamily = WINNET_IP_FAMILY_V4;
+ break;
+ }
+ case AF_INET6:
+ {
+ translatedFamily = WINNET_IP_FAMILY_V6;
+ break;
+ }
+ default:
+ {
+ throw std::runtime_error("Unexpected default-route-changed address family");
+ }
+ }
+
+ callback(translatedType, translatedFamily, iface.Value, context);
+ };
+
+ *registrationHandle = g_RouteManager->registerDefaultRouteChangedCallback(forwarder);
+
+ return true;
+ }
+ catch (const std::exception &err)
+ {
+ common::error::UnwindException(err, g_RouteManagerLogSink);
+ return false;
+ }
+ catch (...)
+ {
+ return false;
+ }
+}
+
+extern "C"
+WINNET_LINKAGE
+void
+WINNET_API
+WinNet_UnregisterDefaultRouteChangedCallback(
+ void *registrationHandle
+)
+{
+ if (nullptr == g_RouteManager)
+ {
+ return;
+ }
+
+ try
+ {
+ g_RouteManager->unregisterDefaultRouteChangedCallback(registrationHandle);
+ }
+ catch (const std::exception &err)
+ {
+ g_RouteManagerLogSink->error("Failed to unregister default-route-changed callback");
+ common::error::UnwindException(err, g_RouteManagerLogSink);
+ }
+ catch (...)
+ {
+ }
+}
+
+extern "C"
+WINNET_LINKAGE
void
WINNET_API
WinNet_DeactivateRouteManager(
diff --git a/windows/winnet/src/winnet/winnet.h b/windows/winnet/src/winnet/winnet.h
index 7b4272fc72..c7a161c3d8 100644
--- a/windows/winnet/src/winnet/winnet.h
+++ b/windows/winnet/src/winnet/winnet.h
@@ -169,6 +169,52 @@ WinNet_DeleteRoute(
const WINNET_ROUTE *route
);
+enum WINNET_DEFAULT_ROUTE_CHANGED_EVENT_TYPE
+{
+ // Best default route changed.
+ WINNET_DEFAULT_ROUTE_CHANGED_EVENT_TYPE_UPDATED = 0,
+
+ // No default routes exist.
+ WINNET_DEFAULT_ROUTE_CHANGED_EVENT_TYPE_REMOVED = 1,
+};
+
+enum WINNET_IP_FAMILY
+{
+ WINNET_IP_FAMILY_V4 = 0,
+ WINNET_IP_FAMILY_V6 = 1,
+};
+
+typedef void (WINNET_API *WinNetDefaultRouteChangedCallback)
+(
+ WINNET_DEFAULT_ROUTE_CHANGED_EVENT_TYPE eventType,
+
+ // Signals which IP family the event relates to.
+ WINNET_IP_FAMILY family,
+
+ // For update events, signals the interface associated with the new best default route.
+ uint64_t interfaceLuid,
+
+ void *context
+);
+
+extern "C"
+WINNET_LINKAGE
+bool
+WINNET_API
+WinNet_RegisterDefaultRouteChangedCallback(
+ WinNetDefaultRouteChangedCallback callback,
+ void *context,
+ void **registrationHandle
+);
+
+extern "C"
+WINNET_LINKAGE
+void
+WINNET_API
+WinNet_UnregisterDefaultRouteChangedCallback(
+ void *registrationHandle
+);
+
extern "C"
WINNET_LINKAGE
void