summaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorOdd Stranne <odd@mullvad.net>2020-02-03 13:43:12 +0100
committerOdd Stranne <odd@mullvad.net>2020-02-03 13:43:12 +0100
commiteb86bf3ea06616e2a6d009f30fc11722791763fb (patch)
tree45713d1a6aa3cfbd3157422f13e5c86f3b839d32
parent3e141b7dfde547bd0491198cdca0d982b824d829 (diff)
parent119b662c0371f3bbbb7ff4836eb8ce077c899c35 (diff)
downloadmullvadvpn-eb86bf3ea06616e2a6d009f30fc11722791763fb.tar.xz
mullvadvpn-eb86bf3ea06616e2a6d009f30fc11722791763fb.zip
Merge branch 'win-improve-route-installation'
-rw-r--r--CHANGELOG.md5
-rw-r--r--talpid-core/src/tunnel/wireguard/wireguard_go.rs2
-rw-r--r--talpid-core/src/winnet.rs85
-rw-r--r--windows/winnet/src/winnet/converters.cpp118
-rw-r--r--windows/winnet/src/winnet/converters.h16
-rw-r--r--windows/winnet/src/winnet/routing/helpers.cpp1
-rw-r--r--windows/winnet/src/winnet/routing/helpers.h1
-rw-r--r--windows/winnet/src/winnet/routing/routemanager.cpp174
-rw-r--r--windows/winnet/src/winnet/routing/routemanager.h32
-rw-r--r--windows/winnet/src/winnet/routing/types.cpp8
-rw-r--r--windows/winnet/src/winnet/routing/types.h3
-rw-r--r--windows/winnet/src/winnet/winnet.cpp215
-rw-r--r--windows/winnet/src/winnet/winnet.h35
-rw-r--r--windows/winnet/src/winnet/winnet.vcxproj2
-rw-r--r--windows/winnet/src/winnet/winnet.vcxproj.filters2
15 files changed, 281 insertions, 418 deletions
diff --git a/CHANGELOG.md b/CHANGELOG.md
index d05959ca96..6f36f6569a 100644
--- a/CHANGELOG.md
+++ b/CHANGELOG.md
@@ -51,7 +51,10 @@ Line wrap the file at 100 chars. Th
#### Windows
- Use a branded TAP driver for OpenVPN to prevent conflicts with other software and solve issues
related to driver upgrades. Also use the NDIS 6 driver on Windows 7.
-
+- Be more aggressive when installing routes, in effect taking ownership of existing duplicate route
+ entries. This allows the daemon to initialize properly even if a previous instance did not have a
+ clean shutdown.
+
### Fixed
- Don't try to replace WireGuard key if account has too many keys already.
- Fix bogus update notification caused by an outdated cache.
diff --git a/talpid-core/src/tunnel/wireguard/wireguard_go.rs b/talpid-core/src/tunnel/wireguard/wireguard_go.rs
index 91af5aa7a2..c569c8727e 100644
--- a/talpid-core/src/tunnel/wireguard/wireguard_go.rs
+++ b/talpid-core/src/tunnel/wireguard/wireguard_go.rs
@@ -160,7 +160,7 @@ impl WgGoTunnel {
#[cfg(target_os = "windows")]
pub unsafe extern "system" fn default_route_changed_callback(
event_type: winnet::WinNetDefaultRouteChangeEventType,
- address_family: winnet::WinNetIpFamily,
+ address_family: winnet::WinNetAddrFamily,
interface_luid: u64,
_ctx: *mut libc::c_void,
) {
diff --git a/talpid-core/src/winnet.rs b/talpid-core/src/winnet.rs
index 08a4937b3c..2eee4b205a 100644
--- a/talpid-core/src/winnet.rs
+++ b/talpid-core/src/winnet.rs
@@ -105,45 +105,25 @@ pub fn get_tap_interface_alias() -> Result<OsString, Error> {
Ok(alias.to_os_string())
}
-#[repr(C)]
-struct WinNetIpType(u32);
-
-const WINNET_IPV4: u32 = 0;
-const WINNET_IPV6: u32 = 1;
-
-impl WinNetIpType {
- pub fn v4() -> Self {
- WinNetIpType(WINNET_IPV4)
- }
-
- pub fn v6() -> Self {
- WinNetIpType(WINNET_IPV6)
- }
-}
-
-
-#[repr(C)]
-pub struct WinNetIpNetwork {
- ip_type: WinNetIpType,
- ip_bytes: [u8; 16],
- prefix: u8,
+#[allow(dead_code)]
+#[repr(u32)]
+pub enum WinNetAddrFamily {
+ IPV4 = 0,
+ IPV6 = 1,
}
-impl From<IpNetwork> for WinNetIpNetwork {
- fn from(network: IpNetwork) -> WinNetIpNetwork {
- let WinNetIp { ip_type, ip_bytes } = WinNetIp::from(network.ip());
- let prefix = network.prefix();
- WinNetIpNetwork {
- ip_type,
- ip_bytes,
- prefix,
+impl WinNetAddrFamily {
+ pub fn to_windows_proto_enum(&self) -> u16 {
+ match self {
+ Self::IPV4 => 2,
+ Self::IPV6 => 23,
}
}
}
#[repr(C)]
pub struct WinNetIp {
- ip_type: WinNetIpType,
+ addr_family: WinNetAddrFamily,
ip_bytes: [u8; 16],
}
@@ -154,7 +134,7 @@ impl From<IpAddr> for WinNetIp {
IpAddr::V4(v4_addr) => {
bytes[..4].copy_from_slice(&v4_addr.octets());
WinNetIp {
- ip_type: WinNetIpType::v4(),
+ addr_family: WinNetAddrFamily::IPV4,
ip_bytes: bytes,
}
}
@@ -162,7 +142,7 @@ impl From<IpAddr> for WinNetIp {
bytes.copy_from_slice(&v6_addr.octets());
WinNetIp {
- ip_type: WinNetIpType::v6(),
+ addr_family: WinNetAddrFamily::IPV6,
ip_bytes: bytes,
}
}
@@ -171,6 +151,21 @@ impl From<IpAddr> for WinNetIp {
}
#[repr(C)]
+pub struct WinNetIpNetwork {
+ prefix: u8,
+ ip: WinNetIp,
+}
+
+impl From<IpNetwork> for WinNetIpNetwork {
+ fn from(network: IpNetwork) -> WinNetIpNetwork {
+ WinNetIpNetwork {
+ prefix: network.prefix(),
+ ip: WinNetIp::from(network.ip()),
+ }
+ }
+}
+
+#[repr(C)]
pub struct WinNetNode {
gateway: *mut WinNetIp,
device_name: *mut u16,
@@ -196,7 +191,6 @@ impl WinNetNode {
}
}
-
fn from_device(name: &str) -> Self {
let device_name = WideCString::from_str(name)
.expect("Failed to convert UTF-8 string to null terminated UCS string")
@@ -234,7 +228,6 @@ impl Drop for WinNetNode {
}
}
-
#[repr(C)]
pub struct WinNetRoute {
gateway: WinNetIpNetwork,
@@ -251,7 +244,7 @@ impl WinNetRoute {
pub fn new(node: WinNetNode, gateway: WinNetIpNetwork) -> Self {
let node = Box::into_raw(Box::new(node));
- WinNetRoute { gateway, node }
+ Self { gateway, node }
}
}
@@ -273,7 +266,7 @@ pub fn activate_routing_manager(routes: &[WinNetRoute]) -> bool {
pub struct WinNetCallbackHandle {
handle: *mut libc::c_void,
- // allows us to keep the context pointer allive.
+ // Allows us to keep the context pointer alive.
_context: Box<dyn std::any::Any>,
}
@@ -292,25 +285,9 @@ pub enum WinNetDefaultRouteChangeEventType {
DefaultRouteRemoved = 1,
}
-#[allow(dead_code)]
-#[repr(u16)]
-pub enum WinNetIpFamily {
- V4 = 0,
- V6 = 1,
-}
-
-impl WinNetIpFamily {
- pub fn to_windows_proto_enum(&self) -> u16 {
- match self {
- Self::V4 => 2,
- Self::V6 => 23,
- }
- }
-}
-
pub type DefaultRouteChangedCallback = unsafe extern "system" fn(
event_type: WinNetDefaultRouteChangeEventType,
- ip_family: WinNetIpFamily,
+ addr_family: WinNetAddrFamily,
interface_luid: u64,
ctx: *mut c_void,
);
diff --git a/windows/winnet/src/winnet/converters.cpp b/windows/winnet/src/winnet/converters.cpp
new file mode 100644
index 0000000000..3584f8e34f
--- /dev/null
+++ b/windows/winnet/src/winnet/converters.cpp
@@ -0,0 +1,118 @@
+#include <stdafx.h>
+#include "converters.h"
+#include <libcommon/error.h>
+#include <cstdint>
+
+using namespace winnet::routing;
+
+namespace
+{
+
+SOCKADDR_INET IpToNative(const WINNET_IP &from)
+{
+ SOCKADDR_INET to = { 0 };
+
+ switch (from.family)
+ {
+ case WINNET_ADDR_FAMILY_IPV4:
+ {
+ to.Ipv4.sin_family = AF_INET;
+ to.Ipv4.sin_addr.s_addr = *reinterpret_cast<const uint32_t*>(from.bytes);
+
+ break;
+ }
+ case WINNET_ADDR_FAMILY_IPV6:
+ {
+ to.Ipv6.sin6_family = AF_INET6;
+ memcpy(to.Ipv6.sin6_addr.u.Byte, from.bytes, 16);
+
+ break;
+ }
+ default:
+ {
+ THROW_ERROR("Invalid network address family");
+ }
+ }
+
+ return to;
+}
+
+} // anonymous namespace
+
+namespace winnet
+{
+
+Network ConvertNetwork(const WINNET_IP_NETWORK &in)
+{
+ //
+ // Convert WINNET_IPNETWORK into Network aka IP_ADDRESS_PREFIX
+ //
+
+ Network out = { 0 };
+
+ out.PrefixLength = in.prefix;
+ out.Prefix = IpToNative(in.addr);
+
+ return out;
+}
+
+std::optional<Node> ConvertNode(const WINNET_NODE *in)
+{
+ if (nullptr == in)
+ {
+ return std::nullopt;
+ }
+
+ if (nullptr == in->deviceName && nullptr == in->gateway)
+ {
+ THROW_ERROR("Invalid 'WINNET_NODE' definition");
+ }
+
+ std::optional<std::wstring> deviceName;
+ std::optional<NodeAddress> gateway;
+
+ if (nullptr != in->deviceName)
+ {
+ deviceName = in->deviceName;
+ }
+
+ if (nullptr != in->gateway)
+ {
+ gateway = IpToNative(*in->gateway);
+ }
+
+ return Node(deviceName, gateway);
+}
+
+std::vector<Route> ConvertRoutes(const WINNET_ROUTE *routes, uint32_t numRoutes)
+{
+ std::vector<Route> out;
+
+ out.reserve(numRoutes);
+
+ for (size_t i = 0; i < numRoutes; ++i)
+ {
+ out.emplace_back(Route
+ {
+ ConvertNetwork(routes[i].network),
+ ConvertNode(routes[i].node)
+ });
+ }
+
+ return out;
+}
+
+std::vector<SOCKADDR_INET> ConvertAddresses(const WINNET_IP *addresses, uint32_t numAddresses)
+{
+ std::vector<SOCKADDR_INET> out;
+ out.reserve(numAddresses);
+
+ for (uint32_t i = 0; i < numAddresses; ++i)
+ {
+ out.emplace_back(IpToNative(addresses[i]));
+ }
+
+ return out;
+}
+
+}
diff --git a/windows/winnet/src/winnet/converters.h b/windows/winnet/src/winnet/converters.h
new file mode 100644
index 0000000000..8a1c59e4da
--- /dev/null
+++ b/windows/winnet/src/winnet/converters.h
@@ -0,0 +1,16 @@
+#pragma once
+
+#include "winnet.h"
+#include "routing/types.h"
+#include <optional>
+#include <vector>
+
+namespace winnet
+{
+
+routing::Network ConvertNetwork(const WINNET_IP_NETWORK &in);
+std::optional<routing::Node> ConvertNode(const WINNET_NODE *in);
+std::vector<routing::Route> ConvertRoutes(const WINNET_ROUTE *routes, uint32_t numRoutes);
+std::vector<SOCKADDR_INET> ConvertAddresses(const WINNET_IP *addresses, uint32_t numAddresses);
+
+}
diff --git a/windows/winnet/src/winnet/routing/helpers.cpp b/windows/winnet/src/winnet/routing/helpers.cpp
index 9b3f5f87b7..2f171c08a1 100644
--- a/windows/winnet/src/winnet/routing/helpers.cpp
+++ b/windows/winnet/src/winnet/routing/helpers.cpp
@@ -3,7 +3,6 @@
#include <ws2def.h>
#include <in6addr.h>
#include <numeric>
-//#include <netioapi.h>
#include <libcommon/error.h>
#include <libcommon/memory.h>
diff --git a/windows/winnet/src/winnet/routing/helpers.h b/windows/winnet/src/winnet/routing/helpers.h
index 3ef5e85b75..37a0acc611 100644
--- a/windows/winnet/src/winnet/routing/helpers.h
+++ b/windows/winnet/src/winnet/routing/helpers.h
@@ -41,6 +41,5 @@ std::vector<const SOCKET_ADDRESS *> IsolateGatewayAddresses
bool AddressPresent(const std::vector<const SOCKET_ADDRESS *> &hay, const SOCKADDR_INET *needle);
-//NodeAddress ConvertSocketAddress(const SOCKET_ADDRESS *sa);
}
diff --git a/windows/winnet/src/winnet/routing/routemanager.cpp b/windows/winnet/src/winnet/routing/routemanager.cpp
index 166dbf8dcf..81ee4e3d96 100644
--- a/windows/winnet/src/winnet/routing/routemanager.cpp
+++ b/windows/winnet/src/winnet/routing/routemanager.cpp
@@ -162,29 +162,19 @@ InterfaceAndGateway ResolveNode(ADDRESS_FAMILY family, const std::optional<Node>
return InterfaceAndGateway{ InterfaceLuidFromGateway(node.gateway().value()), node.gateway().value() };
}
-// TODO: Move to libcommon
-uint32_t ByteSwap(uint32_t val)
-{
- return
- (
- ((val & 0xFF) << 24) |
- ((val & 0xFF00) << 8) |
- ((val & 0xFF0000) >> 8) |
- ((val & 0xFF000000) >> 24)
- );
-}
-
std::wstring FormatNetwork(const Network &network)
{
+ using namespace common::string;
+
switch (network.Prefix.si_family)
{
case AF_INET:
{
- return common::string::FormatIpv4(ByteSwap(network.Prefix.Ipv4.sin_addr.s_addr), network.PrefixLength);
+ return FormatIpv4<AddressOrder::NetworkByteOrder>(network.Prefix.Ipv4.sin_addr.s_addr, network.PrefixLength);
}
case AF_INET6:
{
- return common::string::FormatIpv6(network.Prefix.Ipv6.sin6_addr.u.Byte, network.PrefixLength);
+ return FormatIpv6(network.Prefix.Ipv6.sin6_addr.u.Byte, network.PrefixLength);
}
default:
{
@@ -252,86 +242,27 @@ void RouteManager::addRoutes(const std::vector<Route> &routes)
{
try
{
- auto record = findRouteRecord(route);
-
- if (record != m_routes.end())
- {
- deleteFromRoutingTable(record->registeredRoute);
- eventLog.emplace_back(EventEntry{ EventType::DELETE_ROUTE, *record });
- m_routes.erase(record);
- }
-
- const RouteRecord newRecord { route, addIntoRoutingTable(route) };
+ RouteRecord newRecord{ route, addIntoRoutingTable(route) };
eventLog.emplace_back(EventEntry{ EventType::ADD_ROUTE, newRecord });
- m_routes.emplace_back(std::move(newRecord));
- }
- catch (...)
- {
- undoEvents(eventLog);
-
- THROW_ERROR("Failed during batch insertion of routes");
- }
- }
-}
-
-void RouteManager::addRoute(const Route &route)
-{
- AutoLockType lock(m_routesLock);
- std::optional<RouteRecord> deletedRecord;
+ auto existingRecord = findRouteRecord(newRecord.registeredRoute);
- auto record = findRouteRecord(route);
-
- if (record != m_routes.end())
- {
- try
- {
- deleteFromRoutingTable(record->registeredRoute);
- }
- catch (...)
- {
- THROW_ERROR("Failed to evict old route when adding new route");
- }
-
- deletedRecord = *record;
- m_routes.erase(record);
- }
-
- try
- {
- m_routes.emplace_back
- (
- RouteRecord{ route, addIntoRoutingTable(route) }
- );
- }
- catch (...)
- {
- //
- // Restore deleted record.
- //
-
- if (deletedRecord.has_value())
- {
- auto &r = deletedRecord.value();
-
- try
+ if (m_routes.end() == existingRecord)
{
- restoreIntoRoutingTable(r.registeredRoute);
- m_routes.emplace_back(r);
+ m_routes.emplace_back(std::move(newRecord));
}
- catch (const std::exception &ex)
+ else
{
- const auto err = std::string("Failed to restore evicted route during rollback: ").append(ex.what());
- m_logSink->error(err.c_str());
+ *existingRecord = std::move(newRecord);
}
}
+ catch (...)
+ {
+ undoEvents(eventLog);
- //
- // Just rethrow because the error is from addIntoRoutingTable().
- //
-
- throw;
+ THROW_ERROR("Failed during batch insertion of routes");
+ }
}
}
@@ -345,11 +276,11 @@ void RouteManager::deleteRoutes(const std::vector<Route> &routes)
{
try
{
- auto record = findRouteRecord(route);
+ const auto record = findRouteRecordFromSpec(route);
if (m_routes.end() == record)
{
- const auto err = std::wstring(L"Request to delete previously unregistered route: ")
+ const auto err = std::wstring(L"Request to delete unknown route: ")
.append(FormatNetwork(route.network()));
m_logSink->warning(common::string::ToAnsi(err).c_str());
@@ -358,6 +289,7 @@ void RouteManager::deleteRoutes(const std::vector<Route> &routes)
}
deleteFromRoutingTable(record->registeredRoute);
+
eventLog.emplace_back(EventEntry{ EventType::DELETE_ROUTE, *record });
m_routes.erase(record);
}
@@ -370,26 +302,6 @@ void RouteManager::deleteRoutes(const std::vector<Route> &routes)
}
}
-void RouteManager::deleteRoute(const Route &route)
-{
- AutoLockType lock(m_routesLock);
-
- auto record = findRouteRecord(route);
-
- if (m_routes.end() == record)
- {
- const auto err = std::wstring(L"Request to delete previously unregistered route: ")
- .append(FormatNetwork(route.network()));
-
- m_logSink->warning(common::string::ToAnsi(err).c_str());
-
- return;
- }
-
- deleteFromRoutingTable(record->registeredRoute);
- m_routes.erase(record);
-}
-
RouteManager::CallbackHandle RouteManager::registerDefaultRouteChangedCallback(DefaultRouteChangedCallback callback)
{
AutoRecursiveLockType lock(m_defaultRouteCallbacksLock);
@@ -415,17 +327,20 @@ void RouteManager::unregisterDefaultRouteChangedCallback(CallbackHandle handle)
}
}
-std::list<RouteManager::RouteRecord>::iterator RouteManager::findRouteRecord(const Network &network)
+std::list<RouteManager::RouteRecord>::iterator RouteManager::findRouteRecord(const RegisteredRoute &route)
{
- return std::find_if(m_routes.begin(), m_routes.end(), [&network](const auto &candidate)
+ return std::find_if(m_routes.begin(), m_routes.end(), [&route](const auto &record)
{
- return EqualAddress(network, candidate.route.network());
+ return route == record.registeredRoute;
});
}
-std::list<RouteManager::RouteRecord>::iterator RouteManager::findRouteRecord(const Route &route)
+std::list<RouteManager::RouteRecord>::iterator RouteManager::findRouteRecordFromSpec(const Route &route)
{
- return findRouteRecord(route.network());
+ return std::find_if(m_routes.begin(), m_routes.end(), [&route](const auto &record)
+ {
+ return route == record.route;
+ });
}
RouteManager::RegisteredRoute RouteManager::addIntoRoutingTable(const Route &route)
@@ -443,12 +358,23 @@ RouteManager::RegisteredRoute RouteManager::addIntoRoutingTable(const Route &rou
spec.Protocol = MIB_IPPROTO_NETMGMT;
spec.Origin = NlroManual;
+ auto status = CreateIpForwardEntry2(&spec);
+
+ //
+ // The return code ERROR_OBJECT_ALREADY_EXISTS means there is already an existing route
+ // on the same interface, with the same DestinationPrefix and NextHop.
//
- // Do not treat ERROR_OBJECT_ALREADY_EXISTS as being successful.
- // Because it may not take route metric into consideration.
+ // However, all the other properties of the route may be different. And the properties may
+ // not have the exact same values as when the route was registered, because windows
+ // will adjust route properties at time of route insertion as well as later.
+ //
+ // The simplest thing in this case is to just overwrite the route.
//
- const auto status = CreateIpForwardEntry2(&spec);
+ if (status == ERROR_OBJECT_ALREADY_EXISTS)
+ {
+ status = SetIpForwardEntry2(&spec);
+ }
if (NO_ERROR != status)
{
@@ -519,7 +445,7 @@ void RouteManager::undoEvents(const std::vector<EventEntry> &eventLog)
{
case EventType::ADD_ROUTE:
{
- auto record = findRouteRecord(it->record.route);
+ const auto record = findRouteRecord(it->record.registeredRoute);
if (m_routes.end() == record)
{
@@ -555,10 +481,7 @@ void RouteManager::undoEvents(const std::vector<EventEntry> &eventLog)
// static
std::wstring RouteManager::FormatRegisteredRoute(const RegisteredRoute &route)
{
- //
- // TODO: Fix broken IP formatting
- // Update FormatIpv4 function with an additional argument to specify network/host byte order.
- //
+ using namespace common::string;
std::wstringstream ss;
@@ -568,10 +491,10 @@ std::wstring RouteManager::FormatRegisteredRoute(const RegisteredRoute &route)
if (0 != route.nextHop.Ipv4.sin_addr.s_addr)
{
- gateway = common::string::FormatIpv4(ByteSwap(route.nextHop.Ipv4.sin_addr.s_addr));
+ gateway = FormatIpv4<AddressOrder::NetworkByteOrder>(route.nextHop.Ipv4.sin_addr.s_addr);
}
- ss << common::string::FormatIpv4(ByteSwap(route.network.Prefix.Ipv4.sin_addr.s_addr), route.network.PrefixLength)
+ ss << FormatIpv4<AddressOrder::NetworkByteOrder>(route.network.Prefix.Ipv4.sin_addr.s_addr, route.network.PrefixLength)
<< L" with gateway " << gateway
<< L" on interface with LUID 0x" << std::hex << route.luid.Value;
}
@@ -584,10 +507,10 @@ std::wstring RouteManager::FormatRegisteredRoute(const RegisteredRoute &route)
if (0 != std::accumulate(begin, end, 0))
{
- gateway = common::string::FormatIpv6(route.nextHop.Ipv6.sin6_addr.u.Byte);
+ gateway = FormatIpv6(route.nextHop.Ipv6.sin6_addr.u.Byte);
}
- ss << common::string::FormatIpv6(route.network.Prefix.Ipv6.sin6_addr.u.Byte, route.network.PrefixLength)
+ ss << FormatIpv6(route.network.Prefix.Ipv6.sin6_addr.u.Byte, route.network.PrefixLength)
<< L" with gateway " << gateway
<< L" on interface with LUID 0x" << std::hex << route.luid.Value;
}
@@ -668,6 +591,11 @@ void RouteManager::defaultRouteChanged(ADDRESS_FAMILY family, DefaultRouteMonito
for (auto &it : affectedRoutes)
{
+ //
+ // We can't update the existing route because defining characteristics are being changed.
+ // So removing and adding again is the only option.
+ //
+
try
{
deleteFromRoutingTable(it->registeredRoute);
diff --git a/windows/winnet/src/winnet/routing/routemanager.h b/windows/winnet/src/winnet/routing/routemanager.h
index 981c8e6834..92c712e25d 100644
--- a/windows/winnet/src/winnet/routing/routemanager.h
+++ b/windows/winnet/src/winnet/routing/routemanager.h
@@ -13,6 +13,7 @@
#include <libcommon/string.h>
#include <libcommon/logging/ilogsink.h>
#include "defaultroutemonitor.h"
+#include "helpers.h"
namespace winnet::routing
{
@@ -30,10 +31,7 @@ public:
RouteManager &operator=(RouteManager &&) = delete;
void addRoutes(const std::vector<Route> &routes);
- void addRoute(const Route &route);
-
void deleteRoutes(const std::vector<Route> &routes);
- void deleteRoute(const Route &route);
using DefaultRouteChangedEventType = DefaultRouteMonitor::EventType;
@@ -65,6 +63,13 @@ private:
Network network;
NET_LUID luid;
NodeAddress nextHop;
+
+ bool operator==(const RegisteredRoute &rhs) const
+ {
+ return luid.Value == rhs.luid.Value
+ && EqualAddress(nextHop, rhs.nextHop)
+ && EqualAddress(network, rhs.network);
+ }
};
struct RouteRecord
@@ -79,11 +84,24 @@ private:
std::list<DefaultRouteChangedCallback> m_defaultRouteCallbacks;
std::recursive_mutex m_defaultRouteCallbacksLock;
- // Find record based on destination and mask.
- std::list<RouteRecord>::iterator findRouteRecord(const Network &network);
+ //
+ // Find record based on route registration data.
+ //
+ // Note: Searching the records and matching on route specification is
+ // unreliable because of the node attribute on the route. Different node
+ // specifications can resolve to the same physical node.
+ //
+ // (node = exit node = interface)
+ //
+ std::list<RouteRecord>::iterator findRouteRecord(const RegisteredRoute &route);
- // Note: Same as above!
- std::list<RouteRecord>::iterator findRouteRecord(const Route &route);
+ //
+ // Find record based on route specification.
+ //
+ // Note: Only ever use this to find the registration data for a route
+ // that was successfully registered previously.
+ //
+ std::list<RouteRecord>::iterator findRouteRecordFromSpec(const Route &route);
RegisteredRoute addIntoRoutingTable(const Route &route);
void restoreIntoRoutingTable(const RegisteredRoute &route);
diff --git a/windows/winnet/src/winnet/routing/types.cpp b/windows/winnet/src/winnet/routing/types.cpp
index 5168f96634..9a7c755feb 100644
--- a/windows/winnet/src/winnet/routing/types.cpp
+++ b/windows/winnet/src/winnet/routing/types.cpp
@@ -39,6 +39,10 @@ bool Node::operator==(const Node &rhs) const
return false;
}
}
+ else if (rhs.m_deviceName.has_value())
+ {
+ return false;
+ }
if (m_gateway.has_value())
{
@@ -48,6 +52,10 @@ bool Node::operator==(const Node &rhs) const
return false;
}
}
+ else if (rhs.m_gateway.has_value())
+ {
+ return false;
+ }
return true;
}
diff --git a/windows/winnet/src/winnet/routing/types.h b/windows/winnet/src/winnet/routing/types.h
index 1e132feb00..4b70a5739d 100644
--- a/windows/winnet/src/winnet/routing/types.h
+++ b/windows/winnet/src/winnet/routing/types.h
@@ -7,9 +7,6 @@
#include <ws2def.h>
#include <ws2ipdef.h>
#include <iphlpapi.h>
-//#include <netioapi.h>
-//#include <functional>
-
namespace winnet::routing
{
diff --git a/windows/winnet/src/winnet/winnet.cpp b/windows/winnet/src/winnet/winnet.cpp
index 79ba77a229..7a206635ed 100644
--- a/windows/winnet/src/winnet/winnet.cpp
+++ b/windows/winnet/src/winnet/winnet.cpp
@@ -3,6 +3,7 @@
#include "NetworkInterfaces.h"
#include "offlinemonitor.h"
#include "routing/routemanager.h"
+#include "converters.h"
#include <libshared/logging/logsinkadapter.h>
#include <libshared/logging/unwind.h>
#include <libshared/network/interfaceutils.h>
@@ -28,156 +29,6 @@ std::mutex g_RouteManagerLock;
RouteManager *g_RouteManager = nullptr;
std::shared_ptr<shared::logging::LogSinkAdapter> g_RouteManagerLogSink;
-Network ConvertNetwork(const WINNET_IPNETWORK &in)
-{
- //
- // Convert WINNET_IPNETWORK into Network aka IP_ADDRESS_PREFIX
- //
-
- Network out{ 0 };
-
- out.PrefixLength = in.prefix;
-
- switch (in.type)
- {
- case WINNET_IP_TYPE_IPV4:
- {
- out.Prefix.si_family = AF_INET;
- out.Prefix.Ipv4.sin_family = AF_INET;
- out.Prefix.Ipv4.sin_addr.s_addr = *reinterpret_cast<const uint32_t *>(in.bytes);
-
- break;
- }
- case WINNET_IP_TYPE_IPV6:
- {
- out.Prefix.si_family = AF_INET6;
- out.Prefix.Ipv6.sin6_family = AF_INET6;
- memcpy(out.Prefix.Ipv6.sin6_addr.u.Byte, in.bytes, 16);
-
- break;
- }
- default:
- {
- THROW_ERROR("Missing case handler in switch clause");
- }
- }
-
- return out;
-}
-
-std::optional<Node> ConvertNode(const WINNET_NODE *in)
-{
- if (nullptr == in)
- {
- return {};
- }
-
- if (nullptr == in->deviceName && nullptr == in->gateway)
- {
- THROW_ERROR("Invalid 'WINNET_NODE' definition");
- }
-
- std::optional<std::wstring> deviceName;
- std::optional<NodeAddress> gateway;
-
- if (nullptr != in->deviceName)
- {
- deviceName = in->deviceName;
- }
-
- if (nullptr != in->gateway)
- {
- NodeAddress gw { 0 };
-
- switch (in->gateway->type)
- {
- case WINNET_IP_TYPE_IPV4:
- {
- gw.si_family = AF_INET;
- gw.Ipv4.sin_addr.s_addr = *reinterpret_cast<const uint32_t *>(in->gateway->bytes);
-
- break;
- }
- case WINNET_IP_TYPE_IPV6:
- {
- gw.si_family = AF_INET6;
- memcpy(&gw.Ipv6.sin6_addr.u.Byte, in->gateway->bytes, 16);
-
- break;
- }
- default:
- {
- THROW_ERROR("Invalid gateway type specifier in 'WINNET_NODE' definition");
- }
- }
-
- gateway = gw;
- }
-
- return Node(deviceName, gateway);
-}
-
-std::vector<Route> ConvertRoutes(const WINNET_ROUTE *routes, uint32_t numRoutes)
-{
- std::vector<Route> out;
-
- out.reserve(numRoutes);
-
- for (size_t i = 0; i < numRoutes; ++i)
- {
- out.emplace_back(Route
- {
- ConvertNetwork(routes[i].network),
- ConvertNode(routes[i].node)
- });
- }
-
- return out;
-}
-
-std::vector<SOCKADDR_INET> ConvertAddresses(const WINNET_IP *addresses, uint32_t numAddresses)
-{
- //
- // This duplicates the same logic we have above.
- // TODO: Fix when time permits.
- //
-
- std::vector<SOCKADDR_INET> out;
- out.reserve(numAddresses);
-
- for (uint32_t i = 0; i < numAddresses; ++i)
- {
- const WINNET_IP &from = addresses[i];
- SOCKADDR_INET to{ 0 };
-
- switch (from.type)
- {
- case WINNET_IP_TYPE_IPV4:
- {
- to.si_family = AF_INET;
- to.Ipv4.sin_addr.s_addr = *reinterpret_cast<const uint32_t *>(from.bytes);
-
- break;
- }
- case WINNET_IP_TYPE_IPV6:
- {
- to.si_family = AF_INET6;
- memcpy(&to.Ipv6.sin6_addr.u.Byte, from.bytes, 16);
-
- break;
- }
- default:
- {
- THROW_ERROR("Invalid address family in 'WINNET_IP' definition");
- }
- }
-
- out.push_back(to);
- }
-
- return out;
-}
-
} //anonymous namespace
extern "C"
@@ -406,7 +257,7 @@ WinNet_AddRoutes(
try
{
- g_RouteManager->addRoutes(ConvertRoutes(routes, numRoutes));
+ g_RouteManager->addRoutes(winnet::ConvertRoutes(routes, numRoutes));
return true;
}
catch (const std::exception &err)
@@ -428,31 +279,7 @@ WinNet_AddRoute(
const WINNET_ROUTE *route
)
{
- AutoLockType lock(g_RouteManagerLock);
-
- if (nullptr == g_RouteManager)
- {
- return false;
- }
-
- try
- {
- g_RouteManager->addRoute
- (
- Route{ ConvertNetwork(route->network), ConvertNode(route->node) }
- );
-
- return true;
- }
- catch (const std::exception &err)
- {
- common::error::UnwindException(err, g_RouteManagerLogSink);
- return false;
- }
- catch (...)
- {
- return false;
- }
+ return WinNet_AddRoutes(route, 1);
}
extern "C"
@@ -473,7 +300,7 @@ WinNet_DeleteRoutes(
try
{
- g_RouteManager->deleteRoutes(ConvertRoutes(routes, numRoutes));
+ g_RouteManager->deleteRoutes(winnet::ConvertRoutes(routes, numRoutes));
return true;
}
catch (const std::exception &err)
@@ -495,31 +322,7 @@ WinNet_DeleteRoute(
const WINNET_ROUTE *route
)
{
- AutoLockType lock(g_RouteManagerLock);
-
- if (nullptr == g_RouteManager)
- {
- return false;
- }
-
- try
- {
- g_RouteManager->deleteRoute
- (
- Route{ ConvertNetwork(route->network), ConvertNode(route->node) }
- );
-
- return true;
- }
- catch (const std::exception &err)
- {
- common::error::UnwindException(err, g_RouteManagerLogSink);
- return false;
- }
- catch (...)
- {
- return false;
- }
+ return WinNet_DeleteRoutes(route, 1);
}
extern "C"
@@ -563,10 +366,10 @@ WinNet_RegisterDefaultRouteChangedCallback(
// Translate the family type.
//
- static const std::pair<ADDRESS_FAMILY, WINNET_IP_FAMILY> familyMap[] =
+ static const std::pair<ADDRESS_FAMILY, WINNET_ADDR_FAMILY> familyMap[] =
{
- { static_cast<ADDRESS_FAMILY>(AF_INET), WINNET_IP_FAMILY_V4 },
- { static_cast<ADDRESS_FAMILY>(AF_INET6), WINNET_IP_FAMILY_V6 }
+ { static_cast<ADDRESS_FAMILY>(AF_INET), WINNET_ADDR_FAMILY_IPV4 },
+ { static_cast<ADDRESS_FAMILY>(AF_INET6), WINNET_ADDR_FAMILY_IPV6 }
};
const auto translatedFamily = common::ValueMapper::Map<>(family, familyMap);
@@ -676,7 +479,7 @@ WinNet_AddDeviceIpAddresses(
THROW_ERROR(msg.c_str());
}
- InterfaceUtils::AddDeviceIpAddresses(luid, ConvertAddresses(addresses, numAddresses));
+ InterfaceUtils::AddDeviceIpAddresses(luid, winnet::ConvertAddresses(addresses, numAddresses));
return true;
}
diff --git a/windows/winnet/src/winnet/winnet.h b/windows/winnet/src/winnet/winnet.h
index 40ccf9f421..5afc4052eb 100644
--- a/windows/winnet/src/winnet/winnet.h
+++ b/windows/winnet/src/winnet/winnet.h
@@ -91,26 +91,25 @@ WINNET_API
WinNet_DeactivateConnectivityMonitor(
);
-enum WINNET_IP_TYPE
+enum WINNET_ADDR_FAMILY
{
- WINNET_IP_TYPE_IPV4 = 0,
- WINNET_IP_TYPE_IPV6 = 1,
+ WINNET_ADDR_FAMILY_IPV4 = 0,
+ WINNET_ADDR_FAMILY_IPV6 = 1,
};
-typedef struct tag_WINNET_IPNETWORK
+typedef struct tag_WINNET_IP
{
- WINNET_IP_TYPE type;
+ WINNET_ADDR_FAMILY family;
uint8_t bytes[16]; // Network byte order.
- uint8_t prefix;
}
-WINNET_IPNETWORK;
+WINNET_IP;
-typedef struct tag_WINNET_IP
+typedef struct tag_WINNET_IP_NETWORK
{
- WINNET_IP_TYPE type;
- uint8_t bytes[16]; // Network byte order.
+ uint8_t prefix;
+ WINNET_IP addr;
}
-WINNET_IP;
+WINNET_IP_NETWORK;
typedef struct tag_WINNET_NODE
{
@@ -121,7 +120,7 @@ WINNET_NODE;
typedef struct tag_WINNET_ROUTE
{
- WINNET_IPNETWORK network;
+ WINNET_IP_NETWORK network;
const WINNET_NODE *node;
}
WINNET_ROUTE;
@@ -178,20 +177,14 @@ enum WINNET_DEFAULT_ROUTE_CHANGED_EVENT_TYPE
WINNET_DEFAULT_ROUTE_CHANGED_EVENT_TYPE_REMOVED = 1,
};
-enum WINNET_IP_FAMILY
-{
- WINNET_IP_FAMILY_V4 = 0,
- WINNET_IP_FAMILY_V6 = 1,
-};
-
typedef void (WINNET_API *WinNetDefaultRouteChangedCallback)
(
WINNET_DEFAULT_ROUTE_CHANGED_EVENT_TYPE eventType,
- // Signals which IP family the event relates to.
- WINNET_IP_FAMILY family,
+ // Indicates which IP family the event relates to.
+ WINNET_ADDR_FAMILY family,
- // For update events, signals the interface associated with the new best default route.
+ // For update events, indicates the interface associated with the new best default route.
uint64_t interfaceLuid,
void *context
diff --git a/windows/winnet/src/winnet/winnet.vcxproj b/windows/winnet/src/winnet/winnet.vcxproj
index c21f75b2c1..7b4578d4b1 100644
--- a/windows/winnet/src/winnet/winnet.vcxproj
+++ b/windows/winnet/src/winnet/winnet.vcxproj
@@ -27,6 +27,7 @@
</ProjectConfiguration>
</ItemGroup>
<ItemGroup>
+ <ClCompile Include="converters.cpp" />
<ClCompile Include="networkadaptermonitor.cpp" />
<ClCompile Include="dllmain.cpp" />
<ClCompile Include="InterfacePair.cpp" />
@@ -40,6 +41,7 @@
<ClCompile Include="winnet.cpp" />
</ItemGroup>
<ItemGroup>
+ <ClInclude Include="converters.h" />
<ClInclude Include="networkadaptermonitor.h" />
<ClInclude Include="InterfacePair.h" />
<ClInclude Include="offlinemonitor.h" />
diff --git a/windows/winnet/src/winnet/winnet.vcxproj.filters b/windows/winnet/src/winnet/winnet.vcxproj.filters
index db5a7c8b3a..2d5a039c6b 100644
--- a/windows/winnet/src/winnet/winnet.vcxproj.filters
+++ b/windows/winnet/src/winnet/winnet.vcxproj.filters
@@ -20,6 +20,7 @@
<ClCompile Include="routing\routemanager.cpp">
<Filter>routing</Filter>
</ClCompile>
+ <ClCompile Include="converters.cpp" />
</ItemGroup>
<ItemGroup>
<ClInclude Include="stdafx.h" />
@@ -41,6 +42,7 @@
<ClInclude Include="routing\routemanager.h">
<Filter>routing</Filter>
</ClInclude>
+ <ClInclude Include="converters.h" />
</ItemGroup>
<ItemGroup>
<None Include="winnet.def" />