diff options
| author | Odd Stranne <odd@mullvad.net> | 2019-10-31 23:19:49 +0100 |
|---|---|---|
| committer | Odd Stranne <odd@mullvad.net> | 2019-11-25 13:49:39 +0100 |
| commit | 61f840bc4827863e1c7d5ca281161508ded401c6 (patch) | |
| tree | ed5c2e546ab817aa7934fc1458c334711db02333 | |
| parent | c82f4dbf8c1d128009f23e635a6b02b0cb124b3b (diff) | |
| download | mullvadvpn-61f840bc4827863e1c7d5ca281161508ded401c6.tar.xz mullvadvpn-61f840bc4827863e1c7d5ca281161508ded401c6.zip | |
Rearrange all the things
| -rw-r--r-- | windows/winnet/src/winnet/adapters.cpp | 81 | ||||
| -rw-r--r-- | windows/winnet/src/winnet/adapters.h | 40 | ||||
| -rw-r--r-- | windows/winnet/src/winnet/routemanager.h | 238 | ||||
| -rw-r--r-- | windows/winnet/src/winnet/routing/defaultroutemonitor.cpp | 177 | ||||
| -rw-r--r-- | windows/winnet/src/winnet/routing/defaultroutemonitor.h | 69 | ||||
| -rw-r--r-- | windows/winnet/src/winnet/routing/helpers.cpp | 275 | ||||
| -rw-r--r-- | windows/winnet/src/winnet/routing/helpers.h | 46 | ||||
| -rw-r--r-- | windows/winnet/src/winnet/routing/routemanager.cpp (renamed from windows/winnet/src/winnet/routemanager.cpp) | 610 | ||||
| -rw-r--r-- | windows/winnet/src/winnet/routing/routemanager.h | 112 | ||||
| -rw-r--r-- | windows/winnet/src/winnet/routing/types.cpp | 84 | ||||
| -rw-r--r-- | windows/winnet/src/winnet/routing/types.h | 77 | ||||
| -rw-r--r-- | windows/winnet/src/winnet/winnet.cpp | 117 | ||||
| -rw-r--r-- | windows/winnet/src/winnet/winnet.vcxproj | 16 | ||||
| -rw-r--r-- | windows/winnet/src/winnet/winnet.vcxproj.filters | 33 |
14 files changed, 1120 insertions, 855 deletions
diff --git a/windows/winnet/src/winnet/adapters.cpp b/windows/winnet/src/winnet/adapters.cpp deleted file mode 100644 index a497c3484d..0000000000 --- a/windows/winnet/src/winnet/adapters.cpp +++ /dev/null @@ -1,81 +0,0 @@ -#include "stdafx.h" -#include "adapters.h" -#include "libcommon/error.h" -#include <sstream> -#include <stdexcept> - -const IP_ADAPTER_ADDRESSES *Adapters::next() const -{ - if (nullptr == m_currentEntry) - { - return nullptr; - } - - auto entry = m_currentEntry; - m_currentEntry = m_currentEntry->Next; - - return entry; -} - -Adapters::Adapters(DWORD family, DWORD flags) -{ - std::vector<uint8_t> buffer; - - static const size_t MSDN_RECOMMENDED_STARTING_BUFFER_SIZE = 1024 * 15; - buffer.resize(MSDN_RECOMMENDED_STARTING_BUFFER_SIZE); - - ULONG bufferSize = static_cast<ULONG>(buffer.size()); - auto bufferPointer = reinterpret_cast<IP_ADAPTER_ADDRESSES *>(&buffer[0]); - - // - // Acquire interfaces. - // - - for (;;) - { - const auto status = GetAdaptersAddresses(family, flags, nullptr, bufferPointer, &bufferSize); - - if (ERROR_SUCCESS == status) - { - break; - } - - if (ERROR_NO_DATA == status) - { - m_buffer.clear(); - m_currentEntry = nullptr; - - return; - } - - THROW_UNLESS(ERROR_BUFFER_OVERFLOW, status, "Probe required buffer size for GetAdaptersAddresses"); - - buffer.resize(bufferSize); - bufferPointer = reinterpret_cast<IP_ADAPTER_ADDRESSES *>(&buffer[0]); - } - - // - // Verify structure compatibility. - // The structure has been extended many times. - // - - const auto systemSize = bufferPointer->Length; - const auto codeSize = sizeof(IP_ADAPTER_ADDRESSES); - - if (systemSize < codeSize) - { - std::stringstream ss; - - ss << "Expecting IP_ADAPTER_ADDRESSES to have size " << codeSize << " bytes. " - << "Found structure with size " << systemSize << " bytes."; - - throw std::runtime_error(ss.str()); - } - - // - // Initialize members. - // - - m_buffer = std::move(buffer); - reset(); -} diff --git a/windows/winnet/src/winnet/adapters.h b/windows/winnet/src/winnet/adapters.h deleted file mode 100644 index e2f2e82e53..0000000000 --- a/windows/winnet/src/winnet/adapters.h +++ /dev/null @@ -1,40 +0,0 @@ -#pragma once - -#include <vector> -#include <winsock2.h> -#include <windows.h> -#include <iphlpapi.h> - -// -// This is a thin wrapper on top of GetAdaptersAddresses() -// in order to simplify memory management. -// - -class Adapters -{ - std::vector<uint8_t> m_buffer; - mutable const IP_ADAPTER_ADDRESSES *m_currentEntry; - -public: - - Adapters(const Adapters &) = delete; - Adapters &operator=(const Adapters &) = delete; - - Adapters(Adapters &&rhs) - : m_buffer(std::move(rhs.m_buffer)) - , m_currentEntry(rhs.m_currentEntry) - { - } - - Adapters(DWORD family, DWORD flags); - - const IP_ADAPTER_ADDRESSES *next() const; - - void reset() const - { - if (false == m_buffer.empty()) - { - m_currentEntry = reinterpret_cast<const IP_ADAPTER_ADDRESSES *>(&m_buffer[0]); - } - } -}; diff --git a/windows/winnet/src/winnet/routemanager.h b/windows/winnet/src/winnet/routemanager.h deleted file mode 100644 index 9ae007b757..0000000000 --- a/windows/winnet/src/winnet/routemanager.h +++ /dev/null @@ -1,238 +0,0 @@ -#pragma once - -#include <string> -#include <memory> -#include <vector> -#include <list> -#include <stdexcept> -#include <optional> -#include <mutex> -#include <winsock2.h> -#include <windows.h> -#include <ws2def.h> -#include <ws2ipdef.h> -#include <iphlpapi.h> -#include <netioapi.h> -#include <functional> - -// Custom header files below here. -// So broken networking headers don't get confused and break the compilation. -// === -#include <libcommon/string.h> -#include <libcommon/logging/ilogsink.h> - -namespace routemanager { - -using Network = IP_ADDRESS_PREFIX; -using NodeAddress = SOCKADDR_INET; - -bool EqualAddress(const Network &lhs, const Network &rhs); -bool EqualAddress(const NodeAddress &lhs, const NodeAddress &rhs); - -class Node -{ -public: - - Node(const std::optional<std::wstring> &deviceName, const std::optional<NodeAddress> &gateway) - : m_deviceName(deviceName) - , m_gateway(gateway) - { - if (false == m_deviceName.has_value() && false == m_gateway.has_value()) - { - throw std::runtime_error("Invalid node definition"); - } - - if (m_deviceName.has_value()) - { - const auto trimmed = common::string::Trim<>(m_deviceName.value()); - - if (trimmed.empty()) - { - throw std::runtime_error("Invalid device name in node definition"); - } - - m_deviceName = std::move(trimmed); - } - } - - const std::optional<std::wstring> &deviceName() const - { - return m_deviceName; - } - - const std::optional<NodeAddress> &gateway() const - { - return m_gateway; - } - - bool operator==(const Node &rhs) const - { - if (m_deviceName.has_value()) - { - if (false == rhs.m_deviceName.has_value() - || 0 != _wcsicmp(m_deviceName.value().c_str(), rhs.deviceName().value().c_str())) - { - return false; - } - } - - if (m_gateway.has_value()) - { - if (false == rhs.m_gateway.has_value() - || false == EqualAddress(m_gateway.value(), rhs.gateway().value())) - { - return false; - } - } - - return true; - } - -private: - - std::optional<std::wstring> m_deviceName; - std::optional<NodeAddress> m_gateway; -}; - -class Route -{ -public: - - Route(const Network &network, const std::optional<Node> &node) - : m_network(network) - , m_node(node) - { - } - - const Network &network() const - { - return m_network; - } - - const std::optional<Node> &node() const - { - return m_node; - } - - bool operator==(const Route &rhs) const - { - if (m_node.has_value()) - { - return rhs.node().has_value() - && EqualAddress(m_network, rhs.network()) - && m_node.value() == rhs.node().value(); - } - - return false == rhs.node().has_value() - && EqualAddress(m_network, rhs.network()); - } - -private: - - Network m_network; - std::optional<Node> m_node; -}; - -class RouteManager -{ -public: - - RouteManager(std::shared_ptr<common::logging::ILogSink> logSink); - ~RouteManager(); - - RouteManager(const RouteManager &) = delete; - RouteManager &operator=(const RouteManager &) = delete; - RouteManager(RouteManager &&) = default; - - void addRoutes(const std::vector<Route> &routes); - void addRoute(const Route &route); - - void deleteRoutes(const std::vector<Route> &routes); - void deleteRoute(const Route &route); - - enum class DefaultRouteChangedEvent - { - // The best default route changed. - Updated, - - // No default routes exist. - Removed, - }; - - using DefaultRouteChangedCallback = std::function<void - ( - DefaultRouteChangedEvent eventType, - - // Signals which IP family the event relates to. - ADDRESS_FAMILY addressFamily, - - // For update events, signals the interface associated with the new best default route. - NET_LUID iface - )>; - - using CallbackHandle = void*; - - CallbackHandle registerDefaultRouteChangedCallback(DefaultRouteChangedCallback callback); - void unregisterDefaultRouteChangedCallback(CallbackHandle handle); - -private: - - std::shared_ptr<common::logging::ILogSink> m_logSink; - - // These are the exact details derived from the route specification (`Route`). - // They are used when registering and deleting a route in the system. - struct RegisteredRoute - { - Network network; - NET_LUID luid; - NodeAddress nextHop; - }; - - struct RouteRecord - { - Route route; - RegisteredRoute registeredRoute; - }; - - std::list<RouteRecord> m_routes; - std::recursive_mutex m_routesLock; - - std::list<DefaultRouteChangedCallback> m_defaultRouteCallbacks; - std::recursive_mutex m_defaultRouteCallbacksLock; - - // Find record based on destination and mask. - std::list<RouteRecord>::iterator findRouteRecord(const Network &network); - - // Note: Same as above! - std::list<RouteRecord>::iterator findRouteRecord(const Route &route); - - RegisteredRoute addIntoRoutingTable(const Route &route); - void restoreIntoRoutingTable(const RegisteredRoute &route); - void deleteFromRoutingTable(const RegisteredRoute &route); - - enum class EventType - { - ADD_ROUTE, - DELETE_ROUTE, - }; - - struct EventEntry - { - EventType type; - RouteRecord record; - }; - - void undoEvents(const std::vector<EventEntry> &eventLog); - - HANDLE m_notificationHandle; - - static void NETIOAPI_API_ RouteChangeCallback(void *context, MIB_IPFORWARD_ROW2 *row, MIB_NOTIFICATION_TYPE notificationType); - void routeChangeCallback(MIB_IPFORWARD_ROW2 *row, MIB_NOTIFICATION_TYPE notificationType); - - static std::wstring FormatRegisteredRoute(const RegisteredRoute &route); - - void notifyNewBestDefaultRoute(ADDRESS_FAMILY addressFamily, NET_LUID iface); - void notifyNoDefaultRoutesExist(ADDRESS_FAMILY addressFamily); -}; - -} diff --git a/windows/winnet/src/winnet/routing/defaultroutemonitor.cpp b/windows/winnet/src/winnet/routing/defaultroutemonitor.cpp new file mode 100644 index 0000000000..55d7560904 --- /dev/null +++ b/windows/winnet/src/winnet/routing/defaultroutemonitor.cpp @@ -0,0 +1,177 @@ +#include "stdafx.h" +#include <libcommon/error.h> +#include "defaultroutemonitor.h" +#include "helpers.h" + +namespace winnet::routing +{ + +namespace +{ + +const uint32_t POINT_TWO_SECOND_BURST = 200; +const uint32_t TWO_SECOND_INTERFERENCE = 2000; + +} // anonymous namespace + +DefaultRouteMonitor::DefaultRouteMonitor +( + ADDRESS_FAMILY family, + Callback callback, + std::shared_ptr<common::logging::ILogSink> logSink +) + : m_family(family) + , m_callback(callback) + , m_logSink(logSink) + , m_evaluateRoutesGuard(std::make_unique<common::BurstGuard>( + std::bind(&DefaultRouteMonitor::evaluateRoutes, this), + POINT_TWO_SECOND_BURST, + TWO_SECOND_INTERFERENCE + )) +{ + try + { + m_bestRoute = GetBestDefaultRoute(m_family); + } + catch (...) + { + } + + const auto status = NotifyRouteChange2(AF_UNSPEC, RouteChangeCallback, this, FALSE, &m_routeNotificationHandle); + + THROW_UNLESS(NO_ERROR, status, "Register for route table change notifications"); + + try + { + const auto s2 = NotifyIpInterfaceChange(AF_UNSPEC, InterfaceChangeCallback, this, + FALSE, &m_interfaceNotificationHandle); + + THROW_UNLESS(NO_ERROR, status, "Register for network interface change notifications"); + } + catch (...) + { + CancelMibChangeNotify2(m_routeNotificationHandle); + throw; + } +} + +DefaultRouteMonitor::~DefaultRouteMonitor() +{ + // + // Cancel notifications to stop triggering the BurstGuard. + // + + CancelMibChangeNotify2(m_interfaceNotificationHandle); + CancelMibChangeNotify2(m_routeNotificationHandle); + + // + // Controlled destruction of BurstGuard to prevent it from calling here + // after other member variables have been destructed. + // + + m_evaluateRoutesGuard.reset(); +} + +//static +void NETIOAPI_API_ DefaultRouteMonitor::RouteChangeCallback +( + void *context, + MIB_IPFORWARD_ROW2 *row, + MIB_NOTIFICATION_TYPE +) +{ + // + // We're only interested in changes that add/remove/update a default route. + // + + if (0 != row->DestinationPrefix.PrefixLength + || false == RouteHasGateway(*row)) + { + return; + } + + reinterpret_cast<DefaultRouteMonitor *>(context)->m_evaluateRoutesGuard->trigger(); +} + +//static +void NETIOAPI_API_ DefaultRouteMonitor::InterfaceChangeCallback +( + void *context, + MIB_IPINTERFACE_ROW *, + MIB_NOTIFICATION_TYPE +) +{ + reinterpret_cast<DefaultRouteMonitor *>(context)->m_evaluateRoutesGuard->trigger(); +} + +void DefaultRouteMonitor::evaluateRoutes() +{ + std::scoped_lock<std::mutex> lock(m_evaluationLock); + + try + { + evaluateRoutesInner(); + } + catch (const std::exception &ex) + { + const auto msg = std::string("Failure while evaluating route table: ").append(ex.what()); + m_logSink->error(msg.c_str()); + } + catch (...) + { + m_logSink->error("Unspecified failure while evaluating route table"); + } +} + +void DefaultRouteMonitor::evaluateRoutesInner() +{ + std::optional<InterfaceAndGateway> currentBestRoute; + + try + { + currentBestRoute = GetBestDefaultRoute(m_family); + } + catch (...) + { + } + + // + // If there was no default route previously. + // + + if (false == m_bestRoute.has_value()) + { + if (currentBestRoute.has_value()) + { + m_bestRoute = currentBestRoute; + m_callback(EventType::Updated, m_bestRoute); + } + + return; + } + + // + // There used to be a default route. + // If there is not currently a default route. + // + + if (false == currentBestRoute.has_value()) + { + m_bestRoute.reset(); + m_callback(EventType::Removed, std::nullopt); + + return; + } + + // + // The current best route may have changed. + // + + if (m_bestRoute.value() != currentBestRoute.value()) + { + m_bestRoute = currentBestRoute; + m_callback(EventType::Updated, m_bestRoute); + } +} + +} diff --git a/windows/winnet/src/winnet/routing/defaultroutemonitor.h b/windows/winnet/src/winnet/routing/defaultroutemonitor.h new file mode 100644 index 0000000000..5575685a82 --- /dev/null +++ b/windows/winnet/src/winnet/routing/defaultroutemonitor.h @@ -0,0 +1,69 @@ +#pragma once + +#include <ifdef.h> +#include <ws2def.h> +#include <functional> +#include <optional> +#include <memory> +#include <mutex> +#include <libcommon/logging/ilogsink.h> +#include <libcommon/burstguard.h> +#include "types.h" + +namespace winnet::routing +{ + +class DefaultRouteMonitor +{ +public: + + enum class EventType + { + // The best default route changed. + Updated, + + // No default routes exist. + Removed, + }; + + using Callback = std::function<void + ( + EventType eventType, + + // For update events, data associated with the new best default route. + const std::optional<InterfaceAndGateway> &route + )>; + + DefaultRouteMonitor(ADDRESS_FAMILY family, Callback callback, std::shared_ptr<common::logging::ILogSink> logSink); + ~DefaultRouteMonitor(); + + DefaultRouteMonitor(const DefaultRouteMonitor &) = delete; + DefaultRouteMonitor(DefaultRouteMonitor &&) = delete; + DefaultRouteMonitor &operator=(const DefaultRouteMonitor &) = delete; + DefaultRouteMonitor &operator=(DefaultRouteMonitor &&) = delete; + +private: + + ADDRESS_FAMILY m_family; + Callback m_callback; + std::shared_ptr<common::logging::ILogSink> m_logSink; + + // This can't be a plain member variable. + // We need to be able to delete it explicitly in order to have a controlled tear down. + std::unique_ptr<common::BurstGuard> m_evaluateRoutesGuard; + + std::optional<InterfaceAndGateway> m_bestRoute; + + HANDLE m_routeNotificationHandle; + HANDLE m_interfaceNotificationHandle; + + std::mutex m_evaluationLock; + + static void NETIOAPI_API_ RouteChangeCallback(void *context, MIB_IPFORWARD_ROW2 *row, MIB_NOTIFICATION_TYPE notificationType); + static void NETIOAPI_API_ InterfaceChangeCallback(void *context, MIB_IPINTERFACE_ROW *row, MIB_NOTIFICATION_TYPE notificationType); + + void evaluateRoutes(); + void evaluateRoutesInner(); +}; + +} diff --git a/windows/winnet/src/winnet/routing/helpers.cpp b/windows/winnet/src/winnet/routing/helpers.cpp new file mode 100644 index 0000000000..cabf19bce6 --- /dev/null +++ b/windows/winnet/src/winnet/routing/helpers.cpp @@ -0,0 +1,275 @@ +#include "stdafx.h" +#include "helpers.h" +#include <stdexcept> +#include <ws2def.h> +#include <in6addr.h> +#include <numeric> +//#include <netioapi.h> +#include <libcommon/error.h> +#include <libcommon/memory.h> + +namespace winnet::routing +{ + +bool EqualAddress(const Network &lhs, const Network &rhs) +{ + if (lhs.PrefixLength != rhs.PrefixLength) + { + return false; + } + + return EqualAddress(lhs.Prefix, rhs.Prefix); +} + +bool EqualAddress(const NodeAddress &lhs, const NodeAddress &rhs) +{ + if (lhs.si_family != rhs.si_family) + { + return false; + } + + switch (lhs.si_family) + { + case AF_INET: + { + return lhs.Ipv4.sin_addr.s_addr == rhs.Ipv4.sin_addr.s_addr; + } + case AF_INET6: + { + return 0 == memcmp(&lhs.Ipv6.sin6_addr, &rhs.Ipv6.sin6_addr, sizeof(IN6_ADDR)); + } + default: + { + throw std::runtime_error("Invalid address family for network address"); + } + } +} + +bool EqualAddress(const SOCKADDR_INET *lhs, const SOCKET_ADDRESS *rhs) +{ + if (lhs->si_family != rhs->lpSockaddr->sa_family) + { + return false; + } + + switch (lhs->si_family) + { + case AF_INET: + { + auto typedRhs = reinterpret_cast<const SOCKADDR_IN *>(rhs->lpSockaddr); + return lhs->Ipv4.sin_addr.s_addr == typedRhs->sin_addr.s_addr; + } + case AF_INET6: + { + auto typedRhs = reinterpret_cast<const SOCKADDR_IN6 *>(rhs->lpSockaddr); + return 0 == memcmp(lhs->Ipv6.sin6_addr.u.Byte, typedRhs->sin6_addr.u.Byte, 16); + } + default: + { + throw std::runtime_error("Missing case handler in switch clause"); + } + } +} + +bool GetAdapterInterface(NET_LUID adapter, ADDRESS_FAMILY addressFamily, MIB_IPINTERFACE_ROW *iface) +{ + memset(iface, 0, sizeof(MIB_IPINTERFACE_ROW)); + + iface->Family = addressFamily; + iface->InterfaceLuid = adapter; + + return NO_ERROR == GetIpInterfaceEntry(iface); +} + +std::vector<AnnotatedRoute> AnnotateRoutes(const std::vector<const MIB_IPFORWARD_ROW2 *> &routes) +{ + std::vector<AnnotatedRoute> annotated; + annotated.reserve(routes.size()); + + for (auto route : routes) + { + MIB_IPINTERFACE_ROW iface; + + if (false == GetAdapterInterface(route->InterfaceLuid, route->DestinationPrefix.Prefix.si_family, &iface)) + { + continue; + } + + annotated.emplace_back + ( + AnnotatedRoute{ route, bool_cast(iface.Connected), route->Metric + iface.Metric } + ); + } + + return annotated; +} + +bool RouteHasGateway(const MIB_IPFORWARD_ROW2 &route) +{ + switch (route.NextHop.si_family) + { + case AF_INET: + { + return 0 != route.NextHop.Ipv4.sin_addr.s_addr; + } + case AF_INET6: + { + const uint8_t *begin = &route.NextHop.Ipv6.sin6_addr.u.Byte[0]; + const uint8_t *end = begin + 16; + + return 0 != std::accumulate(begin, end, 0); + } + default: + { + return false; + } + }; +} + +InterfaceAndGateway GetBestDefaultRoute(ADDRESS_FAMILY family) +{ + PMIB_IPFORWARD_TABLE2 table; + + auto status = GetIpForwardTable2(family, &table); + + THROW_UNLESS(NO_ERROR, status, "Acquire route table"); + + common::memory::ScopeDestructor sd; + + sd += [table] + { + FreeMibTable(table); + }; + + std::vector<const MIB_IPFORWARD_ROW2 *> candidates; + candidates.reserve(table->NumEntries); + + // + // Enumerate routes looking for: route 0/0 && gateway specified. + // + + for (ULONG i = 0; i < table->NumEntries; ++i) + { + const MIB_IPFORWARD_ROW2 &candidate = table->Table[i]; + + if (0 == candidate.DestinationPrefix.PrefixLength + && RouteHasGateway(candidate)) + { + candidates.emplace_back(&candidate); + } + } + + auto annotated = AnnotateRoutes(candidates); + + if (annotated.empty()) + { + throw std::runtime_error("Unable to determine details of default route"); + } + + // + // Sort on (active, effectiveMetric) ascending by metric. + // + + std::sort(annotated.begin(), annotated.end(), [](const AnnotatedRoute &lhs, const AnnotatedRoute &rhs) + { + if (lhs.active == rhs.active) + { + return lhs.effectiveMetric < rhs.effectiveMetric; + } + + return lhs.active && false == rhs.active; + }); + + // + // Ensure the top rated route is active. + // + + if (false == annotated[0].active) + { + throw std::runtime_error("Unable to identify active default route"); + } + + return InterfaceAndGateway { annotated[0].route->InterfaceLuid, annotated[0].route->NextHop }; +} + +bool AdapterInterfaceEnabled(const IP_ADAPTER_ADDRESSES *adapter, ADDRESS_FAMILY family) +{ + switch (family) + { + case AF_INET: + { + return 0 != adapter->Ipv4Enabled; + } + case AF_INET6: + { + return 0 != adapter->Ipv6Enabled; + } + default: + { + throw std::runtime_error("Missing case handler in switch clause"); + } + } +} + +std::vector<const SOCKET_ADDRESS *> IsolateGatewayAddresses +( + PIP_ADAPTER_GATEWAY_ADDRESS_LH head, + ADDRESS_FAMILY family +) +{ + std::vector<const SOCKET_ADDRESS *> matches; + + for (auto gateway = head; nullptr != gateway; gateway = gateway->Next) + { + if (family == gateway->Address.lpSockaddr->sa_family) + { + matches.emplace_back(&gateway->Address); + } + } + + return matches; +} + +bool AddressPresent(const std::vector<const SOCKET_ADDRESS *> &hay, const SOCKADDR_INET *needle) +{ + for (const auto candidate : hay) + { + if (EqualAddress(needle, candidate)) + { + return true; + } + } + + return false; +} + +//NodeAddress ConvertSocketAddress(const SOCKET_ADDRESS *sa) +//{ +// NodeAddress out = { 0 }; +// +// switch (sa->lpSockaddr->sa_family) +// { +// case AF_INET: +// { +// out.si_family = AF_INET; +// out.Ipv4 = *reinterpret_cast<SOCKADDR_IN *>(sa->lpSockaddr); +// +// break; +// } +// case AF_INET6: +// { +// out.si_family = AF_INET6; +// out.Ipv6 = *reinterpret_cast<SOCKADDR_IN6 *>(sa->lpSockaddr); +// +// break; +// } +// default: +// { +// throw std::runtime_error("Missing case handler in switch clause"); +// } +// }; +// +// return out; +//} + +} diff --git a/windows/winnet/src/winnet/routing/helpers.h b/windows/winnet/src/winnet/routing/helpers.h new file mode 100644 index 0000000000..3ef5e85b75 --- /dev/null +++ b/windows/winnet/src/winnet/routing/helpers.h @@ -0,0 +1,46 @@ +#pragma once + +#include "types.h" +#include <vector> + +namespace winnet::routing +{ + +bool EqualAddress(const Network &lhs, const Network &rhs); +bool EqualAddress(const NodeAddress &lhs, const NodeAddress &rhs); +bool EqualAddress(const SOCKADDR_INET *lhs, const SOCKET_ADDRESS *rhs); + +bool GetAdapterInterface(NET_LUID adapter, ADDRESS_FAMILY addressFamily, MIB_IPINTERFACE_ROW *iface); + +struct AnnotatedRoute +{ + const MIB_IPFORWARD_ROW2 *route; + bool active; + uint32_t effectiveMetric; +}; + +template<typename T> +bool bool_cast(const T &value) +{ + return 0 != value; +} + +std::vector<AnnotatedRoute> AnnotateRoutes(const std::vector<const MIB_IPFORWARD_ROW2 *> &routes); + +bool RouteHasGateway(const MIB_IPFORWARD_ROW2 &route); + +InterfaceAndGateway GetBestDefaultRoute(ADDRESS_FAMILY family); + +bool AdapterInterfaceEnabled(const IP_ADAPTER_ADDRESSES *adapter, ADDRESS_FAMILY family); + +std::vector<const SOCKET_ADDRESS *> IsolateGatewayAddresses +( + PIP_ADAPTER_GATEWAY_ADDRESS_LH head, + ADDRESS_FAMILY family +); + +bool AddressPresent(const std::vector<const SOCKET_ADDRESS *> &hay, const SOCKADDR_INET *needle); + +//NodeAddress ConvertSocketAddress(const SOCKET_ADDRESS *sa); + +} diff --git a/windows/winnet/src/winnet/routemanager.cpp b/windows/winnet/src/winnet/routing/routemanager.cpp index c1e897b578..668e64bb68 100644 --- a/windows/winnet/src/winnet/routemanager.cpp +++ b/windows/winnet/src/winnet/routing/routemanager.cpp @@ -1,234 +1,29 @@ #include "stdafx.h" #include "routemanager.h" -#include "adapters.h" +#include "helpers.h" #include <libcommon/error.h> #include <libcommon/memory.h> #include <libcommon/string.h> +#include <libcommon/network/adapters.h> #include <vector> #include <algorithm> #include <numeric> #include <sstream> +#include <stdexcept> -using LockType = std::scoped_lock<std::recursive_mutex>; +using AutoLockType = std::scoped_lock<std::mutex>; +using AutoRecursiveLockType = std::scoped_lock<std::recursive_mutex>; +using namespace std::placeholders; -namespace -{ - -bool GetAdapterInterface(NET_LUID adapter, ADDRESS_FAMILY addressFamily, MIB_IPINTERFACE_ROW *iface) -{ - memset(iface, 0, sizeof(MIB_IPINTERFACE_ROW)); - - iface->Family = addressFamily; - iface->InterfaceLuid = adapter; - - return NO_ERROR == GetIpInterfaceEntry(iface); -} - -struct AnnotatedRoute -{ - const MIB_IPFORWARD_ROW2 *route; - bool active; - uint32_t effectiveMetric; -}; - -template<typename T> -bool bool_cast(const T &value) -{ - return 0 != value; -} - -std::vector<AnnotatedRoute> AnnotateRoutes(const std::vector<const MIB_IPFORWARD_ROW2 *> &routes) -{ - std::vector<AnnotatedRoute> annotated; - annotated.reserve(routes.size()); - - for (auto route : routes) - { - MIB_IPINTERFACE_ROW iface; - - if (false == GetAdapterInterface(route->InterfaceLuid, route->DestinationPrefix.Prefix.si_family, &iface)) - { - continue; - } - - annotated.emplace_back - ( - AnnotatedRoute{ route, bool_cast(iface.Connected), route->Metric + iface.Metric } - ); - } - - return annotated; -} - -bool RouteHasGateway(const MIB_IPFORWARD_ROW2 &route) -{ - switch (route.NextHop.si_family) - { - case AF_INET: - { - return 0 != route.NextHop.Ipv4.sin_addr.s_addr; - } - case AF_INET6: - { - const uint8_t *begin = &route.NextHop.Ipv6.sin6_addr.u.Byte[0]; - const uint8_t *end = begin + 16; - - return 0 != std::accumulate(begin, end, 0); - } - default: - { - return false; - } - }; -} - -struct InterfaceAndGateway -{ - NET_LUID iface; - routemanager::NodeAddress gateway; -}; - -InterfaceAndGateway ResolveNodeFromDefaultRoute(ADDRESS_FAMILY family) -{ - PMIB_IPFORWARD_TABLE2 table; - - auto status = GetIpForwardTable2(family, &table); - - THROW_UNLESS(NO_ERROR, status, "Acquire route table"); - - common::memory::ScopeDestructor sd; - - sd += [table] - { - FreeMibTable(table); - }; - - std::vector<const MIB_IPFORWARD_ROW2 *> candidates; - candidates.reserve(table->NumEntries); - - // - // Enumerate routes looking for: route 0/0 && gateway specified. - // - - for (ULONG i = 0; i < table->NumEntries; ++i) - { - const MIB_IPFORWARD_ROW2 &candidate = table->Table[i]; - - if (0 == candidate.DestinationPrefix.PrefixLength - && RouteHasGateway(candidate)) - { - candidates.emplace_back(&candidate); - } - } - - auto annotated = AnnotateRoutes(candidates); - - if (annotated.empty()) - { - throw std::runtime_error("Unable to determine details of default route"); - } - - // - // Sort on (active, effectiveMetric) ascending by metric. - // - - std::sort(annotated.begin(), annotated.end(), [](const AnnotatedRoute &lhs, const AnnotatedRoute &rhs) - { - if (lhs.active == rhs.active) - { - return lhs.effectiveMetric < rhs.effectiveMetric; - } - - return lhs.active && false == rhs.active; - }); - - // - // Ensure the top rated route is active. - // - - if (false == annotated[0].active) - { - throw std::runtime_error("Unable to identify active default route"); - } - - return InterfaceAndGateway { annotated[0].route->InterfaceLuid, annotated[0].route->NextHop }; -} - -bool AdapterInterfaceEnabled(const IP_ADAPTER_ADDRESSES *adapter, ADDRESS_FAMILY family) -{ - switch (family) - { - case AF_INET: - { - return 0 != adapter->Ipv4Enabled; - } - case AF_INET6: - { - return 0 != adapter->Ipv6Enabled; - } - default: - { - throw std::runtime_error("Missing case handler in switch clause"); - } - } -} - -std::vector<const SOCKET_ADDRESS *> IsolateGatewayAddresses( - PIP_ADAPTER_GATEWAY_ADDRESS_LH head, ADDRESS_FAMILY family) -{ - std::vector<const SOCKET_ADDRESS *> matches; - - for (auto gateway = head; nullptr != gateway; gateway = gateway->Next) - { - if (family == gateway->Address.lpSockaddr->sa_family) - { - matches.emplace_back(&gateway->Address); - } - } - - return matches; -} - -bool EqualAddress(const SOCKADDR_INET *lhs, const SOCKET_ADDRESS *rhs) +namespace winnet::routing { - if (lhs->si_family != rhs->lpSockaddr->sa_family) - { - return false; - } - switch (lhs->si_family) - { - case AF_INET: - { - auto typedRhs = reinterpret_cast<const SOCKADDR_IN *>(rhs->lpSockaddr); - return lhs->Ipv4.sin_addr.s_addr == typedRhs->sin_addr.s_addr; - } - case AF_INET6: - { - auto typedRhs = reinterpret_cast<const SOCKADDR_IN6 *>(rhs->lpSockaddr); - return 0 == memcmp(lhs->Ipv6.sin6_addr.u.Byte, typedRhs->sin6_addr.u.Byte, 16); - } - default: - { - throw std::runtime_error("Missing case handler in switch clause"); - } - } -} - -bool AddressPresent(const std::vector<const SOCKET_ADDRESS *> &hay, const SOCKADDR_INET *needle) +namespace { - for (const auto candidate : hay) - { - if (EqualAddress(needle, candidate)) - { - return true; - } - } - return false; -} +using Adapters = common::network::Adapters; -NET_LUID InterfaceLuidFromGateway(const routemanager::NodeAddress &gateway) +NET_LUID InterfaceLuidFromGateway(const NodeAddress &gateway) { const DWORD adapterFlags = GAA_FLAG_SKIP_ANYCAST | GAA_FLAG_SKIP_MULTICAST | GAA_FLAG_SKIP_DNS_SERVER | GAA_FLAG_SKIP_FRIENDLY_NAME | GAA_FLAG_INCLUDE_GATEWAYS; @@ -284,7 +79,41 @@ NET_LUID InterfaceLuidFromGateway(const routemanager::NodeAddress &gateway) return matches[0]->Luid; } -InterfaceAndGateway ResolveNode(ADDRESS_FAMILY family, const std::optional<routemanager::Node> &optionalNode) +bool ParseStringEncodedLuid(const std::wstring &encodedLuid, NET_LUID &luid) +{ + // + // The `#` is a valid character in adapter names so we use `?` instead. + // The LUID is thus prefixed with `?` and hex encoded and left-padded with zeroes. + // E.g. `?deadbeefcafebabe` or `?000dbeefcafebabe`. + // + + static const size_t StringEncodedLuidLength = 17; + + if (encodedLuid.size() != StringEncodedLuidLength + || L'?' != encodedLuid[0]) + { + return false; + } + + try + { + std::wstringstream ss; + + ss << std::hex << &encodedLuid[1]; + ss >> luid.Value; + } + catch (...) + { + const auto ansi = common::string::ToAnsi(encodedLuid); + const auto err = std::string("Failed to parse string encoded LUID: ").append(ansi); + + std::throw_with_nested(std::runtime_error(err)); + } + + return true; +} + +InterfaceAndGateway ResolveNode(ADDRESS_FAMILY family, const std::optional<Node> &optionalNode) { // // There are four cases: @@ -297,7 +126,7 @@ InterfaceAndGateway ResolveNode(ADDRESS_FAMILY family, const std::optional<route if (false == optionalNode.has_value()) { - return ResolveNodeFromDefaultRoute(family); + return GetBestDefaultRoute(family); } const auto &node = optionalNode.value(); @@ -307,34 +136,8 @@ InterfaceAndGateway ResolveNode(ADDRESS_FAMILY family, const std::optional<route const auto &deviceName = node.deviceName().value(); NET_LUID luid; - // - // Try to parse a string encoded LUID. - // The `#` is a valid character in adapter names so we use `?` instead. - // The LUID is thus prefixed with `?` and hex encoded and left-padded with zeroes - // E.g. `?deadbeefcafebabe` or `?000dbeefcafebabe` - // - - static const size_t StringEncodedLuidLength = 17; - - if (StringEncodedLuidLength == deviceName.size() - && L'?' == deviceName[0]) - { - try - { - std::wstringstream ss; - - ss << std::hex << &deviceName[1]; - ss >> luid.Value; - } - catch (...) - { - const auto ansiName = common::string::ToAnsi(deviceName); - const auto err = std::string("Failed to parse string encoded LUID: ").append(ansiName); - - std::throw_with_nested(std::runtime_error(err)); - } - } - else if (0 != ConvertInterfaceAliasToLuid(deviceName.c_str(), &luid)) + if (false == ParseStringEncodedLuid(deviceName, luid) + && 0 != ConvertInterfaceAliasToLuid(deviceName.c_str(), &luid)) { const auto ansiName = common::string::ToAnsi(deviceName); const auto err = std::string("Unable to derive interface LUID from interface alias: ").append(ansiName); @@ -344,7 +147,7 @@ InterfaceAndGateway ResolveNode(ADDRESS_FAMILY family, const std::optional<route auto onLinkProvider = [&family]() { - routemanager::NodeAddress onLink = { 0 }; + NodeAddress onLink = { 0 }; onLink.si_family = family; return onLink; @@ -360,42 +163,25 @@ InterfaceAndGateway ResolveNode(ADDRESS_FAMILY family, const std::optional<route return InterfaceAndGateway{ InterfaceLuidFromGateway(node.gateway().value()), node.gateway().value() }; } -routemanager::NodeAddress ConvertSocketAddress(const SOCKET_ADDRESS *sa) +// TODO: Move to libcommon +uint32_t ByteSwap(uint32_t val) { - routemanager::NodeAddress out = { 0 }; - - switch (sa->lpSockaddr->sa_family) - { - case AF_INET: - { - out.si_family = AF_INET; - out.Ipv4 = *reinterpret_cast<SOCKADDR_IN *>(sa->lpSockaddr); - - break; - } - case AF_INET6: - { - out.si_family = AF_INET6; - out.Ipv6 = *reinterpret_cast<SOCKADDR_IN6 *>(sa->lpSockaddr); - - break; - } - default: - { - throw std::runtime_error("Missing case handler in switch clause"); - } - }; - - return out; + return + ( + ((val & 0xFF) << 24) | + ((val & 0xFF00) << 8) | + ((val & 0xFF0000) >> 8) | + ((val & 0xFF000000) >> 24) + ); } -std::wstring FormatNetwork(const routemanager::Network &network) +std::wstring FormatNetwork(const Network &network) { switch (network.Prefix.si_family) { case AF_INET: { - return common::string::FormatIpv4(network.Prefix.Ipv4.sin_addr.s_addr, network.PrefixLength); + return common::string::FormatIpv4(ByteSwap(network.Prefix.Ipv4.sin_addr.s_addr), network.PrefixLength); } case AF_INET6: { @@ -408,55 +194,35 @@ std::wstring FormatNetwork(const routemanager::Network &network) } } -} // anon namespace - -namespace routemanager { - -bool EqualAddress(const Network &lhs, const Network &rhs) -{ - if (lhs.PrefixLength != rhs.PrefixLength) - { - return false; - } - - return EqualAddress(lhs.Prefix, rhs.Prefix); -} - -bool EqualAddress(const NodeAddress &lhs, const NodeAddress &rhs) -{ - if (lhs.si_family != rhs.si_family) - { - return false; - } - - switch (lhs.si_family) - { - case AF_INET: - { - return lhs.Ipv4.sin_addr.s_addr == rhs.Ipv4.sin_addr.s_addr; - } - case AF_INET6: - { - return 0 == memcmp(&lhs.Ipv6.sin6_addr, &rhs.Ipv6.sin6_addr, sizeof(IN6_ADDR)); - } - default: - { - throw std::runtime_error("Invalid address family for network address"); - } - } -} +} // anonymous namespace RouteManager::RouteManager(std::shared_ptr<common::logging::ILogSink> logSink) : m_logSink(logSink) + , m_routeMonitorV4(std::make_unique<DefaultRouteMonitor>( + static_cast<ADDRESS_FAMILY>(AF_INET), + std::bind(&RouteManager::defaultRouteChanged, this, static_cast<ADDRESS_FAMILY>(AF_INET), _1, _2), + logSink + )) + , m_routeMonitorV6(std::make_unique<DefaultRouteMonitor>( + static_cast<ADDRESS_FAMILY>(AF_INET6), + std::bind(&RouteManager::defaultRouteChanged, this, static_cast<ADDRESS_FAMILY>(AF_INET6), _1, _2), + logSink + )) { - const auto status = NotifyRouteChange2(AF_UNSPEC, RouteChangeCallback, this, FALSE, &m_notificationHandle); - - THROW_UNLESS(NO_ERROR, status, "Register for route table change notifications"); } RouteManager::~RouteManager() { - CancelMibChangeNotify2(m_notificationHandle); + // + // Stop callbacks that are triggered by events in Windows from coming in. + // + + m_routeMonitorV4.reset(); + m_routeMonitorV6.reset(); + + // + // Delete all routes owned by us. + // for (const auto &record : m_routes) { @@ -479,7 +245,7 @@ RouteManager::~RouteManager() void RouteManager::addRoutes(const std::vector<Route> &routes) { - LockType lock(m_routesLock); + AutoLockType lock(m_routesLock); std::vector<EventEntry> eventLog; @@ -512,7 +278,7 @@ void RouteManager::addRoutes(const std::vector<Route> &routes) void RouteManager::addRoute(const Route &route) { - LockType lock(m_routesLock); + AutoLockType lock(m_routesLock); std::optional<RouteRecord> deletedRecord; @@ -572,7 +338,7 @@ void RouteManager::addRoute(const Route &route) void RouteManager::deleteRoutes(const std::vector<Route> &routes) { - LockType lock(m_routesLock); + AutoLockType lock(m_routesLock); std::vector<EventEntry> eventLog; @@ -607,7 +373,7 @@ void RouteManager::deleteRoutes(const std::vector<Route> &routes) void RouteManager::deleteRoute(const Route &route) { - LockType lock(m_routesLock); + AutoLockType lock(m_routesLock); auto record = findRouteRecord(route); @@ -627,7 +393,7 @@ void RouteManager::deleteRoute(const Route &route) RouteManager::CallbackHandle RouteManager::registerDefaultRouteChangedCallback(DefaultRouteChangedCallback callback) { - LockType lock(m_defaultRouteCallbacksLock); + AutoRecursiveLockType lock(m_defaultRouteCallbacksLock); m_defaultRouteCallbacks.emplace_back(callback); @@ -637,7 +403,7 @@ RouteManager::CallbackHandle RouteManager::registerDefaultRouteChangedCallback(D void RouteManager::unregisterDefaultRouteChangedCallback(CallbackHandle handle) { - LockType lock(m_defaultRouteCallbacksLock); + AutoRecursiveLockType lock(m_defaultRouteCallbacksLock); for (auto it = m_defaultRouteCallbacks.begin(); it != m_defaultRouteCallbacks.end(); ++it) { @@ -765,7 +531,6 @@ void RouteManager::undoEvents(const std::vector<EventEntry> &eventLog) throw std::logic_error("Missing case handler in switch clause"); } } - } catch (const std::exception &ex) { @@ -775,46 +540,95 @@ void RouteManager::undoEvents(const std::vector<EventEntry> &eventLog) } } -//static -void NETIOAPI_API_ -RouteManager::RouteChangeCallback(void *context, MIB_IPFORWARD_ROW2 *row, MIB_NOTIFICATION_TYPE notificationType) +// static +std::wstring RouteManager::FormatRegisteredRoute(const RegisteredRoute &route) { - auto instance = reinterpret_cast<RouteManager *>(context); + // + // TODO: Fix broken IP formatting + // Update FormatIpv4 function with an additional argument to specify network/host byte order. + // - try + std::wstringstream ss; + + if (AF_INET == route.network.Prefix.si_family) { - instance->routeChangeCallback(row, notificationType); + std::wstring gateway(L"\"On-link\""); + + if (0 != route.nextHop.Ipv4.sin_addr.s_addr) + { + gateway = common::string::FormatIpv4(ByteSwap(route.nextHop.Ipv4.sin_addr.s_addr)); + } + + ss << common::string::FormatIpv4(ByteSwap(route.network.Prefix.Ipv4.sin_addr.s_addr), route.network.PrefixLength) + << L" with gateway " << gateway + << L" on interface with LUID 0x" << std::hex << route.luid.Value; } - catch (const std::exception &ex) + else if (AF_INET6 == route.network.Prefix.si_family) { - auto msg = std::string("Failure while processing route change notification: ").append(ex.what()); - instance->m_logSink->error(msg.c_str()); + std::wstring gateway(L"\"On-link\""); + + const uint8_t *begin = &route.nextHop.Ipv6.sin6_addr.u.Byte[0]; + const uint8_t *end = begin + 16; + + if (0 != std::accumulate(begin, end, 0)) + { + gateway = common::string::FormatIpv6(route.nextHop.Ipv6.sin6_addr.u.Byte); + } + + ss << common::string::FormatIpv6(route.network.Prefix.Ipv6.sin6_addr.u.Byte, route.network.PrefixLength) + << L" with gateway " << gateway + << L" on interface with LUID 0x" << std::hex << route.luid.Value; } - catch (...) + else { - instance->m_logSink->error("Unspecified failure while processing route change notification"); + ss << L"Failed to format route details"; } + + return ss.str(); } -void RouteManager::routeChangeCallback(MIB_IPFORWARD_ROW2 *row, MIB_NOTIFICATION_TYPE) +void RouteManager::defaultRouteChanged(ADDRESS_FAMILY family, DefaultRouteMonitor::EventType eventType, + const std::optional<InterfaceAndGateway> &route) { // - // We're only interested in changes that add/remove/update a default route. + // Forward event to all registered listeners. // - if (0 != row->DestinationPrefix.PrefixLength - || false == RouteHasGateway(*row)) + m_defaultRouteCallbacksLock.lock(); + + for (const auto &callback : m_defaultRouteCallbacks) { - return; + try + { + callback(eventType, family, route); + } + catch (const std::exception &ex) + { + const auto msg = std::string("Failure in default-route-changed callback: ").append(ex.what()); + m_logSink->error(msg.c_str()); + } + catch (...) + { + m_logSink->error("Unspecified failure in default-route-changed callback"); + } } + m_defaultRouteCallbacksLock.unlock(); + // - // Are we managing any static routes that rely on the default route? + // Examine event to determine if best default route has changed. // - const auto family = row->DestinationPrefix.Prefix.si_family; + if (DefaultRouteMonitor::EventType::Updated != eventType) + { + return; + } - LockType lock(m_routesLock); + // + // Examine our routes to see if any of them are policy bound to the best default route. + // + + AutoLockType routesLock(m_routesLock); using RecordIterator = std::list<RouteRecord>::iterator; @@ -835,50 +649,16 @@ void RouteManager::routeChangeCallback(MIB_IPFORWARD_ROW2 *row, MIB_NOTIFICATION } // - // Assume all of our affected routes are using the same gateway on the same adapter. - // - // Has the current best default route changed? - // - - const auto oldBestRoute = InterfaceAndGateway - { - affectedRoutes.front()->registeredRoute.luid, - affectedRoutes.front()->registeredRoute.nextHop - }; - - InterfaceAndGateway newBestRoute = { 0 }; - - try - { - newBestRoute = ResolveNodeFromDefaultRoute(family); - } - catch (const std::exception &ex) - { - const auto msg = std::string("Failed to resolve default network route. " \ - "Assuming there isn't one: ").append(ex.what()); - - m_logSink->info(msg.c_str()); - - return; - } - - if (oldBestRoute.iface.Value == newBestRoute.iface.Value - && EqualAddress(oldBestRoute.gateway, newBestRoute.gateway)) - { - return; - } - - // - // Best default route has changed. Update affected routes. + // Update all affected routes. // - m_logSink->info("Default route has changed. Refreshing dependent routes"); + m_logSink->info("Best default route has changed. Refreshing dependent routes"); - for (auto route : affectedRoutes) + for (auto &it : affectedRoutes) { try { - deleteFromRoutingTable(route->registeredRoute); + deleteFromRoutingTable(it->registeredRoute); } catch (const std::exception &ex) { @@ -890,12 +670,12 @@ void RouteManager::routeChangeCallback(MIB_IPFORWARD_ROW2 *row, MIB_NOTIFICATION continue; } - route->registeredRoute.luid = newBestRoute.iface; - route->registeredRoute.nextHop = newBestRoute.gateway; + it->registeredRoute.luid = route.value().iface; + it->registeredRoute.nextHop = route.value().gateway; try { - restoreIntoRoutingTable(route->registeredRoute); + restoreIntoRoutingTable(it->registeredRoute); } catch (const std::exception &ex) { @@ -907,72 +687,6 @@ void RouteManager::routeChangeCallback(MIB_IPFORWARD_ROW2 *row, MIB_NOTIFICATION continue; } } - - // - // TODO-MAYBE: Notify clients about new default route. - // -} - -// static -std::wstring RouteManager::FormatRegisteredRoute(const RegisteredRoute &route) -{ - std::wstringstream ss; - - if (AF_INET == route.network.Prefix.si_family) - { - std::wstring gateway(L"\"On-link\""); - - if (0 != route.nextHop.Ipv4.sin_addr.s_addr) - { - gateway = common::string::FormatIpv4(route.nextHop.Ipv4.sin_addr.s_addr); - } - - ss << common::string::FormatIpv4(route.network.Prefix.Ipv4.sin_addr.s_addr, route.network.PrefixLength) - << L" with gateway " << gateway - << L" on interface with LUID 0x" << std::hex << route.luid.Value; - } - else if (AF_INET6 == route.network.Prefix.si_family) - { - std::wstring gateway(L"\"On-link\""); - - const uint8_t *begin = &route.nextHop.Ipv6.sin6_addr.u.Byte[0]; - const uint8_t *end = begin + 16; - - if (0 != std::accumulate(begin, end, 0)) - { - gateway = common::string::FormatIpv6(route.nextHop.Ipv6.sin6_addr.u.Byte); - } - - ss << common::string::FormatIpv6(route.network.Prefix.Ipv6.sin6_addr.u.Byte, route.network.PrefixLength) - << L" with gateway " << gateway - << L" on interface with LUID 0x" << std::hex << route.luid.Value; - } - else - { - ss << L"Failed to format route details"; - } - - return ss.str(); -} - -void RouteManager::notifyNewBestDefaultRoute(ADDRESS_FAMILY addressFamily, NET_LUID iface) -{ - LockType lock(m_defaultRouteCallbacksLock); - - for (const auto &callback : m_defaultRouteCallbacks) - { - callback(DefaultRouteChangedEvent::Updated, addressFamily, iface); - } -} - -void RouteManager::notifyNoDefaultRoutesExist(ADDRESS_FAMILY addressFamily) -{ - LockType lock(m_defaultRouteCallbacksLock); - - for (const auto &callback : m_defaultRouteCallbacks) - { - callback(DefaultRouteChangedEvent::Removed, addressFamily, NET_LUID{ 0 }); - } } } diff --git a/windows/winnet/src/winnet/routing/routemanager.h b/windows/winnet/src/winnet/routing/routemanager.h new file mode 100644 index 0000000000..981c8e6834 --- /dev/null +++ b/windows/winnet/src/winnet/routing/routemanager.h @@ -0,0 +1,112 @@ +#pragma once + +#include <string> +#include <memory> +#include <vector> +#include <list> +#include <optional> +#include <mutex> +#include <functional> +#include <windows.h> +#include <ws2def.h> +#include <ifdef.h> +#include <libcommon/string.h> +#include <libcommon/logging/ilogsink.h> +#include "defaultroutemonitor.h" + +namespace winnet::routing +{ + +class RouteManager +{ +public: + + RouteManager(std::shared_ptr<common::logging::ILogSink> logSink); + ~RouteManager(); + + RouteManager(const RouteManager &) = delete; + RouteManager(RouteManager &&) = default; + RouteManager &operator=(const RouteManager &) = delete; + RouteManager &operator=(RouteManager &&) = delete; + + void addRoutes(const std::vector<Route> &routes); + void addRoute(const Route &route); + + void deleteRoutes(const std::vector<Route> &routes); + void deleteRoute(const Route &route); + + using DefaultRouteChangedEventType = DefaultRouteMonitor::EventType; + + using DefaultRouteChangedCallback = std::function<void + ( + DefaultRouteChangedEventType eventType, + ADDRESS_FAMILY family, + + // For update events, data associated with the new best default route. + const std::optional<InterfaceAndGateway> &route + )>; + + using CallbackHandle = void*; + + CallbackHandle registerDefaultRouteChangedCallback(DefaultRouteChangedCallback callback); + void unregisterDefaultRouteChangedCallback(CallbackHandle handle); + +private: + + std::shared_ptr<common::logging::ILogSink> m_logSink; + + std::unique_ptr<DefaultRouteMonitor> m_routeMonitorV4; + std::unique_ptr<DefaultRouteMonitor> m_routeMonitorV6; + + // These are the exact details derived from the route specification (`Route`). + // They are used when registering and deleting a route in the system. + struct RegisteredRoute + { + Network network; + NET_LUID luid; + NodeAddress nextHop; + }; + + struct RouteRecord + { + Route route; + RegisteredRoute registeredRoute; + }; + + std::list<RouteRecord> m_routes; + std::mutex m_routesLock; + + std::list<DefaultRouteChangedCallback> m_defaultRouteCallbacks; + std::recursive_mutex m_defaultRouteCallbacksLock; + + // Find record based on destination and mask. + std::list<RouteRecord>::iterator findRouteRecord(const Network &network); + + // Note: Same as above! + std::list<RouteRecord>::iterator findRouteRecord(const Route &route); + + RegisteredRoute addIntoRoutingTable(const Route &route); + void restoreIntoRoutingTable(const RegisteredRoute &route); + void deleteFromRoutingTable(const RegisteredRoute &route); + + enum class EventType + { + ADD_ROUTE, + DELETE_ROUTE, + }; + + struct EventEntry + { + EventType type; + RouteRecord record; + }; + + void undoEvents(const std::vector<EventEntry> &eventLog); + + static std::wstring FormatRegisteredRoute(const RegisteredRoute &route); + + void defaultRouteChanged(ADDRESS_FAMILY family, DefaultRouteMonitor::EventType eventType, + const std::optional<InterfaceAndGateway> &route); +}; + +} diff --git a/windows/winnet/src/winnet/routing/types.cpp b/windows/winnet/src/winnet/routing/types.cpp new file mode 100644 index 0000000000..ac71c8108f --- /dev/null +++ b/windows/winnet/src/winnet/routing/types.cpp @@ -0,0 +1,84 @@ +#include "stdafx.h" +#include "types.h" +#include "helpers.h" +#include <libcommon/string.h> + +namespace winnet::routing +{ + +Node::Node(const std::optional<std::wstring> &deviceName, const std::optional<NodeAddress> &gateway) + : m_deviceName(deviceName) + , m_gateway(gateway) +{ + if (false == m_deviceName.has_value() && false == m_gateway.has_value()) + { + throw std::runtime_error("Invalid node definition"); + } + + if (m_deviceName.has_value()) + { + const auto trimmed = common::string::Trim<>(m_deviceName.value()); + + if (trimmed.empty()) + { + throw std::runtime_error("Invalid device name in node definition"); + } + + m_deviceName = std::move(trimmed); + } +} + +bool Node::operator==(const Node &rhs) const +{ + if (m_deviceName.has_value()) + { + if (false == rhs.m_deviceName.has_value() + || 0 != _wcsicmp(m_deviceName.value().c_str(), rhs.deviceName().value().c_str())) + { + return false; + } + } + + if (m_gateway.has_value()) + { + if (false == rhs.m_gateway.has_value() + || false == EqualAddress(m_gateway.value(), rhs.gateway().value())) + { + return false; + } + } + + return true; +} + +Route::Route(const Network &network, const std::optional<Node> &node) + : m_network(network) + , m_node(node) +{ +} + +bool Route::operator==(const Route &rhs) const +{ + if (m_node.has_value()) + { + return rhs.node().has_value() + && EqualAddress(m_network, rhs.network()) + && m_node.value() == rhs.node().value(); + } + + return false == rhs.node().has_value() + && EqualAddress(m_network, rhs.network()); +} + +bool InterfaceAndGateway::operator==(const InterfaceAndGateway &rhs) +{ + return iface.Value == rhs.iface.Value + && EqualAddress(gateway, rhs.gateway); +} + +bool InterfaceAndGateway::operator!=(const InterfaceAndGateway &rhs) +{ + return !(*this == rhs); +} + +} diff --git a/windows/winnet/src/winnet/routing/types.h b/windows/winnet/src/winnet/routing/types.h new file mode 100644 index 0000000000..1e132feb00 --- /dev/null +++ b/windows/winnet/src/winnet/routing/types.h @@ -0,0 +1,77 @@ +#pragma once + +#include <string> +#include <optional> +#include <winsock2.h> +#include <windows.h> +#include <ws2def.h> +#include <ws2ipdef.h> +#include <iphlpapi.h> +//#include <netioapi.h> +//#include <functional> + + +namespace winnet::routing +{ + +using Network = IP_ADDRESS_PREFIX; +using NodeAddress = SOCKADDR_INET; + +class Node +{ +public: + + Node(const std::optional<std::wstring> &deviceName, const std::optional<NodeAddress> &gateway); + + const std::optional<std::wstring> &deviceName() const + { + return m_deviceName; + } + + const std::optional<NodeAddress> &gateway() const + { + return m_gateway; + } + + bool operator==(const Node &rhs) const; + +private: + + std::optional<std::wstring> m_deviceName; + std::optional<NodeAddress> m_gateway; +}; + +class Route +{ +public: + + Route(const Network &network, const std::optional<Node> &node); + + const Network &network() const + { + return m_network; + } + + const std::optional<Node> &node() const + { + return m_node; + } + + bool operator==(const Route &rhs) const; + +private: + + Network m_network; + std::optional<Node> m_node; +}; + +struct InterfaceAndGateway +{ + NET_LUID iface; + NodeAddress gateway; + + bool operator==(const InterfaceAndGateway &rhs); + bool operator!=(const InterfaceAndGateway &rhs); +}; + +} diff --git a/windows/winnet/src/winnet/winnet.cpp b/windows/winnet/src/winnet/winnet.cpp index d36a60fd90..48d12b5ea3 100644 --- a/windows/winnet/src/winnet/winnet.cpp +++ b/windows/winnet/src/winnet/winnet.cpp @@ -3,22 +3,25 @@ #include "NetworkInterfaces.h"
#include "interfaceutils.h"
#include "offlinemonitor.h"
+#include "routing/routemanager.h"
#include "../../shared/logsinkadapter.h"
#include <libcommon/error.h>
#include <libcommon/network.h>
-#include "routemanager.h"
#include <cstdint>
#include <stdexcept>
#include <memory>
#include <optional>
+#include <mutex>
-using namespace routemanager;
+using namespace winnet::routing;
+using AutoLockType = std::scoped_lock<std::mutex>;
namespace
{
OfflineMonitor *g_OfflineMonitor = nullptr;
+std::mutex g_RouteManagerLock;
RouteManager *g_RouteManager = nullptr;
std::shared_ptr<shared::LogSinkAdapter> g_RouteManagerLogSink;
@@ -369,6 +372,8 @@ WinNet_ActivateRouteManager( void *logSinkContext
)
{
+ AutoLockType lock(g_RouteManagerLock);
+
try
{
if (nullptr != g_RouteManager)
@@ -401,6 +406,8 @@ WinNet_AddRoutes( uint32_t numRoutes
)
{
+ AutoLockType lock(g_RouteManagerLock);
+
if (nullptr == g_RouteManager)
{
return false;
@@ -430,6 +437,8 @@ WinNet_AddRoute( const WINNET_ROUTE *route
)
{
+ AutoLockType lock(g_RouteManagerLock);
+
if (nullptr == g_RouteManager)
{
return false;
@@ -464,6 +473,8 @@ WinNet_DeleteRoutes( uint32_t numRoutes
)
{
+ AutoLockType lock(g_RouteManagerLock);
+
if (nullptr == g_RouteManager)
{
return false;
@@ -493,6 +504,8 @@ WinNet_DeleteRoute( const WINNET_ROUTE *route
)
{
+ AutoLockType lock(g_RouteManagerLock);
+
if (nullptr == g_RouteManager)
{
return false;
@@ -518,6 +531,26 @@ WinNet_DeleteRoute( }
}
+//
+// TODO: Move to libcommon.
+//
+struct ValueMapper
+{
+ template<typename T, typename U, std::size_t S>
+ static U map(T t, const std::pair<T, U> (&dictionary)[S])
+ {
+ for (const auto &entry : dictionary)
+ {
+ if (t == entry.first)
+ {
+ return entry.second;
+ }
+ }
+
+ throw std::runtime_error("Could not map between values");
+ }
+};
+
extern "C"
WINNET_LINKAGE
bool
@@ -528,6 +561,8 @@ WinNet_RegisterDefaultRouteChangedCallback( void **registrationHandle
)
{
+ AutoLockType lock(g_RouteManagerLock);
+
if (nullptr == g_RouteManager)
{
return false;
@@ -535,50 +570,52 @@ WinNet_RegisterDefaultRouteChangedCallback( try
{
- auto forwarder = [callback, context]
- (RouteManager::DefaultRouteChangedEvent eventType, ADDRESS_FAMILY addressFamily, NET_LUID iface)
+ auto forwarder = [callback, context](RouteManager::DefaultRouteChangedEventType eventType,
+ ADDRESS_FAMILY family, const std::optional<InterfaceAndGateway> &route)
{
- WINNET_DEFAULT_ROUTE_CHANGED_EVENT_TYPE translatedType;
+ //
+ // Translate the event type.
+ //
- switch (eventType)
+ using from_t = RouteManager::DefaultRouteChangedEventType;
+ using to_t = WINNET_DEFAULT_ROUTE_CHANGED_EVENT_TYPE;
+
+ static const std::pair<from_t, to_t> eventTypeMap[] =
{
- case RouteManager::DefaultRouteChangedEvent::Updated:
- {
- translatedType = WINNET_DEFAULT_ROUTE_CHANGED_EVENT_TYPE_UPDATED;
- break;
- }
- case RouteManager::DefaultRouteChangedEvent::Removed:
- {
- translatedType = WINNET_DEFAULT_ROUTE_CHANGED_EVENT_TYPE_REMOVED;
- break;
- }
- default:
- {
- throw std::runtime_error("Unexpected default-route-changed event type");
- }
- }
+ { from_t::Updated, WINNET_DEFAULT_ROUTE_CHANGED_EVENT_TYPE_UPDATED },
+ { from_t::Removed, WINNET_DEFAULT_ROUTE_CHANGED_EVENT_TYPE_REMOVED }
+ };
- WINNET_IP_FAMILY translatedFamily;
+ const auto translatedEventType = ValueMapper::map<>(eventType, eventTypeMap);
- switch (addressFamily)
+ //
+ // Translate the family type.
+ //
+
+ static const std::pair<ADDRESS_FAMILY, WINNET_IP_FAMILY> familyMap[] =
{
- case AF_INET:
- {
- translatedFamily = WINNET_IP_FAMILY_V4;
- break;
- }
- case AF_INET6:
- {
- translatedFamily = WINNET_IP_FAMILY_V6;
- break;
- }
- default:
- {
- throw std::runtime_error("Unexpected default-route-changed address family");
- }
+ { static_cast<ADDRESS_FAMILY>(AF_INET), WINNET_IP_FAMILY_V4 },
+ { static_cast<ADDRESS_FAMILY>(AF_INET6), WINNET_IP_FAMILY_V6 }
+ };
+
+ const auto translatedFamily = ValueMapper::map<>(family, familyMap);
+
+ //
+ // Determine which LUID to forward.
+ //
+
+ uint64_t translatedLuid = 0;
+
+ if (RouteManager::DefaultRouteChangedEventType::Updated == eventType)
+ {
+ translatedLuid = route.value().iface.Value;
}
- callback(translatedType, translatedFamily, iface.Value, context);
+ //
+ // Forward to client.
+ //
+
+ callback(translatedEventType, translatedFamily, translatedLuid, context);
};
*registrationHandle = g_RouteManager->registerDefaultRouteChangedCallback(forwarder);
@@ -604,6 +641,8 @@ WinNet_UnregisterDefaultRouteChangedCallback( void *registrationHandle
)
{
+ AutoLockType lock(g_RouteManagerLock);
+
if (nullptr == g_RouteManager)
{
return;
@@ -630,6 +669,8 @@ WINNET_API WinNet_DeactivateRouteManager(
)
{
+ AutoLockType lock(g_RouteManagerLock);
+
try
{
delete g_RouteManager;
diff --git a/windows/winnet/src/winnet/winnet.vcxproj b/windows/winnet/src/winnet/winnet.vcxproj index 8be655436f..5e71a1f733 100644 --- a/windows/winnet/src/winnet/winnet.vcxproj +++ b/windows/winnet/src/winnet/winnet.vcxproj @@ -28,24 +28,28 @@ </ItemGroup> <ItemGroup> <ClCompile Include="networkadaptermonitor.cpp" /> - <ClCompile Include="adapters.cpp" /> <ClCompile Include="dllmain.cpp" /> <ClCompile Include="InterfacePair.cpp" /> <ClCompile Include="interfaceutils.cpp" /> <ClCompile Include="offlinemonitor.cpp" /> <ClCompile Include="NetworkInterfaces.cpp" /> - <ClCompile Include="routemanager.cpp" /> + <ClCompile Include="routing\defaultroutemonitor.cpp" /> + <ClCompile Include="routing\helpers.cpp" /> + <ClCompile Include="routing\routemanager.cpp" /> + <ClCompile Include="routing\types.cpp" /> <ClCompile Include="stdafx.cpp" /> <ClCompile Include="winnet.cpp" /> </ItemGroup> <ItemGroup> <ClInclude Include="networkadaptermonitor.h" /> - <ClInclude Include="adapters.h" /> <ClInclude Include="InterfacePair.h" /> <ClInclude Include="interfaceutils.h" /> <ClInclude Include="offlinemonitor.h" /> <ClInclude Include="NetworkInterfaces.h" /> - <ClInclude Include="routemanager.h" /> + <ClInclude Include="routing\defaultroutemonitor.h" /> + <ClInclude Include="routing\helpers.h" /> + <ClInclude Include="routing\routemanager.h" /> + <ClInclude Include="routing\types.h" /> <ClInclude Include="stdafx.h" /> <ClInclude Include="targetver.h" /> <ClInclude Include="winnet.h" /> @@ -212,7 +216,7 @@ <ConformanceMode>true</ConformanceMode> <RuntimeLibrary>MultiThreadedDebug</RuntimeLibrary> <LanguageStandard>stdcpplatest</LanguageStandard> - <AdditionalIncludeDirectories>$(ProjectDir)..\..\..\windows-libraries\src\%(AdditionalIncludeDirectories)</AdditionalIncludeDirectories> + <AdditionalIncludeDirectories>$(ProjectDir)..\..\..\windows-libraries\src\;$(ProjectDir);%(AdditionalIncludeDirectories)</AdditionalIncludeDirectories> </ClCompile> <Link> <SubSystem>Windows</SubSystem> @@ -282,7 +286,7 @@ <ConformanceMode>true</ConformanceMode> <RuntimeLibrary>MultiThreaded</RuntimeLibrary> <LanguageStandard>stdcpplatest</LanguageStandard> - <AdditionalIncludeDirectories>$(ProjectDir)..\..\..\windows-libraries\src\%(AdditionalIncludeDirectories)</AdditionalIncludeDirectories> + <AdditionalIncludeDirectories>$(ProjectDir)..\..\..\windows-libraries\src\;$(ProjectDir);%(AdditionalIncludeDirectories)</AdditionalIncludeDirectories> </ClCompile> <Link> <SubSystem>Windows</SubSystem> diff --git a/windows/winnet/src/winnet/winnet.vcxproj.filters b/windows/winnet/src/winnet/winnet.vcxproj.filters index e739032cdf..dfe6d29ec7 100644 --- a/windows/winnet/src/winnet/winnet.vcxproj.filters +++ b/windows/winnet/src/winnet/winnet.vcxproj.filters @@ -9,8 +9,18 @@ <ClCompile Include="interfaceutils.cpp" /> <ClCompile Include="networkadaptermonitor.cpp" /> <ClCompile Include="offlinemonitor.cpp" /> - <ClCompile Include="routemanager.cpp" /> - <ClCompile Include="adapters.cpp" /> + <ClCompile Include="routing\types.cpp"> + <Filter>routing</Filter> + </ClCompile> + <ClCompile Include="routing\helpers.cpp"> + <Filter>routing</Filter> + </ClCompile> + <ClCompile Include="routing\defaultroutemonitor.cpp"> + <Filter>routing</Filter> + </ClCompile> + <ClCompile Include="routing\routemanager.cpp"> + <Filter>routing</Filter> + </ClCompile> </ItemGroup> <ItemGroup> <ClInclude Include="stdafx.h" /> @@ -21,8 +31,18 @@ <ClInclude Include="interfaceutils.h" /> <ClInclude Include="networkadaptermonitor.h" /> <ClInclude Include="offlinemonitor.h" /> - <ClInclude Include="routemanager.h" /> - <ClInclude Include="adapters.h" /> + <ClInclude Include="routing\types.h"> + <Filter>routing</Filter> + </ClInclude> + <ClInclude Include="routing\helpers.h"> + <Filter>routing</Filter> + </ClInclude> + <ClInclude Include="routing\defaultroutemonitor.h"> + <Filter>routing</Filter> + </ClInclude> + <ClInclude Include="routing\routemanager.h"> + <Filter>routing</Filter> + </ClInclude> </ItemGroup> <ItemGroup> <None Include="winnet.def" /> @@ -30,4 +50,9 @@ <ItemGroup> <ResourceCompile Include="winnet.rc" /> </ItemGroup> + <ItemGroup> + <Filter Include="routing"> + <UniqueIdentifier>{8df22cc6-597f-4342-bc57-7647393084be}</UniqueIdentifier> + </Filter> + </ItemGroup> </Project>
\ No newline at end of file |
