diff options
| author | Odd Stranne <odd@mullvad.net> | 2019-11-25 14:23:36 +0100 |
|---|---|---|
| committer | Odd Stranne <odd@mullvad.net> | 2019-11-25 14:23:36 +0100 |
| commit | 67a86af237d3305f84bb7044aa1bdf5e1122e81f (patch) | |
| tree | f910fc098a52d5466de50224d8cd66c203aac4fa /windows | |
| parent | dc6d5d8e87738f919b8f924b3381a1954097138b (diff) | |
| parent | e4f46afbe027cb8ddbb45a66ce014af9acbc54b6 (diff) | |
| download | mullvadvpn-67a86af237d3305f84bb7044aa1bdf5e1122e81f.tar.xz mullvadvpn-67a86af237d3305f84bb7044aa1bdf5e1122e81f.zip | |
Merge branch 'win-wireguard'
Diffstat (limited to 'windows')
26 files changed, 2530 insertions, 17 deletions
diff --git a/windows/winfw/src/winfw/fwcontext.cpp b/windows/winfw/src/winfw/fwcontext.cpp index e5325f0b9c..49f8793572 100644 --- a/windows/winfw/src/winfw/fwcontext.cpp +++ b/windows/winfw/src/winfw/fwcontext.cpp @@ -13,10 +13,10 @@ #include "rules/permitvpnrelay.h" #include "rules/permitvpntunnel.h" #include "rules/permitvpntunnelservice.h" +#include "rules/permitping.h" #include "rules/restrictdns.h" #include "libwfp/transaction.h" #include "libwfp/filterengine.h" -#include "libwfp/ipaddress.h" #include <functional> #include <stdexcept> #include <utility> @@ -99,7 +99,12 @@ FwContext::FwContext(uint32_t timeout, const WinFwSettings &settings) m_baseline = checkpoint; } -bool FwContext::applyPolicyConnecting(const WinFwSettings &settings, const WinFwRelay &relay) +bool FwContext::applyPolicyConnecting +( + const WinFwSettings &settings, + const WinFwRelay &relay, + const std::optional<PingableHosts> &pingableHosts +) { Ruleset ruleset; @@ -112,6 +117,22 @@ bool FwContext::applyPolicyConnecting(const WinFwSettings &settings, const WinFw TranslateProtocol(relay.protocol) )); + // + // Permit pinging the gateway inside the tunnel. + // + if (pingableHosts.has_value()) + { + const auto &ph = pingableHosts.value(); + + for (const auto &host : ph.hosts) + { + ruleset.emplace_back(std::make_unique<rules::PermitPing>( + ph.tunnelInterfaceAlias, + host + )); + } + } + return applyRuleset(ruleset); } diff --git a/windows/winfw/src/winfw/fwcontext.h b/windows/winfw/src/winfw/fwcontext.h index 89ef40e1d3..9d5b34c51b 100644 --- a/windows/winfw/src/winfw/fwcontext.h +++ b/windows/winfw/src/winfw/fwcontext.h @@ -3,9 +3,11 @@ #include "winfw.h" #include "sessioncontroller.h" #include "rules/ifirewallrule.h" +#include "libwfp/ipaddress.h" #include <cstdint> #include <memory> #include <vector> +#include <optional> class FwContext { @@ -16,7 +18,19 @@ public: // This ctor applies the "blocked" policy. FwContext(uint32_t timeout, const WinFwSettings &settings); - bool applyPolicyConnecting(const WinFwSettings &settings, const WinFwRelay &relay); + struct PingableHosts + { + std::optional<std::wstring> tunnelInterfaceAlias; + std::vector<wfp::IpAddress> hosts; + }; + + bool applyPolicyConnecting + ( + const WinFwSettings &settings, + const WinFwRelay &relay, + const std::optional<PingableHosts> &pingableHosts + ); + bool applyPolicyConnected ( const WinFwSettings &settings, diff --git a/windows/winfw/src/winfw/mullvadguids.cpp b/windows/winfw/src/winfw/mullvadguids.cpp index 010d41e44a..e73fac26ed 100644 --- a/windows/winfw/src/winfw/mullvadguids.cpp +++ b/windows/winfw/src/winfw/mullvadguids.cpp @@ -59,6 +59,8 @@ DetailedWfpObjectRegistry MullvadGuids::BuildDetailedRegistry() registry.insert(std::make_pair(WfpObjectType::Filter, FilterPermitNdp_Outbound_Router_Solicitation())); registry.insert(std::make_pair(WfpObjectType::Filter, FilterPermitNdp_Inbound_Router_Advertisement())); registry.insert(std::make_pair(WfpObjectType::Filter, FilterPermitNdp_Inbound_Redirect())); + registry.insert(std::make_pair(WfpObjectType::Filter, FilterPermitPing_Outbound_Icmpv4())); + registry.insert(std::make_pair(WfpObjectType::Filter, FilterPermitPing_Outbound_Icmpv6())); return registry; } @@ -567,3 +569,31 @@ const GUID &MullvadGuids::FilterPermitNdp_Inbound_Redirect() return g; } + +//static +const GUID &MullvadGuids::FilterPermitPing_Outbound_Icmpv4() +{ + static const GUID g = + { + 0x2ecf7ff7, + 0xc951, + 0x4056, + { 0xb0, 0xf7, 0x40, 0xa4, 0x5c, 0x7e, 0xb4, 0xc2 } + }; + + return g; +} + +//static +const GUID &MullvadGuids::FilterPermitPing_Outbound_Icmpv6() +{ + static const GUID g = + { + 0x3deb8cab, + 0x1edb, + 0x4aa1, + { 0xb2, 0x73, 0xec, 0x61, 0x4f, 0x50, 0xdc, 0x13 } + }; + + return g; +} diff --git a/windows/winfw/src/winfw/mullvadguids.h b/windows/winfw/src/winfw/mullvadguids.h index d4fb470d90..3c3ca9702b 100644 --- a/windows/winfw/src/winfw/mullvadguids.h +++ b/windows/winfw/src/winfw/mullvadguids.h @@ -67,4 +67,7 @@ public: static const GUID &FilterPermitNdp_Outbound_Router_Solicitation(); static const GUID &FilterPermitNdp_Inbound_Router_Advertisement(); static const GUID &FilterPermitNdp_Inbound_Redirect(); + + static const GUID &FilterPermitPing_Outbound_Icmpv4(); + static const GUID &FilterPermitPing_Outbound_Icmpv6(); }; diff --git a/windows/winfw/src/winfw/rules/permitping.cpp b/windows/winfw/src/winfw/rules/permitping.cpp new file mode 100644 index 0000000000..f6aed36bf2 --- /dev/null +++ b/windows/winfw/src/winfw/rules/permitping.cpp @@ -0,0 +1,98 @@ +#include "stdafx.h" +#include "permitping.h" +#include "winfw/mullvadguids.h" +#include "libwfp/filterbuilder.h" +#include "libwfp/conditionbuilder.h" +#include "libwfp/conditions/conditionip.h" +#include "libwfp/conditions/conditioninterface.h" +#include "libwfp/conditions/conditionprotocol.h" + + +using namespace wfp::conditions; + +namespace rules +{ + +PermitPing::PermitPing +( + const std::optional<std::wstring> &interfaceAlias, + const wfp::IpAddress &host +) + : m_interfaceAlias(interfaceAlias) + , m_host(host) +{ +} + +bool PermitPing::apply(IObjectInstaller &objectInstaller) +{ + if (wfp::IpAddress::Type::Ipv4 == m_host.type()) + { + return applyIcmpv4(objectInstaller); + } + + return applyIcmpv6(objectInstaller); +} + +bool PermitPing::applyIcmpv4(IObjectInstaller &objectInstaller) const +{ + wfp::FilterBuilder filterBuilder; + + // + // #1 Permit outbound ICMPv4 to %host% on %interface% + // + + filterBuilder + .key(MullvadGuids::FilterPermitPing_Outbound_Icmpv4()) + .name(L"Permit outbound ICMP to specific host (ICMPv4)") + .description(L"This filter is part of a rule that permits ping") + .provider(MullvadGuids::Provider()) + .layer(FWPM_LAYER_ALE_AUTH_CONNECT_V4) + .sublayer(MullvadGuids::SublayerWhitelist()) + .weight(wfp::FilterBuilder::WeightClass::Max) + .permit(); + + wfp::ConditionBuilder conditionBuilder(FWPM_LAYER_ALE_AUTH_CONNECT_V4); + + conditionBuilder.add_condition(ConditionIp::Remote(m_host)); + conditionBuilder.add_condition(ConditionProtocol::Icmp()); + + if (m_interfaceAlias.has_value()) + { + conditionBuilder.add_condition(ConditionInterface::Alias(m_interfaceAlias.value())); + } + + return objectInstaller.addFilter(filterBuilder, conditionBuilder); +} + +bool PermitPing::applyIcmpv6(IObjectInstaller &objectInstaller) const +{ + wfp::FilterBuilder filterBuilder; + + // + // #1 Permit outbound ICMPv6 to %host% on %interface% + // + + filterBuilder + .key(MullvadGuids::FilterPermitPing_Outbound_Icmpv6()) + .name(L"Permit outbound ICMP to specific host (ICMPv6)") + .description(L"This filter is part of a rule that permits ping") + .provider(MullvadGuids::Provider()) + .layer(FWPM_LAYER_ALE_AUTH_CONNECT_V6) + .sublayer(MullvadGuids::SublayerWhitelist()) + .weight(wfp::FilterBuilder::WeightClass::Max) + .permit(); + + wfp::ConditionBuilder conditionBuilder(FWPM_LAYER_ALE_AUTH_CONNECT_V6); + + conditionBuilder.add_condition(ConditionIp::Remote(m_host)); + conditionBuilder.add_condition(ConditionProtocol::IcmpV6()); + + if (m_interfaceAlias.has_value()) + { + conditionBuilder.add_condition(ConditionInterface::Alias(m_interfaceAlias.value())); + } + + return objectInstaller.addFilter(filterBuilder, conditionBuilder); +} + +} diff --git a/windows/winfw/src/winfw/rules/permitping.h b/windows/winfw/src/winfw/rules/permitping.h new file mode 100644 index 0000000000..c8238ceaa8 --- /dev/null +++ b/windows/winfw/src/winfw/rules/permitping.h @@ -0,0 +1,28 @@ +#pragma once + +#include "ifirewallrule.h" +#include <libwfp/ipaddress.h> +#include <string> +#include <optional> + +namespace rules +{ + +class PermitPing : public IFirewallRule +{ +public: + + PermitPing(const std::optional<std::wstring> &interfaceAlias, const wfp::IpAddress &host); + + bool apply(IObjectInstaller &objectInstaller) override; + +private: + + const std::optional<std::wstring> m_interfaceAlias; + const wfp::IpAddress m_host; + + bool applyIcmpv4(IObjectInstaller &objectInstaller) const; + bool applyIcmpv6(IObjectInstaller &objectInstaller) const; +}; + +} diff --git a/windows/winfw/src/winfw/winfw.cpp b/windows/winfw/src/winfw/winfw.cpp index 7b9ea2dc6b..3065408f3d 100644 --- a/windows/winfw/src/winfw/winfw.cpp +++ b/windows/winfw/src/winfw/winfw.cpp @@ -4,6 +4,7 @@ #include "objectpurger.h" #include <windows.h> #include <stdexcept> +#include <optional> namespace { @@ -15,6 +16,34 @@ void * g_errorContext = nullptr; FwContext *g_fwContext = nullptr; +std::optional<FwContext::PingableHosts> ConvertPingableHosts(const PingableHosts *pingableHosts) +{ + if (nullptr == pingableHosts) + { + return {}; + } + + if (nullptr == pingableHosts->hosts + || 0 == pingableHosts->numHosts) + { + throw std::runtime_error("Invalid PingableHosts structure"); + } + + FwContext::PingableHosts converted; + + if (nullptr != pingableHosts->tunnelInterfaceAlias) + { + converted.tunnelInterfaceAlias = pingableHosts->tunnelInterfaceAlias; + } + + for (size_t i = 0; i < pingableHosts->numHosts; ++i) + { + converted.hosts.emplace_back(wfp::IpAddress(pingableHosts->hosts[i])); + } + + return converted; +} + } // anonymous namespace WINFW_LINKAGE @@ -130,7 +159,8 @@ bool WINFW_API WinFw_ApplyPolicyConnecting( const WinFwSettings &settings, - const WinFwRelay &relay + const WinFwRelay &relay, + const PingableHosts *pingableHosts ) { if (nullptr == g_fwContext) @@ -140,7 +170,7 @@ WinFw_ApplyPolicyConnecting( try { - return g_fwContext->applyPolicyConnecting(settings, relay); + return g_fwContext->applyPolicyConnecting(settings, relay, ConvertPingableHosts(pingableHosts)); } catch (std::exception &err) { diff --git a/windows/winfw/src/winfw/winfw.h b/windows/winfw/src/winfw/winfw.h index 95e66a608f..6d43b0db4c 100644 --- a/windows/winfw/src/winfw/winfw.h +++ b/windows/winfw/src/winfw/winfw.h @@ -105,11 +105,29 @@ WINFW_API WinFw_Deinitialize(); // +// PingableHosts: +// +// Specifies a set of IP addresses that should be reachable by ICMP when the connecting +// policy is effective. +// +// The interface alias is optional and can be used to restrict the traffic such +// that it is only allowed on that specific interface. +// +typedef struct tag_PingableHosts +{ + const wchar_t *tunnelInterfaceAlias; + const wchar_t **hosts; + size_t numHosts; +} +PingableHosts; + +// // ApplyPolicyConnecting: // // Apply restrictions in the firewall that block all traffic, except: // - What is specified by settings // - Communication with the relay server +// - ICMP (for ping) to/from tunnel gateway // extern "C" WINFW_LINKAGE @@ -117,7 +135,8 @@ bool WINFW_API WinFw_ApplyPolicyConnecting( const WinFwSettings &settings, - const WinFwRelay &relay + const WinFwRelay &relay, + const PingableHosts *pingableHosts ); // diff --git a/windows/winfw/src/winfw/winfw.vcxproj b/windows/winfw/src/winfw/winfw.vcxproj index 4777503f72..cbabe2f4f7 100644 --- a/windows/winfw/src/winfw/winfw.vcxproj +++ b/windows/winfw/src/winfw/winfw.vcxproj @@ -30,6 +30,7 @@ <ClCompile Include="rules\permitlanservice.cpp" /> <ClCompile Include="rules\permitloopback.cpp" /> <ClCompile Include="rules\permitndp.cpp" /> + <ClCompile Include="rules\permitping.cpp" /> <ClCompile Include="rules\permitvpntunnelservice.cpp" /> <ClCompile Include="rules\permitvpnrelay.cpp" /> <ClCompile Include="rules\permitvpntunnel.cpp" /> @@ -53,6 +54,7 @@ <ClInclude Include="objectpurger.h" /> <ClInclude Include="rules\permitdhcpserver.h" /> <ClInclude Include="rules\permitndp.h" /> + <ClInclude Include="rules\permitping.h" /> <ClInclude Include="wfpobjecttype.h" /> <ClInclude Include="rules\blockall.h" /> <ClInclude Include="rules\ifirewallrule.h" /> diff --git a/windows/winfw/src/winfw/winfw.vcxproj.filters b/windows/winfw/src/winfw/winfw.vcxproj.filters index 0319b0214a..a758a1c9ec 100644 --- a/windows/winfw/src/winfw/winfw.vcxproj.filters +++ b/windows/winfw/src/winfw/winfw.vcxproj.filters @@ -43,6 +43,9 @@ <ClCompile Include="rules\permitndp.cpp"> <Filter>rules</Filter> </ClCompile> + <ClCompile Include="rules\permitping.cpp"> + <Filter>rules</Filter> + </ClCompile> </ItemGroup> <ItemGroup> <ClInclude Include="stdafx.h" /> @@ -93,6 +96,9 @@ <ClInclude Include="rules\permitndp.h"> <Filter>rules</Filter> </ClInclude> + <ClInclude Include="rules\permitping.h"> + <Filter>rules</Filter> + </ClInclude> </ItemGroup> <ItemGroup> <Filter Include="rules"> diff --git a/windows/winnet/src/extras/loader/loader.vcxproj.filters b/windows/winnet/src/extras/loader/loader.vcxproj.filters index cd0f4643c7..408a9591b1 100644 --- a/windows/winnet/src/extras/loader/loader.vcxproj.filters +++ b/windows/winnet/src/extras/loader/loader.vcxproj.filters @@ -3,9 +3,13 @@ <ItemGroup> <ClCompile Include="loader.cpp" /> <ClCompile Include="stdafx.cpp" /> + <ClCompile Include="..\..\winnet\routemanager.cpp" /> + <ClCompile Include="..\..\winnet\adapters.cpp" /> </ItemGroup> <ItemGroup> <ClInclude Include="stdafx.h" /> <ClInclude Include="targetver.h" /> + <ClInclude Include="..\..\winnet\routemanager.h" /> + <ClInclude Include="..\..\winnet\adapters.h" /> </ItemGroup> </Project>
\ No newline at end of file diff --git a/windows/winnet/src/winnet/interfaceutils.cpp b/windows/winnet/src/winnet/interfaceutils.cpp index babe03eba6..202d9d0724 100644 --- a/windows/winnet/src/winnet/interfaceutils.cpp +++ b/windows/winnet/src/winnet/interfaceutils.cpp @@ -2,13 +2,8 @@ #include "interfaceutils.h" #include "libcommon/error.h" #include "libcommon/string.h" -#include <vector> #include <cstdint> #include <algorithm> -#include <winsock2.h> -#include <iphlpapi.h> -#include <windows.h> - //static std::set<InterfaceUtils::NetworkAdapter> InterfaceUtils::GetAllAdapters() @@ -112,3 +107,18 @@ std::wstring InterfaceUtils::GetTapInterfaceAlias() throw std::runtime_error("Unable to find TAP adapter"); } + +//static +void InterfaceUtils::AddDeviceIpAddresses(NET_LUID device, const std::vector<SOCKADDR_INET> &addresses) +{ + for (const auto &address : addresses) + { + MIB_UNICASTIPADDRESS_ROW row; + InitializeUnicastIpAddressEntry(&row); + + row.InterfaceLuid = device; + row.Address = address; + + THROW_UNLESS(NO_ERROR, CreateUnicastIpAddressEntry(&row), "Assign IP address on network interface"); + } +} diff --git a/windows/winnet/src/winnet/interfaceutils.h b/windows/winnet/src/winnet/interfaceutils.h index f5c31963c2..8ab1249a50 100644 --- a/windows/winnet/src/winnet/interfaceutils.h +++ b/windows/winnet/src/winnet/interfaceutils.h @@ -2,6 +2,17 @@ #include <string> #include <set> +#include <vector> + +// Secret include order to get most common networking structs/apis +// And avoiding compilation errors +#include <winsock2.h> +#include <windows.h> +#include <ws2def.h> +#include <ws2ipdef.h> +#include <iphlpapi.h> +#include <netioapi.h> +// end class InterfaceUtils { @@ -35,4 +46,6 @@ public: // Determines alias of primary TAP adapter. // static std::wstring GetTapInterfaceAlias(); + + static void AddDeviceIpAddresses(NET_LUID device, const std::vector<SOCKADDR_INET> &addresses); }; diff --git a/windows/winnet/src/winnet/routing/defaultroutemonitor.cpp b/windows/winnet/src/winnet/routing/defaultroutemonitor.cpp new file mode 100644 index 0000000000..55d7560904 --- /dev/null +++ b/windows/winnet/src/winnet/routing/defaultroutemonitor.cpp @@ -0,0 +1,177 @@ +#include "stdafx.h" +#include <libcommon/error.h> +#include "defaultroutemonitor.h" +#include "helpers.h" + +namespace winnet::routing +{ + +namespace +{ + +const uint32_t POINT_TWO_SECOND_BURST = 200; +const uint32_t TWO_SECOND_INTERFERENCE = 2000; + +} // anonymous namespace + +DefaultRouteMonitor::DefaultRouteMonitor +( + ADDRESS_FAMILY family, + Callback callback, + std::shared_ptr<common::logging::ILogSink> logSink +) + : m_family(family) + , m_callback(callback) + , m_logSink(logSink) + , m_evaluateRoutesGuard(std::make_unique<common::BurstGuard>( + std::bind(&DefaultRouteMonitor::evaluateRoutes, this), + POINT_TWO_SECOND_BURST, + TWO_SECOND_INTERFERENCE + )) +{ + try + { + m_bestRoute = GetBestDefaultRoute(m_family); + } + catch (...) + { + } + + const auto status = NotifyRouteChange2(AF_UNSPEC, RouteChangeCallback, this, FALSE, &m_routeNotificationHandle); + + THROW_UNLESS(NO_ERROR, status, "Register for route table change notifications"); + + try + { + const auto s2 = NotifyIpInterfaceChange(AF_UNSPEC, InterfaceChangeCallback, this, + FALSE, &m_interfaceNotificationHandle); + + THROW_UNLESS(NO_ERROR, status, "Register for network interface change notifications"); + } + catch (...) + { + CancelMibChangeNotify2(m_routeNotificationHandle); + throw; + } +} + +DefaultRouteMonitor::~DefaultRouteMonitor() +{ + // + // Cancel notifications to stop triggering the BurstGuard. + // + + CancelMibChangeNotify2(m_interfaceNotificationHandle); + CancelMibChangeNotify2(m_routeNotificationHandle); + + // + // Controlled destruction of BurstGuard to prevent it from calling here + // after other member variables have been destructed. + // + + m_evaluateRoutesGuard.reset(); +} + +//static +void NETIOAPI_API_ DefaultRouteMonitor::RouteChangeCallback +( + void *context, + MIB_IPFORWARD_ROW2 *row, + MIB_NOTIFICATION_TYPE +) +{ + // + // We're only interested in changes that add/remove/update a default route. + // + + if (0 != row->DestinationPrefix.PrefixLength + || false == RouteHasGateway(*row)) + { + return; + } + + reinterpret_cast<DefaultRouteMonitor *>(context)->m_evaluateRoutesGuard->trigger(); +} + +//static +void NETIOAPI_API_ DefaultRouteMonitor::InterfaceChangeCallback +( + void *context, + MIB_IPINTERFACE_ROW *, + MIB_NOTIFICATION_TYPE +) +{ + reinterpret_cast<DefaultRouteMonitor *>(context)->m_evaluateRoutesGuard->trigger(); +} + +void DefaultRouteMonitor::evaluateRoutes() +{ + std::scoped_lock<std::mutex> lock(m_evaluationLock); + + try + { + evaluateRoutesInner(); + } + catch (const std::exception &ex) + { + const auto msg = std::string("Failure while evaluating route table: ").append(ex.what()); + m_logSink->error(msg.c_str()); + } + catch (...) + { + m_logSink->error("Unspecified failure while evaluating route table"); + } +} + +void DefaultRouteMonitor::evaluateRoutesInner() +{ + std::optional<InterfaceAndGateway> currentBestRoute; + + try + { + currentBestRoute = GetBestDefaultRoute(m_family); + } + catch (...) + { + } + + // + // If there was no default route previously. + // + + if (false == m_bestRoute.has_value()) + { + if (currentBestRoute.has_value()) + { + m_bestRoute = currentBestRoute; + m_callback(EventType::Updated, m_bestRoute); + } + + return; + } + + // + // There used to be a default route. + // If there is not currently a default route. + // + + if (false == currentBestRoute.has_value()) + { + m_bestRoute.reset(); + m_callback(EventType::Removed, std::nullopt); + + return; + } + + // + // The current best route may have changed. + // + + if (m_bestRoute.value() != currentBestRoute.value()) + { + m_bestRoute = currentBestRoute; + m_callback(EventType::Updated, m_bestRoute); + } +} + +} diff --git a/windows/winnet/src/winnet/routing/defaultroutemonitor.h b/windows/winnet/src/winnet/routing/defaultroutemonitor.h new file mode 100644 index 0000000000..5575685a82 --- /dev/null +++ b/windows/winnet/src/winnet/routing/defaultroutemonitor.h @@ -0,0 +1,69 @@ +#pragma once + +#include <ifdef.h> +#include <ws2def.h> +#include <functional> +#include <optional> +#include <memory> +#include <mutex> +#include <libcommon/logging/ilogsink.h> +#include <libcommon/burstguard.h> +#include "types.h" + +namespace winnet::routing +{ + +class DefaultRouteMonitor +{ +public: + + enum class EventType + { + // The best default route changed. + Updated, + + // No default routes exist. + Removed, + }; + + using Callback = std::function<void + ( + EventType eventType, + + // For update events, data associated with the new best default route. + const std::optional<InterfaceAndGateway> &route + )>; + + DefaultRouteMonitor(ADDRESS_FAMILY family, Callback callback, std::shared_ptr<common::logging::ILogSink> logSink); + ~DefaultRouteMonitor(); + + DefaultRouteMonitor(const DefaultRouteMonitor &) = delete; + DefaultRouteMonitor(DefaultRouteMonitor &&) = delete; + DefaultRouteMonitor &operator=(const DefaultRouteMonitor &) = delete; + DefaultRouteMonitor &operator=(DefaultRouteMonitor &&) = delete; + +private: + + ADDRESS_FAMILY m_family; + Callback m_callback; + std::shared_ptr<common::logging::ILogSink> m_logSink; + + // This can't be a plain member variable. + // We need to be able to delete it explicitly in order to have a controlled tear down. + std::unique_ptr<common::BurstGuard> m_evaluateRoutesGuard; + + std::optional<InterfaceAndGateway> m_bestRoute; + + HANDLE m_routeNotificationHandle; + HANDLE m_interfaceNotificationHandle; + + std::mutex m_evaluationLock; + + static void NETIOAPI_API_ RouteChangeCallback(void *context, MIB_IPFORWARD_ROW2 *row, MIB_NOTIFICATION_TYPE notificationType); + static void NETIOAPI_API_ InterfaceChangeCallback(void *context, MIB_IPINTERFACE_ROW *row, MIB_NOTIFICATION_TYPE notificationType); + + void evaluateRoutes(); + void evaluateRoutesInner(); +}; + +} diff --git a/windows/winnet/src/winnet/routing/helpers.cpp b/windows/winnet/src/winnet/routing/helpers.cpp new file mode 100644 index 0000000000..cabf19bce6 --- /dev/null +++ b/windows/winnet/src/winnet/routing/helpers.cpp @@ -0,0 +1,275 @@ +#include "stdafx.h" +#include "helpers.h" +#include <stdexcept> +#include <ws2def.h> +#include <in6addr.h> +#include <numeric> +//#include <netioapi.h> +#include <libcommon/error.h> +#include <libcommon/memory.h> + +namespace winnet::routing +{ + +bool EqualAddress(const Network &lhs, const Network &rhs) +{ + if (lhs.PrefixLength != rhs.PrefixLength) + { + return false; + } + + return EqualAddress(lhs.Prefix, rhs.Prefix); +} + +bool EqualAddress(const NodeAddress &lhs, const NodeAddress &rhs) +{ + if (lhs.si_family != rhs.si_family) + { + return false; + } + + switch (lhs.si_family) + { + case AF_INET: + { + return lhs.Ipv4.sin_addr.s_addr == rhs.Ipv4.sin_addr.s_addr; + } + case AF_INET6: + { + return 0 == memcmp(&lhs.Ipv6.sin6_addr, &rhs.Ipv6.sin6_addr, sizeof(IN6_ADDR)); + } + default: + { + throw std::runtime_error("Invalid address family for network address"); + } + } +} + +bool EqualAddress(const SOCKADDR_INET *lhs, const SOCKET_ADDRESS *rhs) +{ + if (lhs->si_family != rhs->lpSockaddr->sa_family) + { + return false; + } + + switch (lhs->si_family) + { + case AF_INET: + { + auto typedRhs = reinterpret_cast<const SOCKADDR_IN *>(rhs->lpSockaddr); + return lhs->Ipv4.sin_addr.s_addr == typedRhs->sin_addr.s_addr; + } + case AF_INET6: + { + auto typedRhs = reinterpret_cast<const SOCKADDR_IN6 *>(rhs->lpSockaddr); + return 0 == memcmp(lhs->Ipv6.sin6_addr.u.Byte, typedRhs->sin6_addr.u.Byte, 16); + } + default: + { + throw std::runtime_error("Missing case handler in switch clause"); + } + } +} + +bool GetAdapterInterface(NET_LUID adapter, ADDRESS_FAMILY addressFamily, MIB_IPINTERFACE_ROW *iface) +{ + memset(iface, 0, sizeof(MIB_IPINTERFACE_ROW)); + + iface->Family = addressFamily; + iface->InterfaceLuid = adapter; + + return NO_ERROR == GetIpInterfaceEntry(iface); +} + +std::vector<AnnotatedRoute> AnnotateRoutes(const std::vector<const MIB_IPFORWARD_ROW2 *> &routes) +{ + std::vector<AnnotatedRoute> annotated; + annotated.reserve(routes.size()); + + for (auto route : routes) + { + MIB_IPINTERFACE_ROW iface; + + if (false == GetAdapterInterface(route->InterfaceLuid, route->DestinationPrefix.Prefix.si_family, &iface)) + { + continue; + } + + annotated.emplace_back + ( + AnnotatedRoute{ route, bool_cast(iface.Connected), route->Metric + iface.Metric } + ); + } + + return annotated; +} + +bool RouteHasGateway(const MIB_IPFORWARD_ROW2 &route) +{ + switch (route.NextHop.si_family) + { + case AF_INET: + { + return 0 != route.NextHop.Ipv4.sin_addr.s_addr; + } + case AF_INET6: + { + const uint8_t *begin = &route.NextHop.Ipv6.sin6_addr.u.Byte[0]; + const uint8_t *end = begin + 16; + + return 0 != std::accumulate(begin, end, 0); + } + default: + { + return false; + } + }; +} + +InterfaceAndGateway GetBestDefaultRoute(ADDRESS_FAMILY family) +{ + PMIB_IPFORWARD_TABLE2 table; + + auto status = GetIpForwardTable2(family, &table); + + THROW_UNLESS(NO_ERROR, status, "Acquire route table"); + + common::memory::ScopeDestructor sd; + + sd += [table] + { + FreeMibTable(table); + }; + + std::vector<const MIB_IPFORWARD_ROW2 *> candidates; + candidates.reserve(table->NumEntries); + + // + // Enumerate routes looking for: route 0/0 && gateway specified. + // + + for (ULONG i = 0; i < table->NumEntries; ++i) + { + const MIB_IPFORWARD_ROW2 &candidate = table->Table[i]; + + if (0 == candidate.DestinationPrefix.PrefixLength + && RouteHasGateway(candidate)) + { + candidates.emplace_back(&candidate); + } + } + + auto annotated = AnnotateRoutes(candidates); + + if (annotated.empty()) + { + throw std::runtime_error("Unable to determine details of default route"); + } + + // + // Sort on (active, effectiveMetric) ascending by metric. + // + + std::sort(annotated.begin(), annotated.end(), [](const AnnotatedRoute &lhs, const AnnotatedRoute &rhs) + { + if (lhs.active == rhs.active) + { + return lhs.effectiveMetric < rhs.effectiveMetric; + } + + return lhs.active && false == rhs.active; + }); + + // + // Ensure the top rated route is active. + // + + if (false == annotated[0].active) + { + throw std::runtime_error("Unable to identify active default route"); + } + + return InterfaceAndGateway { annotated[0].route->InterfaceLuid, annotated[0].route->NextHop }; +} + +bool AdapterInterfaceEnabled(const IP_ADAPTER_ADDRESSES *adapter, ADDRESS_FAMILY family) +{ + switch (family) + { + case AF_INET: + { + return 0 != adapter->Ipv4Enabled; + } + case AF_INET6: + { + return 0 != adapter->Ipv6Enabled; + } + default: + { + throw std::runtime_error("Missing case handler in switch clause"); + } + } +} + +std::vector<const SOCKET_ADDRESS *> IsolateGatewayAddresses +( + PIP_ADAPTER_GATEWAY_ADDRESS_LH head, + ADDRESS_FAMILY family +) +{ + std::vector<const SOCKET_ADDRESS *> matches; + + for (auto gateway = head; nullptr != gateway; gateway = gateway->Next) + { + if (family == gateway->Address.lpSockaddr->sa_family) + { + matches.emplace_back(&gateway->Address); + } + } + + return matches; +} + +bool AddressPresent(const std::vector<const SOCKET_ADDRESS *> &hay, const SOCKADDR_INET *needle) +{ + for (const auto candidate : hay) + { + if (EqualAddress(needle, candidate)) + { + return true; + } + } + + return false; +} + +//NodeAddress ConvertSocketAddress(const SOCKET_ADDRESS *sa) +//{ +// NodeAddress out = { 0 }; +// +// switch (sa->lpSockaddr->sa_family) +// { +// case AF_INET: +// { +// out.si_family = AF_INET; +// out.Ipv4 = *reinterpret_cast<SOCKADDR_IN *>(sa->lpSockaddr); +// +// break; +// } +// case AF_INET6: +// { +// out.si_family = AF_INET6; +// out.Ipv6 = *reinterpret_cast<SOCKADDR_IN6 *>(sa->lpSockaddr); +// +// break; +// } +// default: +// { +// throw std::runtime_error("Missing case handler in switch clause"); +// } +// }; +// +// return out; +//} + +} diff --git a/windows/winnet/src/winnet/routing/helpers.h b/windows/winnet/src/winnet/routing/helpers.h new file mode 100644 index 0000000000..3ef5e85b75 --- /dev/null +++ b/windows/winnet/src/winnet/routing/helpers.h @@ -0,0 +1,46 @@ +#pragma once + +#include "types.h" +#include <vector> + +namespace winnet::routing +{ + +bool EqualAddress(const Network &lhs, const Network &rhs); +bool EqualAddress(const NodeAddress &lhs, const NodeAddress &rhs); +bool EqualAddress(const SOCKADDR_INET *lhs, const SOCKET_ADDRESS *rhs); + +bool GetAdapterInterface(NET_LUID adapter, ADDRESS_FAMILY addressFamily, MIB_IPINTERFACE_ROW *iface); + +struct AnnotatedRoute +{ + const MIB_IPFORWARD_ROW2 *route; + bool active; + uint32_t effectiveMetric; +}; + +template<typename T> +bool bool_cast(const T &value) +{ + return 0 != value; +} + +std::vector<AnnotatedRoute> AnnotateRoutes(const std::vector<const MIB_IPFORWARD_ROW2 *> &routes); + +bool RouteHasGateway(const MIB_IPFORWARD_ROW2 &route); + +InterfaceAndGateway GetBestDefaultRoute(ADDRESS_FAMILY family); + +bool AdapterInterfaceEnabled(const IP_ADAPTER_ADDRESSES *adapter, ADDRESS_FAMILY family); + +std::vector<const SOCKET_ADDRESS *> IsolateGatewayAddresses +( + PIP_ADAPTER_GATEWAY_ADDRESS_LH head, + ADDRESS_FAMILY family +); + +bool AddressPresent(const std::vector<const SOCKET_ADDRESS *> &hay, const SOCKADDR_INET *needle); + +//NodeAddress ConvertSocketAddress(const SOCKET_ADDRESS *sa); + +} diff --git a/windows/winnet/src/winnet/routing/routemanager.cpp b/windows/winnet/src/winnet/routing/routemanager.cpp new file mode 100644 index 0000000000..668e64bb68 --- /dev/null +++ b/windows/winnet/src/winnet/routing/routemanager.cpp @@ -0,0 +1,692 @@ +#include "stdafx.h" +#include "routemanager.h" +#include "helpers.h" +#include <libcommon/error.h> +#include <libcommon/memory.h> +#include <libcommon/string.h> +#include <libcommon/network/adapters.h> +#include <vector> +#include <algorithm> +#include <numeric> +#include <sstream> +#include <stdexcept> + +using AutoLockType = std::scoped_lock<std::mutex>; +using AutoRecursiveLockType = std::scoped_lock<std::recursive_mutex>; +using namespace std::placeholders; + +namespace winnet::routing +{ + +namespace +{ + +using Adapters = common::network::Adapters; + +NET_LUID InterfaceLuidFromGateway(const NodeAddress &gateway) +{ + const DWORD adapterFlags = GAA_FLAG_SKIP_ANYCAST | GAA_FLAG_SKIP_MULTICAST | GAA_FLAG_SKIP_DNS_SERVER + | GAA_FLAG_SKIP_FRIENDLY_NAME | GAA_FLAG_INCLUDE_GATEWAYS; + + Adapters adapters(gateway.si_family, adapterFlags); + + // + // Process adapters to find matching ones. + // + + std::vector<const IP_ADAPTER_ADDRESSES *> matches; + + for (auto adapter = adapters.next(); nullptr != adapter; adapter = adapters.next()) + { + if (false == AdapterInterfaceEnabled(adapter, gateway.si_family)) + { + continue; + } + + auto gateways = IsolateGatewayAddresses(adapter->FirstGatewayAddress, gateway.si_family); + + if (AddressPresent(gateways, &gateway)) + { + matches.emplace_back(adapter); + } + } + + if (matches.empty()) + { + throw std::runtime_error("Unable to find network adapter with specified gateway"); + } + + // + // Sort matching interfaces ascending by metric. + // + + const bool targetV4 = (AF_INET == gateway.si_family); + + std::sort(matches.begin(), matches.end(), [&targetV4](const IP_ADAPTER_ADDRESSES *lhs, const IP_ADAPTER_ADDRESSES *rhs) + { + if (targetV4) + { + return lhs->Ipv4Metric < rhs->Ipv4Metric; + } + + return lhs->Ipv6Metric < rhs->Ipv6Metric; + }); + + // + // Select the interface with the best (lowest) metric. + // + + return matches[0]->Luid; +} + +bool ParseStringEncodedLuid(const std::wstring &encodedLuid, NET_LUID &luid) +{ + // + // The `#` is a valid character in adapter names so we use `?` instead. + // The LUID is thus prefixed with `?` and hex encoded and left-padded with zeroes. + // E.g. `?deadbeefcafebabe` or `?000dbeefcafebabe`. + // + + static const size_t StringEncodedLuidLength = 17; + + if (encodedLuid.size() != StringEncodedLuidLength + || L'?' != encodedLuid[0]) + { + return false; + } + + try + { + std::wstringstream ss; + + ss << std::hex << &encodedLuid[1]; + ss >> luid.Value; + } + catch (...) + { + const auto ansi = common::string::ToAnsi(encodedLuid); + const auto err = std::string("Failed to parse string encoded LUID: ").append(ansi); + + std::throw_with_nested(std::runtime_error(err)); + } + + return true; +} + +InterfaceAndGateway ResolveNode(ADDRESS_FAMILY family, const std::optional<Node> &optionalNode) +{ + // + // There are four cases: + // + // Unspecified node (use interface and gateway of default route). + // Node is specified by name. + // Node is specified by name and gateway. + // Node is specified by gateway. + // + + if (false == optionalNode.has_value()) + { + return GetBestDefaultRoute(family); + } + + const auto &node = optionalNode.value(); + + if (node.deviceName().has_value()) + { + const auto &deviceName = node.deviceName().value(); + NET_LUID luid; + + if (false == ParseStringEncodedLuid(deviceName, luid) + && 0 != ConvertInterfaceAliasToLuid(deviceName.c_str(), &luid)) + { + const auto ansiName = common::string::ToAnsi(deviceName); + const auto err = std::string("Unable to derive interface LUID from interface alias: ").append(ansiName); + + throw std::runtime_error(err); + } + + auto onLinkProvider = [&family]() + { + NodeAddress onLink = { 0 }; + onLink.si_family = family; + + return onLink; + }; + + return InterfaceAndGateway{ luid, node.gateway().value_or(onLinkProvider()) }; + } + + // + // The node is specified only by gateway. + // + + return InterfaceAndGateway{ InterfaceLuidFromGateway(node.gateway().value()), node.gateway().value() }; +} + +// TODO: Move to libcommon +uint32_t ByteSwap(uint32_t val) +{ + return + ( + ((val & 0xFF) << 24) | + ((val & 0xFF00) << 8) | + ((val & 0xFF0000) >> 8) | + ((val & 0xFF000000) >> 24) + ); +} + +std::wstring FormatNetwork(const Network &network) +{ + switch (network.Prefix.si_family) + { + case AF_INET: + { + return common::string::FormatIpv4(ByteSwap(network.Prefix.Ipv4.sin_addr.s_addr), network.PrefixLength); + } + case AF_INET6: + { + return common::string::FormatIpv6(network.Prefix.Ipv6.sin6_addr.u.Byte, network.PrefixLength); + } + default: + { + return L"Failed to format network details"; + } + } +} + +} // anonymous namespace + +RouteManager::RouteManager(std::shared_ptr<common::logging::ILogSink> logSink) + : m_logSink(logSink) + , m_routeMonitorV4(std::make_unique<DefaultRouteMonitor>( + static_cast<ADDRESS_FAMILY>(AF_INET), + std::bind(&RouteManager::defaultRouteChanged, this, static_cast<ADDRESS_FAMILY>(AF_INET), _1, _2), + logSink + )) + , m_routeMonitorV6(std::make_unique<DefaultRouteMonitor>( + static_cast<ADDRESS_FAMILY>(AF_INET6), + std::bind(&RouteManager::defaultRouteChanged, this, static_cast<ADDRESS_FAMILY>(AF_INET6), _1, _2), + logSink + )) +{ +} + +RouteManager::~RouteManager() +{ + // + // Stop callbacks that are triggered by events in Windows from coming in. + // + + m_routeMonitorV4.reset(); + m_routeMonitorV6.reset(); + + // + // Delete all routes owned by us. + // + + for (const auto &record : m_routes) + { + try + { + deleteFromRoutingTable(record.registeredRoute); + } + catch (const std::exception &ex) + { + std::wstringstream ss; + + ss << L"Failed to delete route as part of cleaning up, Route: " + << FormatRegisteredRoute(record.registeredRoute); + + m_logSink->error(common::string::ToAnsi(ss.str()).c_str()); + m_logSink->error(ex.what()); + } + } +} + +void RouteManager::addRoutes(const std::vector<Route> &routes) +{ + AutoLockType lock(m_routesLock); + + std::vector<EventEntry> eventLog; + + for (const auto &route : routes) + { + try + { + auto record = findRouteRecord(route); + + if (record != m_routes.end()) + { + deleteFromRoutingTable(record->registeredRoute); + eventLog.emplace_back(EventEntry{ EventType::DELETE_ROUTE, *record }); + m_routes.erase(record); + } + + const RouteRecord newRecord { route, addIntoRoutingTable(route) }; + + eventLog.emplace_back(EventEntry{ EventType::ADD_ROUTE, newRecord }); + m_routes.emplace_back(std::move(newRecord)); + } + catch (...) + { + undoEvents(eventLog); + + std::throw_with_nested(std::runtime_error("Failed during batch insertion of routes")); + } + } +} + +void RouteManager::addRoute(const Route &route) +{ + AutoLockType lock(m_routesLock); + + std::optional<RouteRecord> deletedRecord; + + auto record = findRouteRecord(route); + + if (record != m_routes.end()) + { + try + { + deleteFromRoutingTable(record->registeredRoute); + } + catch (...) + { + std::throw_with_nested(std::runtime_error("Failed to evict old route when adding new route")); + } + + deletedRecord = *record; + m_routes.erase(record); + } + + try + { + m_routes.emplace_back + ( + RouteRecord{ route, addIntoRoutingTable(route) } + ); + } + catch (...) + { + // + // Restore deleted record. + // + + if (deletedRecord.has_value()) + { + auto &r = deletedRecord.value(); + + try + { + restoreIntoRoutingTable(r.registeredRoute); + m_routes.emplace_back(r); + } + catch (const std::exception &ex) + { + const auto err = std::string("Failed to restore evicted route during rollback: ").append(ex.what()); + m_logSink->error(err.c_str()); + } + } + + // + // Just rethrow because the error is from addIntoRoutingTable(). + // + + throw; + } +} + +void RouteManager::deleteRoutes(const std::vector<Route> &routes) +{ + AutoLockType lock(m_routesLock); + + std::vector<EventEntry> eventLog; + + for (const auto &route : routes) + { + try + { + auto record = findRouteRecord(route); + + if (m_routes.end() == record) + { + const auto err = std::wstring(L"Request to delete previously unregistered route: ") + .append(FormatNetwork(route.network())); + + m_logSink->warning(common::string::ToAnsi(err).c_str()); + + continue; + } + + deleteFromRoutingTable(record->registeredRoute); + eventLog.emplace_back(EventEntry{ EventType::DELETE_ROUTE, *record }); + m_routes.erase(record); + } + catch (...) + { + undoEvents(eventLog); + + std::throw_with_nested(std::runtime_error("Failed during batch removal of routes")); + } + } +} + +void RouteManager::deleteRoute(const Route &route) +{ + AutoLockType lock(m_routesLock); + + auto record = findRouteRecord(route); + + if (m_routes.end() == record) + { + const auto err = std::wstring(L"Request to delete previously unregistered route: ") + .append(FormatNetwork(route.network())); + + m_logSink->warning(common::string::ToAnsi(err).c_str()); + + return; + } + + deleteFromRoutingTable(record->registeredRoute); + m_routes.erase(record); +} + +RouteManager::CallbackHandle RouteManager::registerDefaultRouteChangedCallback(DefaultRouteChangedCallback callback) +{ + AutoRecursiveLockType lock(m_defaultRouteCallbacksLock); + + m_defaultRouteCallbacks.emplace_back(callback); + + // Return raw address of record in list. + return &m_defaultRouteCallbacks.back(); +} + +void RouteManager::unregisterDefaultRouteChangedCallback(CallbackHandle handle) +{ + AutoRecursiveLockType lock(m_defaultRouteCallbacksLock); + + for (auto it = m_defaultRouteCallbacks.begin(); it != m_defaultRouteCallbacks.end(); ++it) + { + // Match on raw address of record. + if (&*it == handle) + { + m_defaultRouteCallbacks.erase(it); + return; + } + } +} + +std::list<RouteManager::RouteRecord>::iterator RouteManager::findRouteRecord(const Network &network) +{ + return std::find_if(m_routes.begin(), m_routes.end(), [&network](const auto &candidate) + { + return EqualAddress(network, candidate.route.network()); + }); +} + +std::list<RouteManager::RouteRecord>::iterator RouteManager::findRouteRecord(const Route &route) +{ + return findRouteRecord(route.network()); +} + +RouteManager::RegisteredRoute RouteManager::addIntoRoutingTable(const Route &route) +{ + const auto node = ResolveNode(route.network().Prefix.si_family, route.node()); + + MIB_IPFORWARD_ROW2 spec; + + InitializeIpForwardEntry(&spec); + + spec.InterfaceLuid = node.iface; + spec.DestinationPrefix = route.network(); + spec.NextHop = node.gateway; + spec.Metric = 0; + spec.Protocol = MIB_IPPROTO_NETMGMT; + spec.Origin = NlroManual; + + // + // Do not treat ERROR_OBJECT_ALREADY_EXISTS as being successful. + // Because it may not take route metric into consideration. + // + + THROW_UNLESS(NO_ERROR, CreateIpForwardEntry2(&spec), "Register route in routing table"); + + return RegisteredRoute { route.network(), node.iface, node.gateway }; +} + +void RouteManager::restoreIntoRoutingTable(const RegisteredRoute &route) +{ + MIB_IPFORWARD_ROW2 spec; + + InitializeIpForwardEntry(&spec); + + spec.InterfaceLuid = route.luid; + spec.DestinationPrefix = route.network; + spec.NextHop = route.nextHop; + spec.Metric = 0; + spec.Protocol = MIB_IPPROTO_NETMGMT; + spec.Origin = NlroManual; + + THROW_UNLESS(NO_ERROR, CreateIpForwardEntry2(&spec), "Register route in routing table"); +} + +void RouteManager::deleteFromRoutingTable(const RegisteredRoute &route) +{ + MIB_IPFORWARD_ROW2 r = { 0}; + + r.InterfaceLuid = route.luid; + r.DestinationPrefix = route.network; + r.NextHop = route.nextHop; + + auto status = DeleteIpForwardEntry2(&r); + + if (ERROR_NOT_FOUND == status) + { + status = NO_ERROR; + + const auto err = std::wstring(L"Attempting to delete route which was not present in routing table, " \ + "ignoring and proceeding. Route: ").append(FormatRegisteredRoute(route)); + + m_logSink->warning(common::string::ToAnsi(err).c_str()); + } + + THROW_UNLESS(NO_ERROR, status, "Delete route in routing table"); +} + +void RouteManager::undoEvents(const std::vector<EventEntry> &eventLog) +{ + // + // Rewind state by processing events in the reverse order. + // + + for (auto it = eventLog.rbegin(); it != eventLog.rend(); ++it) + { + try + { + switch (it->type) + { + case EventType::ADD_ROUTE: + { + auto record = findRouteRecord(it->record.route); + + if (m_routes.end() == record) + { + throw std::runtime_error("Internal state inconsistency in route manager"); + } + + deleteFromRoutingTable(record->registeredRoute); + m_routes.erase(record); + + break; + } + case EventType::DELETE_ROUTE: + { + restoreIntoRoutingTable(it->record.registeredRoute); + m_routes.emplace_back(it->record); + + break; + } + default: + { + throw std::logic_error("Missing case handler in switch clause"); + } + } + } + catch (const std::exception &ex) + { + const auto err = std::string("Attempting to rollback state: ").append(ex.what()); + m_logSink->error(err.c_str()); + } + } +} + +// static +std::wstring RouteManager::FormatRegisteredRoute(const RegisteredRoute &route) +{ + // + // TODO: Fix broken IP formatting + // Update FormatIpv4 function with an additional argument to specify network/host byte order. + // + + std::wstringstream ss; + + if (AF_INET == route.network.Prefix.si_family) + { + std::wstring gateway(L"\"On-link\""); + + if (0 != route.nextHop.Ipv4.sin_addr.s_addr) + { + gateway = common::string::FormatIpv4(ByteSwap(route.nextHop.Ipv4.sin_addr.s_addr)); + } + + ss << common::string::FormatIpv4(ByteSwap(route.network.Prefix.Ipv4.sin_addr.s_addr), route.network.PrefixLength) + << L" with gateway " << gateway + << L" on interface with LUID 0x" << std::hex << route.luid.Value; + } + else if (AF_INET6 == route.network.Prefix.si_family) + { + std::wstring gateway(L"\"On-link\""); + + const uint8_t *begin = &route.nextHop.Ipv6.sin6_addr.u.Byte[0]; + const uint8_t *end = begin + 16; + + if (0 != std::accumulate(begin, end, 0)) + { + gateway = common::string::FormatIpv6(route.nextHop.Ipv6.sin6_addr.u.Byte); + } + + ss << common::string::FormatIpv6(route.network.Prefix.Ipv6.sin6_addr.u.Byte, route.network.PrefixLength) + << L" with gateway " << gateway + << L" on interface with LUID 0x" << std::hex << route.luid.Value; + } + else + { + ss << L"Failed to format route details"; + } + + return ss.str(); +} + +void RouteManager::defaultRouteChanged(ADDRESS_FAMILY family, DefaultRouteMonitor::EventType eventType, + const std::optional<InterfaceAndGateway> &route) +{ + // + // Forward event to all registered listeners. + // + + m_defaultRouteCallbacksLock.lock(); + + for (const auto &callback : m_defaultRouteCallbacks) + { + try + { + callback(eventType, family, route); + } + catch (const std::exception &ex) + { + const auto msg = std::string("Failure in default-route-changed callback: ").append(ex.what()); + m_logSink->error(msg.c_str()); + } + catch (...) + { + m_logSink->error("Unspecified failure in default-route-changed callback"); + } + } + + m_defaultRouteCallbacksLock.unlock(); + + // + // Examine event to determine if best default route has changed. + // + + if (DefaultRouteMonitor::EventType::Updated != eventType) + { + return; + } + + // + // Examine our routes to see if any of them are policy bound to the best default route. + // + + AutoLockType routesLock(m_routesLock); + + using RecordIterator = std::list<RouteRecord>::iterator; + + std::list<RecordIterator> affectedRoutes; + + for (RecordIterator it = m_routes.begin(); it != m_routes.end(); ++it) + { + if (false == it->route.node().has_value() + && family == it->route.network().Prefix.si_family) + { + affectedRoutes.emplace_back(it); + } + } + + if (affectedRoutes.empty()) + { + return; + } + + // + // Update all affected routes. + // + + m_logSink->info("Best default route has changed. Refreshing dependent routes"); + + for (auto &it : affectedRoutes) + { + try + { + deleteFromRoutingTable(it->registeredRoute); + } + catch (const std::exception &ex) + { + const auto msg = std::string("Failed to delete route when refreshing " \ + "existing routes: ").append(ex.what()); + + m_logSink->error(msg.c_str()); + + continue; + } + + it->registeredRoute.luid = route.value().iface; + it->registeredRoute.nextHop = route.value().gateway; + + try + { + restoreIntoRoutingTable(it->registeredRoute); + } + catch (const std::exception &ex) + { + const auto msg = std::string("Failed to add route when refreshing " \ + "existing routes: ").append(ex.what()); + + m_logSink->error(msg.c_str()); + + continue; + } + } +} + +} diff --git a/windows/winnet/src/winnet/routing/routemanager.h b/windows/winnet/src/winnet/routing/routemanager.h new file mode 100644 index 0000000000..981c8e6834 --- /dev/null +++ b/windows/winnet/src/winnet/routing/routemanager.h @@ -0,0 +1,112 @@ +#pragma once + +#include <string> +#include <memory> +#include <vector> +#include <list> +#include <optional> +#include <mutex> +#include <functional> +#include <windows.h> +#include <ws2def.h> +#include <ifdef.h> +#include <libcommon/string.h> +#include <libcommon/logging/ilogsink.h> +#include "defaultroutemonitor.h" + +namespace winnet::routing +{ + +class RouteManager +{ +public: + + RouteManager(std::shared_ptr<common::logging::ILogSink> logSink); + ~RouteManager(); + + RouteManager(const RouteManager &) = delete; + RouteManager(RouteManager &&) = default; + RouteManager &operator=(const RouteManager &) = delete; + RouteManager &operator=(RouteManager &&) = delete; + + void addRoutes(const std::vector<Route> &routes); + void addRoute(const Route &route); + + void deleteRoutes(const std::vector<Route> &routes); + void deleteRoute(const Route &route); + + using DefaultRouteChangedEventType = DefaultRouteMonitor::EventType; + + using DefaultRouteChangedCallback = std::function<void + ( + DefaultRouteChangedEventType eventType, + ADDRESS_FAMILY family, + + // For update events, data associated with the new best default route. + const std::optional<InterfaceAndGateway> &route + )>; + + using CallbackHandle = void*; + + CallbackHandle registerDefaultRouteChangedCallback(DefaultRouteChangedCallback callback); + void unregisterDefaultRouteChangedCallback(CallbackHandle handle); + +private: + + std::shared_ptr<common::logging::ILogSink> m_logSink; + + std::unique_ptr<DefaultRouteMonitor> m_routeMonitorV4; + std::unique_ptr<DefaultRouteMonitor> m_routeMonitorV6; + + // These are the exact details derived from the route specification (`Route`). + // They are used when registering and deleting a route in the system. + struct RegisteredRoute + { + Network network; + NET_LUID luid; + NodeAddress nextHop; + }; + + struct RouteRecord + { + Route route; + RegisteredRoute registeredRoute; + }; + + std::list<RouteRecord> m_routes; + std::mutex m_routesLock; + + std::list<DefaultRouteChangedCallback> m_defaultRouteCallbacks; + std::recursive_mutex m_defaultRouteCallbacksLock; + + // Find record based on destination and mask. + std::list<RouteRecord>::iterator findRouteRecord(const Network &network); + + // Note: Same as above! + std::list<RouteRecord>::iterator findRouteRecord(const Route &route); + + RegisteredRoute addIntoRoutingTable(const Route &route); + void restoreIntoRoutingTable(const RegisteredRoute &route); + void deleteFromRoutingTable(const RegisteredRoute &route); + + enum class EventType + { + ADD_ROUTE, + DELETE_ROUTE, + }; + + struct EventEntry + { + EventType type; + RouteRecord record; + }; + + void undoEvents(const std::vector<EventEntry> &eventLog); + + static std::wstring FormatRegisteredRoute(const RegisteredRoute &route); + + void defaultRouteChanged(ADDRESS_FAMILY family, DefaultRouteMonitor::EventType eventType, + const std::optional<InterfaceAndGateway> &route); +}; + +} diff --git a/windows/winnet/src/winnet/routing/types.cpp b/windows/winnet/src/winnet/routing/types.cpp new file mode 100644 index 0000000000..ac71c8108f --- /dev/null +++ b/windows/winnet/src/winnet/routing/types.cpp @@ -0,0 +1,84 @@ +#include "stdafx.h" +#include "types.h" +#include "helpers.h" +#include <libcommon/string.h> + +namespace winnet::routing +{ + +Node::Node(const std::optional<std::wstring> &deviceName, const std::optional<NodeAddress> &gateway) + : m_deviceName(deviceName) + , m_gateway(gateway) +{ + if (false == m_deviceName.has_value() && false == m_gateway.has_value()) + { + throw std::runtime_error("Invalid node definition"); + } + + if (m_deviceName.has_value()) + { + const auto trimmed = common::string::Trim<>(m_deviceName.value()); + + if (trimmed.empty()) + { + throw std::runtime_error("Invalid device name in node definition"); + } + + m_deviceName = std::move(trimmed); + } +} + +bool Node::operator==(const Node &rhs) const +{ + if (m_deviceName.has_value()) + { + if (false == rhs.m_deviceName.has_value() + || 0 != _wcsicmp(m_deviceName.value().c_str(), rhs.deviceName().value().c_str())) + { + return false; + } + } + + if (m_gateway.has_value()) + { + if (false == rhs.m_gateway.has_value() + || false == EqualAddress(m_gateway.value(), rhs.gateway().value())) + { + return false; + } + } + + return true; +} + +Route::Route(const Network &network, const std::optional<Node> &node) + : m_network(network) + , m_node(node) +{ +} + +bool Route::operator==(const Route &rhs) const +{ + if (m_node.has_value()) + { + return rhs.node().has_value() + && EqualAddress(m_network, rhs.network()) + && m_node.value() == rhs.node().value(); + } + + return false == rhs.node().has_value() + && EqualAddress(m_network, rhs.network()); +} + +bool InterfaceAndGateway::operator==(const InterfaceAndGateway &rhs) +{ + return iface.Value == rhs.iface.Value + && EqualAddress(gateway, rhs.gateway); +} + +bool InterfaceAndGateway::operator!=(const InterfaceAndGateway &rhs) +{ + return !(*this == rhs); +} + +} diff --git a/windows/winnet/src/winnet/routing/types.h b/windows/winnet/src/winnet/routing/types.h new file mode 100644 index 0000000000..1e132feb00 --- /dev/null +++ b/windows/winnet/src/winnet/routing/types.h @@ -0,0 +1,77 @@ +#pragma once + +#include <string> +#include <optional> +#include <winsock2.h> +#include <windows.h> +#include <ws2def.h> +#include <ws2ipdef.h> +#include <iphlpapi.h> +//#include <netioapi.h> +//#include <functional> + + +namespace winnet::routing +{ + +using Network = IP_ADDRESS_PREFIX; +using NodeAddress = SOCKADDR_INET; + +class Node +{ +public: + + Node(const std::optional<std::wstring> &deviceName, const std::optional<NodeAddress> &gateway); + + const std::optional<std::wstring> &deviceName() const + { + return m_deviceName; + } + + const std::optional<NodeAddress> &gateway() const + { + return m_gateway; + } + + bool operator==(const Node &rhs) const; + +private: + + std::optional<std::wstring> m_deviceName; + std::optional<NodeAddress> m_gateway; +}; + +class Route +{ +public: + + Route(const Network &network, const std::optional<Node> &node); + + const Network &network() const + { + return m_network; + } + + const std::optional<Node> &node() const + { + return m_node; + } + + bool operator==(const Route &rhs) const; + +private: + + Network m_network; + std::optional<Node> m_node; +}; + +struct InterfaceAndGateway +{ + NET_LUID iface; + NodeAddress gateway; + + bool operator==(const InterfaceAndGateway &rhs); + bool operator!=(const InterfaceAndGateway &rhs); +}; + +} diff --git a/windows/winnet/src/winnet/winnet.cpp b/windows/winnet/src/winnet/winnet.cpp index 4b006964a6..48d12b5ea3 100644 --- a/windows/winnet/src/winnet/winnet.cpp +++ b/windows/winnet/src/winnet/winnet.cpp @@ -3,17 +3,135 @@ #include "NetworkInterfaces.h"
#include "interfaceutils.h"
#include "offlinemonitor.h"
+#include "routing/routemanager.h"
#include "../../shared/logsinkadapter.h"
#include <libcommon/error.h>
+#include <libcommon/network.h>
#include <cstdint>
#include <stdexcept>
#include <memory>
+#include <optional>
+#include <mutex>
+
+using namespace winnet::routing;
+using AutoLockType = std::scoped_lock<std::mutex>;
namespace
{
OfflineMonitor *g_OfflineMonitor = nullptr;
+std::mutex g_RouteManagerLock;
+RouteManager *g_RouteManager = nullptr;
+std::shared_ptr<shared::LogSinkAdapter> g_RouteManagerLogSink;
+
+Network ConvertNetwork(const WINNET_IPNETWORK &in)
+{
+ //
+ // Convert WINNET_IPNETWORK into Network aka IP_ADDRESS_PREFIX
+ //
+
+ Network out{ 0 };
+
+ out.PrefixLength = in.prefix;
+
+ switch (in.type)
+ {
+ case WINNET_IP_TYPE_IPV4:
+ {
+ out.Prefix.si_family = AF_INET;
+ out.Prefix.Ipv4.sin_family = AF_INET;
+ out.Prefix.Ipv4.sin_addr.s_addr = *reinterpret_cast<const uint32_t *>(in.bytes);
+
+ break;
+ }
+ case WINNET_IP_TYPE_IPV6:
+ {
+ out.Prefix.si_family = AF_INET6;
+ out.Prefix.Ipv6.sin6_family = AF_INET6;
+ memcpy(out.Prefix.Ipv6.sin6_addr.u.Byte, in.bytes, 16);
+
+ break;
+ }
+ default:
+ {
+ throw std::runtime_error("Missing case handler in switch clause");
+ }
+ }
+
+ return out;
+}
+
+std::optional<Node> ConvertNode(const WINNET_NODE *in)
+{
+ if (nullptr == in)
+ {
+ return {};
+ }
+
+ if (nullptr == in->deviceName && nullptr == in->gateway)
+ {
+ throw std::runtime_error("Invalid 'WINNET_NODE' definition");
+ }
+
+ std::optional<std::wstring> deviceName;
+ std::optional<NodeAddress> gateway;
+
+ if (nullptr != in->deviceName)
+ {
+ deviceName = in->deviceName;
+ }
+
+ if (nullptr != in->gateway)
+ {
+ NodeAddress gw { 0 };
+
+ switch (in->gateway->type)
+ {
+ case WINNET_IP_TYPE_IPV4:
+ {
+ gw.si_family = AF_INET;
+ gw.Ipv4.sin_addr.s_addr = *reinterpret_cast<const uint32_t *>(in->gateway->bytes);
+
+ break;
+ }
+ case WINNET_IP_TYPE_IPV6:
+ {
+ gw.si_family = AF_INET6;
+ memcpy(&gw.Ipv6.sin6_addr.u.Byte, in->gateway->bytes, 16);
+
+ break;
+ }
+ default:
+ {
+ throw std::logic_error("Invalid gateway type specifier in 'WINNET_NODE' definition");
+ }
+ }
+
+ gateway = gw;
+ }
+
+ return Node(deviceName, gateway);
+}
+
+std::vector<Route> ConvertRoutes(const WINNET_ROUTE *routes, uint32_t numRoutes)
+{
+ std::vector<Route> out;
+
+ out.reserve(numRoutes);
+
+ for (size_t i = 0; i < numRoutes; ++i)
+ {
+ out.emplace_back(Route
+ {
+ ConvertNetwork(routes[i].network),
+ ConvertNode(routes[i].node)
+ });
+ }
+
+ return out;
+}
+
void UnwindAndLog(MullvadLogSink logSink, void *logSinkContext, const std::exception &err)
{
if (nullptr == logSink)
@@ -26,6 +144,49 @@ void UnwindAndLog(MullvadLogSink logSink, void *logSinkContext, const std::excep common::error::UnwindException(err, logger);
}
+std::vector<SOCKADDR_INET> ConvertAddresses(const WINNET_IP *addresses, uint32_t numAddresses)
+{
+ //
+ // This duplicates the same logic we have above.
+ // TODO: Fix when time permits.
+ //
+
+ std::vector<SOCKADDR_INET> out;
+ out.reserve(numAddresses);
+
+ for (uint32_t i = 0; i < numAddresses; ++i)
+ {
+ const WINNET_IP &from = addresses[i];
+ SOCKADDR_INET to{ 0 };
+
+ switch (from.type)
+ {
+ case WINNET_IP_TYPE_IPV4:
+ {
+ to.si_family = AF_INET;
+ to.Ipv4.sin_addr.s_addr = *reinterpret_cast<const uint32_t *>(from.bytes);
+
+ break;
+ }
+ case WINNET_IP_TYPE_IPV6:
+ {
+ to.si_family = AF_INET6;
+ memcpy(&to.Ipv6.sin6_addr.u.Byte, from.bytes, 16);
+
+ break;
+ }
+ default:
+ {
+ throw std::logic_error("Invalid address family in 'WINNET_IP' definition");
+ }
+ }
+
+ out.push_back(to);
+ }
+
+ return out;
+}
+
} //anonymous namespace
extern "C"
@@ -66,12 +227,12 @@ WinNet_GetTapInterfaceIpv6Status( {
try
{
- MIB_IPINTERFACE_ROW interface = { 0 };
+ MIB_IPINTERFACE_ROW iface = { 0 };
- interface.InterfaceLuid = NetworkInterfaces::GetInterfaceLuid(InterfaceUtils::GetTapInterfaceAlias());
- interface.Family = AF_INET6;
+ iface.InterfaceLuid = NetworkInterfaces::GetInterfaceLuid(InterfaceUtils::GetTapInterfaceAlias());
+ iface.Family = AF_INET6;
- const auto status = GetIpInterfaceEntry(&interface);
+ const auto status = GetIpInterfaceEntry(&iface);
if (NO_ERROR == status)
{
@@ -201,3 +362,360 @@ WinNet_DeactivateConnectivityMonitor( {
}
}
+
+extern "C"
+WINNET_LINKAGE
+bool
+WINNET_API
+WinNet_ActivateRouteManager(
+ MullvadLogSink logSink,
+ void *logSinkContext
+)
+{
+ AutoLockType lock(g_RouteManagerLock);
+
+ try
+ {
+ if (nullptr != g_RouteManager)
+ {
+ throw std::runtime_error("Cannot activate route manager twice");
+ }
+
+ g_RouteManagerLogSink = std::make_shared<shared::LogSinkAdapter>(logSink, logSinkContext);
+ g_RouteManager = new RouteManager(g_RouteManagerLogSink);
+
+ return true;
+ }
+ catch (const std::exception &err)
+ {
+ UnwindAndLog(logSink, logSinkContext, err);
+ return false;
+ }
+ catch (...)
+ {
+ return false;
+ }
+}
+
+extern "C"
+WINNET_LINKAGE
+bool
+WINNET_API
+WinNet_AddRoutes(
+ const WINNET_ROUTE *routes,
+ uint32_t numRoutes
+)
+{
+ AutoLockType lock(g_RouteManagerLock);
+
+ if (nullptr == g_RouteManager)
+ {
+ return false;
+ }
+
+ try
+ {
+ g_RouteManager->addRoutes(ConvertRoutes(routes, numRoutes));
+ return true;
+ }
+ catch (const std::exception &err)
+ {
+ common::error::UnwindException(err, g_RouteManagerLogSink);
+ return false;
+ }
+ catch (...)
+ {
+ return false;
+ }
+}
+
+extern "C"
+WINNET_LINKAGE
+bool
+WINNET_API
+WinNet_AddRoute(
+ const WINNET_ROUTE *route
+)
+{
+ AutoLockType lock(g_RouteManagerLock);
+
+ if (nullptr == g_RouteManager)
+ {
+ return false;
+ }
+
+ try
+ {
+ g_RouteManager->addRoute
+ (
+ Route{ ConvertNetwork(route->network), ConvertNode(route->node) }
+ );
+
+ return true;
+ }
+ catch (const std::exception &err)
+ {
+ common::error::UnwindException(err, g_RouteManagerLogSink);
+ return false;
+ }
+ catch (...)
+ {
+ return false;
+ }
+}
+
+extern "C"
+WINNET_LINKAGE
+bool
+WINNET_API
+WinNet_DeleteRoutes(
+ const WINNET_ROUTE *routes,
+ uint32_t numRoutes
+)
+{
+ AutoLockType lock(g_RouteManagerLock);
+
+ if (nullptr == g_RouteManager)
+ {
+ return false;
+ }
+
+ try
+ {
+ g_RouteManager->deleteRoutes(ConvertRoutes(routes, numRoutes));
+ return true;
+ }
+ catch (const std::exception &err)
+ {
+ common::error::UnwindException(err, g_RouteManagerLogSink);
+ return false;
+ }
+ catch (...)
+ {
+ return false;
+ }
+}
+
+extern "C"
+WINNET_LINKAGE
+bool
+WINNET_API
+WinNet_DeleteRoute(
+ const WINNET_ROUTE *route
+)
+{
+ AutoLockType lock(g_RouteManagerLock);
+
+ if (nullptr == g_RouteManager)
+ {
+ return false;
+ }
+
+ try
+ {
+ g_RouteManager->deleteRoute
+ (
+ Route{ ConvertNetwork(route->network), ConvertNode(route->node) }
+ );
+
+ return true;
+ }
+ catch (const std::exception &err)
+ {
+ common::error::UnwindException(err, g_RouteManagerLogSink);
+ return false;
+ }
+ catch (...)
+ {
+ return false;
+ }
+}
+
+//
+// TODO: Move to libcommon.
+//
+struct ValueMapper
+{
+ template<typename T, typename U, std::size_t S>
+ static U map(T t, const std::pair<T, U> (&dictionary)[S])
+ {
+ for (const auto &entry : dictionary)
+ {
+ if (t == entry.first)
+ {
+ return entry.second;
+ }
+ }
+
+ throw std::runtime_error("Could not map between values");
+ }
+};
+
+extern "C"
+WINNET_LINKAGE
+bool
+WINNET_API
+WinNet_RegisterDefaultRouteChangedCallback(
+ WinNetDefaultRouteChangedCallback callback,
+ void *context,
+ void **registrationHandle
+)
+{
+ AutoLockType lock(g_RouteManagerLock);
+
+ if (nullptr == g_RouteManager)
+ {
+ return false;
+ }
+
+ try
+ {
+ auto forwarder = [callback, context](RouteManager::DefaultRouteChangedEventType eventType,
+ ADDRESS_FAMILY family, const std::optional<InterfaceAndGateway> &route)
+ {
+ //
+ // Translate the event type.
+ //
+
+ using from_t = RouteManager::DefaultRouteChangedEventType;
+ using to_t = WINNET_DEFAULT_ROUTE_CHANGED_EVENT_TYPE;
+
+ static const std::pair<from_t, to_t> eventTypeMap[] =
+ {
+ { from_t::Updated, WINNET_DEFAULT_ROUTE_CHANGED_EVENT_TYPE_UPDATED },
+ { from_t::Removed, WINNET_DEFAULT_ROUTE_CHANGED_EVENT_TYPE_REMOVED }
+ };
+
+ const auto translatedEventType = ValueMapper::map<>(eventType, eventTypeMap);
+
+ //
+ // Translate the family type.
+ //
+
+ static const std::pair<ADDRESS_FAMILY, WINNET_IP_FAMILY> familyMap[] =
+ {
+ { static_cast<ADDRESS_FAMILY>(AF_INET), WINNET_IP_FAMILY_V4 },
+ { static_cast<ADDRESS_FAMILY>(AF_INET6), WINNET_IP_FAMILY_V6 }
+ };
+
+ const auto translatedFamily = ValueMapper::map<>(family, familyMap);
+
+ //
+ // Determine which LUID to forward.
+ //
+
+ uint64_t translatedLuid = 0;
+
+ if (RouteManager::DefaultRouteChangedEventType::Updated == eventType)
+ {
+ translatedLuid = route.value().iface.Value;
+ }
+
+ //
+ // Forward to client.
+ //
+
+ callback(translatedEventType, translatedFamily, translatedLuid, context);
+ };
+
+ *registrationHandle = g_RouteManager->registerDefaultRouteChangedCallback(forwarder);
+
+ return true;
+ }
+ catch (const std::exception &err)
+ {
+ common::error::UnwindException(err, g_RouteManagerLogSink);
+ return false;
+ }
+ catch (...)
+ {
+ return false;
+ }
+}
+
+extern "C"
+WINNET_LINKAGE
+void
+WINNET_API
+WinNet_UnregisterDefaultRouteChangedCallback(
+ void *registrationHandle
+)
+{
+ AutoLockType lock(g_RouteManagerLock);
+
+ if (nullptr == g_RouteManager)
+ {
+ return;
+ }
+
+ try
+ {
+ g_RouteManager->unregisterDefaultRouteChangedCallback(registrationHandle);
+ }
+ catch (const std::exception &err)
+ {
+ g_RouteManagerLogSink->error("Failed to unregister default-route-changed callback");
+ common::error::UnwindException(err, g_RouteManagerLogSink);
+ }
+ catch (...)
+ {
+ }
+}
+
+extern "C"
+WINNET_LINKAGE
+void
+WINNET_API
+WinNet_DeactivateRouteManager(
+)
+{
+ AutoLockType lock(g_RouteManagerLock);
+
+ try
+ {
+ delete g_RouteManager;
+ g_RouteManager = nullptr;
+ }
+ catch (...)
+ {
+ }
+}
+
+extern "C"
+WINNET_LINKAGE
+bool
+WINNET_API
+WinNet_AddDeviceIpAddresses(
+ const wchar_t *deviceAlias,
+ const WINNET_IP *addresses,
+ uint32_t numAddresses,
+ MullvadLogSink logSink,
+ void *logSinkContext
+)
+{
+ try
+ {
+ NET_LUID luid;
+
+ if (0 != ConvertInterfaceAliasToLuid(deviceAlias, &luid))
+ {
+ const auto ansiName = common::string::ToAnsi(deviceAlias);
+ const auto err = std::string("Unable to derive interface LUID from interface alias: ").append(ansiName);
+
+ throw std::runtime_error(err);
+ }
+
+ InterfaceUtils::AddDeviceIpAddresses(luid, ConvertAddresses(addresses, numAddresses));
+
+ return true;
+ }
+ catch (const std::exception &err)
+ {
+ UnwindAndLog(logSink, logSinkContext, err);
+ return false;
+ }
+ catch (...)
+ {
+ return false;
+ }
+}
diff --git a/windows/winnet/src/winnet/winnet.def b/windows/winnet/src/winnet/winnet.def index 04c3f22ee3..b23ae6c854 100644 --- a/windows/winnet/src/winnet/winnet.def +++ b/windows/winnet/src/winnet/winnet.def @@ -6,3 +6,6 @@ EXPORTS WinNet_ReleaseString WinNet_ActivateConnectivityMonitor WinNet_DeactivateConnectivityMonitor + WinNet_ActivateRouteManager + WinNet_DeactivateRouteManager + WinNet_AddDeviceIpAddresses diff --git a/windows/winnet/src/winnet/winnet.h b/windows/winnet/src/winnet/winnet.h index 9b1af52e36..c7a161c3d8 100644 --- a/windows/winnet/src/winnet/winnet.h +++ b/windows/winnet/src/winnet/winnet.h @@ -1,6 +1,7 @@ #pragma once #include "../../shared/logsink.h" +#include <stdint.h> #include <stdbool.h> #ifndef WINNET_STATIC @@ -89,3 +90,147 @@ void WINNET_API WinNet_DeactivateConnectivityMonitor( ); + +enum WINNET_IP_TYPE +{ + WINNET_IP_TYPE_IPV4 = 0, + WINNET_IP_TYPE_IPV6 = 1, +}; + +typedef struct tag_WINNET_IPNETWORK +{ + WINNET_IP_TYPE type; + uint8_t bytes[16]; // Network byte order. + uint8_t prefix; +} +WINNET_IPNETWORK; + +typedef struct tag_WINNET_IP +{ + WINNET_IP_TYPE type; + uint8_t bytes[16]; // Network byte order. +} +WINNET_IP; + +typedef struct tag_WINNET_NODE +{ + const WINNET_IP *gateway; + const wchar_t *deviceName; +} +WINNET_NODE; + +typedef struct tag_WINNET_ROUTE +{ + WINNET_IPNETWORK network; + const WINNET_NODE *node; +} +WINNET_ROUTE; + +extern "C" +WINNET_LINKAGE +bool +WINNET_API +WinNet_ActivateRouteManager( + MullvadLogSink logSink, + void *logSinkContext +); + +extern "C" +WINNET_LINKAGE +bool +WINNET_API +WinNet_AddRoutes( + const WINNET_ROUTE *routes, + uint32_t numRoutes +); + +extern "C" +WINNET_LINKAGE +bool +WINNET_API +WinNet_AddRoute( + const WINNET_ROUTE *route +); + +extern "C" +WINNET_LINKAGE +bool +WINNET_API +WinNet_DeleteRoutes( + const WINNET_ROUTE *routes, + uint32_t numRoutes +); + +extern "C" +WINNET_LINKAGE +bool +WINNET_API +WinNet_DeleteRoute( + const WINNET_ROUTE *route +); + +enum WINNET_DEFAULT_ROUTE_CHANGED_EVENT_TYPE +{ + // Best default route changed. + WINNET_DEFAULT_ROUTE_CHANGED_EVENT_TYPE_UPDATED = 0, + + // No default routes exist. + WINNET_DEFAULT_ROUTE_CHANGED_EVENT_TYPE_REMOVED = 1, +}; + +enum WINNET_IP_FAMILY +{ + WINNET_IP_FAMILY_V4 = 0, + WINNET_IP_FAMILY_V6 = 1, +}; + +typedef void (WINNET_API *WinNetDefaultRouteChangedCallback) +( + WINNET_DEFAULT_ROUTE_CHANGED_EVENT_TYPE eventType, + + // Signals which IP family the event relates to. + WINNET_IP_FAMILY family, + + // For update events, signals the interface associated with the new best default route. + uint64_t interfaceLuid, + + void *context +); + +extern "C" +WINNET_LINKAGE +bool +WINNET_API +WinNet_RegisterDefaultRouteChangedCallback( + WinNetDefaultRouteChangedCallback callback, + void *context, + void **registrationHandle +); + +extern "C" +WINNET_LINKAGE +void +WINNET_API +WinNet_UnregisterDefaultRouteChangedCallback( + void *registrationHandle +); + +extern "C" +WINNET_LINKAGE +void +WINNET_API +WinNet_DeactivateRouteManager( +); + +extern "C" +WINNET_LINKAGE +bool +WINNET_API +WinNet_AddDeviceIpAddresses( + const wchar_t *deviceAlias, + const WINNET_IP *addresses, + uint32_t numAddresses, + MullvadLogSink logSink, + void *logSinkContext +); + diff --git a/windows/winnet/src/winnet/winnet.vcxproj b/windows/winnet/src/winnet/winnet.vcxproj index 192320daaf..5e71a1f733 100644 --- a/windows/winnet/src/winnet/winnet.vcxproj +++ b/windows/winnet/src/winnet/winnet.vcxproj @@ -33,6 +33,10 @@ <ClCompile Include="interfaceutils.cpp" /> <ClCompile Include="offlinemonitor.cpp" /> <ClCompile Include="NetworkInterfaces.cpp" /> + <ClCompile Include="routing\defaultroutemonitor.cpp" /> + <ClCompile Include="routing\helpers.cpp" /> + <ClCompile Include="routing\routemanager.cpp" /> + <ClCompile Include="routing\types.cpp" /> <ClCompile Include="stdafx.cpp" /> <ClCompile Include="winnet.cpp" /> </ItemGroup> @@ -42,6 +46,10 @@ <ClInclude Include="interfaceutils.h" /> <ClInclude Include="offlinemonitor.h" /> <ClInclude Include="NetworkInterfaces.h" /> + <ClInclude Include="routing\defaultroutemonitor.h" /> + <ClInclude Include="routing\helpers.h" /> + <ClInclude Include="routing\routemanager.h" /> + <ClInclude Include="routing\types.h" /> <ClInclude Include="stdafx.h" /> <ClInclude Include="targetver.h" /> <ClInclude Include="winnet.h" /> @@ -208,7 +216,7 @@ <ConformanceMode>true</ConformanceMode> <RuntimeLibrary>MultiThreadedDebug</RuntimeLibrary> <LanguageStandard>stdcpplatest</LanguageStandard> - <AdditionalIncludeDirectories>$(ProjectDir)..\..\..\windows-libraries\src\%(AdditionalIncludeDirectories)</AdditionalIncludeDirectories> + <AdditionalIncludeDirectories>$(ProjectDir)..\..\..\windows-libraries\src\;$(ProjectDir);%(AdditionalIncludeDirectories)</AdditionalIncludeDirectories> </ClCompile> <Link> <SubSystem>Windows</SubSystem> @@ -278,7 +286,7 @@ <ConformanceMode>true</ConformanceMode> <RuntimeLibrary>MultiThreaded</RuntimeLibrary> <LanguageStandard>stdcpplatest</LanguageStandard> - <AdditionalIncludeDirectories>$(ProjectDir)..\..\..\windows-libraries\src\%(AdditionalIncludeDirectories)</AdditionalIncludeDirectories> + <AdditionalIncludeDirectories>$(ProjectDir)..\..\..\windows-libraries\src\;$(ProjectDir);%(AdditionalIncludeDirectories)</AdditionalIncludeDirectories> </ClCompile> <Link> <SubSystem>Windows</SubSystem> diff --git a/windows/winnet/src/winnet/winnet.vcxproj.filters b/windows/winnet/src/winnet/winnet.vcxproj.filters index 9a901d3203..dfe6d29ec7 100644 --- a/windows/winnet/src/winnet/winnet.vcxproj.filters +++ b/windows/winnet/src/winnet/winnet.vcxproj.filters @@ -9,6 +9,18 @@ <ClCompile Include="interfaceutils.cpp" /> <ClCompile Include="networkadaptermonitor.cpp" /> <ClCompile Include="offlinemonitor.cpp" /> + <ClCompile Include="routing\types.cpp"> + <Filter>routing</Filter> + </ClCompile> + <ClCompile Include="routing\helpers.cpp"> + <Filter>routing</Filter> + </ClCompile> + <ClCompile Include="routing\defaultroutemonitor.cpp"> + <Filter>routing</Filter> + </ClCompile> + <ClCompile Include="routing\routemanager.cpp"> + <Filter>routing</Filter> + </ClCompile> </ItemGroup> <ItemGroup> <ClInclude Include="stdafx.h" /> @@ -19,6 +31,18 @@ <ClInclude Include="interfaceutils.h" /> <ClInclude Include="networkadaptermonitor.h" /> <ClInclude Include="offlinemonitor.h" /> + <ClInclude Include="routing\types.h"> + <Filter>routing</Filter> + </ClInclude> + <ClInclude Include="routing\helpers.h"> + <Filter>routing</Filter> + </ClInclude> + <ClInclude Include="routing\defaultroutemonitor.h"> + <Filter>routing</Filter> + </ClInclude> + <ClInclude Include="routing\routemanager.h"> + <Filter>routing</Filter> + </ClInclude> </ItemGroup> <ItemGroup> <None Include="winnet.def" /> @@ -26,4 +50,9 @@ <ItemGroup> <ResourceCompile Include="winnet.rc" /> </ItemGroup> + <ItemGroup> + <Filter Include="routing"> + <UniqueIdentifier>{8df22cc6-597f-4342-bc57-7647393084be}</UniqueIdentifier> + </Filter> + </ItemGroup> </Project>
\ No newline at end of file |
