diff options
| author | David Lönnhager <david.l@mullvad.net> | 2020-02-07 13:07:14 +0100 |
|---|---|---|
| committer | David Lönnhager <david.l@mullvad.net> | 2020-02-07 13:07:14 +0100 |
| commit | 0fb314e647ff54cfaa5dea036e4ca4a527fbeeaa (patch) | |
| tree | b5bfb95be5cb7e1fa94558725dc5cfcff1da6def | |
| parent | a6c3608e710503ac29cf51f721f1dc4e8c7e2724 (diff) | |
| parent | 052d425fb7bfea4ec97bc28e2c1959c86c6220ef (diff) | |
| download | mullvadvpn-0fb314e647ff54cfaa5dea036e4ca4a527fbeeaa.tar.xz mullvadvpn-0fb314e647ff54cfaa5dea036e4ca4a527fbeeaa.zip | |
Merge branch 'winnet-metric-update'
| -rw-r--r-- | talpid-core/src/firewall/windows.rs | 2 | ||||
| -rw-r--r-- | talpid-core/src/winnet.rs | 10 | ||||
| -rw-r--r-- | windows/winnet/src/winnet/InterfacePair.cpp | 68 | ||||
| -rw-r--r-- | windows/winnet/src/winnet/InterfacePair.h | 10 | ||||
| -rw-r--r-- | windows/winnet/src/winnet/NetworkInterfaces.cpp | 47 | ||||
| -rw-r--r-- | windows/winnet/src/winnet/NetworkInterfaces.h | 9 | ||||
| -rw-r--r-- | windows/winnet/src/winnet/stdafx.h | 1 | ||||
| -rw-r--r-- | windows/winnet/src/winnet/winnet.cpp | 12 | ||||
| -rw-r--r-- | windows/winnet/src/winnet/winnet.def | 2 | ||||
| -rw-r--r-- | windows/winnet/src/winnet/winnet.h | 12 |
10 files changed, 74 insertions, 99 deletions
diff --git a/talpid-core/src/firewall/windows.rs b/talpid-core/src/firewall/windows.rs index 38e388ecbf..89e880da7e 100644 --- a/talpid-core/src/firewall/windows.rs +++ b/talpid-core/src/firewall/windows.rs @@ -194,7 +194,7 @@ impl Firewall { protocol: WinFwProt::from(endpoint.protocol), }; - let metrics_set = winnet::ensure_top_metric_for_interface(&tunnel_metadata.interface) + let metrics_set = winnet::ensure_best_metric_for_interface(&tunnel_metadata.interface) .map_err(Error::SetTapMetric)?; if metrics_set { diff --git a/talpid-core/src/winnet.rs b/talpid-core/src/winnet.rs index 2eee4b205a..256695eb5c 100644 --- a/talpid-core/src/winnet.rs +++ b/talpid-core/src/winnet.rs @@ -35,12 +35,12 @@ fn logging_context() -> *const u8 { } /// Returns true if metrics were changed, false otherwise -pub fn ensure_top_metric_for_interface(interface_alias: &str) -> Result<bool, Error> { +pub fn ensure_best_metric_for_interface(interface_alias: &str) -> Result<bool, Error> { let interface_alias_ws = WideCString::from_str(interface_alias).map_err(Error::InvalidInterfaceAlias)?; let metric_result = unsafe { - WinNet_EnsureTopMetric( + WinNet_EnsureBestMetric( interface_alias_ws.as_ptr(), Some(log_sink), logging_context(), @@ -56,7 +56,7 @@ pub fn ensure_top_metric_for_interface(interface_alias: &str) -> Result<bool, Er 2 => Err(Error::MetricApplication), // Unexpected value i => { - log::error!("Unexpected return code from WinNet_EnsureTopMetric: {}", i); + log::error!("Unexpected return code from WinNet_EnsureBestMetric: {}", i); Err(Error::MetricApplication) } } @@ -365,8 +365,8 @@ mod api { #[link_name = "WinNet_DeactivateRouteManager"] pub fn WinNet_DeactivateRouteManager(); - #[link_name = "WinNet_EnsureTopMetric"] - pub fn WinNet_EnsureTopMetric( + #[link_name = "WinNet_EnsureBestMetric"] + pub fn WinNet_EnsureBestMetric( tunnel_interface_alias: *const wchar_t, sink: Option<LogSink>, sink_context: *const u8, diff --git a/windows/winnet/src/winnet/InterfacePair.cpp b/windows/winnet/src/winnet/InterfacePair.cpp index 239ea11c53..935bc557aa 100644 --- a/windows/winnet/src/winnet/InterfacePair.cpp +++ b/windows/winnet/src/winnet/InterfacePair.cpp @@ -2,6 +2,7 @@ #include "InterfacePair.h" #include <libcommon/error.h> #include <sstream> +#include <algorithm> #ifndef STATUS_NOT_FOUND #define STATUS_NOT_FOUND 0xC0000225 @@ -11,13 +12,13 @@ InterfacePair::InterfacePair(NET_LUID interface_luid) { IPv4Iface.Family = AF_INET; IPv4Iface.InterfaceLuid = interface_luid; - InitializeInterface(&IPv4Iface); + InitializeInterface(IPv4Iface); IPv6Iface.Family = AF_INET6; IPv6Iface.InterfaceLuid = interface_luid; - InitializeInterface(&IPv6Iface); + InitializeInterface(IPv6Iface); - if (!(HasIPv4() || HasIPv6())) + if (!HasIPv4() && !HasIPv6()) { std::stringstream ss; @@ -28,59 +29,66 @@ InterfacePair::InterfacePair(NET_LUID interface_luid) } } -int InterfacePair::HighestMetric() +int InterfacePair::WorstMetric() { - return IPv6Iface.Metric < IPv4Iface.Metric ? IPv4Iface.Metric : IPv6Iface.Metric; + return std::max(IPv6Iface.Metric, IPv4Iface.Metric); } -void InterfacePair::SetMetric(int metric) +int InterfacePair::BestMetric() { - if (HasIPv4()) - { + return std::min(IPv4Iface.Metric, IPv6Iface.Metric); +} + +void InterfacePair::SetMetric(uint32_t metric) +{ + if (HasIPv4() && (IPv4Iface.UseAutomaticMetric || metric != IPv4Iface.Metric)) + { IPv4Iface.SitePrefixLength = 0; IPv4Iface.Metric = metric; IPv4Iface.UseAutomaticMetric = false; - SetInterface(&IPv4Iface); + SetInterface(IPv4Iface); } - if (HasIPv6()) - { + if (HasIPv6() && (IPv6Iface.UseAutomaticMetric || metric != IPv6Iface.Metric)) + { IPv6Iface.Metric = metric; IPv6Iface.UseAutomaticMetric = false; - SetInterface(&IPv6Iface); + SetInterface(IPv6Iface); } } -void InterfacePair::SetInterface(PMIB_IPINTERFACE_ROW iface) { - - const auto status = SetIpInterfaceEntry(iface); +void InterfacePair::SetInterface(const MIB_IPINTERFACE_ROW &iface) +{ + MIB_IPINTERFACE_ROW row = iface; + const auto status = SetIpInterfaceEntry(&row); - if (status != NO_ERROR) - { - std::stringstream ss; + if (NO_ERROR != status) + { + std::stringstream ss; - ss << "Set metric for " - << (iface->Family == AF_INET ? "IPv4" : "IPv6") - << " on interface with LUID 0x" - << std::hex << iface->InterfaceLuid.Value; + ss << "Set metric for " + << (row.Family == AF_INET ? "IPv4" : "IPv6") + << " on interface with LUID 0x" + << std::hex << row.InterfaceLuid.Value; - THROW_WINDOWS_ERROR(status, ss.str().c_str()); - } + THROW_WINDOWS_ERROR(status, ss.str().c_str()); + } } bool InterfacePair::HasIPv4() { - return IPv4Iface.Family != AF_UNSPEC; + return AF_UNSPEC != IPv4Iface.Family; } bool InterfacePair::HasIPv6() { - return IPv6Iface.Family != AF_UNSPEC; + return AF_UNSPEC != IPv6Iface.Family; } -void InterfacePair::InitializeInterface(PMIB_IPINTERFACE_ROW iface) +//static +void InterfacePair::InitializeInterface(MIB_IPINTERFACE_ROW &iface) { - const auto status = GetIpInterfaceEntry(iface); + const auto status = GetIpInterfaceEntry(&iface); if (NO_ERROR == status) { @@ -89,14 +97,14 @@ void InterfacePair::InitializeInterface(PMIB_IPINTERFACE_ROW iface) if (STATUS_NOT_FOUND == status || ERROR_NOT_FOUND == status) { - iface->Family = AF_UNSPEC; + iface.Family = AF_UNSPEC; } else { std::stringstream ss; ss << "Retrieve info on network interface with LUID 0x" - << std::hex << iface->InterfaceLuid.Value; + << std::hex << iface.InterfaceLuid.Value; THROW_WINDOWS_ERROR(status, ss.str().c_str()); } diff --git a/windows/winnet/src/winnet/InterfacePair.h b/windows/winnet/src/winnet/InterfacePair.h index 9582bac3cd..55a3f59e90 100644 --- a/windows/winnet/src/winnet/InterfacePair.h +++ b/windows/winnet/src/winnet/InterfacePair.h @@ -4,22 +4,24 @@ #include <ws2ipdef.h> #include <iphlpapi.h> #include <netioapi.h> +#include <cstdint> class InterfacePair { public: InterfacePair(NET_LUID interface_luid); - int HighestMetric(); - void SetMetric(int metric); + int BestMetric(); + int WorstMetric(); + void SetMetric(uint32_t metric); private: MIB_IPINTERFACE_ROW IPv4Iface = { 0 }; MIB_IPINTERFACE_ROW IPv6Iface = { 0 }; - void InitializeInterface(PMIB_IPINTERFACE_ROW iface); + static void InitializeInterface(MIB_IPINTERFACE_ROW &iface); bool HasIPv4(); bool HasIPv6(); - void SetInterface(PMIB_IPINTERFACE_ROW iface); + void SetInterface(const MIB_IPINTERFACE_ROW &iface); }; diff --git a/windows/winnet/src/winnet/NetworkInterfaces.cpp b/windows/winnet/src/winnet/NetworkInterfaces.cpp index f364b6b4cd..dc27e95e92 100644 --- a/windows/winnet/src/winnet/NetworkInterfaces.cpp +++ b/windows/winnet/src/winnet/NetworkInterfaces.cpp @@ -7,7 +7,7 @@ #include <sstream> #include <cstdint> -bool NetworkInterfaces::HasHighestMetric(PMIB_IPINTERFACE_ROW targetIface) +bool NetworkInterfaces::HasBestMetric(PMIB_IPINTERFACE_ROW targetIface) { for (unsigned int i = 0; i < mInterfaces->NumEntries; ++i) { @@ -20,39 +20,6 @@ bool NetworkInterfaces::HasHighestMetric(PMIB_IPINTERFACE_ROW targetIface) return true; } -void NetworkInterfaces::EnsureIfaceMetricIsHighest(NET_LUID interfaceLuid) -{ - for (ULONG i = 0; i < mInterfaces->NumEntries; ++i) - { - PMIB_IPINTERFACE_ROW iface = &mInterfaces->Table[i]; - - // Ignoring the target interface. - if (iface->InterfaceLuid.Value == interfaceLuid.Value || iface->UseAutomaticMetric || iface->Metric > MAX_METRIC) - { - continue; - } - - iface->Metric++; - - if (AF_INET == iface->Family) - { - iface->SitePrefixLength = 0; - } - - const auto status = SetIpInterfaceEntry(iface); - - if (NO_ERROR != status) - { - std::stringstream ss; - - ss << "Failed to increment metric for interface with LUID 0x" - << std::hex << iface->InterfaceLuid.Value; - - THROW_WINDOWS_ERROR(status, ss.str().c_str()); - } - } -} - NetworkInterfaces::NetworkInterfaces() { mInterfaces = nullptr; @@ -65,21 +32,19 @@ NetworkInterfaces::NetworkInterfaces() } } -bool NetworkInterfaces::SetTopMetricForInterfacesByAlias(const wchar_t * deviceAlias) +bool NetworkInterfaces::SetBestMetricForInterfacesByAlias(const wchar_t * deviceAlias) { - return SetTopMetricForInterfacesWithLuid(GetInterfaceLuid(deviceAlias)); + return SetBestMetricForInterfacesWithLuid(GetInterfaceLuid(deviceAlias)); } -bool NetworkInterfaces::SetTopMetricForInterfacesWithLuid(NET_LUID targetIfaceId) +bool NetworkInterfaces::SetBestMetricForInterfacesWithLuid(NET_LUID targetIfaceId) { InterfacePair targetInterfaces = InterfacePair(targetIfaceId); - - if (targetInterfaces.HighestMetric() == MAX_METRIC) + if (BEST_METRIC == targetInterfaces.WorstMetric()) { return false; } - - targetInterfaces.SetMetric(MAX_METRIC); + targetInterfaces.SetMetric(BEST_METRIC); return true; } diff --git a/windows/winnet/src/winnet/NetworkInterfaces.h b/windows/winnet/src/winnet/NetworkInterfaces.h index bf1d53dddf..e987939454 100644 --- a/windows/winnet/src/winnet/NetworkInterfaces.h +++ b/windows/winnet/src/winnet/NetworkInterfaces.h @@ -13,20 +13,19 @@ class NetworkInterfaces private: PMIB_IPINTERFACE_TABLE mInterfaces; - bool HasHighestMetric(PMIB_IPINTERFACE_ROW targetIface); + bool HasBestMetric(PMIB_IPINTERFACE_ROW targetIface); public: NetworkInterfaces(const NetworkInterfaces &) = delete; NetworkInterfaces &operator=(const NetworkInterfaces &) = delete; - void EnsureIfaceMetricIsHighest(NET_LUID interfaceLuid); NetworkInterfaces(); - bool SetTopMetricForInterfacesByAlias(const wchar_t *deviceAlias); - bool SetTopMetricForInterfacesWithLuid(NET_LUID targetIface); + bool SetBestMetricForInterfacesByAlias(const wchar_t *deviceAlias); + bool SetBestMetricForInterfacesWithLuid(NET_LUID targetIface); ~NetworkInterfaces(); static NET_LUID GetInterfaceLuid(const std::wstring &interfaceAlias); const MIB_IPINTERFACE_ROW *GetInterface(NET_LUID interfaceLuid, ADDRESS_FAMILY interfaceFamily); }; -const static uint32_t MAX_METRIC = 1; +constexpr static uint32_t BEST_METRIC = 1; diff --git a/windows/winnet/src/winnet/stdafx.h b/windows/winnet/src/winnet/stdafx.h index 254cb49b0d..3115e02cd4 100644 --- a/windows/winnet/src/winnet/stdafx.h +++ b/windows/winnet/src/winnet/stdafx.h @@ -10,6 +10,7 @@ #include "targetver.h"
+#define NOMINMAX
#define WIN32_LEAN_AND_MEAN // Exclude rarely-used stuff from Windows headers
// Windows Header Files:
#include <windows.h>
diff --git a/windows/winnet/src/winnet/winnet.cpp b/windows/winnet/src/winnet/winnet.cpp index 7a206635ed..b71484bf7c 100644 --- a/windows/winnet/src/winnet/winnet.cpp +++ b/windows/winnet/src/winnet/winnet.cpp @@ -33,9 +33,9 @@ std::shared_ptr<shared::logging::LogSinkAdapter> g_RouteManagerLogSink; extern "C"
WINNET_LINKAGE
-WINNET_ETM_STATUS
+WINNET_EBM_STATUS
WINNET_API
-WinNet_EnsureTopMetric(
+WinNet_EnsureBestMetric(
const wchar_t *deviceAlias,
MullvadLogSink logSink,
void *logSinkContext
@@ -44,17 +44,17 @@ WinNet_EnsureTopMetric( try
{
NetworkInterfaces interfaces;
- bool metrics_set = interfaces.SetTopMetricForInterfacesByAlias(deviceAlias);
- return metrics_set ? WINNET_ETM_STATUS_METRIC_SET : WINNET_ETM_STATUS_METRIC_NO_CHANGE;
+ return interfaces.SetBestMetricForInterfacesByAlias(deviceAlias) ?
+ WINNET_EBM_STATUS_METRIC_SET : WINNET_EBM_STATUS_METRIC_NO_CHANGE;
}
catch (const std::exception &err)
{
shared::logging::UnwindAndLog(logSink, logSinkContext, err);
- return WINNET_ETM_STATUS_FAILURE;
+ return WINNET_EBM_STATUS_FAILURE;
}
catch (...)
{
- return WINNET_ETM_STATUS_FAILURE;
+ return WINNET_EBM_STATUS_FAILURE;
}
};
diff --git a/windows/winnet/src/winnet/winnet.def b/windows/winnet/src/winnet/winnet.def index b23ae6c854..ecec97959e 100644 --- a/windows/winnet/src/winnet/winnet.def +++ b/windows/winnet/src/winnet/winnet.def @@ -1,6 +1,6 @@ LIBRARY winnet EXPORTS - WinNet_EnsureTopMetric + WinNet_EnsureBestMetric WinNet_GetTapInterfaceIpv6Status WinNet_GetTapInterfaceAlias WinNet_ReleaseString diff --git a/windows/winnet/src/winnet/winnet.h b/windows/winnet/src/winnet/winnet.h index 5afc4052eb..5e2a4154a4 100644 --- a/windows/winnet/src/winnet/winnet.h +++ b/windows/winnet/src/winnet/winnet.h @@ -16,18 +16,18 @@ #define WINNET_API __stdcall -enum WINNET_ETM_STATUS +enum WINNET_EBM_STATUS { - WINNET_ETM_STATUS_METRIC_NO_CHANGE = 0, - WINNET_ETM_STATUS_METRIC_SET = 1, - WINNET_ETM_STATUS_FAILURE = 2, + WINNET_EBM_STATUS_METRIC_NO_CHANGE = 0, + WINNET_EBM_STATUS_METRIC_SET = 1, + WINNET_EBM_STATUS_FAILURE = 2, }; extern "C" WINNET_LINKAGE -WINNET_ETM_STATUS +WINNET_EBM_STATUS WINNET_API -WinNet_EnsureTopMetric( +WinNet_EnsureBestMetric( const wchar_t *deviceAlias, MullvadLogSink logSink, void *logSinkContext |
