diff options
| author | Odd Stranne <odd@mullvad.net> | 2020-01-31 13:01:05 +0100 |
|---|---|---|
| committer | Odd Stranne <odd@mullvad.net> | 2020-02-03 13:42:46 +0100 |
| commit | ca0dc2241dd7607f8f07ed2566ae8bfba2b9731f (patch) | |
| tree | da1dd8cfd8eaab3bc159a73229ddfb9f65e8d406 /windows | |
| parent | 3e503aaacb593ed22ddc9594a4f06803283300f9 (diff) | |
| download | mullvadvpn-ca0dc2241dd7607f8f07ed2566ae8bfba2b9731f.tar.xz mullvadvpn-ca0dc2241dd7607f8f07ed2566ae8bfba2b9731f.zip | |
Improve route registration
Diffstat (limited to 'windows')
| -rw-r--r-- | windows/winnet/src/winnet/routing/routemanager.cpp | 58 | ||||
| -rw-r--r-- | windows/winnet/src/winnet/routing/routemanager.h | 29 |
2 files changed, 62 insertions, 25 deletions
diff --git a/windows/winnet/src/winnet/routing/routemanager.cpp b/windows/winnet/src/winnet/routing/routemanager.cpp index edcb594ff1..81ed54f38e 100644 --- a/windows/winnet/src/winnet/routing/routemanager.cpp +++ b/windows/winnet/src/winnet/routing/routemanager.cpp @@ -252,19 +252,20 @@ void RouteManager::addRoutes(const std::vector<Route> &routes) { try { - auto record = findRouteRecord(route); + RouteRecord newRecord{ route, addIntoRoutingTable(route) }; - if (record != m_routes.end()) - { - deleteFromRoutingTable(record->registeredRoute); - eventLog.emplace_back(EventEntry{ EventType::DELETE_ROUTE, *record }); - m_routes.erase(record); - } + eventLog.emplace_back(EventEntry{ EventType::ADD_ROUTE, newRecord }); - const RouteRecord newRecord { route, addIntoRoutingTable(route) }; + auto existingRecord = findRouteRecord(newRecord.registeredRoute); - eventLog.emplace_back(EventEntry{ EventType::ADD_ROUTE, newRecord }); - m_routes.emplace_back(std::move(newRecord)); + if (m_routes.end() == existingRecord) + { + m_routes.emplace_back(std::move(newRecord)); + } + else + { + *existingRecord = std::move(newRecord); + } } catch (...) { @@ -285,11 +286,11 @@ void RouteManager::deleteRoutes(const std::vector<Route> &routes) { try { - auto record = findRouteRecord(route); + const auto record = findRouteRecordFromSpec(route); if (m_routes.end() == record) { - const auto err = std::wstring(L"Request to delete previously unregistered route: ") + const auto err = std::wstring(L"Request to delete unknown route: ") .append(FormatNetwork(route.network())); m_logSink->warning(common::string::ToAnsi(err).c_str()); @@ -298,6 +299,7 @@ void RouteManager::deleteRoutes(const std::vector<Route> &routes) } deleteFromRoutingTable(record->registeredRoute); + eventLog.emplace_back(EventEntry{ EventType::DELETE_ROUTE, *record }); m_routes.erase(record); } @@ -335,17 +337,20 @@ void RouteManager::unregisterDefaultRouteChangedCallback(CallbackHandle handle) } } -std::list<RouteManager::RouteRecord>::iterator RouteManager::findRouteRecord(const Network &network) +std::list<RouteManager::RouteRecord>::iterator RouteManager::findRouteRecord(const RegisteredRoute &route) { - return std::find_if(m_routes.begin(), m_routes.end(), [&network](const auto &candidate) + return std::find_if(m_routes.begin(), m_routes.end(), [&route](const auto &record) { - return EqualAddress(network, candidate.route.network()); + return route == record.registeredRoute; }); } -std::list<RouteManager::RouteRecord>::iterator RouteManager::findRouteRecord(const Route &route) +std::list<RouteManager::RouteRecord>::iterator RouteManager::findRouteRecordFromSpec(const Route &route) { - return findRouteRecord(route.network()); + return std::find_if(m_routes.begin(), m_routes.end(), [&route](const auto &record) + { + return route == record.route; + }); } RouteManager::RegisteredRoute RouteManager::addIntoRoutingTable(const Route &route) @@ -363,12 +368,23 @@ RouteManager::RegisteredRoute RouteManager::addIntoRoutingTable(const Route &rou spec.Protocol = MIB_IPPROTO_NETMGMT; spec.Origin = NlroManual; + auto status = CreateIpForwardEntry2(&spec); + + // + // The return code ERROR_OBJECT_ALREADY_EXISTS means there is already an existing route + // on the same interface, with the same DestinationPrefix and NextHop. // - // Do not treat ERROR_OBJECT_ALREADY_EXISTS as being successful. - // Because it may not take route metric into consideration. + // However, all the other properties of the route may be different. And the properties may + // not have the exact same values as when the route was registered, because windows + // will adjust route properties at time of route insertion as well as later. + // + // The simplest thing in this case is to just overwrite the route. // - const auto status = CreateIpForwardEntry2(&spec); + if (status == ERROR_OBJECT_ALREADY_EXISTS) + { + status = SetIpForwardEntry2(&spec); + } if (NO_ERROR != status) { @@ -439,7 +455,7 @@ void RouteManager::undoEvents(const std::vector<EventEntry> &eventLog) { case EventType::ADD_ROUTE: { - auto record = findRouteRecord(it->record.route); + const auto record = findRouteRecord(it->record.registeredRoute); if (m_routes.end() == record) { diff --git a/windows/winnet/src/winnet/routing/routemanager.h b/windows/winnet/src/winnet/routing/routemanager.h index f8a58b4443..92c712e25d 100644 --- a/windows/winnet/src/winnet/routing/routemanager.h +++ b/windows/winnet/src/winnet/routing/routemanager.h @@ -13,6 +13,7 @@ #include <libcommon/string.h> #include <libcommon/logging/ilogsink.h> #include "defaultroutemonitor.h" +#include "helpers.h" namespace winnet::routing { @@ -62,6 +63,13 @@ private: Network network; NET_LUID luid; NodeAddress nextHop; + + bool operator==(const RegisteredRoute &rhs) const + { + return luid.Value == rhs.luid.Value + && EqualAddress(nextHop, rhs.nextHop) + && EqualAddress(network, rhs.network); + } }; struct RouteRecord @@ -76,11 +84,24 @@ private: std::list<DefaultRouteChangedCallback> m_defaultRouteCallbacks; std::recursive_mutex m_defaultRouteCallbacksLock; - // Find record based on destination and mask. - std::list<RouteRecord>::iterator findRouteRecord(const Network &network); + // + // Find record based on route registration data. + // + // Note: Searching the records and matching on route specification is + // unreliable because of the node attribute on the route. Different node + // specifications can resolve to the same physical node. + // + // (node = exit node = interface) + // + std::list<RouteRecord>::iterator findRouteRecord(const RegisteredRoute &route); - // Note: Same as above! - std::list<RouteRecord>::iterator findRouteRecord(const Route &route); + // + // Find record based on route specification. + // + // Note: Only ever use this to find the registration data for a route + // that was successfully registered previously. + // + std::list<RouteRecord>::iterator findRouteRecordFromSpec(const Route &route); RegisteredRoute addIntoRoutingTable(const Route &route); void restoreIntoRoutingTable(const RegisteredRoute &route); |
