diff options
| author | Odd Stranne <odd@mullvad.net> | 2019-06-05 15:45:37 +0200 |
|---|---|---|
| committer | Odd Stranne <odd@mullvad.net> | 2019-11-25 13:46:13 +0100 |
| commit | 9a52ae1c7b04550a2a999a233a1d930a6e27cd36 (patch) | |
| tree | c3a08fdb6aba04dabcf6a092888a77e0a3dee150 | |
| parent | 5da3d5fab49142401d81197490726c8f677bec20 (diff) | |
| download | mullvadvpn-9a52ae1c7b04550a2a999a233a1d930a6e27cd36.tar.xz mullvadvpn-9a52ae1c7b04550a2a999a233a1d930a6e27cd36.zip | |
Add Windows route manager
| -rw-r--r-- | windows/winnet/src/extras/loader/loader.vcxproj.filters | 4 | ||||
| -rw-r--r-- | windows/winnet/src/winnet/adapters.cpp | 81 | ||||
| -rw-r--r-- | windows/winnet/src/winnet/adapters.h | 40 | ||||
| -rw-r--r-- | windows/winnet/src/winnet/routemanager.cpp | 933 | ||||
| -rw-r--r-- | windows/winnet/src/winnet/routemanager.h | 207 | ||||
| -rw-r--r-- | windows/winnet/src/winnet/winnet.cpp | 299 | ||||
| -rw-r--r-- | windows/winnet/src/winnet/winnet.def | 3 | ||||
| -rw-r--r-- | windows/winnet/src/winnet/winnet.h | 88 | ||||
| -rw-r--r-- | windows/winnet/src/winnet/winnet.vcxproj | 4 | ||||
| -rw-r--r-- | windows/winnet/src/winnet/winnet.vcxproj.filters | 4 |
10 files changed, 1659 insertions, 4 deletions
diff --git a/windows/winnet/src/extras/loader/loader.vcxproj.filters b/windows/winnet/src/extras/loader/loader.vcxproj.filters index cd0f4643c7..408a9591b1 100644 --- a/windows/winnet/src/extras/loader/loader.vcxproj.filters +++ b/windows/winnet/src/extras/loader/loader.vcxproj.filters @@ -3,9 +3,13 @@ <ItemGroup> <ClCompile Include="loader.cpp" /> <ClCompile Include="stdafx.cpp" /> + <ClCompile Include="..\..\winnet\routemanager.cpp" /> + <ClCompile Include="..\..\winnet\adapters.cpp" /> </ItemGroup> <ItemGroup> <ClInclude Include="stdafx.h" /> <ClInclude Include="targetver.h" /> + <ClInclude Include="..\..\winnet\routemanager.h" /> + <ClInclude Include="..\..\winnet\adapters.h" /> </ItemGroup> </Project>
\ No newline at end of file diff --git a/windows/winnet/src/winnet/adapters.cpp b/windows/winnet/src/winnet/adapters.cpp new file mode 100644 index 0000000000..a497c3484d --- /dev/null +++ b/windows/winnet/src/winnet/adapters.cpp @@ -0,0 +1,81 @@ +#include "stdafx.h" +#include "adapters.h" +#include "libcommon/error.h" +#include <sstream> +#include <stdexcept> + +const IP_ADAPTER_ADDRESSES *Adapters::next() const +{ + if (nullptr == m_currentEntry) + { + return nullptr; + } + + auto entry = m_currentEntry; + m_currentEntry = m_currentEntry->Next; + + return entry; +} + +Adapters::Adapters(DWORD family, DWORD flags) +{ + std::vector<uint8_t> buffer; + + static const size_t MSDN_RECOMMENDED_STARTING_BUFFER_SIZE = 1024 * 15; + buffer.resize(MSDN_RECOMMENDED_STARTING_BUFFER_SIZE); + + ULONG bufferSize = static_cast<ULONG>(buffer.size()); + auto bufferPointer = reinterpret_cast<IP_ADAPTER_ADDRESSES *>(&buffer[0]); + + // + // Acquire interfaces. + // + + for (;;) + { + const auto status = GetAdaptersAddresses(family, flags, nullptr, bufferPointer, &bufferSize); + + if (ERROR_SUCCESS == status) + { + break; + } + + if (ERROR_NO_DATA == status) + { + m_buffer.clear(); + m_currentEntry = nullptr; + + return; + } + + THROW_UNLESS(ERROR_BUFFER_OVERFLOW, status, "Probe required buffer size for GetAdaptersAddresses"); + + buffer.resize(bufferSize); + bufferPointer = reinterpret_cast<IP_ADAPTER_ADDRESSES *>(&buffer[0]); + } + + // + // Verify structure compatibility. + // The structure has been extended many times. + // + + const auto systemSize = bufferPointer->Length; + const auto codeSize = sizeof(IP_ADAPTER_ADDRESSES); + + if (systemSize < codeSize) + { + std::stringstream ss; + + ss << "Expecting IP_ADAPTER_ADDRESSES to have size " << codeSize << " bytes. " + << "Found structure with size " << systemSize << " bytes."; + + throw std::runtime_error(ss.str()); + } + + // + // Initialize members. + // + + m_buffer = std::move(buffer); + reset(); +} diff --git a/windows/winnet/src/winnet/adapters.h b/windows/winnet/src/winnet/adapters.h new file mode 100644 index 0000000000..e2f2e82e53 --- /dev/null +++ b/windows/winnet/src/winnet/adapters.h @@ -0,0 +1,40 @@ +#pragma once + +#include <vector> +#include <winsock2.h> +#include <windows.h> +#include <iphlpapi.h> + +// +// This is a thin wrapper on top of GetAdaptersAddresses() +// in order to simplify memory management. +// + +class Adapters +{ + std::vector<uint8_t> m_buffer; + mutable const IP_ADAPTER_ADDRESSES *m_currentEntry; + +public: + + Adapters(const Adapters &) = delete; + Adapters &operator=(const Adapters &) = delete; + + Adapters(Adapters &&rhs) + : m_buffer(std::move(rhs.m_buffer)) + , m_currentEntry(rhs.m_currentEntry) + { + } + + Adapters(DWORD family, DWORD flags); + + const IP_ADAPTER_ADDRESSES *next() const; + + void reset() const + { + if (false == m_buffer.empty()) + { + m_currentEntry = reinterpret_cast<const IP_ADAPTER_ADDRESSES *>(&m_buffer[0]); + } + } +}; diff --git a/windows/winnet/src/winnet/routemanager.cpp b/windows/winnet/src/winnet/routemanager.cpp new file mode 100644 index 0000000000..35aef395ba --- /dev/null +++ b/windows/winnet/src/winnet/routemanager.cpp @@ -0,0 +1,933 @@ +#include "stdafx.h" +#include "routemanager.h" +#include "adapters.h" +#include <libcommon/error.h> +#include <libcommon/memory.h> +#include <libcommon/string.h> +#include <vector> +#include <algorithm> +#include <numeric> +#include <sstream> + +using LockType = std::scoped_lock<std::recursive_mutex>; + +namespace +{ + +bool GetAdapterInterface(NET_LUID adapter, ADDRESS_FAMILY addressFamily, MIB_IPINTERFACE_ROW *iface) +{ + memset(iface, 0, sizeof(MIB_IPINTERFACE_ROW)); + + iface->Family = addressFamily; + iface->InterfaceLuid = adapter; + + return NO_ERROR == GetIpInterfaceEntry(iface); +} + +struct AnnotatedRoute +{ + const MIB_IPFORWARD_ROW2 *route; + bool active; + uint32_t effectiveMetric; +}; + +template<typename T> +bool bool_cast(const T &value) +{ + return 0 != value; +} + +std::vector<AnnotatedRoute> AnnotateRoutes(const std::vector<const MIB_IPFORWARD_ROW2 *> &routes) +{ + std::vector<AnnotatedRoute> annotated; + annotated.reserve(routes.size()); + + for (auto route : routes) + { + MIB_IPINTERFACE_ROW iface; + + if (false == GetAdapterInterface(route->InterfaceLuid, route->DestinationPrefix.Prefix.si_family, &iface)) + { + continue; + } + + annotated.emplace_back + ( + AnnotatedRoute{ route, bool_cast(iface.Connected), route->Metric + iface.Metric } + ); + } + + return annotated; +} + +bool RouteHasGateway(const MIB_IPFORWARD_ROW2 &route) +{ + switch (route.NextHop.si_family) + { + case AF_INET: + { + return 0 != route.NextHop.Ipv4.sin_addr.s_addr; + } + case AF_INET6: + { + const uint8_t *begin = &route.NextHop.Ipv6.sin6_addr.u.Byte[0]; + const uint8_t *end = begin + 16; + + return 0 != std::accumulate(begin, end, 0); + } + default: + { + return false; + } + }; +} + +struct InterfaceAndGateway +{ + NET_LUID iface; + routemanager::NodeAddress gateway; +}; + +InterfaceAndGateway ResolveNodeFromDefaultRoute(ADDRESS_FAMILY family) +{ + PMIB_IPFORWARD_TABLE2 table; + + auto status = GetIpForwardTable2(family, &table); + + THROW_UNLESS(NO_ERROR, status, "Acquire route table"); + + common::memory::ScopeDestructor sd; + + sd += [table] + { + FreeMibTable(table); + }; + + std::vector<const MIB_IPFORWARD_ROW2 *> candidates; + candidates.reserve(table->NumEntries); + + // + // Enumerate routes looking for: route 0/0 && gateway specified. + // + + for (ULONG i = 0; i < table->NumEntries; ++i) + { + const MIB_IPFORWARD_ROW2 &candidate = table->Table[i]; + + if (0 == candidate.DestinationPrefix.PrefixLength + && RouteHasGateway(candidate)) + { + candidates.emplace_back(&candidate); + } + } + + auto annotated = AnnotateRoutes(candidates); + + if (annotated.empty()) + { + throw std::runtime_error("Unable to determine details of default route"); + } + + // + // Sort on (active, effectiveMetric) ascending by metric. + // + + std::sort(annotated.begin(), annotated.end(), [](const AnnotatedRoute &lhs, const AnnotatedRoute &rhs) + { + if (lhs.active == rhs.active) + { + return lhs.effectiveMetric < rhs.effectiveMetric; + } + + return lhs.active && false == rhs.active; + }); + + // + // Ensure the top rated route is active. + // + + if (false == annotated[0].active) + { + throw std::runtime_error("Unable to identify active default route"); + } + + return InterfaceAndGateway { annotated[0].route->InterfaceLuid, annotated[0].route->NextHop }; +} + +bool AdapterInterfaceEnabled(const IP_ADAPTER_ADDRESSES *adapter, ADDRESS_FAMILY family) +{ + switch (family) + { + case AF_INET: + { + return 0 != adapter->Ipv4Enabled; + } + case AF_INET6: + { + return 0 != adapter->Ipv6Enabled; + } + default: + { + throw std::runtime_error("Missing case handler in switch clause"); + } + } +} + +std::vector<const SOCKET_ADDRESS *> IsolateGatewayAddresses( + PIP_ADAPTER_GATEWAY_ADDRESS_LH head, ADDRESS_FAMILY family) +{ + std::vector<const SOCKET_ADDRESS *> matches; + + for (auto gateway = head; nullptr != gateway; gateway = gateway->Next) + { + if (family == gateway->Address.lpSockaddr->sa_family) + { + matches.emplace_back(&gateway->Address); + } + } + + return matches; +} + +bool EqualAddress(const SOCKADDR_INET *lhs, const SOCKET_ADDRESS *rhs) +{ + if (lhs->si_family != rhs->lpSockaddr->sa_family) + { + return false; + } + + switch (lhs->si_family) + { + case AF_INET: + { + auto typedRhs = reinterpret_cast<const SOCKADDR_IN *>(rhs->lpSockaddr); + return lhs->Ipv4.sin_addr.s_addr == typedRhs->sin_addr.s_addr; + } + case AF_INET6: + { + auto typedRhs = reinterpret_cast<const SOCKADDR_IN6 *>(rhs->lpSockaddr); + return 0 == memcmp(lhs->Ipv6.sin6_addr.u.Byte, typedRhs->sin6_addr.u.Byte, 16); + } + default: + { + throw std::runtime_error("Missing case handler in switch clause"); + } + } +} + +bool AddressPresent(const std::vector<const SOCKET_ADDRESS *> &hay, const SOCKADDR_INET *needle) +{ + for (const auto candidate : hay) + { + if (EqualAddress(needle, candidate)) + { + return true; + } + } + + return false; +} + +NET_LUID InterfaceLuidFromGateway(const routemanager::NodeAddress &gateway) +{ + const DWORD adapterFlags = GAA_FLAG_SKIP_ANYCAST | GAA_FLAG_SKIP_MULTICAST | GAA_FLAG_SKIP_DNS_SERVER + | GAA_FLAG_SKIP_FRIENDLY_NAME | GAA_FLAG_INCLUDE_GATEWAYS; + + Adapters adapters(gateway.si_family, adapterFlags); + + // + // Process adapters to find matching ones. + // + + std::vector<const IP_ADAPTER_ADDRESSES *> matches; + + for (auto adapter = adapters.next(); nullptr != adapter; adapter = adapters.next()) + { + if (false == AdapterInterfaceEnabled(adapter, gateway.si_family)) + { + continue; + } + + auto gateways = IsolateGatewayAddresses(adapter->FirstGatewayAddress, gateway.si_family); + + if (AddressPresent(gateways, &gateway)) + { + matches.emplace_back(adapter); + } + } + + if (matches.empty()) + { + throw std::runtime_error("Unable to find network adapter with specified gateway"); + } + + // + // Sort matching interfaces ascending by metric. + // + + const bool targetV4 = (AF_INET == gateway.si_family); + + std::sort(matches.begin(), matches.end(), [&targetV4](const IP_ADAPTER_ADDRESSES *lhs, const IP_ADAPTER_ADDRESSES *rhs) + { + if (targetV4) + { + return lhs->Ipv4Metric < rhs->Ipv4Metric; + } + + return lhs->Ipv6Metric < rhs->Ipv6Metric; + }); + + // + // Select the interface with the best (lowest) metric. + // + + return matches[0]->Luid; +} + +InterfaceAndGateway ResolveNode(ADDRESS_FAMILY family, const std::optional<routemanager::Node> &optionalNode) +{ + // + // There are four cases: + // + // Unspecified node (use interface and gateway of default route). + // Node is specified by name. + // Node is specified by name and gateway. + // Node is specified by gateway. + // + + if (false == optionalNode.has_value()) + { + return ResolveNodeFromDefaultRoute(family); + } + + const auto &node = optionalNode.value(); + + if (node.deviceName().has_value()) + { + const auto &deviceName = node.deviceName().value(); + NET_LUID luid; + + // + // Try to parse a string encoded LUID. + // The `#` is a valid character in adapter names so we use `?` instead. + // The LUID is thus prefixed with `?` and hex encoded and left-padded with zeroes + // E.g. `?deadbeefcafebabe` or `?000dbeefcafebabe` + // + + static const size_t StringEncodedLuidLength = 17; + + if (StringEncodedLuidLength == deviceName.size() + && L'?' == deviceName[0]) + { + try + { + std::wstringstream ss; + + ss << std::hex << &deviceName[1]; + ss >> luid.Value; + } + catch (...) + { + const auto ansiName = common::string::ToAnsi(deviceName); + const auto err = std::string("Failed to parse string encoded LUID: ").append(ansiName); + + std::throw_with_nested(std::runtime_error(err)); + } + } + else if (0 != ConvertInterfaceAliasToLuid(deviceName.c_str(), &luid)) + { + const auto ansiName = common::string::ToAnsi(deviceName); + const auto err = std::string("Unable to derive interface LUID from interface alias: ").append(ansiName); + + throw std::runtime_error(err); + } + + auto onLinkProvider = [&family]() + { + routemanager::NodeAddress onLink = { 0 }; + onLink.si_family = family; + + return onLink; + }; + + return InterfaceAndGateway{ luid, node.gateway().value_or(onLinkProvider()) }; + } + + // + // The node is specified only by gateway. + // + + return InterfaceAndGateway{ InterfaceLuidFromGateway(node.gateway().value()), node.gateway().value() }; +} + +routemanager::NodeAddress ConvertSocketAddress(const SOCKET_ADDRESS *sa) +{ + routemanager::NodeAddress out = { 0 }; + + switch (sa->lpSockaddr->sa_family) + { + case AF_INET: + { + out.si_family = AF_INET; + out.Ipv4 = *reinterpret_cast<SOCKADDR_IN *>(sa->lpSockaddr); + + break; + } + case AF_INET6: + { + out.si_family = AF_INET6; + out.Ipv6 = *reinterpret_cast<SOCKADDR_IN6 *>(sa->lpSockaddr); + + break; + } + default: + { + throw std::runtime_error("Missing case handler in switch clause"); + } + }; + + return out; +} + +std::wstring FormatNetwork(const routemanager::Network &network) +{ + switch (network.Prefix.si_family) + { + case AF_INET: + { + return common::string::FormatIpv4(network.Prefix.Ipv4.sin_addr.s_addr, network.PrefixLength); + } + case AF_INET6: + { + return common::string::FormatIpv6(network.Prefix.Ipv6.sin6_addr.u.Byte, network.PrefixLength); + } + default: + { + return L"Failed to format network details"; + } + } +} + +} // anon namespace + +namespace routemanager { + +bool EqualAddress(const Network &lhs, const Network &rhs) +{ + if (lhs.PrefixLength != rhs.PrefixLength) + { + return false; + } + + return EqualAddress(lhs.Prefix, rhs.Prefix); +} + +bool EqualAddress(const NodeAddress &lhs, const NodeAddress &rhs) +{ + if (lhs.si_family != rhs.si_family) + { + return false; + } + + switch (lhs.si_family) + { + case AF_INET: + { + return lhs.Ipv4.sin_addr.s_addr == rhs.Ipv4.sin_addr.s_addr; + } + case AF_INET6: + { + return 0 == memcmp(&lhs.Ipv6.sin6_addr, &rhs.Ipv6.sin6_addr, sizeof(IN6_ADDR)); + } + default: + { + throw std::runtime_error("Invalid address family for network address"); + } + } +} + +RouteManager::RouteManager(std::shared_ptr<common::logging::ILogSink> logSink) + : m_logSink(logSink) +{ + const auto status = NotifyRouteChange2(AF_UNSPEC, RouteChangeCallback, this, FALSE, &m_notificationHandle); + + THROW_UNLESS(NO_ERROR, status, "Register for route table change notifications"); +} + +RouteManager::~RouteManager() +{ + CancelMibChangeNotify2(m_notificationHandle); + + for (const auto &record : m_routes) + { + try + { + deleteFromRoutingTable(record.registeredRoute); + } + catch (const std::exception &ex) + { + std::wstringstream ss; + + ss << L"Failed to delete route as part of cleaning up, Route: " + << FormatRegisteredRoute(record.registeredRoute); + + m_logSink->error(common::string::ToAnsi(ss.str()).c_str()); + m_logSink->error(ex.what()); + } + } +} + +void RouteManager::addRoutes(const std::vector<Route> &routes) +{ + LockType lock(m_routesLock); + + std::vector<EventEntry> eventLog; + + for (const auto &route : routes) + { + try + { + auto record = findRouteRecord(route); + + if (record != m_routes.end()) + { + deleteFromRoutingTable(record->registeredRoute); + eventLog.emplace_back(EventEntry{ EventType::DELETE_ROUTE, *record }); + m_routes.erase(record); + } + + const RouteRecord newRecord { route, addIntoRoutingTable(route) }; + + eventLog.emplace_back(EventEntry{ EventType::ADD_ROUTE, newRecord }); + m_routes.emplace_back(std::move(newRecord)); + } + catch (...) + { + undoEvents(eventLog); + + std::throw_with_nested(std::runtime_error("Failed during batch insertion of routes")); + } + } +} + +void RouteManager::addRoute(const Route &route) +{ + LockType lock(m_routesLock); + + std::optional<RouteRecord> deletedRecord; + + auto record = findRouteRecord(route); + + if (record != m_routes.end()) + { + try + { + deleteFromRoutingTable(record->registeredRoute); + } + catch (...) + { + std::throw_with_nested(std::runtime_error("Failed to evict old route when adding new route")); + } + + deletedRecord = *record; + m_routes.erase(record); + } + + try + { + m_routes.emplace_back + ( + RouteRecord{ route, addIntoRoutingTable(route) } + ); + } + catch (...) + { + // + // Restore deleted record. + // + + if (deletedRecord.has_value()) + { + auto &r = deletedRecord.value(); + + try + { + restoreIntoRoutingTable(r.registeredRoute); + m_routes.emplace_back(r); + } + catch (const std::exception &ex) + { + const auto err = std::string("Failed to restore evicted route during rollback: ").append(ex.what()); + m_logSink->error(err.c_str()); + } + } + + // + // Just rethrow because the error is from addIntoRoutingTable(). + // + + throw; + } +} + +void RouteManager::deleteRoutes(const std::vector<Route> &routes) +{ + LockType lock(m_routesLock); + + std::vector<EventEntry> eventLog; + + for (const auto &route : routes) + { + try + { + auto record = findRouteRecord(route); + + if (m_routes.end() == record) + { + const auto err = std::wstring(L"Request to delete previously unregistered route: ") + .append(FormatNetwork(route.network())); + + m_logSink->warning(common::string::ToAnsi(err).c_str()); + + continue; + } + + deleteFromRoutingTable(record->registeredRoute); + eventLog.emplace_back(EventEntry{ EventType::DELETE_ROUTE, *record }); + m_routes.erase(record); + } + catch (...) + { + undoEvents(eventLog); + + std::throw_with_nested(std::runtime_error("Failed during batch removal of routes")); + } + } +} + +void RouteManager::deleteRoute(const Route &route) +{ + LockType lock(m_routesLock); + + auto record = findRouteRecord(route); + + if (m_routes.end() == record) + { + const auto err = std::wstring(L"Request to delete previously unregistered route: ") + .append(FormatNetwork(route.network())); + + m_logSink->warning(common::string::ToAnsi(err).c_str()); + + return; + } + + deleteFromRoutingTable(record->registeredRoute); + m_routes.erase(record); +} + +std::list<RouteManager::RouteRecord>::iterator RouteManager::findRouteRecord(const Network &network) +{ + return std::find_if(m_routes.begin(), m_routes.end(), [&network](const auto &candidate) + { + return EqualAddress(network, candidate.route.network()); + }); +} + +std::list<RouteManager::RouteRecord>::iterator RouteManager::findRouteRecord(const Route &route) +{ + return findRouteRecord(route.network()); +} + +RouteManager::RegisteredRoute RouteManager::addIntoRoutingTable(const Route &route) +{ + const auto node = ResolveNode(route.network().Prefix.si_family, route.node()); + + MIB_IPFORWARD_ROW2 spec; + + InitializeIpForwardEntry(&spec); + + spec.InterfaceLuid = node.iface; + spec.DestinationPrefix = route.network(); + spec.NextHop = node.gateway; + spec.Metric = 0; + spec.Protocol = MIB_IPPROTO_NETMGMT; + spec.Origin = NlroManual; + + // + // Do not treat ERROR_OBJECT_ALREADY_EXISTS as being successful. + // Because it may not take route metric into consideration. + // + + THROW_UNLESS(NO_ERROR, CreateIpForwardEntry2(&spec), "Register route in routing table"); + + return RegisteredRoute { route.network(), node.iface, node.gateway }; +} + +void RouteManager::restoreIntoRoutingTable(const RegisteredRoute &route) +{ + MIB_IPFORWARD_ROW2 spec; + + InitializeIpForwardEntry(&spec); + + spec.InterfaceLuid = route.luid; + spec.DestinationPrefix = route.network; + spec.NextHop = route.nextHop; + spec.Metric = 0; + spec.Protocol = MIB_IPPROTO_NETMGMT; + spec.Origin = NlroManual; + + THROW_UNLESS(NO_ERROR, CreateIpForwardEntry2(&spec), "Register route in routing table"); +} + +void RouteManager::deleteFromRoutingTable(const RegisteredRoute &route) +{ + MIB_IPFORWARD_ROW2 r = { 0}; + + r.InterfaceLuid = route.luid; + r.DestinationPrefix = route.network; + r.NextHop = route.nextHop; + + auto status = DeleteIpForwardEntry2(&r); + + if (ERROR_NOT_FOUND == status) + { + status = NO_ERROR; + + const auto err = std::wstring(L"Attempting to delete route which was not present in routing table, " \ + "ignoring and proceeding. Route: ").append(FormatRegisteredRoute(route)); + + m_logSink->warning(common::string::ToAnsi(err).c_str()); + } + + THROW_UNLESS(NO_ERROR, status, "Delete route in routing table"); +} + +void RouteManager::undoEvents(const std::vector<EventEntry> &eventLog) +{ + // + // Rewind state by processing events in the reverse order. + // + + for (auto it = eventLog.rbegin(); it != eventLog.rend(); ++it) + { + try + { + switch (it->type) + { + case EventType::ADD_ROUTE: + { + auto record = findRouteRecord(it->record.route); + + if (m_routes.end() == record) + { + throw std::runtime_error("Internal state inconsistency in route manager"); + } + + deleteFromRoutingTable(record->registeredRoute); + m_routes.erase(record); + + break; + } + case EventType::DELETE_ROUTE: + { + restoreIntoRoutingTable(it->record.registeredRoute); + m_routes.emplace_back(it->record); + + break; + } + default: + { + throw std::logic_error("Missing case handler in switch clause"); + } + } + + } + catch (const std::exception &ex) + { + const auto err = std::string("Attempting to rollback state: ").append(ex.what()); + m_logSink->error(err.c_str()); + } + } +} + +//static +void NETIOAPI_API_ +RouteManager::RouteChangeCallback(void *context, MIB_IPFORWARD_ROW2 *row, MIB_NOTIFICATION_TYPE notificationType) +{ + auto instance = reinterpret_cast<RouteManager *>(context); + + try + { + instance->routeChangeCallback(row, notificationType); + } + catch (const std::exception &ex) + { + auto msg = std::string("Failure while processing route change notification: ").append(ex.what()); + instance->m_logSink->error(msg.c_str()); + } + catch (...) + { + instance->m_logSink->error("Unspecified failure while processing route change notification"); + } +} + +void RouteManager::routeChangeCallback(MIB_IPFORWARD_ROW2 *row, MIB_NOTIFICATION_TYPE) +{ + // + // We're only interested in changes that add/remove/update a default route. + // + + if (0 != row->DestinationPrefix.PrefixLength + || false == RouteHasGateway(*row)) + { + return; + } + + // + // Are we managing any static routes that rely on the default route? + // + + const auto family = row->DestinationPrefix.Prefix.si_family; + + LockType lock(m_routesLock); + + using RecordIterator = std::list<RouteRecord>::iterator; + + std::list<RecordIterator> affectedRoutes; + + for (RecordIterator it = m_routes.begin(); it != m_routes.end(); ++it) + { + if (false == it->route.node().has_value() + && family == it->route.network().Prefix.si_family) + { + affectedRoutes.emplace_back(it); + } + } + + if (affectedRoutes.empty()) + { + return; + } + + // + // Assume all of our affected routes are using the same gateway on the same adapter. + // + // Has the current best default route changed? + // + + const auto oldBestRoute = InterfaceAndGateway + { + affectedRoutes.front()->registeredRoute.luid, + affectedRoutes.front()->registeredRoute.nextHop + }; + + InterfaceAndGateway newBestRoute = { 0 }; + + try + { + newBestRoute = ResolveNodeFromDefaultRoute(family); + } + catch (const std::exception &ex) + { + const auto msg = std::string("Failed to resolve default network route. " \ + "Assuming there isn't one: ").append(ex.what()); + + m_logSink->info(msg.c_str()); + + return; + } + + if (oldBestRoute.iface.Value == newBestRoute.iface.Value + && EqualAddress(oldBestRoute.gateway, newBestRoute.gateway)) + { + return; + } + + // + // Best default route has changed. Update affected routes. + // + + m_logSink->info("Default route has changed. Refreshing dependent routes"); + + for (auto route : affectedRoutes) + { + try + { + deleteFromRoutingTable(route->registeredRoute); + } + catch (const std::exception &ex) + { + const auto msg = std::string("Failed to delete route when refreshing " \ + "existing routes: ").append(ex.what()); + + m_logSink->error(msg.c_str()); + + continue; + } + + route->registeredRoute.luid = newBestRoute.iface; + route->registeredRoute.nextHop = newBestRoute.gateway; + + try + { + restoreIntoRoutingTable(route->registeredRoute); + } + catch (const std::exception &ex) + { + const auto msg = std::string("Failed to add route when refreshing " \ + "existing routes: ").append(ex.what()); + + m_logSink->error(msg.c_str()); + + continue; + } + } + + // + // TODO-MAYBE: Notify clients about new default route. + // +} + +// static +std::wstring RouteManager::FormatRegisteredRoute(const RegisteredRoute &route) +{ + std::wstringstream ss; + + if (AF_INET == route.network.Prefix.si_family) + { + std::wstring gateway(L"\"On-link\""); + + if (0 != route.nextHop.Ipv4.sin_addr.s_addr) + { + gateway = common::string::FormatIpv4(route.nextHop.Ipv4.sin_addr.s_addr); + } + + ss << common::string::FormatIpv4(route.network.Prefix.Ipv4.sin_addr.s_addr, route.network.PrefixLength) + << L" with gateway " << gateway + << L" on interface with LUID 0x" << std::hex << route.luid.Value; + } + else if (AF_INET6 == route.network.Prefix.si_family) + { + std::wstring gateway(L"\"On-link\""); + + const uint8_t *begin = &route.nextHop.Ipv6.sin6_addr.u.Byte[0]; + const uint8_t *end = begin + 16; + + if (0 != std::accumulate(begin, end, 0)) + { + gateway = common::string::FormatIpv6(route.nextHop.Ipv6.sin6_addr.u.Byte); + } + + ss << common::string::FormatIpv6(route.network.Prefix.Ipv6.sin6_addr.u.Byte, route.network.PrefixLength) + << L" with gateway " << gateway + << L" on interface with LUID 0x" << std::hex << route.luid.Value; + } + else + { + ss << L"Failed to format route details"; + } + + return ss.str(); +} + +} diff --git a/windows/winnet/src/winnet/routemanager.h b/windows/winnet/src/winnet/routemanager.h new file mode 100644 index 0000000000..6161dcd62a --- /dev/null +++ b/windows/winnet/src/winnet/routemanager.h @@ -0,0 +1,207 @@ +#pragma once + +#include <string> +#include <memory> +#include <vector> +#include <list> +#include <stdexcept> +#include <optional> +#include <mutex> +#include <winsock2.h> +#include <windows.h> +#include <ws2def.h> +#include <ws2ipdef.h> +#include <iphlpapi.h> +#include <netioapi.h> + +// Custom header files below here. +// So broken networking headers don't get confused and break the compilation. +// === +#include <libcommon/string.h> +#include <libcommon/logging/ilogsink.h> + +namespace routemanager { + +using Network = IP_ADDRESS_PREFIX; +using NodeAddress = SOCKADDR_INET; + +bool EqualAddress(const Network &lhs, const Network &rhs); +bool EqualAddress(const NodeAddress &lhs, const NodeAddress &rhs); + +class Node +{ +public: + + Node(const std::optional<std::wstring> &deviceName, const std::optional<NodeAddress> &gateway) + : m_deviceName(deviceName) + , m_gateway(gateway) + { + if (false == m_deviceName.has_value() && false == m_gateway.has_value()) + { + throw std::runtime_error("Invalid node definition"); + } + + if (m_deviceName.has_value()) + { + const auto trimmed = common::string::Trim<>(m_deviceName.value()); + + if (trimmed.empty()) + { + throw std::runtime_error("Invalid device name in node definition"); + } + + m_deviceName = std::move(trimmed); + } + } + + const std::optional<std::wstring> &deviceName() const + { + return m_deviceName; + } + + const std::optional<NodeAddress> &gateway() const + { + return m_gateway; + } + + bool operator==(const Node &rhs) const + { + if (m_deviceName.has_value()) + { + if (false == rhs.m_deviceName.has_value() + || 0 != _wcsicmp(m_deviceName.value().c_str(), rhs.deviceName().value().c_str())) + { + return false; + } + } + + if (m_gateway.has_value()) + { + if (false == rhs.m_gateway.has_value() + || false == EqualAddress(m_gateway.value(), rhs.gateway().value())) + { + return false; + } + } + + return true; + } + +private: + + std::optional<std::wstring> m_deviceName; + std::optional<NodeAddress> m_gateway; +}; + +class Route +{ +public: + + Route(const Network &network, const std::optional<Node> &node) + : m_network(network) + , m_node(node) + { + } + + const Network &network() const + { + return m_network; + } + + const std::optional<Node> &node() const + { + return m_node; + } + + bool operator==(const Route &rhs) const + { + if (m_node.has_value()) + { + return rhs.node().has_value() + && EqualAddress(m_network, rhs.network()) + && m_node.value() == rhs.node().value(); + } + + return false == rhs.node().has_value() + && EqualAddress(m_network, rhs.network()); + } + +private: + + Network m_network; + std::optional<Node> m_node; +}; + +class RouteManager +{ +public: + + RouteManager(std::shared_ptr<common::logging::ILogSink> logSink); + ~RouteManager(); + + RouteManager(const RouteManager &) = delete; + RouteManager &operator=(const RouteManager &) = delete; + RouteManager(RouteManager &&) = default; + + void addRoutes(const std::vector<Route> &routes); + void addRoute(const Route &route); + + void deleteRoutes(const std::vector<Route> &routes); + void deleteRoute(const Route &route); + +private: + + std::shared_ptr<common::logging::ILogSink> m_logSink; + + // These are the exact details derived from the route specification (`Route`). + // They are used when registering and deleting a route in the system. + struct RegisteredRoute + { + Network network; + NET_LUID luid; + NodeAddress nextHop; + }; + + struct RouteRecord + { + Route route; + RegisteredRoute registeredRoute; + }; + + std::list<RouteRecord> m_routes; + + std::recursive_mutex m_routesLock; + + // Find record based on destination and mask. + std::list<RouteRecord>::iterator findRouteRecord(const Network &network); + + // Note: Same as above! + std::list<RouteRecord>::iterator findRouteRecord(const Route &route); + + RegisteredRoute addIntoRoutingTable(const Route &route); + void restoreIntoRoutingTable(const RegisteredRoute &route); + void deleteFromRoutingTable(const RegisteredRoute &route); + + enum class EventType + { + ADD_ROUTE, + DELETE_ROUTE, + }; + + struct EventEntry + { + EventType type; + RouteRecord record; + }; + + void undoEvents(const std::vector<EventEntry> &eventLog); + + HANDLE m_notificationHandle; + + static void NETIOAPI_API_ RouteChangeCallback(void *context, MIB_IPFORWARD_ROW2 *row, MIB_NOTIFICATION_TYPE notificationType); + void routeChangeCallback(MIB_IPFORWARD_ROW2 *row, MIB_NOTIFICATION_TYPE notificationType); + + static std::wstring FormatRegisteredRoute(const RegisteredRoute &route); +}; + +} diff --git a/windows/winnet/src/winnet/winnet.cpp b/windows/winnet/src/winnet/winnet.cpp index 4b006964a6..569df8f591 100644 --- a/windows/winnet/src/winnet/winnet.cpp +++ b/windows/winnet/src/winnet/winnet.cpp @@ -5,15 +5,130 @@ #include "offlinemonitor.h"
#include "../../shared/logsinkadapter.h"
#include <libcommon/error.h>
+#include <libcommon/network.h>
+#include "routemanager.h"
#include <cstdint>
#include <stdexcept>
#include <memory>
+#include <optional>
+
+using namespace routemanager;
namespace
{
OfflineMonitor *g_OfflineMonitor = nullptr;
+RouteManager *g_RouteManager = nullptr;
+std::shared_ptr<shared::LogSinkAdapter> g_RouteManagerLogSink;
+
+Network ConvertNetwork(const WINNET_IPNETWORK &in)
+{
+ //
+ // Convert WINNET_IPNETWORK into Network aka IP_ADDRESS_PREFIX
+ //
+
+ Network out{ 0 };
+
+ out.PrefixLength = in.prefix;
+
+ switch (in.type)
+ {
+ case WINNET_IP_TYPE_IPV4:
+ {
+ out.Prefix.si_family = AF_INET;
+ out.Prefix.Ipv4.sin_family = AF_INET;
+ out.Prefix.Ipv4.sin_addr.s_addr = *reinterpret_cast<const uint32_t *>(in.bytes);
+
+ break;
+ }
+ case WINNET_IP_TYPE_IPV6:
+ {
+ out.Prefix.si_family = AF_INET6;
+ out.Prefix.Ipv6.sin6_family = AF_INET6;
+ memcpy(out.Prefix.Ipv6.sin6_addr.u.Byte, in.bytes, 16);
+
+ break;
+ }
+ default:
+ {
+ throw std::runtime_error("Missing case handler in switch clause");
+ }
+ }
+
+ return out;
+}
+
+std::optional<Node> ConvertNode(const WINNET_NODE *in)
+{
+ if (nullptr == in)
+ {
+ return {};
+ }
+
+ if (nullptr == in->deviceName && nullptr == in->gateway)
+ {
+ throw std::runtime_error("Invalid 'WINNET_NODE' definition");
+ }
+
+ std::optional<std::wstring> deviceName;
+ std::optional<NodeAddress> gateway;
+
+ if (nullptr != in->deviceName)
+ {
+ deviceName = in->deviceName;
+ }
+
+ if (nullptr != in->gateway)
+ {
+ NodeAddress gw { 0 };
+
+ switch (in->gateway->type)
+ {
+ case WINNET_IP_TYPE_IPV4:
+ {
+ gw.si_family = AF_INET;
+ gw.Ipv4.sin_addr.s_addr = *reinterpret_cast<const uint32_t *>(in->gateway->bytes);
+
+ break;
+ }
+ case WINNET_IP_TYPE_IPV6:
+ {
+ gw.si_family = AF_INET6;
+ memcpy(&gw.Ipv6.sin6_addr.u.Byte, in->gateway->bytes, 16);
+
+ break;
+ }
+ default:
+ {
+ throw std::logic_error("Invalid gateway type specifier in 'WINNET_NODE' definition");
+ }
+ }
+
+ gateway = gw;
+ }
+
+ return Node(deviceName, gateway);
+}
+
+std::vector<Route> ConvertRoutes(const WINNET_ROUTE *routes, uint32_t numRoutes)
+{
+ std::vector<Route> out;
+
+ out.reserve(numRoutes);
+
+ for (size_t i = 0; i < numRoutes; ++i)
+ {
+ out.emplace_back(Route
+ {
+ ConvertNetwork(routes[i].network),
+ ConvertNode(routes[i].node)
+ });
+ }
+
+ return out;
+}
+
void UnwindAndLog(MullvadLogSink logSink, void *logSinkContext, const std::exception &err)
{
if (nullptr == logSink)
@@ -66,12 +181,12 @@ WinNet_GetTapInterfaceIpv6Status( {
try
{
- MIB_IPINTERFACE_ROW interface = { 0 };
+ MIB_IPINTERFACE_ROW iface = { 0 };
- interface.InterfaceLuid = NetworkInterfaces::GetInterfaceLuid(InterfaceUtils::GetTapInterfaceAlias());
- interface.Family = AF_INET6;
+ iface.InterfaceLuid = NetworkInterfaces::GetInterfaceLuid(InterfaceUtils::GetTapInterfaceAlias());
+ iface.Family = AF_INET6;
- const auto status = GetIpInterfaceEntry(&interface);
+ const auto status = GetIpInterfaceEntry(&iface);
if (NO_ERROR == status)
{
@@ -201,3 +316,179 @@ WinNet_DeactivateConnectivityMonitor( {
}
}
+
+extern "C"
+WINNET_LINKAGE
+bool
+WINNET_API
+WinNet_ActivateRouteManager(
+ MullvadLogSink logSink,
+ void *logSinkContext
+)
+{
+ try
+ {
+ if (nullptr != g_RouteManager)
+ {
+ throw std::runtime_error("Cannot activate route manager twice");
+ }
+
+ g_RouteManagerLogSink = std::make_shared<shared::LogSinkAdapter>(logSink, logSinkContext);
+ g_RouteManager = new RouteManager(g_RouteManagerLogSink);
+
+ return true;
+ }
+ catch (const std::exception &err)
+ {
+ UnwindAndLog(logSink, logSinkContext, err);
+ return false;
+ }
+ catch (...)
+ {
+ return false;
+ }
+}
+
+extern "C"
+WINNET_LINKAGE
+bool
+WINNET_API
+WinNet_AddRoutes(
+ const WINNET_ROUTE *routes,
+ uint32_t numRoutes
+)
+{
+ if (nullptr == g_RouteManager)
+ {
+ return false;
+ }
+
+ try
+ {
+ g_RouteManager->addRoutes(ConvertRoutes(routes, numRoutes));
+ return true;
+ }
+ catch (const std::exception &err)
+ {
+ common::error::UnwindException(err, g_RouteManagerLogSink);
+ return false;
+ }
+ catch (...)
+ {
+ return false;
+ }
+}
+
+extern "C"
+WINNET_LINKAGE
+bool
+WINNET_API
+WinNet_AddRoute(
+ const WINNET_ROUTE *route
+)
+{
+ if (nullptr == g_RouteManager)
+ {
+ return false;
+ }
+
+ try
+ {
+ g_RouteManager->addRoute
+ (
+ Route{ ConvertNetwork(route->network), ConvertNode(route->node) }
+ );
+
+ return true;
+ }
+ catch (const std::exception &err)
+ {
+ common::error::UnwindException(err, g_RouteManagerLogSink);
+ return false;
+ }
+ catch (...)
+ {
+ return false;
+ }
+}
+
+extern "C"
+WINNET_LINKAGE
+bool
+WINNET_API
+WinNet_DeleteRoutes(
+ const WINNET_ROUTE *routes,
+ uint32_t numRoutes
+)
+{
+ if (nullptr == g_RouteManager)
+ {
+ return false;
+ }
+
+ try
+ {
+ g_RouteManager->deleteRoutes(ConvertRoutes(routes, numRoutes));
+ return true;
+ }
+ catch (const std::exception &err)
+ {
+ common::error::UnwindException(err, g_RouteManagerLogSink);
+ return false;
+ }
+ catch (...)
+ {
+ return false;
+ }
+}
+
+extern "C"
+WINNET_LINKAGE
+bool
+WINNET_API
+WinNet_DeleteRoute(
+ const WINNET_ROUTE *route
+)
+{
+ if (nullptr == g_RouteManager)
+ {
+ return false;
+ }
+
+ try
+ {
+ g_RouteManager->deleteRoute
+ (
+ Route{ ConvertNetwork(route->network), ConvertNode(route->node) }
+ );
+
+ return true;
+ }
+ catch (const std::exception &err)
+ {
+ common::error::UnwindException(err, g_RouteManagerLogSink);
+ return false;
+ }
+ catch (...)
+ {
+ return false;
+ }
+}
+
+extern "C"
+WINNET_LINKAGE
+void
+WINNET_API
+WinNet_DeactivateRouteManager(
+)
+{
+ try
+ {
+ delete g_RouteManager;
+ g_RouteManager = nullptr;
+ }
+ catch (...)
+ {
+ }
+}
+
diff --git a/windows/winnet/src/winnet/winnet.def b/windows/winnet/src/winnet/winnet.def index 04c3f22ee3..05738682e3 100644 --- a/windows/winnet/src/winnet/winnet.def +++ b/windows/winnet/src/winnet/winnet.def @@ -6,3 +6,6 @@ EXPORTS WinNet_ReleaseString WinNet_ActivateConnectivityMonitor WinNet_DeactivateConnectivityMonitor + WinNet_ActivateRouteManager + WinNet_DeactivateRouteManager + diff --git a/windows/winnet/src/winnet/winnet.h b/windows/winnet/src/winnet/winnet.h index 9b1af52e36..1906e77ce6 100644 --- a/windows/winnet/src/winnet/winnet.h +++ b/windows/winnet/src/winnet/winnet.h @@ -1,6 +1,7 @@ #pragma once #include "../../shared/logsink.h" +#include <stdint.h> #include <stdbool.h> #ifndef WINNET_STATIC @@ -89,3 +90,90 @@ void WINNET_API WinNet_DeactivateConnectivityMonitor( ); + + +enum WINNET_IP_TYPE +{ + WINNET_IP_TYPE_IPV4 = 0, + WINNET_IP_TYPE_IPV6 = 1, +}; + +typedef struct tag_WINNET_IPNETWORK +{ + WINNET_IP_TYPE type; + uint8_t bytes[16]; // Network byte order. + uint8_t prefix; +} +WINNET_IPNETWORK; + +typedef struct tag_WINNET_IP +{ + WINNET_IP_TYPE type; + uint8_t bytes[16]; // Network byte order. +} +WINNET_IP; + +typedef struct tag_WINNET_NODE +{ + const WINNET_IP *gateway; + const wchar_t *deviceName; +} +WINNET_NODE; + +typedef struct tag_WINNET_ROUTE +{ + WINNET_IPNETWORK network; + const WINNET_NODE *node; +} +WINNET_ROUTE; + +extern "C" +WINNET_LINKAGE +bool +WINNET_API +WinNet_ActivateRouteManager( + MullvadLogSink logSink, + void *logSinkContext +); + +extern "C" +WINNET_LINKAGE +bool +WINNET_API +WinNet_AddRoutes( + const WINNET_ROUTE *routes, + uint32_t numRoutes +); + +extern "C" +WINNET_LINKAGE +bool +WINNET_API +WinNet_AddRoute( + const WINNET_ROUTE *route +); + +extern "C" +WINNET_LINKAGE +bool +WINNET_API +WinNet_DeleteRoutes( + const WINNET_ROUTE *routes, + uint32_t numRoutes +); + +extern "C" +WINNET_LINKAGE +bool +WINNET_API +WinNet_DeleteRoute( + const WINNET_ROUTE *route +); + +extern "C" +WINNET_LINKAGE +void +WINNET_API +WinNet_DeactivateRouteManager( +); + diff --git a/windows/winnet/src/winnet/winnet.vcxproj b/windows/winnet/src/winnet/winnet.vcxproj index 192320daaf..8be655436f 100644 --- a/windows/winnet/src/winnet/winnet.vcxproj +++ b/windows/winnet/src/winnet/winnet.vcxproj @@ -28,20 +28,24 @@ </ItemGroup> <ItemGroup> <ClCompile Include="networkadaptermonitor.cpp" /> + <ClCompile Include="adapters.cpp" /> <ClCompile Include="dllmain.cpp" /> <ClCompile Include="InterfacePair.cpp" /> <ClCompile Include="interfaceutils.cpp" /> <ClCompile Include="offlinemonitor.cpp" /> <ClCompile Include="NetworkInterfaces.cpp" /> + <ClCompile Include="routemanager.cpp" /> <ClCompile Include="stdafx.cpp" /> <ClCompile Include="winnet.cpp" /> </ItemGroup> <ItemGroup> <ClInclude Include="networkadaptermonitor.h" /> + <ClInclude Include="adapters.h" /> <ClInclude Include="InterfacePair.h" /> <ClInclude Include="interfaceutils.h" /> <ClInclude Include="offlinemonitor.h" /> <ClInclude Include="NetworkInterfaces.h" /> + <ClInclude Include="routemanager.h" /> <ClInclude Include="stdafx.h" /> <ClInclude Include="targetver.h" /> <ClInclude Include="winnet.h" /> diff --git a/windows/winnet/src/winnet/winnet.vcxproj.filters b/windows/winnet/src/winnet/winnet.vcxproj.filters index 9a901d3203..e739032cdf 100644 --- a/windows/winnet/src/winnet/winnet.vcxproj.filters +++ b/windows/winnet/src/winnet/winnet.vcxproj.filters @@ -9,6 +9,8 @@ <ClCompile Include="interfaceutils.cpp" /> <ClCompile Include="networkadaptermonitor.cpp" /> <ClCompile Include="offlinemonitor.cpp" /> + <ClCompile Include="routemanager.cpp" /> + <ClCompile Include="adapters.cpp" /> </ItemGroup> <ItemGroup> <ClInclude Include="stdafx.h" /> @@ -19,6 +21,8 @@ <ClInclude Include="interfaceutils.h" /> <ClInclude Include="networkadaptermonitor.h" /> <ClInclude Include="offlinemonitor.h" /> + <ClInclude Include="routemanager.h" /> + <ClInclude Include="adapters.h" /> </ItemGroup> <ItemGroup> <None Include="winnet.def" /> |
