summaryrefslogtreecommitdiffhomepage
path: root/windows
diff options
context:
space:
mode:
Diffstat (limited to 'windows')
-rw-r--r--windows/winfw/src/winfw/fwcontext.cpp25
-rw-r--r--windows/winfw/src/winfw/fwcontext.h16
-rw-r--r--windows/winfw/src/winfw/mullvadguids.cpp30
-rw-r--r--windows/winfw/src/winfw/mullvadguids.h3
-rw-r--r--windows/winfw/src/winfw/rules/permitping.cpp98
-rw-r--r--windows/winfw/src/winfw/rules/permitping.h28
-rw-r--r--windows/winfw/src/winfw/winfw.cpp34
-rw-r--r--windows/winfw/src/winfw/winfw.h21
-rw-r--r--windows/winfw/src/winfw/winfw.vcxproj2
-rw-r--r--windows/winfw/src/winfw/winfw.vcxproj.filters6
-rw-r--r--windows/winnet/src/extras/loader/loader.vcxproj.filters4
-rw-r--r--windows/winnet/src/winnet/interfaceutils.cpp20
-rw-r--r--windows/winnet/src/winnet/interfaceutils.h13
-rw-r--r--windows/winnet/src/winnet/routing/defaultroutemonitor.cpp177
-rw-r--r--windows/winnet/src/winnet/routing/defaultroutemonitor.h69
-rw-r--r--windows/winnet/src/winnet/routing/helpers.cpp275
-rw-r--r--windows/winnet/src/winnet/routing/helpers.h46
-rw-r--r--windows/winnet/src/winnet/routing/routemanager.cpp692
-rw-r--r--windows/winnet/src/winnet/routing/routemanager.h112
-rw-r--r--windows/winnet/src/winnet/routing/types.cpp84
-rw-r--r--windows/winnet/src/winnet/routing/types.h77
-rw-r--r--windows/winnet/src/winnet/winnet.cpp526
-rw-r--r--windows/winnet/src/winnet/winnet.def3
-rw-r--r--windows/winnet/src/winnet/winnet.h145
-rw-r--r--windows/winnet/src/winnet/winnet.vcxproj12
-rw-r--r--windows/winnet/src/winnet/winnet.vcxproj.filters29
26 files changed, 2530 insertions, 17 deletions
diff --git a/windows/winfw/src/winfw/fwcontext.cpp b/windows/winfw/src/winfw/fwcontext.cpp
index e5325f0b9c..49f8793572 100644
--- a/windows/winfw/src/winfw/fwcontext.cpp
+++ b/windows/winfw/src/winfw/fwcontext.cpp
@@ -13,10 +13,10 @@
#include "rules/permitvpnrelay.h"
#include "rules/permitvpntunnel.h"
#include "rules/permitvpntunnelservice.h"
+#include "rules/permitping.h"
#include "rules/restrictdns.h"
#include "libwfp/transaction.h"
#include "libwfp/filterengine.h"
-#include "libwfp/ipaddress.h"
#include <functional>
#include <stdexcept>
#include <utility>
@@ -99,7 +99,12 @@ FwContext::FwContext(uint32_t timeout, const WinFwSettings &settings)
m_baseline = checkpoint;
}
-bool FwContext::applyPolicyConnecting(const WinFwSettings &settings, const WinFwRelay &relay)
+bool FwContext::applyPolicyConnecting
+(
+ const WinFwSettings &settings,
+ const WinFwRelay &relay,
+ const std::optional<PingableHosts> &pingableHosts
+)
{
Ruleset ruleset;
@@ -112,6 +117,22 @@ bool FwContext::applyPolicyConnecting(const WinFwSettings &settings, const WinFw
TranslateProtocol(relay.protocol)
));
+ //
+ // Permit pinging the gateway inside the tunnel.
+ //
+ if (pingableHosts.has_value())
+ {
+ const auto &ph = pingableHosts.value();
+
+ for (const auto &host : ph.hosts)
+ {
+ ruleset.emplace_back(std::make_unique<rules::PermitPing>(
+ ph.tunnelInterfaceAlias,
+ host
+ ));
+ }
+ }
+
return applyRuleset(ruleset);
}
diff --git a/windows/winfw/src/winfw/fwcontext.h b/windows/winfw/src/winfw/fwcontext.h
index 89ef40e1d3..9d5b34c51b 100644
--- a/windows/winfw/src/winfw/fwcontext.h
+++ b/windows/winfw/src/winfw/fwcontext.h
@@ -3,9 +3,11 @@
#include "winfw.h"
#include "sessioncontroller.h"
#include "rules/ifirewallrule.h"
+#include "libwfp/ipaddress.h"
#include <cstdint>
#include <memory>
#include <vector>
+#include <optional>
class FwContext
{
@@ -16,7 +18,19 @@ public:
// This ctor applies the "blocked" policy.
FwContext(uint32_t timeout, const WinFwSettings &settings);
- bool applyPolicyConnecting(const WinFwSettings &settings, const WinFwRelay &relay);
+ struct PingableHosts
+ {
+ std::optional<std::wstring> tunnelInterfaceAlias;
+ std::vector<wfp::IpAddress> hosts;
+ };
+
+ bool applyPolicyConnecting
+ (
+ const WinFwSettings &settings,
+ const WinFwRelay &relay,
+ const std::optional<PingableHosts> &pingableHosts
+ );
+
bool applyPolicyConnected
(
const WinFwSettings &settings,
diff --git a/windows/winfw/src/winfw/mullvadguids.cpp b/windows/winfw/src/winfw/mullvadguids.cpp
index 010d41e44a..e73fac26ed 100644
--- a/windows/winfw/src/winfw/mullvadguids.cpp
+++ b/windows/winfw/src/winfw/mullvadguids.cpp
@@ -59,6 +59,8 @@ DetailedWfpObjectRegistry MullvadGuids::BuildDetailedRegistry()
registry.insert(std::make_pair(WfpObjectType::Filter, FilterPermitNdp_Outbound_Router_Solicitation()));
registry.insert(std::make_pair(WfpObjectType::Filter, FilterPermitNdp_Inbound_Router_Advertisement()));
registry.insert(std::make_pair(WfpObjectType::Filter, FilterPermitNdp_Inbound_Redirect()));
+ registry.insert(std::make_pair(WfpObjectType::Filter, FilterPermitPing_Outbound_Icmpv4()));
+ registry.insert(std::make_pair(WfpObjectType::Filter, FilterPermitPing_Outbound_Icmpv6()));
return registry;
}
@@ -567,3 +569,31 @@ const GUID &MullvadGuids::FilterPermitNdp_Inbound_Redirect()
return g;
}
+
+//static
+const GUID &MullvadGuids::FilterPermitPing_Outbound_Icmpv4()
+{
+ static const GUID g =
+ {
+ 0x2ecf7ff7,
+ 0xc951,
+ 0x4056,
+ { 0xb0, 0xf7, 0x40, 0xa4, 0x5c, 0x7e, 0xb4, 0xc2 }
+ };
+
+ return g;
+}
+
+//static
+const GUID &MullvadGuids::FilterPermitPing_Outbound_Icmpv6()
+{
+ static const GUID g =
+ {
+ 0x3deb8cab,
+ 0x1edb,
+ 0x4aa1,
+ { 0xb2, 0x73, 0xec, 0x61, 0x4f, 0x50, 0xdc, 0x13 }
+ };
+
+ return g;
+}
diff --git a/windows/winfw/src/winfw/mullvadguids.h b/windows/winfw/src/winfw/mullvadguids.h
index d4fb470d90..3c3ca9702b 100644
--- a/windows/winfw/src/winfw/mullvadguids.h
+++ b/windows/winfw/src/winfw/mullvadguids.h
@@ -67,4 +67,7 @@ public:
static const GUID &FilterPermitNdp_Outbound_Router_Solicitation();
static const GUID &FilterPermitNdp_Inbound_Router_Advertisement();
static const GUID &FilterPermitNdp_Inbound_Redirect();
+
+ static const GUID &FilterPermitPing_Outbound_Icmpv4();
+ static const GUID &FilterPermitPing_Outbound_Icmpv6();
};
diff --git a/windows/winfw/src/winfw/rules/permitping.cpp b/windows/winfw/src/winfw/rules/permitping.cpp
new file mode 100644
index 0000000000..f6aed36bf2
--- /dev/null
+++ b/windows/winfw/src/winfw/rules/permitping.cpp
@@ -0,0 +1,98 @@
+#include "stdafx.h"
+#include "permitping.h"
+#include "winfw/mullvadguids.h"
+#include "libwfp/filterbuilder.h"
+#include "libwfp/conditionbuilder.h"
+#include "libwfp/conditions/conditionip.h"
+#include "libwfp/conditions/conditioninterface.h"
+#include "libwfp/conditions/conditionprotocol.h"
+
+
+using namespace wfp::conditions;
+
+namespace rules
+{
+
+PermitPing::PermitPing
+(
+ const std::optional<std::wstring> &interfaceAlias,
+ const wfp::IpAddress &host
+)
+ : m_interfaceAlias(interfaceAlias)
+ , m_host(host)
+{
+}
+
+bool PermitPing::apply(IObjectInstaller &objectInstaller)
+{
+ if (wfp::IpAddress::Type::Ipv4 == m_host.type())
+ {
+ return applyIcmpv4(objectInstaller);
+ }
+
+ return applyIcmpv6(objectInstaller);
+}
+
+bool PermitPing::applyIcmpv4(IObjectInstaller &objectInstaller) const
+{
+ wfp::FilterBuilder filterBuilder;
+
+ //
+ // #1 Permit outbound ICMPv4 to %host% on %interface%
+ //
+
+ filterBuilder
+ .key(MullvadGuids::FilterPermitPing_Outbound_Icmpv4())
+ .name(L"Permit outbound ICMP to specific host (ICMPv4)")
+ .description(L"This filter is part of a rule that permits ping")
+ .provider(MullvadGuids::Provider())
+ .layer(FWPM_LAYER_ALE_AUTH_CONNECT_V4)
+ .sublayer(MullvadGuids::SublayerWhitelist())
+ .weight(wfp::FilterBuilder::WeightClass::Max)
+ .permit();
+
+ wfp::ConditionBuilder conditionBuilder(FWPM_LAYER_ALE_AUTH_CONNECT_V4);
+
+ conditionBuilder.add_condition(ConditionIp::Remote(m_host));
+ conditionBuilder.add_condition(ConditionProtocol::Icmp());
+
+ if (m_interfaceAlias.has_value())
+ {
+ conditionBuilder.add_condition(ConditionInterface::Alias(m_interfaceAlias.value()));
+ }
+
+ return objectInstaller.addFilter(filterBuilder, conditionBuilder);
+}
+
+bool PermitPing::applyIcmpv6(IObjectInstaller &objectInstaller) const
+{
+ wfp::FilterBuilder filterBuilder;
+
+ //
+ // #1 Permit outbound ICMPv6 to %host% on %interface%
+ //
+
+ filterBuilder
+ .key(MullvadGuids::FilterPermitPing_Outbound_Icmpv6())
+ .name(L"Permit outbound ICMP to specific host (ICMPv6)")
+ .description(L"This filter is part of a rule that permits ping")
+ .provider(MullvadGuids::Provider())
+ .layer(FWPM_LAYER_ALE_AUTH_CONNECT_V6)
+ .sublayer(MullvadGuids::SublayerWhitelist())
+ .weight(wfp::FilterBuilder::WeightClass::Max)
+ .permit();
+
+ wfp::ConditionBuilder conditionBuilder(FWPM_LAYER_ALE_AUTH_CONNECT_V6);
+
+ conditionBuilder.add_condition(ConditionIp::Remote(m_host));
+ conditionBuilder.add_condition(ConditionProtocol::IcmpV6());
+
+ if (m_interfaceAlias.has_value())
+ {
+ conditionBuilder.add_condition(ConditionInterface::Alias(m_interfaceAlias.value()));
+ }
+
+ return objectInstaller.addFilter(filterBuilder, conditionBuilder);
+}
+
+}
diff --git a/windows/winfw/src/winfw/rules/permitping.h b/windows/winfw/src/winfw/rules/permitping.h
new file mode 100644
index 0000000000..c8238ceaa8
--- /dev/null
+++ b/windows/winfw/src/winfw/rules/permitping.h
@@ -0,0 +1,28 @@
+#pragma once
+
+#include "ifirewallrule.h"
+#include <libwfp/ipaddress.h>
+#include <string>
+#include <optional>
+
+namespace rules
+{
+
+class PermitPing : public IFirewallRule
+{
+public:
+
+ PermitPing(const std::optional<std::wstring> &interfaceAlias, const wfp::IpAddress &host);
+
+ bool apply(IObjectInstaller &objectInstaller) override;
+
+private:
+
+ const std::optional<std::wstring> m_interfaceAlias;
+ const wfp::IpAddress m_host;
+
+ bool applyIcmpv4(IObjectInstaller &objectInstaller) const;
+ bool applyIcmpv6(IObjectInstaller &objectInstaller) const;
+};
+
+}
diff --git a/windows/winfw/src/winfw/winfw.cpp b/windows/winfw/src/winfw/winfw.cpp
index 7b9ea2dc6b..3065408f3d 100644
--- a/windows/winfw/src/winfw/winfw.cpp
+++ b/windows/winfw/src/winfw/winfw.cpp
@@ -4,6 +4,7 @@
#include "objectpurger.h"
#include <windows.h>
#include <stdexcept>
+#include <optional>
namespace
{
@@ -15,6 +16,34 @@ void * g_errorContext = nullptr;
FwContext *g_fwContext = nullptr;
+std::optional<FwContext::PingableHosts> ConvertPingableHosts(const PingableHosts *pingableHosts)
+{
+ if (nullptr == pingableHosts)
+ {
+ return {};
+ }
+
+ if (nullptr == pingableHosts->hosts
+ || 0 == pingableHosts->numHosts)
+ {
+ throw std::runtime_error("Invalid PingableHosts structure");
+ }
+
+ FwContext::PingableHosts converted;
+
+ if (nullptr != pingableHosts->tunnelInterfaceAlias)
+ {
+ converted.tunnelInterfaceAlias = pingableHosts->tunnelInterfaceAlias;
+ }
+
+ for (size_t i = 0; i < pingableHosts->numHosts; ++i)
+ {
+ converted.hosts.emplace_back(wfp::IpAddress(pingableHosts->hosts[i]));
+ }
+
+ return converted;
+}
+
} // anonymous namespace
WINFW_LINKAGE
@@ -130,7 +159,8 @@ bool
WINFW_API
WinFw_ApplyPolicyConnecting(
const WinFwSettings &settings,
- const WinFwRelay &relay
+ const WinFwRelay &relay,
+ const PingableHosts *pingableHosts
)
{
if (nullptr == g_fwContext)
@@ -140,7 +170,7 @@ WinFw_ApplyPolicyConnecting(
try
{
- return g_fwContext->applyPolicyConnecting(settings, relay);
+ return g_fwContext->applyPolicyConnecting(settings, relay, ConvertPingableHosts(pingableHosts));
}
catch (std::exception &err)
{
diff --git a/windows/winfw/src/winfw/winfw.h b/windows/winfw/src/winfw/winfw.h
index 95e66a608f..6d43b0db4c 100644
--- a/windows/winfw/src/winfw/winfw.h
+++ b/windows/winfw/src/winfw/winfw.h
@@ -105,11 +105,29 @@ WINFW_API
WinFw_Deinitialize();
//
+// PingableHosts:
+//
+// Specifies a set of IP addresses that should be reachable by ICMP when the connecting
+// policy is effective.
+//
+// The interface alias is optional and can be used to restrict the traffic such
+// that it is only allowed on that specific interface.
+//
+typedef struct tag_PingableHosts
+{
+ const wchar_t *tunnelInterfaceAlias;
+ const wchar_t **hosts;
+ size_t numHosts;
+}
+PingableHosts;
+
+//
// ApplyPolicyConnecting:
//
// Apply restrictions in the firewall that block all traffic, except:
// - What is specified by settings
// - Communication with the relay server
+// - ICMP (for ping) to/from tunnel gateway
//
extern "C"
WINFW_LINKAGE
@@ -117,7 +135,8 @@ bool
WINFW_API
WinFw_ApplyPolicyConnecting(
const WinFwSettings &settings,
- const WinFwRelay &relay
+ const WinFwRelay &relay,
+ const PingableHosts *pingableHosts
);
//
diff --git a/windows/winfw/src/winfw/winfw.vcxproj b/windows/winfw/src/winfw/winfw.vcxproj
index 4777503f72..cbabe2f4f7 100644
--- a/windows/winfw/src/winfw/winfw.vcxproj
+++ b/windows/winfw/src/winfw/winfw.vcxproj
@@ -30,6 +30,7 @@
<ClCompile Include="rules\permitlanservice.cpp" />
<ClCompile Include="rules\permitloopback.cpp" />
<ClCompile Include="rules\permitndp.cpp" />
+ <ClCompile Include="rules\permitping.cpp" />
<ClCompile Include="rules\permitvpntunnelservice.cpp" />
<ClCompile Include="rules\permitvpnrelay.cpp" />
<ClCompile Include="rules\permitvpntunnel.cpp" />
@@ -53,6 +54,7 @@
<ClInclude Include="objectpurger.h" />
<ClInclude Include="rules\permitdhcpserver.h" />
<ClInclude Include="rules\permitndp.h" />
+ <ClInclude Include="rules\permitping.h" />
<ClInclude Include="wfpobjecttype.h" />
<ClInclude Include="rules\blockall.h" />
<ClInclude Include="rules\ifirewallrule.h" />
diff --git a/windows/winfw/src/winfw/winfw.vcxproj.filters b/windows/winfw/src/winfw/winfw.vcxproj.filters
index 0319b0214a..a758a1c9ec 100644
--- a/windows/winfw/src/winfw/winfw.vcxproj.filters
+++ b/windows/winfw/src/winfw/winfw.vcxproj.filters
@@ -43,6 +43,9 @@
<ClCompile Include="rules\permitndp.cpp">
<Filter>rules</Filter>
</ClCompile>
+ <ClCompile Include="rules\permitping.cpp">
+ <Filter>rules</Filter>
+ </ClCompile>
</ItemGroup>
<ItemGroup>
<ClInclude Include="stdafx.h" />
@@ -93,6 +96,9 @@
<ClInclude Include="rules\permitndp.h">
<Filter>rules</Filter>
</ClInclude>
+ <ClInclude Include="rules\permitping.h">
+ <Filter>rules</Filter>
+ </ClInclude>
</ItemGroup>
<ItemGroup>
<Filter Include="rules">
diff --git a/windows/winnet/src/extras/loader/loader.vcxproj.filters b/windows/winnet/src/extras/loader/loader.vcxproj.filters
index cd0f4643c7..408a9591b1 100644
--- a/windows/winnet/src/extras/loader/loader.vcxproj.filters
+++ b/windows/winnet/src/extras/loader/loader.vcxproj.filters
@@ -3,9 +3,13 @@
<ItemGroup>
<ClCompile Include="loader.cpp" />
<ClCompile Include="stdafx.cpp" />
+ <ClCompile Include="..\..\winnet\routemanager.cpp" />
+ <ClCompile Include="..\..\winnet\adapters.cpp" />
</ItemGroup>
<ItemGroup>
<ClInclude Include="stdafx.h" />
<ClInclude Include="targetver.h" />
+ <ClInclude Include="..\..\winnet\routemanager.h" />
+ <ClInclude Include="..\..\winnet\adapters.h" />
</ItemGroup>
</Project> \ No newline at end of file
diff --git a/windows/winnet/src/winnet/interfaceutils.cpp b/windows/winnet/src/winnet/interfaceutils.cpp
index babe03eba6..202d9d0724 100644
--- a/windows/winnet/src/winnet/interfaceutils.cpp
+++ b/windows/winnet/src/winnet/interfaceutils.cpp
@@ -2,13 +2,8 @@
#include "interfaceutils.h"
#include "libcommon/error.h"
#include "libcommon/string.h"
-#include <vector>
#include <cstdint>
#include <algorithm>
-#include <winsock2.h>
-#include <iphlpapi.h>
-#include <windows.h>
-
//static
std::set<InterfaceUtils::NetworkAdapter> InterfaceUtils::GetAllAdapters()
@@ -112,3 +107,18 @@ std::wstring InterfaceUtils::GetTapInterfaceAlias()
throw std::runtime_error("Unable to find TAP adapter");
}
+
+//static
+void InterfaceUtils::AddDeviceIpAddresses(NET_LUID device, const std::vector<SOCKADDR_INET> &addresses)
+{
+ for (const auto &address : addresses)
+ {
+ MIB_UNICASTIPADDRESS_ROW row;
+ InitializeUnicastIpAddressEntry(&row);
+
+ row.InterfaceLuid = device;
+ row.Address = address;
+
+ THROW_UNLESS(NO_ERROR, CreateUnicastIpAddressEntry(&row), "Assign IP address on network interface");
+ }
+}
diff --git a/windows/winnet/src/winnet/interfaceutils.h b/windows/winnet/src/winnet/interfaceutils.h
index f5c31963c2..8ab1249a50 100644
--- a/windows/winnet/src/winnet/interfaceutils.h
+++ b/windows/winnet/src/winnet/interfaceutils.h
@@ -2,6 +2,17 @@
#include <string>
#include <set>
+#include <vector>
+
+// Secret include order to get most common networking structs/apis
+// And avoiding compilation errors
+#include <winsock2.h>
+#include <windows.h>
+#include <ws2def.h>
+#include <ws2ipdef.h>
+#include <iphlpapi.h>
+#include <netioapi.h>
+// end
class InterfaceUtils
{
@@ -35,4 +46,6 @@ public:
// Determines alias of primary TAP adapter.
//
static std::wstring GetTapInterfaceAlias();
+
+ static void AddDeviceIpAddresses(NET_LUID device, const std::vector<SOCKADDR_INET> &addresses);
};
diff --git a/windows/winnet/src/winnet/routing/defaultroutemonitor.cpp b/windows/winnet/src/winnet/routing/defaultroutemonitor.cpp
new file mode 100644
index 0000000000..55d7560904
--- /dev/null
+++ b/windows/winnet/src/winnet/routing/defaultroutemonitor.cpp
@@ -0,0 +1,177 @@
+#include "stdafx.h"
+#include <libcommon/error.h>
+#include "defaultroutemonitor.h"
+#include "helpers.h"
+
+namespace winnet::routing
+{
+
+namespace
+{
+
+const uint32_t POINT_TWO_SECOND_BURST = 200;
+const uint32_t TWO_SECOND_INTERFERENCE = 2000;
+
+} // anonymous namespace
+
+DefaultRouteMonitor::DefaultRouteMonitor
+(
+ ADDRESS_FAMILY family,
+ Callback callback,
+ std::shared_ptr<common::logging::ILogSink> logSink
+)
+ : m_family(family)
+ , m_callback(callback)
+ , m_logSink(logSink)
+ , m_evaluateRoutesGuard(std::make_unique<common::BurstGuard>(
+ std::bind(&DefaultRouteMonitor::evaluateRoutes, this),
+ POINT_TWO_SECOND_BURST,
+ TWO_SECOND_INTERFERENCE
+ ))
+{
+ try
+ {
+ m_bestRoute = GetBestDefaultRoute(m_family);
+ }
+ catch (...)
+ {
+ }
+
+ const auto status = NotifyRouteChange2(AF_UNSPEC, RouteChangeCallback, this, FALSE, &m_routeNotificationHandle);
+
+ THROW_UNLESS(NO_ERROR, status, "Register for route table change notifications");
+
+ try
+ {
+ const auto s2 = NotifyIpInterfaceChange(AF_UNSPEC, InterfaceChangeCallback, this,
+ FALSE, &m_interfaceNotificationHandle);
+
+ THROW_UNLESS(NO_ERROR, status, "Register for network interface change notifications");
+ }
+ catch (...)
+ {
+ CancelMibChangeNotify2(m_routeNotificationHandle);
+ throw;
+ }
+}
+
+DefaultRouteMonitor::~DefaultRouteMonitor()
+{
+ //
+ // Cancel notifications to stop triggering the BurstGuard.
+ //
+
+ CancelMibChangeNotify2(m_interfaceNotificationHandle);
+ CancelMibChangeNotify2(m_routeNotificationHandle);
+
+ //
+ // Controlled destruction of BurstGuard to prevent it from calling here
+ // after other member variables have been destructed.
+ //
+
+ m_evaluateRoutesGuard.reset();
+}
+
+//static
+void NETIOAPI_API_ DefaultRouteMonitor::RouteChangeCallback
+(
+ void *context,
+ MIB_IPFORWARD_ROW2 *row,
+ MIB_NOTIFICATION_TYPE
+)
+{
+ //
+ // We're only interested in changes that add/remove/update a default route.
+ //
+
+ if (0 != row->DestinationPrefix.PrefixLength
+ || false == RouteHasGateway(*row))
+ {
+ return;
+ }
+
+ reinterpret_cast<DefaultRouteMonitor *>(context)->m_evaluateRoutesGuard->trigger();
+}
+
+//static
+void NETIOAPI_API_ DefaultRouteMonitor::InterfaceChangeCallback
+(
+ void *context,
+ MIB_IPINTERFACE_ROW *,
+ MIB_NOTIFICATION_TYPE
+)
+{
+ reinterpret_cast<DefaultRouteMonitor *>(context)->m_evaluateRoutesGuard->trigger();
+}
+
+void DefaultRouteMonitor::evaluateRoutes()
+{
+ std::scoped_lock<std::mutex> lock(m_evaluationLock);
+
+ try
+ {
+ evaluateRoutesInner();
+ }
+ catch (const std::exception &ex)
+ {
+ const auto msg = std::string("Failure while evaluating route table: ").append(ex.what());
+ m_logSink->error(msg.c_str());
+ }
+ catch (...)
+ {
+ m_logSink->error("Unspecified failure while evaluating route table");
+ }
+}
+
+void DefaultRouteMonitor::evaluateRoutesInner()
+{
+ std::optional<InterfaceAndGateway> currentBestRoute;
+
+ try
+ {
+ currentBestRoute = GetBestDefaultRoute(m_family);
+ }
+ catch (...)
+ {
+ }
+
+ //
+ // If there was no default route previously.
+ //
+
+ if (false == m_bestRoute.has_value())
+ {
+ if (currentBestRoute.has_value())
+ {
+ m_bestRoute = currentBestRoute;
+ m_callback(EventType::Updated, m_bestRoute);
+ }
+
+ return;
+ }
+
+ //
+ // There used to be a default route.
+ // If there is not currently a default route.
+ //
+
+ if (false == currentBestRoute.has_value())
+ {
+ m_bestRoute.reset();
+ m_callback(EventType::Removed, std::nullopt);
+
+ return;
+ }
+
+ //
+ // The current best route may have changed.
+ //
+
+ if (m_bestRoute.value() != currentBestRoute.value())
+ {
+ m_bestRoute = currentBestRoute;
+ m_callback(EventType::Updated, m_bestRoute);
+ }
+}
+
+}
diff --git a/windows/winnet/src/winnet/routing/defaultroutemonitor.h b/windows/winnet/src/winnet/routing/defaultroutemonitor.h
new file mode 100644
index 0000000000..5575685a82
--- /dev/null
+++ b/windows/winnet/src/winnet/routing/defaultroutemonitor.h
@@ -0,0 +1,69 @@
+#pragma once
+
+#include <ifdef.h>
+#include <ws2def.h>
+#include <functional>
+#include <optional>
+#include <memory>
+#include <mutex>
+#include <libcommon/logging/ilogsink.h>
+#include <libcommon/burstguard.h>
+#include "types.h"
+
+namespace winnet::routing
+{
+
+class DefaultRouteMonitor
+{
+public:
+
+ enum class EventType
+ {
+ // The best default route changed.
+ Updated,
+
+ // No default routes exist.
+ Removed,
+ };
+
+ using Callback = std::function<void
+ (
+ EventType eventType,
+
+ // For update events, data associated with the new best default route.
+ const std::optional<InterfaceAndGateway> &route
+ )>;
+
+ DefaultRouteMonitor(ADDRESS_FAMILY family, Callback callback, std::shared_ptr<common::logging::ILogSink> logSink);
+ ~DefaultRouteMonitor();
+
+ DefaultRouteMonitor(const DefaultRouteMonitor &) = delete;
+ DefaultRouteMonitor(DefaultRouteMonitor &&) = delete;
+ DefaultRouteMonitor &operator=(const DefaultRouteMonitor &) = delete;
+ DefaultRouteMonitor &operator=(DefaultRouteMonitor &&) = delete;
+
+private:
+
+ ADDRESS_FAMILY m_family;
+ Callback m_callback;
+ std::shared_ptr<common::logging::ILogSink> m_logSink;
+
+ // This can't be a plain member variable.
+ // We need to be able to delete it explicitly in order to have a controlled tear down.
+ std::unique_ptr<common::BurstGuard> m_evaluateRoutesGuard;
+
+ std::optional<InterfaceAndGateway> m_bestRoute;
+
+ HANDLE m_routeNotificationHandle;
+ HANDLE m_interfaceNotificationHandle;
+
+ std::mutex m_evaluationLock;
+
+ static void NETIOAPI_API_ RouteChangeCallback(void *context, MIB_IPFORWARD_ROW2 *row, MIB_NOTIFICATION_TYPE notificationType);
+ static void NETIOAPI_API_ InterfaceChangeCallback(void *context, MIB_IPINTERFACE_ROW *row, MIB_NOTIFICATION_TYPE notificationType);
+
+ void evaluateRoutes();
+ void evaluateRoutesInner();
+};
+
+}
diff --git a/windows/winnet/src/winnet/routing/helpers.cpp b/windows/winnet/src/winnet/routing/helpers.cpp
new file mode 100644
index 0000000000..cabf19bce6
--- /dev/null
+++ b/windows/winnet/src/winnet/routing/helpers.cpp
@@ -0,0 +1,275 @@
+#include "stdafx.h"
+#include "helpers.h"
+#include <stdexcept>
+#include <ws2def.h>
+#include <in6addr.h>
+#include <numeric>
+//#include <netioapi.h>
+#include <libcommon/error.h>
+#include <libcommon/memory.h>
+
+namespace winnet::routing
+{
+
+bool EqualAddress(const Network &lhs, const Network &rhs)
+{
+ if (lhs.PrefixLength != rhs.PrefixLength)
+ {
+ return false;
+ }
+
+ return EqualAddress(lhs.Prefix, rhs.Prefix);
+}
+
+bool EqualAddress(const NodeAddress &lhs, const NodeAddress &rhs)
+{
+ if (lhs.si_family != rhs.si_family)
+ {
+ return false;
+ }
+
+ switch (lhs.si_family)
+ {
+ case AF_INET:
+ {
+ return lhs.Ipv4.sin_addr.s_addr == rhs.Ipv4.sin_addr.s_addr;
+ }
+ case AF_INET6:
+ {
+ return 0 == memcmp(&lhs.Ipv6.sin6_addr, &rhs.Ipv6.sin6_addr, sizeof(IN6_ADDR));
+ }
+ default:
+ {
+ throw std::runtime_error("Invalid address family for network address");
+ }
+ }
+}
+
+bool EqualAddress(const SOCKADDR_INET *lhs, const SOCKET_ADDRESS *rhs)
+{
+ if (lhs->si_family != rhs->lpSockaddr->sa_family)
+ {
+ return false;
+ }
+
+ switch (lhs->si_family)
+ {
+ case AF_INET:
+ {
+ auto typedRhs = reinterpret_cast<const SOCKADDR_IN *>(rhs->lpSockaddr);
+ return lhs->Ipv4.sin_addr.s_addr == typedRhs->sin_addr.s_addr;
+ }
+ case AF_INET6:
+ {
+ auto typedRhs = reinterpret_cast<const SOCKADDR_IN6 *>(rhs->lpSockaddr);
+ return 0 == memcmp(lhs->Ipv6.sin6_addr.u.Byte, typedRhs->sin6_addr.u.Byte, 16);
+ }
+ default:
+ {
+ throw std::runtime_error("Missing case handler in switch clause");
+ }
+ }
+}
+
+bool GetAdapterInterface(NET_LUID adapter, ADDRESS_FAMILY addressFamily, MIB_IPINTERFACE_ROW *iface)
+{
+ memset(iface, 0, sizeof(MIB_IPINTERFACE_ROW));
+
+ iface->Family = addressFamily;
+ iface->InterfaceLuid = adapter;
+
+ return NO_ERROR == GetIpInterfaceEntry(iface);
+}
+
+std::vector<AnnotatedRoute> AnnotateRoutes(const std::vector<const MIB_IPFORWARD_ROW2 *> &routes)
+{
+ std::vector<AnnotatedRoute> annotated;
+ annotated.reserve(routes.size());
+
+ for (auto route : routes)
+ {
+ MIB_IPINTERFACE_ROW iface;
+
+ if (false == GetAdapterInterface(route->InterfaceLuid, route->DestinationPrefix.Prefix.si_family, &iface))
+ {
+ continue;
+ }
+
+ annotated.emplace_back
+ (
+ AnnotatedRoute{ route, bool_cast(iface.Connected), route->Metric + iface.Metric }
+ );
+ }
+
+ return annotated;
+}
+
+bool RouteHasGateway(const MIB_IPFORWARD_ROW2 &route)
+{
+ switch (route.NextHop.si_family)
+ {
+ case AF_INET:
+ {
+ return 0 != route.NextHop.Ipv4.sin_addr.s_addr;
+ }
+ case AF_INET6:
+ {
+ const uint8_t *begin = &route.NextHop.Ipv6.sin6_addr.u.Byte[0];
+ const uint8_t *end = begin + 16;
+
+ return 0 != std::accumulate(begin, end, 0);
+ }
+ default:
+ {
+ return false;
+ }
+ };
+}
+
+InterfaceAndGateway GetBestDefaultRoute(ADDRESS_FAMILY family)
+{
+ PMIB_IPFORWARD_TABLE2 table;
+
+ auto status = GetIpForwardTable2(family, &table);
+
+ THROW_UNLESS(NO_ERROR, status, "Acquire route table");
+
+ common::memory::ScopeDestructor sd;
+
+ sd += [table]
+ {
+ FreeMibTable(table);
+ };
+
+ std::vector<const MIB_IPFORWARD_ROW2 *> candidates;
+ candidates.reserve(table->NumEntries);
+
+ //
+ // Enumerate routes looking for: route 0/0 && gateway specified.
+ //
+
+ for (ULONG i = 0; i < table->NumEntries; ++i)
+ {
+ const MIB_IPFORWARD_ROW2 &candidate = table->Table[i];
+
+ if (0 == candidate.DestinationPrefix.PrefixLength
+ && RouteHasGateway(candidate))
+ {
+ candidates.emplace_back(&candidate);
+ }
+ }
+
+ auto annotated = AnnotateRoutes(candidates);
+
+ if (annotated.empty())
+ {
+ throw std::runtime_error("Unable to determine details of default route");
+ }
+
+ //
+ // Sort on (active, effectiveMetric) ascending by metric.
+ //
+
+ std::sort(annotated.begin(), annotated.end(), [](const AnnotatedRoute &lhs, const AnnotatedRoute &rhs)
+ {
+ if (lhs.active == rhs.active)
+ {
+ return lhs.effectiveMetric < rhs.effectiveMetric;
+ }
+
+ return lhs.active && false == rhs.active;
+ });
+
+ //
+ // Ensure the top rated route is active.
+ //
+
+ if (false == annotated[0].active)
+ {
+ throw std::runtime_error("Unable to identify active default route");
+ }
+
+ return InterfaceAndGateway { annotated[0].route->InterfaceLuid, annotated[0].route->NextHop };
+}
+
+bool AdapterInterfaceEnabled(const IP_ADAPTER_ADDRESSES *adapter, ADDRESS_FAMILY family)
+{
+ switch (family)
+ {
+ case AF_INET:
+ {
+ return 0 != adapter->Ipv4Enabled;
+ }
+ case AF_INET6:
+ {
+ return 0 != adapter->Ipv6Enabled;
+ }
+ default:
+ {
+ throw std::runtime_error("Missing case handler in switch clause");
+ }
+ }
+}
+
+std::vector<const SOCKET_ADDRESS *> IsolateGatewayAddresses
+(
+ PIP_ADAPTER_GATEWAY_ADDRESS_LH head,
+ ADDRESS_FAMILY family
+)
+{
+ std::vector<const SOCKET_ADDRESS *> matches;
+
+ for (auto gateway = head; nullptr != gateway; gateway = gateway->Next)
+ {
+ if (family == gateway->Address.lpSockaddr->sa_family)
+ {
+ matches.emplace_back(&gateway->Address);
+ }
+ }
+
+ return matches;
+}
+
+bool AddressPresent(const std::vector<const SOCKET_ADDRESS *> &hay, const SOCKADDR_INET *needle)
+{
+ for (const auto candidate : hay)
+ {
+ if (EqualAddress(needle, candidate))
+ {
+ return true;
+ }
+ }
+
+ return false;
+}
+
+//NodeAddress ConvertSocketAddress(const SOCKET_ADDRESS *sa)
+//{
+// NodeAddress out = { 0 };
+//
+// switch (sa->lpSockaddr->sa_family)
+// {
+// case AF_INET:
+// {
+// out.si_family = AF_INET;
+// out.Ipv4 = *reinterpret_cast<SOCKADDR_IN *>(sa->lpSockaddr);
+//
+// break;
+// }
+// case AF_INET6:
+// {
+// out.si_family = AF_INET6;
+// out.Ipv6 = *reinterpret_cast<SOCKADDR_IN6 *>(sa->lpSockaddr);
+//
+// break;
+// }
+// default:
+// {
+// throw std::runtime_error("Missing case handler in switch clause");
+// }
+// };
+//
+// return out;
+//}
+
+}
diff --git a/windows/winnet/src/winnet/routing/helpers.h b/windows/winnet/src/winnet/routing/helpers.h
new file mode 100644
index 0000000000..3ef5e85b75
--- /dev/null
+++ b/windows/winnet/src/winnet/routing/helpers.h
@@ -0,0 +1,46 @@
+#pragma once
+
+#include "types.h"
+#include <vector>
+
+namespace winnet::routing
+{
+
+bool EqualAddress(const Network &lhs, const Network &rhs);
+bool EqualAddress(const NodeAddress &lhs, const NodeAddress &rhs);
+bool EqualAddress(const SOCKADDR_INET *lhs, const SOCKET_ADDRESS *rhs);
+
+bool GetAdapterInterface(NET_LUID adapter, ADDRESS_FAMILY addressFamily, MIB_IPINTERFACE_ROW *iface);
+
+struct AnnotatedRoute
+{
+ const MIB_IPFORWARD_ROW2 *route;
+ bool active;
+ uint32_t effectiveMetric;
+};
+
+template<typename T>
+bool bool_cast(const T &value)
+{
+ return 0 != value;
+}
+
+std::vector<AnnotatedRoute> AnnotateRoutes(const std::vector<const MIB_IPFORWARD_ROW2 *> &routes);
+
+bool RouteHasGateway(const MIB_IPFORWARD_ROW2 &route);
+
+InterfaceAndGateway GetBestDefaultRoute(ADDRESS_FAMILY family);
+
+bool AdapterInterfaceEnabled(const IP_ADAPTER_ADDRESSES *adapter, ADDRESS_FAMILY family);
+
+std::vector<const SOCKET_ADDRESS *> IsolateGatewayAddresses
+(
+ PIP_ADAPTER_GATEWAY_ADDRESS_LH head,
+ ADDRESS_FAMILY family
+);
+
+bool AddressPresent(const std::vector<const SOCKET_ADDRESS *> &hay, const SOCKADDR_INET *needle);
+
+//NodeAddress ConvertSocketAddress(const SOCKET_ADDRESS *sa);
+
+}
diff --git a/windows/winnet/src/winnet/routing/routemanager.cpp b/windows/winnet/src/winnet/routing/routemanager.cpp
new file mode 100644
index 0000000000..668e64bb68
--- /dev/null
+++ b/windows/winnet/src/winnet/routing/routemanager.cpp
@@ -0,0 +1,692 @@
+#include "stdafx.h"
+#include "routemanager.h"
+#include "helpers.h"
+#include <libcommon/error.h>
+#include <libcommon/memory.h>
+#include <libcommon/string.h>
+#include <libcommon/network/adapters.h>
+#include <vector>
+#include <algorithm>
+#include <numeric>
+#include <sstream>
+#include <stdexcept>
+
+using AutoLockType = std::scoped_lock<std::mutex>;
+using AutoRecursiveLockType = std::scoped_lock<std::recursive_mutex>;
+using namespace std::placeholders;
+
+namespace winnet::routing
+{
+
+namespace
+{
+
+using Adapters = common::network::Adapters;
+
+NET_LUID InterfaceLuidFromGateway(const NodeAddress &gateway)
+{
+ const DWORD adapterFlags = GAA_FLAG_SKIP_ANYCAST | GAA_FLAG_SKIP_MULTICAST | GAA_FLAG_SKIP_DNS_SERVER
+ | GAA_FLAG_SKIP_FRIENDLY_NAME | GAA_FLAG_INCLUDE_GATEWAYS;
+
+ Adapters adapters(gateway.si_family, adapterFlags);
+
+ //
+ // Process adapters to find matching ones.
+ //
+
+ std::vector<const IP_ADAPTER_ADDRESSES *> matches;
+
+ for (auto adapter = adapters.next(); nullptr != adapter; adapter = adapters.next())
+ {
+ if (false == AdapterInterfaceEnabled(adapter, gateway.si_family))
+ {
+ continue;
+ }
+
+ auto gateways = IsolateGatewayAddresses(adapter->FirstGatewayAddress, gateway.si_family);
+
+ if (AddressPresent(gateways, &gateway))
+ {
+ matches.emplace_back(adapter);
+ }
+ }
+
+ if (matches.empty())
+ {
+ throw std::runtime_error("Unable to find network adapter with specified gateway");
+ }
+
+ //
+ // Sort matching interfaces ascending by metric.
+ //
+
+ const bool targetV4 = (AF_INET == gateway.si_family);
+
+ std::sort(matches.begin(), matches.end(), [&targetV4](const IP_ADAPTER_ADDRESSES *lhs, const IP_ADAPTER_ADDRESSES *rhs)
+ {
+ if (targetV4)
+ {
+ return lhs->Ipv4Metric < rhs->Ipv4Metric;
+ }
+
+ return lhs->Ipv6Metric < rhs->Ipv6Metric;
+ });
+
+ //
+ // Select the interface with the best (lowest) metric.
+ //
+
+ return matches[0]->Luid;
+}
+
+bool ParseStringEncodedLuid(const std::wstring &encodedLuid, NET_LUID &luid)
+{
+ //
+ // The `#` is a valid character in adapter names so we use `?` instead.
+ // The LUID is thus prefixed with `?` and hex encoded and left-padded with zeroes.
+ // E.g. `?deadbeefcafebabe` or `?000dbeefcafebabe`.
+ //
+
+ static const size_t StringEncodedLuidLength = 17;
+
+ if (encodedLuid.size() != StringEncodedLuidLength
+ || L'?' != encodedLuid[0])
+ {
+ return false;
+ }
+
+ try
+ {
+ std::wstringstream ss;
+
+ ss << std::hex << &encodedLuid[1];
+ ss >> luid.Value;
+ }
+ catch (...)
+ {
+ const auto ansi = common::string::ToAnsi(encodedLuid);
+ const auto err = std::string("Failed to parse string encoded LUID: ").append(ansi);
+
+ std::throw_with_nested(std::runtime_error(err));
+ }
+
+ return true;
+}
+
+InterfaceAndGateway ResolveNode(ADDRESS_FAMILY family, const std::optional<Node> &optionalNode)
+{
+ //
+ // There are four cases:
+ //
+ // Unspecified node (use interface and gateway of default route).
+ // Node is specified by name.
+ // Node is specified by name and gateway.
+ // Node is specified by gateway.
+ //
+
+ if (false == optionalNode.has_value())
+ {
+ return GetBestDefaultRoute(family);
+ }
+
+ const auto &node = optionalNode.value();
+
+ if (node.deviceName().has_value())
+ {
+ const auto &deviceName = node.deviceName().value();
+ NET_LUID luid;
+
+ if (false == ParseStringEncodedLuid(deviceName, luid)
+ && 0 != ConvertInterfaceAliasToLuid(deviceName.c_str(), &luid))
+ {
+ const auto ansiName = common::string::ToAnsi(deviceName);
+ const auto err = std::string("Unable to derive interface LUID from interface alias: ").append(ansiName);
+
+ throw std::runtime_error(err);
+ }
+
+ auto onLinkProvider = [&family]()
+ {
+ NodeAddress onLink = { 0 };
+ onLink.si_family = family;
+
+ return onLink;
+ };
+
+ return InterfaceAndGateway{ luid, node.gateway().value_or(onLinkProvider()) };
+ }
+
+ //
+ // The node is specified only by gateway.
+ //
+
+ return InterfaceAndGateway{ InterfaceLuidFromGateway(node.gateway().value()), node.gateway().value() };
+}
+
+// TODO: Move to libcommon
+uint32_t ByteSwap(uint32_t val)
+{
+ return
+ (
+ ((val & 0xFF) << 24) |
+ ((val & 0xFF00) << 8) |
+ ((val & 0xFF0000) >> 8) |
+ ((val & 0xFF000000) >> 24)
+ );
+}
+
+std::wstring FormatNetwork(const Network &network)
+{
+ switch (network.Prefix.si_family)
+ {
+ case AF_INET:
+ {
+ return common::string::FormatIpv4(ByteSwap(network.Prefix.Ipv4.sin_addr.s_addr), network.PrefixLength);
+ }
+ case AF_INET6:
+ {
+ return common::string::FormatIpv6(network.Prefix.Ipv6.sin6_addr.u.Byte, network.PrefixLength);
+ }
+ default:
+ {
+ return L"Failed to format network details";
+ }
+ }
+}
+
+} // anonymous namespace
+
+RouteManager::RouteManager(std::shared_ptr<common::logging::ILogSink> logSink)
+ : m_logSink(logSink)
+ , m_routeMonitorV4(std::make_unique<DefaultRouteMonitor>(
+ static_cast<ADDRESS_FAMILY>(AF_INET),
+ std::bind(&RouteManager::defaultRouteChanged, this, static_cast<ADDRESS_FAMILY>(AF_INET), _1, _2),
+ logSink
+ ))
+ , m_routeMonitorV6(std::make_unique<DefaultRouteMonitor>(
+ static_cast<ADDRESS_FAMILY>(AF_INET6),
+ std::bind(&RouteManager::defaultRouteChanged, this, static_cast<ADDRESS_FAMILY>(AF_INET6), _1, _2),
+ logSink
+ ))
+{
+}
+
+RouteManager::~RouteManager()
+{
+ //
+ // Stop callbacks that are triggered by events in Windows from coming in.
+ //
+
+ m_routeMonitorV4.reset();
+ m_routeMonitorV6.reset();
+
+ //
+ // Delete all routes owned by us.
+ //
+
+ for (const auto &record : m_routes)
+ {
+ try
+ {
+ deleteFromRoutingTable(record.registeredRoute);
+ }
+ catch (const std::exception &ex)
+ {
+ std::wstringstream ss;
+
+ ss << L"Failed to delete route as part of cleaning up, Route: "
+ << FormatRegisteredRoute(record.registeredRoute);
+
+ m_logSink->error(common::string::ToAnsi(ss.str()).c_str());
+ m_logSink->error(ex.what());
+ }
+ }
+}
+
+void RouteManager::addRoutes(const std::vector<Route> &routes)
+{
+ AutoLockType lock(m_routesLock);
+
+ std::vector<EventEntry> eventLog;
+
+ for (const auto &route : routes)
+ {
+ try
+ {
+ auto record = findRouteRecord(route);
+
+ if (record != m_routes.end())
+ {
+ deleteFromRoutingTable(record->registeredRoute);
+ eventLog.emplace_back(EventEntry{ EventType::DELETE_ROUTE, *record });
+ m_routes.erase(record);
+ }
+
+ const RouteRecord newRecord { route, addIntoRoutingTable(route) };
+
+ eventLog.emplace_back(EventEntry{ EventType::ADD_ROUTE, newRecord });
+ m_routes.emplace_back(std::move(newRecord));
+ }
+ catch (...)
+ {
+ undoEvents(eventLog);
+
+ std::throw_with_nested(std::runtime_error("Failed during batch insertion of routes"));
+ }
+ }
+}
+
+void RouteManager::addRoute(const Route &route)
+{
+ AutoLockType lock(m_routesLock);
+
+ std::optional<RouteRecord> deletedRecord;
+
+ auto record = findRouteRecord(route);
+
+ if (record != m_routes.end())
+ {
+ try
+ {
+ deleteFromRoutingTable(record->registeredRoute);
+ }
+ catch (...)
+ {
+ std::throw_with_nested(std::runtime_error("Failed to evict old route when adding new route"));
+ }
+
+ deletedRecord = *record;
+ m_routes.erase(record);
+ }
+
+ try
+ {
+ m_routes.emplace_back
+ (
+ RouteRecord{ route, addIntoRoutingTable(route) }
+ );
+ }
+ catch (...)
+ {
+ //
+ // Restore deleted record.
+ //
+
+ if (deletedRecord.has_value())
+ {
+ auto &r = deletedRecord.value();
+
+ try
+ {
+ restoreIntoRoutingTable(r.registeredRoute);
+ m_routes.emplace_back(r);
+ }
+ catch (const std::exception &ex)
+ {
+ const auto err = std::string("Failed to restore evicted route during rollback: ").append(ex.what());
+ m_logSink->error(err.c_str());
+ }
+ }
+
+ //
+ // Just rethrow because the error is from addIntoRoutingTable().
+ //
+
+ throw;
+ }
+}
+
+void RouteManager::deleteRoutes(const std::vector<Route> &routes)
+{
+ AutoLockType lock(m_routesLock);
+
+ std::vector<EventEntry> eventLog;
+
+ for (const auto &route : routes)
+ {
+ try
+ {
+ auto record = findRouteRecord(route);
+
+ if (m_routes.end() == record)
+ {
+ const auto err = std::wstring(L"Request to delete previously unregistered route: ")
+ .append(FormatNetwork(route.network()));
+
+ m_logSink->warning(common::string::ToAnsi(err).c_str());
+
+ continue;
+ }
+
+ deleteFromRoutingTable(record->registeredRoute);
+ eventLog.emplace_back(EventEntry{ EventType::DELETE_ROUTE, *record });
+ m_routes.erase(record);
+ }
+ catch (...)
+ {
+ undoEvents(eventLog);
+
+ std::throw_with_nested(std::runtime_error("Failed during batch removal of routes"));
+ }
+ }
+}
+
+void RouteManager::deleteRoute(const Route &route)
+{
+ AutoLockType lock(m_routesLock);
+
+ auto record = findRouteRecord(route);
+
+ if (m_routes.end() == record)
+ {
+ const auto err = std::wstring(L"Request to delete previously unregistered route: ")
+ .append(FormatNetwork(route.network()));
+
+ m_logSink->warning(common::string::ToAnsi(err).c_str());
+
+ return;
+ }
+
+ deleteFromRoutingTable(record->registeredRoute);
+ m_routes.erase(record);
+}
+
+RouteManager::CallbackHandle RouteManager::registerDefaultRouteChangedCallback(DefaultRouteChangedCallback callback)
+{
+ AutoRecursiveLockType lock(m_defaultRouteCallbacksLock);
+
+ m_defaultRouteCallbacks.emplace_back(callback);
+
+ // Return raw address of record in list.
+ return &m_defaultRouteCallbacks.back();
+}
+
+void RouteManager::unregisterDefaultRouteChangedCallback(CallbackHandle handle)
+{
+ AutoRecursiveLockType lock(m_defaultRouteCallbacksLock);
+
+ for (auto it = m_defaultRouteCallbacks.begin(); it != m_defaultRouteCallbacks.end(); ++it)
+ {
+ // Match on raw address of record.
+ if (&*it == handle)
+ {
+ m_defaultRouteCallbacks.erase(it);
+ return;
+ }
+ }
+}
+
+std::list<RouteManager::RouteRecord>::iterator RouteManager::findRouteRecord(const Network &network)
+{
+ return std::find_if(m_routes.begin(), m_routes.end(), [&network](const auto &candidate)
+ {
+ return EqualAddress(network, candidate.route.network());
+ });
+}
+
+std::list<RouteManager::RouteRecord>::iterator RouteManager::findRouteRecord(const Route &route)
+{
+ return findRouteRecord(route.network());
+}
+
+RouteManager::RegisteredRoute RouteManager::addIntoRoutingTable(const Route &route)
+{
+ const auto node = ResolveNode(route.network().Prefix.si_family, route.node());
+
+ MIB_IPFORWARD_ROW2 spec;
+
+ InitializeIpForwardEntry(&spec);
+
+ spec.InterfaceLuid = node.iface;
+ spec.DestinationPrefix = route.network();
+ spec.NextHop = node.gateway;
+ spec.Metric = 0;
+ spec.Protocol = MIB_IPPROTO_NETMGMT;
+ spec.Origin = NlroManual;
+
+ //
+ // Do not treat ERROR_OBJECT_ALREADY_EXISTS as being successful.
+ // Because it may not take route metric into consideration.
+ //
+
+ THROW_UNLESS(NO_ERROR, CreateIpForwardEntry2(&spec), "Register route in routing table");
+
+ return RegisteredRoute { route.network(), node.iface, node.gateway };
+}
+
+void RouteManager::restoreIntoRoutingTable(const RegisteredRoute &route)
+{
+ MIB_IPFORWARD_ROW2 spec;
+
+ InitializeIpForwardEntry(&spec);
+
+ spec.InterfaceLuid = route.luid;
+ spec.DestinationPrefix = route.network;
+ spec.NextHop = route.nextHop;
+ spec.Metric = 0;
+ spec.Protocol = MIB_IPPROTO_NETMGMT;
+ spec.Origin = NlroManual;
+
+ THROW_UNLESS(NO_ERROR, CreateIpForwardEntry2(&spec), "Register route in routing table");
+}
+
+void RouteManager::deleteFromRoutingTable(const RegisteredRoute &route)
+{
+ MIB_IPFORWARD_ROW2 r = { 0};
+
+ r.InterfaceLuid = route.luid;
+ r.DestinationPrefix = route.network;
+ r.NextHop = route.nextHop;
+
+ auto status = DeleteIpForwardEntry2(&r);
+
+ if (ERROR_NOT_FOUND == status)
+ {
+ status = NO_ERROR;
+
+ const auto err = std::wstring(L"Attempting to delete route which was not present in routing table, " \
+ "ignoring and proceeding. Route: ").append(FormatRegisteredRoute(route));
+
+ m_logSink->warning(common::string::ToAnsi(err).c_str());
+ }
+
+ THROW_UNLESS(NO_ERROR, status, "Delete route in routing table");
+}
+
+void RouteManager::undoEvents(const std::vector<EventEntry> &eventLog)
+{
+ //
+ // Rewind state by processing events in the reverse order.
+ //
+
+ for (auto it = eventLog.rbegin(); it != eventLog.rend(); ++it)
+ {
+ try
+ {
+ switch (it->type)
+ {
+ case EventType::ADD_ROUTE:
+ {
+ auto record = findRouteRecord(it->record.route);
+
+ if (m_routes.end() == record)
+ {
+ throw std::runtime_error("Internal state inconsistency in route manager");
+ }
+
+ deleteFromRoutingTable(record->registeredRoute);
+ m_routes.erase(record);
+
+ break;
+ }
+ case EventType::DELETE_ROUTE:
+ {
+ restoreIntoRoutingTable(it->record.registeredRoute);
+ m_routes.emplace_back(it->record);
+
+ break;
+ }
+ default:
+ {
+ throw std::logic_error("Missing case handler in switch clause");
+ }
+ }
+ }
+ catch (const std::exception &ex)
+ {
+ const auto err = std::string("Attempting to rollback state: ").append(ex.what());
+ m_logSink->error(err.c_str());
+ }
+ }
+}
+
+// static
+std::wstring RouteManager::FormatRegisteredRoute(const RegisteredRoute &route)
+{
+ //
+ // TODO: Fix broken IP formatting
+ // Update FormatIpv4 function with an additional argument to specify network/host byte order.
+ //
+
+ std::wstringstream ss;
+
+ if (AF_INET == route.network.Prefix.si_family)
+ {
+ std::wstring gateway(L"\"On-link\"");
+
+ if (0 != route.nextHop.Ipv4.sin_addr.s_addr)
+ {
+ gateway = common::string::FormatIpv4(ByteSwap(route.nextHop.Ipv4.sin_addr.s_addr));
+ }
+
+ ss << common::string::FormatIpv4(ByteSwap(route.network.Prefix.Ipv4.sin_addr.s_addr), route.network.PrefixLength)
+ << L" with gateway " << gateway
+ << L" on interface with LUID 0x" << std::hex << route.luid.Value;
+ }
+ else if (AF_INET6 == route.network.Prefix.si_family)
+ {
+ std::wstring gateway(L"\"On-link\"");
+
+ const uint8_t *begin = &route.nextHop.Ipv6.sin6_addr.u.Byte[0];
+ const uint8_t *end = begin + 16;
+
+ if (0 != std::accumulate(begin, end, 0))
+ {
+ gateway = common::string::FormatIpv6(route.nextHop.Ipv6.sin6_addr.u.Byte);
+ }
+
+ ss << common::string::FormatIpv6(route.network.Prefix.Ipv6.sin6_addr.u.Byte, route.network.PrefixLength)
+ << L" with gateway " << gateway
+ << L" on interface with LUID 0x" << std::hex << route.luid.Value;
+ }
+ else
+ {
+ ss << L"Failed to format route details";
+ }
+
+ return ss.str();
+}
+
+void RouteManager::defaultRouteChanged(ADDRESS_FAMILY family, DefaultRouteMonitor::EventType eventType,
+ const std::optional<InterfaceAndGateway> &route)
+{
+ //
+ // Forward event to all registered listeners.
+ //
+
+ m_defaultRouteCallbacksLock.lock();
+
+ for (const auto &callback : m_defaultRouteCallbacks)
+ {
+ try
+ {
+ callback(eventType, family, route);
+ }
+ catch (const std::exception &ex)
+ {
+ const auto msg = std::string("Failure in default-route-changed callback: ").append(ex.what());
+ m_logSink->error(msg.c_str());
+ }
+ catch (...)
+ {
+ m_logSink->error("Unspecified failure in default-route-changed callback");
+ }
+ }
+
+ m_defaultRouteCallbacksLock.unlock();
+
+ //
+ // Examine event to determine if best default route has changed.
+ //
+
+ if (DefaultRouteMonitor::EventType::Updated != eventType)
+ {
+ return;
+ }
+
+ //
+ // Examine our routes to see if any of them are policy bound to the best default route.
+ //
+
+ AutoLockType routesLock(m_routesLock);
+
+ using RecordIterator = std::list<RouteRecord>::iterator;
+
+ std::list<RecordIterator> affectedRoutes;
+
+ for (RecordIterator it = m_routes.begin(); it != m_routes.end(); ++it)
+ {
+ if (false == it->route.node().has_value()
+ && family == it->route.network().Prefix.si_family)
+ {
+ affectedRoutes.emplace_back(it);
+ }
+ }
+
+ if (affectedRoutes.empty())
+ {
+ return;
+ }
+
+ //
+ // Update all affected routes.
+ //
+
+ m_logSink->info("Best default route has changed. Refreshing dependent routes");
+
+ for (auto &it : affectedRoutes)
+ {
+ try
+ {
+ deleteFromRoutingTable(it->registeredRoute);
+ }
+ catch (const std::exception &ex)
+ {
+ const auto msg = std::string("Failed to delete route when refreshing " \
+ "existing routes: ").append(ex.what());
+
+ m_logSink->error(msg.c_str());
+
+ continue;
+ }
+
+ it->registeredRoute.luid = route.value().iface;
+ it->registeredRoute.nextHop = route.value().gateway;
+
+ try
+ {
+ restoreIntoRoutingTable(it->registeredRoute);
+ }
+ catch (const std::exception &ex)
+ {
+ const auto msg = std::string("Failed to add route when refreshing " \
+ "existing routes: ").append(ex.what());
+
+ m_logSink->error(msg.c_str());
+
+ continue;
+ }
+ }
+}
+
+}
diff --git a/windows/winnet/src/winnet/routing/routemanager.h b/windows/winnet/src/winnet/routing/routemanager.h
new file mode 100644
index 0000000000..981c8e6834
--- /dev/null
+++ b/windows/winnet/src/winnet/routing/routemanager.h
@@ -0,0 +1,112 @@
+#pragma once
+
+#include <string>
+#include <memory>
+#include <vector>
+#include <list>
+#include <optional>
+#include <mutex>
+#include <functional>
+#include <windows.h>
+#include <ws2def.h>
+#include <ifdef.h>
+#include <libcommon/string.h>
+#include <libcommon/logging/ilogsink.h>
+#include "defaultroutemonitor.h"
+
+namespace winnet::routing
+{
+
+class RouteManager
+{
+public:
+
+ RouteManager(std::shared_ptr<common::logging::ILogSink> logSink);
+ ~RouteManager();
+
+ RouteManager(const RouteManager &) = delete;
+ RouteManager(RouteManager &&) = default;
+ RouteManager &operator=(const RouteManager &) = delete;
+ RouteManager &operator=(RouteManager &&) = delete;
+
+ void addRoutes(const std::vector<Route> &routes);
+ void addRoute(const Route &route);
+
+ void deleteRoutes(const std::vector<Route> &routes);
+ void deleteRoute(const Route &route);
+
+ using DefaultRouteChangedEventType = DefaultRouteMonitor::EventType;
+
+ using DefaultRouteChangedCallback = std::function<void
+ (
+ DefaultRouteChangedEventType eventType,
+ ADDRESS_FAMILY family,
+
+ // For update events, data associated with the new best default route.
+ const std::optional<InterfaceAndGateway> &route
+ )>;
+
+ using CallbackHandle = void*;
+
+ CallbackHandle registerDefaultRouteChangedCallback(DefaultRouteChangedCallback callback);
+ void unregisterDefaultRouteChangedCallback(CallbackHandle handle);
+
+private:
+
+ std::shared_ptr<common::logging::ILogSink> m_logSink;
+
+ std::unique_ptr<DefaultRouteMonitor> m_routeMonitorV4;
+ std::unique_ptr<DefaultRouteMonitor> m_routeMonitorV6;
+
+ // These are the exact details derived from the route specification (`Route`).
+ // They are used when registering and deleting a route in the system.
+ struct RegisteredRoute
+ {
+ Network network;
+ NET_LUID luid;
+ NodeAddress nextHop;
+ };
+
+ struct RouteRecord
+ {
+ Route route;
+ RegisteredRoute registeredRoute;
+ };
+
+ std::list<RouteRecord> m_routes;
+ std::mutex m_routesLock;
+
+ std::list<DefaultRouteChangedCallback> m_defaultRouteCallbacks;
+ std::recursive_mutex m_defaultRouteCallbacksLock;
+
+ // Find record based on destination and mask.
+ std::list<RouteRecord>::iterator findRouteRecord(const Network &network);
+
+ // Note: Same as above!
+ std::list<RouteRecord>::iterator findRouteRecord(const Route &route);
+
+ RegisteredRoute addIntoRoutingTable(const Route &route);
+ void restoreIntoRoutingTable(const RegisteredRoute &route);
+ void deleteFromRoutingTable(const RegisteredRoute &route);
+
+ enum class EventType
+ {
+ ADD_ROUTE,
+ DELETE_ROUTE,
+ };
+
+ struct EventEntry
+ {
+ EventType type;
+ RouteRecord record;
+ };
+
+ void undoEvents(const std::vector<EventEntry> &eventLog);
+
+ static std::wstring FormatRegisteredRoute(const RegisteredRoute &route);
+
+ void defaultRouteChanged(ADDRESS_FAMILY family, DefaultRouteMonitor::EventType eventType,
+ const std::optional<InterfaceAndGateway> &route);
+};
+
+}
diff --git a/windows/winnet/src/winnet/routing/types.cpp b/windows/winnet/src/winnet/routing/types.cpp
new file mode 100644
index 0000000000..ac71c8108f
--- /dev/null
+++ b/windows/winnet/src/winnet/routing/types.cpp
@@ -0,0 +1,84 @@
+#include "stdafx.h"
+#include "types.h"
+#include "helpers.h"
+#include <libcommon/string.h>
+
+namespace winnet::routing
+{
+
+Node::Node(const std::optional<std::wstring> &deviceName, const std::optional<NodeAddress> &gateway)
+ : m_deviceName(deviceName)
+ , m_gateway(gateway)
+{
+ if (false == m_deviceName.has_value() && false == m_gateway.has_value())
+ {
+ throw std::runtime_error("Invalid node definition");
+ }
+
+ if (m_deviceName.has_value())
+ {
+ const auto trimmed = common::string::Trim<>(m_deviceName.value());
+
+ if (trimmed.empty())
+ {
+ throw std::runtime_error("Invalid device name in node definition");
+ }
+
+ m_deviceName = std::move(trimmed);
+ }
+}
+
+bool Node::operator==(const Node &rhs) const
+{
+ if (m_deviceName.has_value())
+ {
+ if (false == rhs.m_deviceName.has_value()
+ || 0 != _wcsicmp(m_deviceName.value().c_str(), rhs.deviceName().value().c_str()))
+ {
+ return false;
+ }
+ }
+
+ if (m_gateway.has_value())
+ {
+ if (false == rhs.m_gateway.has_value()
+ || false == EqualAddress(m_gateway.value(), rhs.gateway().value()))
+ {
+ return false;
+ }
+ }
+
+ return true;
+}
+
+Route::Route(const Network &network, const std::optional<Node> &node)
+ : m_network(network)
+ , m_node(node)
+{
+}
+
+bool Route::operator==(const Route &rhs) const
+{
+ if (m_node.has_value())
+ {
+ return rhs.node().has_value()
+ && EqualAddress(m_network, rhs.network())
+ && m_node.value() == rhs.node().value();
+ }
+
+ return false == rhs.node().has_value()
+ && EqualAddress(m_network, rhs.network());
+}
+
+bool InterfaceAndGateway::operator==(const InterfaceAndGateway &rhs)
+{
+ return iface.Value == rhs.iface.Value
+ && EqualAddress(gateway, rhs.gateway);
+}
+
+bool InterfaceAndGateway::operator!=(const InterfaceAndGateway &rhs)
+{
+ return !(*this == rhs);
+}
+
+}
diff --git a/windows/winnet/src/winnet/routing/types.h b/windows/winnet/src/winnet/routing/types.h
new file mode 100644
index 0000000000..1e132feb00
--- /dev/null
+++ b/windows/winnet/src/winnet/routing/types.h
@@ -0,0 +1,77 @@
+#pragma once
+
+#include <string>
+#include <optional>
+#include <winsock2.h>
+#include <windows.h>
+#include <ws2def.h>
+#include <ws2ipdef.h>
+#include <iphlpapi.h>
+//#include <netioapi.h>
+//#include <functional>
+
+
+namespace winnet::routing
+{
+
+using Network = IP_ADDRESS_PREFIX;
+using NodeAddress = SOCKADDR_INET;
+
+class Node
+{
+public:
+
+ Node(const std::optional<std::wstring> &deviceName, const std::optional<NodeAddress> &gateway);
+
+ const std::optional<std::wstring> &deviceName() const
+ {
+ return m_deviceName;
+ }
+
+ const std::optional<NodeAddress> &gateway() const
+ {
+ return m_gateway;
+ }
+
+ bool operator==(const Node &rhs) const;
+
+private:
+
+ std::optional<std::wstring> m_deviceName;
+ std::optional<NodeAddress> m_gateway;
+};
+
+class Route
+{
+public:
+
+ Route(const Network &network, const std::optional<Node> &node);
+
+ const Network &network() const
+ {
+ return m_network;
+ }
+
+ const std::optional<Node> &node() const
+ {
+ return m_node;
+ }
+
+ bool operator==(const Route &rhs) const;
+
+private:
+
+ Network m_network;
+ std::optional<Node> m_node;
+};
+
+struct InterfaceAndGateway
+{
+ NET_LUID iface;
+ NodeAddress gateway;
+
+ bool operator==(const InterfaceAndGateway &rhs);
+ bool operator!=(const InterfaceAndGateway &rhs);
+};
+
+}
diff --git a/windows/winnet/src/winnet/winnet.cpp b/windows/winnet/src/winnet/winnet.cpp
index 4b006964a6..48d12b5ea3 100644
--- a/windows/winnet/src/winnet/winnet.cpp
+++ b/windows/winnet/src/winnet/winnet.cpp
@@ -3,17 +3,135 @@
#include "NetworkInterfaces.h"
#include "interfaceutils.h"
#include "offlinemonitor.h"
+#include "routing/routemanager.h"
#include "../../shared/logsinkadapter.h"
#include <libcommon/error.h>
+#include <libcommon/network.h>
#include <cstdint>
#include <stdexcept>
#include <memory>
+#include <optional>
+#include <mutex>
+
+using namespace winnet::routing;
+using AutoLockType = std::scoped_lock<std::mutex>;
namespace
{
OfflineMonitor *g_OfflineMonitor = nullptr;
+std::mutex g_RouteManagerLock;
+RouteManager *g_RouteManager = nullptr;
+std::shared_ptr<shared::LogSinkAdapter> g_RouteManagerLogSink;
+
+Network ConvertNetwork(const WINNET_IPNETWORK &in)
+{
+ //
+ // Convert WINNET_IPNETWORK into Network aka IP_ADDRESS_PREFIX
+ //
+
+ Network out{ 0 };
+
+ out.PrefixLength = in.prefix;
+
+ switch (in.type)
+ {
+ case WINNET_IP_TYPE_IPV4:
+ {
+ out.Prefix.si_family = AF_INET;
+ out.Prefix.Ipv4.sin_family = AF_INET;
+ out.Prefix.Ipv4.sin_addr.s_addr = *reinterpret_cast<const uint32_t *>(in.bytes);
+
+ break;
+ }
+ case WINNET_IP_TYPE_IPV6:
+ {
+ out.Prefix.si_family = AF_INET6;
+ out.Prefix.Ipv6.sin6_family = AF_INET6;
+ memcpy(out.Prefix.Ipv6.sin6_addr.u.Byte, in.bytes, 16);
+
+ break;
+ }
+ default:
+ {
+ throw std::runtime_error("Missing case handler in switch clause");
+ }
+ }
+
+ return out;
+}
+
+std::optional<Node> ConvertNode(const WINNET_NODE *in)
+{
+ if (nullptr == in)
+ {
+ return {};
+ }
+
+ if (nullptr == in->deviceName && nullptr == in->gateway)
+ {
+ throw std::runtime_error("Invalid 'WINNET_NODE' definition");
+ }
+
+ std::optional<std::wstring> deviceName;
+ std::optional<NodeAddress> gateway;
+
+ if (nullptr != in->deviceName)
+ {
+ deviceName = in->deviceName;
+ }
+
+ if (nullptr != in->gateway)
+ {
+ NodeAddress gw { 0 };
+
+ switch (in->gateway->type)
+ {
+ case WINNET_IP_TYPE_IPV4:
+ {
+ gw.si_family = AF_INET;
+ gw.Ipv4.sin_addr.s_addr = *reinterpret_cast<const uint32_t *>(in->gateway->bytes);
+
+ break;
+ }
+ case WINNET_IP_TYPE_IPV6:
+ {
+ gw.si_family = AF_INET6;
+ memcpy(&gw.Ipv6.sin6_addr.u.Byte, in->gateway->bytes, 16);
+
+ break;
+ }
+ default:
+ {
+ throw std::logic_error("Invalid gateway type specifier in 'WINNET_NODE' definition");
+ }
+ }
+
+ gateway = gw;
+ }
+
+ return Node(deviceName, gateway);
+}
+
+std::vector<Route> ConvertRoutes(const WINNET_ROUTE *routes, uint32_t numRoutes)
+{
+ std::vector<Route> out;
+
+ out.reserve(numRoutes);
+
+ for (size_t i = 0; i < numRoutes; ++i)
+ {
+ out.emplace_back(Route
+ {
+ ConvertNetwork(routes[i].network),
+ ConvertNode(routes[i].node)
+ });
+ }
+
+ return out;
+}
+
void UnwindAndLog(MullvadLogSink logSink, void *logSinkContext, const std::exception &err)
{
if (nullptr == logSink)
@@ -26,6 +144,49 @@ void UnwindAndLog(MullvadLogSink logSink, void *logSinkContext, const std::excep
common::error::UnwindException(err, logger);
}
+std::vector<SOCKADDR_INET> ConvertAddresses(const WINNET_IP *addresses, uint32_t numAddresses)
+{
+ //
+ // This duplicates the same logic we have above.
+ // TODO: Fix when time permits.
+ //
+
+ std::vector<SOCKADDR_INET> out;
+ out.reserve(numAddresses);
+
+ for (uint32_t i = 0; i < numAddresses; ++i)
+ {
+ const WINNET_IP &from = addresses[i];
+ SOCKADDR_INET to{ 0 };
+
+ switch (from.type)
+ {
+ case WINNET_IP_TYPE_IPV4:
+ {
+ to.si_family = AF_INET;
+ to.Ipv4.sin_addr.s_addr = *reinterpret_cast<const uint32_t *>(from.bytes);
+
+ break;
+ }
+ case WINNET_IP_TYPE_IPV6:
+ {
+ to.si_family = AF_INET6;
+ memcpy(&to.Ipv6.sin6_addr.u.Byte, from.bytes, 16);
+
+ break;
+ }
+ default:
+ {
+ throw std::logic_error("Invalid address family in 'WINNET_IP' definition");
+ }
+ }
+
+ out.push_back(to);
+ }
+
+ return out;
+}
+
} //anonymous namespace
extern "C"
@@ -66,12 +227,12 @@ WinNet_GetTapInterfaceIpv6Status(
{
try
{
- MIB_IPINTERFACE_ROW interface = { 0 };
+ MIB_IPINTERFACE_ROW iface = { 0 };
- interface.InterfaceLuid = NetworkInterfaces::GetInterfaceLuid(InterfaceUtils::GetTapInterfaceAlias());
- interface.Family = AF_INET6;
+ iface.InterfaceLuid = NetworkInterfaces::GetInterfaceLuid(InterfaceUtils::GetTapInterfaceAlias());
+ iface.Family = AF_INET6;
- const auto status = GetIpInterfaceEntry(&interface);
+ const auto status = GetIpInterfaceEntry(&iface);
if (NO_ERROR == status)
{
@@ -201,3 +362,360 @@ WinNet_DeactivateConnectivityMonitor(
{
}
}
+
+extern "C"
+WINNET_LINKAGE
+bool
+WINNET_API
+WinNet_ActivateRouteManager(
+ MullvadLogSink logSink,
+ void *logSinkContext
+)
+{
+ AutoLockType lock(g_RouteManagerLock);
+
+ try
+ {
+ if (nullptr != g_RouteManager)
+ {
+ throw std::runtime_error("Cannot activate route manager twice");
+ }
+
+ g_RouteManagerLogSink = std::make_shared<shared::LogSinkAdapter>(logSink, logSinkContext);
+ g_RouteManager = new RouteManager(g_RouteManagerLogSink);
+
+ return true;
+ }
+ catch (const std::exception &err)
+ {
+ UnwindAndLog(logSink, logSinkContext, err);
+ return false;
+ }
+ catch (...)
+ {
+ return false;
+ }
+}
+
+extern "C"
+WINNET_LINKAGE
+bool
+WINNET_API
+WinNet_AddRoutes(
+ const WINNET_ROUTE *routes,
+ uint32_t numRoutes
+)
+{
+ AutoLockType lock(g_RouteManagerLock);
+
+ if (nullptr == g_RouteManager)
+ {
+ return false;
+ }
+
+ try
+ {
+ g_RouteManager->addRoutes(ConvertRoutes(routes, numRoutes));
+ return true;
+ }
+ catch (const std::exception &err)
+ {
+ common::error::UnwindException(err, g_RouteManagerLogSink);
+ return false;
+ }
+ catch (...)
+ {
+ return false;
+ }
+}
+
+extern "C"
+WINNET_LINKAGE
+bool
+WINNET_API
+WinNet_AddRoute(
+ const WINNET_ROUTE *route
+)
+{
+ AutoLockType lock(g_RouteManagerLock);
+
+ if (nullptr == g_RouteManager)
+ {
+ return false;
+ }
+
+ try
+ {
+ g_RouteManager->addRoute
+ (
+ Route{ ConvertNetwork(route->network), ConvertNode(route->node) }
+ );
+
+ return true;
+ }
+ catch (const std::exception &err)
+ {
+ common::error::UnwindException(err, g_RouteManagerLogSink);
+ return false;
+ }
+ catch (...)
+ {
+ return false;
+ }
+}
+
+extern "C"
+WINNET_LINKAGE
+bool
+WINNET_API
+WinNet_DeleteRoutes(
+ const WINNET_ROUTE *routes,
+ uint32_t numRoutes
+)
+{
+ AutoLockType lock(g_RouteManagerLock);
+
+ if (nullptr == g_RouteManager)
+ {
+ return false;
+ }
+
+ try
+ {
+ g_RouteManager->deleteRoutes(ConvertRoutes(routes, numRoutes));
+ return true;
+ }
+ catch (const std::exception &err)
+ {
+ common::error::UnwindException(err, g_RouteManagerLogSink);
+ return false;
+ }
+ catch (...)
+ {
+ return false;
+ }
+}
+
+extern "C"
+WINNET_LINKAGE
+bool
+WINNET_API
+WinNet_DeleteRoute(
+ const WINNET_ROUTE *route
+)
+{
+ AutoLockType lock(g_RouteManagerLock);
+
+ if (nullptr == g_RouteManager)
+ {
+ return false;
+ }
+
+ try
+ {
+ g_RouteManager->deleteRoute
+ (
+ Route{ ConvertNetwork(route->network), ConvertNode(route->node) }
+ );
+
+ return true;
+ }
+ catch (const std::exception &err)
+ {
+ common::error::UnwindException(err, g_RouteManagerLogSink);
+ return false;
+ }
+ catch (...)
+ {
+ return false;
+ }
+}
+
+//
+// TODO: Move to libcommon.
+//
+struct ValueMapper
+{
+ template<typename T, typename U, std::size_t S>
+ static U map(T t, const std::pair<T, U> (&dictionary)[S])
+ {
+ for (const auto &entry : dictionary)
+ {
+ if (t == entry.first)
+ {
+ return entry.second;
+ }
+ }
+
+ throw std::runtime_error("Could not map between values");
+ }
+};
+
+extern "C"
+WINNET_LINKAGE
+bool
+WINNET_API
+WinNet_RegisterDefaultRouteChangedCallback(
+ WinNetDefaultRouteChangedCallback callback,
+ void *context,
+ void **registrationHandle
+)
+{
+ AutoLockType lock(g_RouteManagerLock);
+
+ if (nullptr == g_RouteManager)
+ {
+ return false;
+ }
+
+ try
+ {
+ auto forwarder = [callback, context](RouteManager::DefaultRouteChangedEventType eventType,
+ ADDRESS_FAMILY family, const std::optional<InterfaceAndGateway> &route)
+ {
+ //
+ // Translate the event type.
+ //
+
+ using from_t = RouteManager::DefaultRouteChangedEventType;
+ using to_t = WINNET_DEFAULT_ROUTE_CHANGED_EVENT_TYPE;
+
+ static const std::pair<from_t, to_t> eventTypeMap[] =
+ {
+ { from_t::Updated, WINNET_DEFAULT_ROUTE_CHANGED_EVENT_TYPE_UPDATED },
+ { from_t::Removed, WINNET_DEFAULT_ROUTE_CHANGED_EVENT_TYPE_REMOVED }
+ };
+
+ const auto translatedEventType = ValueMapper::map<>(eventType, eventTypeMap);
+
+ //
+ // Translate the family type.
+ //
+
+ static const std::pair<ADDRESS_FAMILY, WINNET_IP_FAMILY> familyMap[] =
+ {
+ { static_cast<ADDRESS_FAMILY>(AF_INET), WINNET_IP_FAMILY_V4 },
+ { static_cast<ADDRESS_FAMILY>(AF_INET6), WINNET_IP_FAMILY_V6 }
+ };
+
+ const auto translatedFamily = ValueMapper::map<>(family, familyMap);
+
+ //
+ // Determine which LUID to forward.
+ //
+
+ uint64_t translatedLuid = 0;
+
+ if (RouteManager::DefaultRouteChangedEventType::Updated == eventType)
+ {
+ translatedLuid = route.value().iface.Value;
+ }
+
+ //
+ // Forward to client.
+ //
+
+ callback(translatedEventType, translatedFamily, translatedLuid, context);
+ };
+
+ *registrationHandle = g_RouteManager->registerDefaultRouteChangedCallback(forwarder);
+
+ return true;
+ }
+ catch (const std::exception &err)
+ {
+ common::error::UnwindException(err, g_RouteManagerLogSink);
+ return false;
+ }
+ catch (...)
+ {
+ return false;
+ }
+}
+
+extern "C"
+WINNET_LINKAGE
+void
+WINNET_API
+WinNet_UnregisterDefaultRouteChangedCallback(
+ void *registrationHandle
+)
+{
+ AutoLockType lock(g_RouteManagerLock);
+
+ if (nullptr == g_RouteManager)
+ {
+ return;
+ }
+
+ try
+ {
+ g_RouteManager->unregisterDefaultRouteChangedCallback(registrationHandle);
+ }
+ catch (const std::exception &err)
+ {
+ g_RouteManagerLogSink->error("Failed to unregister default-route-changed callback");
+ common::error::UnwindException(err, g_RouteManagerLogSink);
+ }
+ catch (...)
+ {
+ }
+}
+
+extern "C"
+WINNET_LINKAGE
+void
+WINNET_API
+WinNet_DeactivateRouteManager(
+)
+{
+ AutoLockType lock(g_RouteManagerLock);
+
+ try
+ {
+ delete g_RouteManager;
+ g_RouteManager = nullptr;
+ }
+ catch (...)
+ {
+ }
+}
+
+extern "C"
+WINNET_LINKAGE
+bool
+WINNET_API
+WinNet_AddDeviceIpAddresses(
+ const wchar_t *deviceAlias,
+ const WINNET_IP *addresses,
+ uint32_t numAddresses,
+ MullvadLogSink logSink,
+ void *logSinkContext
+)
+{
+ try
+ {
+ NET_LUID luid;
+
+ if (0 != ConvertInterfaceAliasToLuid(deviceAlias, &luid))
+ {
+ const auto ansiName = common::string::ToAnsi(deviceAlias);
+ const auto err = std::string("Unable to derive interface LUID from interface alias: ").append(ansiName);
+
+ throw std::runtime_error(err);
+ }
+
+ InterfaceUtils::AddDeviceIpAddresses(luid, ConvertAddresses(addresses, numAddresses));
+
+ return true;
+ }
+ catch (const std::exception &err)
+ {
+ UnwindAndLog(logSink, logSinkContext, err);
+ return false;
+ }
+ catch (...)
+ {
+ return false;
+ }
+}
diff --git a/windows/winnet/src/winnet/winnet.def b/windows/winnet/src/winnet/winnet.def
index 04c3f22ee3..b23ae6c854 100644
--- a/windows/winnet/src/winnet/winnet.def
+++ b/windows/winnet/src/winnet/winnet.def
@@ -6,3 +6,6 @@ EXPORTS
WinNet_ReleaseString
WinNet_ActivateConnectivityMonitor
WinNet_DeactivateConnectivityMonitor
+ WinNet_ActivateRouteManager
+ WinNet_DeactivateRouteManager
+ WinNet_AddDeviceIpAddresses
diff --git a/windows/winnet/src/winnet/winnet.h b/windows/winnet/src/winnet/winnet.h
index 9b1af52e36..c7a161c3d8 100644
--- a/windows/winnet/src/winnet/winnet.h
+++ b/windows/winnet/src/winnet/winnet.h
@@ -1,6 +1,7 @@
#pragma once
#include "../../shared/logsink.h"
+#include <stdint.h>
#include <stdbool.h>
#ifndef WINNET_STATIC
@@ -89,3 +90,147 @@ void
WINNET_API
WinNet_DeactivateConnectivityMonitor(
);
+
+enum WINNET_IP_TYPE
+{
+ WINNET_IP_TYPE_IPV4 = 0,
+ WINNET_IP_TYPE_IPV6 = 1,
+};
+
+typedef struct tag_WINNET_IPNETWORK
+{
+ WINNET_IP_TYPE type;
+ uint8_t bytes[16]; // Network byte order.
+ uint8_t prefix;
+}
+WINNET_IPNETWORK;
+
+typedef struct tag_WINNET_IP
+{
+ WINNET_IP_TYPE type;
+ uint8_t bytes[16]; // Network byte order.
+}
+WINNET_IP;
+
+typedef struct tag_WINNET_NODE
+{
+ const WINNET_IP *gateway;
+ const wchar_t *deviceName;
+}
+WINNET_NODE;
+
+typedef struct tag_WINNET_ROUTE
+{
+ WINNET_IPNETWORK network;
+ const WINNET_NODE *node;
+}
+WINNET_ROUTE;
+
+extern "C"
+WINNET_LINKAGE
+bool
+WINNET_API
+WinNet_ActivateRouteManager(
+ MullvadLogSink logSink,
+ void *logSinkContext
+);
+
+extern "C"
+WINNET_LINKAGE
+bool
+WINNET_API
+WinNet_AddRoutes(
+ const WINNET_ROUTE *routes,
+ uint32_t numRoutes
+);
+
+extern "C"
+WINNET_LINKAGE
+bool
+WINNET_API
+WinNet_AddRoute(
+ const WINNET_ROUTE *route
+);
+
+extern "C"
+WINNET_LINKAGE
+bool
+WINNET_API
+WinNet_DeleteRoutes(
+ const WINNET_ROUTE *routes,
+ uint32_t numRoutes
+);
+
+extern "C"
+WINNET_LINKAGE
+bool
+WINNET_API
+WinNet_DeleteRoute(
+ const WINNET_ROUTE *route
+);
+
+enum WINNET_DEFAULT_ROUTE_CHANGED_EVENT_TYPE
+{
+ // Best default route changed.
+ WINNET_DEFAULT_ROUTE_CHANGED_EVENT_TYPE_UPDATED = 0,
+
+ // No default routes exist.
+ WINNET_DEFAULT_ROUTE_CHANGED_EVENT_TYPE_REMOVED = 1,
+};
+
+enum WINNET_IP_FAMILY
+{
+ WINNET_IP_FAMILY_V4 = 0,
+ WINNET_IP_FAMILY_V6 = 1,
+};
+
+typedef void (WINNET_API *WinNetDefaultRouteChangedCallback)
+(
+ WINNET_DEFAULT_ROUTE_CHANGED_EVENT_TYPE eventType,
+
+ // Signals which IP family the event relates to.
+ WINNET_IP_FAMILY family,
+
+ // For update events, signals the interface associated with the new best default route.
+ uint64_t interfaceLuid,
+
+ void *context
+);
+
+extern "C"
+WINNET_LINKAGE
+bool
+WINNET_API
+WinNet_RegisterDefaultRouteChangedCallback(
+ WinNetDefaultRouteChangedCallback callback,
+ void *context,
+ void **registrationHandle
+);
+
+extern "C"
+WINNET_LINKAGE
+void
+WINNET_API
+WinNet_UnregisterDefaultRouteChangedCallback(
+ void *registrationHandle
+);
+
+extern "C"
+WINNET_LINKAGE
+void
+WINNET_API
+WinNet_DeactivateRouteManager(
+);
+
+extern "C"
+WINNET_LINKAGE
+bool
+WINNET_API
+WinNet_AddDeviceIpAddresses(
+ const wchar_t *deviceAlias,
+ const WINNET_IP *addresses,
+ uint32_t numAddresses,
+ MullvadLogSink logSink,
+ void *logSinkContext
+);
+
diff --git a/windows/winnet/src/winnet/winnet.vcxproj b/windows/winnet/src/winnet/winnet.vcxproj
index 192320daaf..5e71a1f733 100644
--- a/windows/winnet/src/winnet/winnet.vcxproj
+++ b/windows/winnet/src/winnet/winnet.vcxproj
@@ -33,6 +33,10 @@
<ClCompile Include="interfaceutils.cpp" />
<ClCompile Include="offlinemonitor.cpp" />
<ClCompile Include="NetworkInterfaces.cpp" />
+ <ClCompile Include="routing\defaultroutemonitor.cpp" />
+ <ClCompile Include="routing\helpers.cpp" />
+ <ClCompile Include="routing\routemanager.cpp" />
+ <ClCompile Include="routing\types.cpp" />
<ClCompile Include="stdafx.cpp" />
<ClCompile Include="winnet.cpp" />
</ItemGroup>
@@ -42,6 +46,10 @@
<ClInclude Include="interfaceutils.h" />
<ClInclude Include="offlinemonitor.h" />
<ClInclude Include="NetworkInterfaces.h" />
+ <ClInclude Include="routing\defaultroutemonitor.h" />
+ <ClInclude Include="routing\helpers.h" />
+ <ClInclude Include="routing\routemanager.h" />
+ <ClInclude Include="routing\types.h" />
<ClInclude Include="stdafx.h" />
<ClInclude Include="targetver.h" />
<ClInclude Include="winnet.h" />
@@ -208,7 +216,7 @@
<ConformanceMode>true</ConformanceMode>
<RuntimeLibrary>MultiThreadedDebug</RuntimeLibrary>
<LanguageStandard>stdcpplatest</LanguageStandard>
- <AdditionalIncludeDirectories>$(ProjectDir)..\..\..\windows-libraries\src\%(AdditionalIncludeDirectories)</AdditionalIncludeDirectories>
+ <AdditionalIncludeDirectories>$(ProjectDir)..\..\..\windows-libraries\src\;$(ProjectDir);%(AdditionalIncludeDirectories)</AdditionalIncludeDirectories>
</ClCompile>
<Link>
<SubSystem>Windows</SubSystem>
@@ -278,7 +286,7 @@
<ConformanceMode>true</ConformanceMode>
<RuntimeLibrary>MultiThreaded</RuntimeLibrary>
<LanguageStandard>stdcpplatest</LanguageStandard>
- <AdditionalIncludeDirectories>$(ProjectDir)..\..\..\windows-libraries\src\%(AdditionalIncludeDirectories)</AdditionalIncludeDirectories>
+ <AdditionalIncludeDirectories>$(ProjectDir)..\..\..\windows-libraries\src\;$(ProjectDir);%(AdditionalIncludeDirectories)</AdditionalIncludeDirectories>
</ClCompile>
<Link>
<SubSystem>Windows</SubSystem>
diff --git a/windows/winnet/src/winnet/winnet.vcxproj.filters b/windows/winnet/src/winnet/winnet.vcxproj.filters
index 9a901d3203..dfe6d29ec7 100644
--- a/windows/winnet/src/winnet/winnet.vcxproj.filters
+++ b/windows/winnet/src/winnet/winnet.vcxproj.filters
@@ -9,6 +9,18 @@
<ClCompile Include="interfaceutils.cpp" />
<ClCompile Include="networkadaptermonitor.cpp" />
<ClCompile Include="offlinemonitor.cpp" />
+ <ClCompile Include="routing\types.cpp">
+ <Filter>routing</Filter>
+ </ClCompile>
+ <ClCompile Include="routing\helpers.cpp">
+ <Filter>routing</Filter>
+ </ClCompile>
+ <ClCompile Include="routing\defaultroutemonitor.cpp">
+ <Filter>routing</Filter>
+ </ClCompile>
+ <ClCompile Include="routing\routemanager.cpp">
+ <Filter>routing</Filter>
+ </ClCompile>
</ItemGroup>
<ItemGroup>
<ClInclude Include="stdafx.h" />
@@ -19,6 +31,18 @@
<ClInclude Include="interfaceutils.h" />
<ClInclude Include="networkadaptermonitor.h" />
<ClInclude Include="offlinemonitor.h" />
+ <ClInclude Include="routing\types.h">
+ <Filter>routing</Filter>
+ </ClInclude>
+ <ClInclude Include="routing\helpers.h">
+ <Filter>routing</Filter>
+ </ClInclude>
+ <ClInclude Include="routing\defaultroutemonitor.h">
+ <Filter>routing</Filter>
+ </ClInclude>
+ <ClInclude Include="routing\routemanager.h">
+ <Filter>routing</Filter>
+ </ClInclude>
</ItemGroup>
<ItemGroup>
<None Include="winnet.def" />
@@ -26,4 +50,9 @@
<ItemGroup>
<ResourceCompile Include="winnet.rc" />
</ItemGroup>
+ <ItemGroup>
+ <Filter Include="routing">
+ <UniqueIdentifier>{8df22cc6-597f-4342-bc57-7647393084be}</UniqueIdentifier>
+ </Filter>
+ </ItemGroup>
</Project> \ No newline at end of file