summaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorDavid Lönnhager <david.l@mullvad.net>2020-02-07 13:07:14 +0100
committerDavid Lönnhager <david.l@mullvad.net>2020-02-07 13:07:14 +0100
commit0fb314e647ff54cfaa5dea036e4ca4a527fbeeaa (patch)
treeb5bfb95be5cb7e1fa94558725dc5cfcff1da6def
parenta6c3608e710503ac29cf51f721f1dc4e8c7e2724 (diff)
parent052d425fb7bfea4ec97bc28e2c1959c86c6220ef (diff)
downloadmullvadvpn-0fb314e647ff54cfaa5dea036e4ca4a527fbeeaa.tar.xz
mullvadvpn-0fb314e647ff54cfaa5dea036e4ca4a527fbeeaa.zip
Merge branch 'winnet-metric-update'
-rw-r--r--talpid-core/src/firewall/windows.rs2
-rw-r--r--talpid-core/src/winnet.rs10
-rw-r--r--windows/winnet/src/winnet/InterfacePair.cpp68
-rw-r--r--windows/winnet/src/winnet/InterfacePair.h10
-rw-r--r--windows/winnet/src/winnet/NetworkInterfaces.cpp47
-rw-r--r--windows/winnet/src/winnet/NetworkInterfaces.h9
-rw-r--r--windows/winnet/src/winnet/stdafx.h1
-rw-r--r--windows/winnet/src/winnet/winnet.cpp12
-rw-r--r--windows/winnet/src/winnet/winnet.def2
-rw-r--r--windows/winnet/src/winnet/winnet.h12
10 files changed, 74 insertions, 99 deletions
diff --git a/talpid-core/src/firewall/windows.rs b/talpid-core/src/firewall/windows.rs
index 38e388ecbf..89e880da7e 100644
--- a/talpid-core/src/firewall/windows.rs
+++ b/talpid-core/src/firewall/windows.rs
@@ -194,7 +194,7 @@ impl Firewall {
protocol: WinFwProt::from(endpoint.protocol),
};
- let metrics_set = winnet::ensure_top_metric_for_interface(&tunnel_metadata.interface)
+ let metrics_set = winnet::ensure_best_metric_for_interface(&tunnel_metadata.interface)
.map_err(Error::SetTapMetric)?;
if metrics_set {
diff --git a/talpid-core/src/winnet.rs b/talpid-core/src/winnet.rs
index 2eee4b205a..256695eb5c 100644
--- a/talpid-core/src/winnet.rs
+++ b/talpid-core/src/winnet.rs
@@ -35,12 +35,12 @@ fn logging_context() -> *const u8 {
}
/// Returns true if metrics were changed, false otherwise
-pub fn ensure_top_metric_for_interface(interface_alias: &str) -> Result<bool, Error> {
+pub fn ensure_best_metric_for_interface(interface_alias: &str) -> Result<bool, Error> {
let interface_alias_ws =
WideCString::from_str(interface_alias).map_err(Error::InvalidInterfaceAlias)?;
let metric_result = unsafe {
- WinNet_EnsureTopMetric(
+ WinNet_EnsureBestMetric(
interface_alias_ws.as_ptr(),
Some(log_sink),
logging_context(),
@@ -56,7 +56,7 @@ pub fn ensure_top_metric_for_interface(interface_alias: &str) -> Result<bool, Er
2 => Err(Error::MetricApplication),
// Unexpected value
i => {
- log::error!("Unexpected return code from WinNet_EnsureTopMetric: {}", i);
+ log::error!("Unexpected return code from WinNet_EnsureBestMetric: {}", i);
Err(Error::MetricApplication)
}
}
@@ -365,8 +365,8 @@ mod api {
#[link_name = "WinNet_DeactivateRouteManager"]
pub fn WinNet_DeactivateRouteManager();
- #[link_name = "WinNet_EnsureTopMetric"]
- pub fn WinNet_EnsureTopMetric(
+ #[link_name = "WinNet_EnsureBestMetric"]
+ pub fn WinNet_EnsureBestMetric(
tunnel_interface_alias: *const wchar_t,
sink: Option<LogSink>,
sink_context: *const u8,
diff --git a/windows/winnet/src/winnet/InterfacePair.cpp b/windows/winnet/src/winnet/InterfacePair.cpp
index 239ea11c53..935bc557aa 100644
--- a/windows/winnet/src/winnet/InterfacePair.cpp
+++ b/windows/winnet/src/winnet/InterfacePair.cpp
@@ -2,6 +2,7 @@
#include "InterfacePair.h"
#include <libcommon/error.h>
#include <sstream>
+#include <algorithm>
#ifndef STATUS_NOT_FOUND
#define STATUS_NOT_FOUND 0xC0000225
@@ -11,13 +12,13 @@ InterfacePair::InterfacePair(NET_LUID interface_luid)
{
IPv4Iface.Family = AF_INET;
IPv4Iface.InterfaceLuid = interface_luid;
- InitializeInterface(&IPv4Iface);
+ InitializeInterface(IPv4Iface);
IPv6Iface.Family = AF_INET6;
IPv6Iface.InterfaceLuid = interface_luid;
- InitializeInterface(&IPv6Iface);
+ InitializeInterface(IPv6Iface);
- if (!(HasIPv4() || HasIPv6()))
+ if (!HasIPv4() && !HasIPv6())
{
std::stringstream ss;
@@ -28,59 +29,66 @@ InterfacePair::InterfacePair(NET_LUID interface_luid)
}
}
-int InterfacePair::HighestMetric()
+int InterfacePair::WorstMetric()
{
- return IPv6Iface.Metric < IPv4Iface.Metric ? IPv4Iface.Metric : IPv6Iface.Metric;
+ return std::max(IPv6Iface.Metric, IPv4Iface.Metric);
}
-void InterfacePair::SetMetric(int metric)
+int InterfacePair::BestMetric()
{
- if (HasIPv4())
- {
+ return std::min(IPv4Iface.Metric, IPv6Iface.Metric);
+}
+
+void InterfacePair::SetMetric(uint32_t metric)
+{
+ if (HasIPv4() && (IPv4Iface.UseAutomaticMetric || metric != IPv4Iface.Metric))
+ {
IPv4Iface.SitePrefixLength = 0;
IPv4Iface.Metric = metric;
IPv4Iface.UseAutomaticMetric = false;
- SetInterface(&IPv4Iface);
+ SetInterface(IPv4Iface);
}
- if (HasIPv6())
- {
+ if (HasIPv6() && (IPv6Iface.UseAutomaticMetric || metric != IPv6Iface.Metric))
+ {
IPv6Iface.Metric = metric;
IPv6Iface.UseAutomaticMetric = false;
- SetInterface(&IPv6Iface);
+ SetInterface(IPv6Iface);
}
}
-void InterfacePair::SetInterface(PMIB_IPINTERFACE_ROW iface) {
-
- const auto status = SetIpInterfaceEntry(iface);
+void InterfacePair::SetInterface(const MIB_IPINTERFACE_ROW &iface)
+{
+ MIB_IPINTERFACE_ROW row = iface;
+ const auto status = SetIpInterfaceEntry(&row);
- if (status != NO_ERROR)
- {
- std::stringstream ss;
+ if (NO_ERROR != status)
+ {
+ std::stringstream ss;
- ss << "Set metric for "
- << (iface->Family == AF_INET ? "IPv4" : "IPv6")
- << " on interface with LUID 0x"
- << std::hex << iface->InterfaceLuid.Value;
+ ss << "Set metric for "
+ << (row.Family == AF_INET ? "IPv4" : "IPv6")
+ << " on interface with LUID 0x"
+ << std::hex << row.InterfaceLuid.Value;
- THROW_WINDOWS_ERROR(status, ss.str().c_str());
- }
+ THROW_WINDOWS_ERROR(status, ss.str().c_str());
+ }
}
bool InterfacePair::HasIPv4()
{
- return IPv4Iface.Family != AF_UNSPEC;
+ return AF_UNSPEC != IPv4Iface.Family;
}
bool InterfacePair::HasIPv6()
{
- return IPv6Iface.Family != AF_UNSPEC;
+ return AF_UNSPEC != IPv6Iface.Family;
}
-void InterfacePair::InitializeInterface(PMIB_IPINTERFACE_ROW iface)
+//static
+void InterfacePair::InitializeInterface(MIB_IPINTERFACE_ROW &iface)
{
- const auto status = GetIpInterfaceEntry(iface);
+ const auto status = GetIpInterfaceEntry(&iface);
if (NO_ERROR == status)
{
@@ -89,14 +97,14 @@ void InterfacePair::InitializeInterface(PMIB_IPINTERFACE_ROW iface)
if (STATUS_NOT_FOUND == status || ERROR_NOT_FOUND == status)
{
- iface->Family = AF_UNSPEC;
+ iface.Family = AF_UNSPEC;
}
else
{
std::stringstream ss;
ss << "Retrieve info on network interface with LUID 0x"
- << std::hex << iface->InterfaceLuid.Value;
+ << std::hex << iface.InterfaceLuid.Value;
THROW_WINDOWS_ERROR(status, ss.str().c_str());
}
diff --git a/windows/winnet/src/winnet/InterfacePair.h b/windows/winnet/src/winnet/InterfacePair.h
index 9582bac3cd..55a3f59e90 100644
--- a/windows/winnet/src/winnet/InterfacePair.h
+++ b/windows/winnet/src/winnet/InterfacePair.h
@@ -4,22 +4,24 @@
#include <ws2ipdef.h>
#include <iphlpapi.h>
#include <netioapi.h>
+#include <cstdint>
class InterfacePair
{
public:
InterfacePair(NET_LUID interface_luid);
- int HighestMetric();
- void SetMetric(int metric);
+ int BestMetric();
+ int WorstMetric();
+ void SetMetric(uint32_t metric);
private:
MIB_IPINTERFACE_ROW IPv4Iface = { 0 };
MIB_IPINTERFACE_ROW IPv6Iface = { 0 };
- void InitializeInterface(PMIB_IPINTERFACE_ROW iface);
+ static void InitializeInterface(MIB_IPINTERFACE_ROW &iface);
bool HasIPv4();
bool HasIPv6();
- void SetInterface(PMIB_IPINTERFACE_ROW iface);
+ void SetInterface(const MIB_IPINTERFACE_ROW &iface);
};
diff --git a/windows/winnet/src/winnet/NetworkInterfaces.cpp b/windows/winnet/src/winnet/NetworkInterfaces.cpp
index f364b6b4cd..dc27e95e92 100644
--- a/windows/winnet/src/winnet/NetworkInterfaces.cpp
+++ b/windows/winnet/src/winnet/NetworkInterfaces.cpp
@@ -7,7 +7,7 @@
#include <sstream>
#include <cstdint>
-bool NetworkInterfaces::HasHighestMetric(PMIB_IPINTERFACE_ROW targetIface)
+bool NetworkInterfaces::HasBestMetric(PMIB_IPINTERFACE_ROW targetIface)
{
for (unsigned int i = 0; i < mInterfaces->NumEntries; ++i)
{
@@ -20,39 +20,6 @@ bool NetworkInterfaces::HasHighestMetric(PMIB_IPINTERFACE_ROW targetIface)
return true;
}
-void NetworkInterfaces::EnsureIfaceMetricIsHighest(NET_LUID interfaceLuid)
-{
- for (ULONG i = 0; i < mInterfaces->NumEntries; ++i)
- {
- PMIB_IPINTERFACE_ROW iface = &mInterfaces->Table[i];
-
- // Ignoring the target interface.
- if (iface->InterfaceLuid.Value == interfaceLuid.Value || iface->UseAutomaticMetric || iface->Metric > MAX_METRIC)
- {
- continue;
- }
-
- iface->Metric++;
-
- if (AF_INET == iface->Family)
- {
- iface->SitePrefixLength = 0;
- }
-
- const auto status = SetIpInterfaceEntry(iface);
-
- if (NO_ERROR != status)
- {
- std::stringstream ss;
-
- ss << "Failed to increment metric for interface with LUID 0x"
- << std::hex << iface->InterfaceLuid.Value;
-
- THROW_WINDOWS_ERROR(status, ss.str().c_str());
- }
- }
-}
-
NetworkInterfaces::NetworkInterfaces()
{
mInterfaces = nullptr;
@@ -65,21 +32,19 @@ NetworkInterfaces::NetworkInterfaces()
}
}
-bool NetworkInterfaces::SetTopMetricForInterfacesByAlias(const wchar_t * deviceAlias)
+bool NetworkInterfaces::SetBestMetricForInterfacesByAlias(const wchar_t * deviceAlias)
{
- return SetTopMetricForInterfacesWithLuid(GetInterfaceLuid(deviceAlias));
+ return SetBestMetricForInterfacesWithLuid(GetInterfaceLuid(deviceAlias));
}
-bool NetworkInterfaces::SetTopMetricForInterfacesWithLuid(NET_LUID targetIfaceId)
+bool NetworkInterfaces::SetBestMetricForInterfacesWithLuid(NET_LUID targetIfaceId)
{
InterfacePair targetInterfaces = InterfacePair(targetIfaceId);
-
- if (targetInterfaces.HighestMetric() == MAX_METRIC)
+ if (BEST_METRIC == targetInterfaces.WorstMetric())
{
return false;
}
-
- targetInterfaces.SetMetric(MAX_METRIC);
+ targetInterfaces.SetMetric(BEST_METRIC);
return true;
}
diff --git a/windows/winnet/src/winnet/NetworkInterfaces.h b/windows/winnet/src/winnet/NetworkInterfaces.h
index bf1d53dddf..e987939454 100644
--- a/windows/winnet/src/winnet/NetworkInterfaces.h
+++ b/windows/winnet/src/winnet/NetworkInterfaces.h
@@ -13,20 +13,19 @@ class NetworkInterfaces
private:
PMIB_IPINTERFACE_TABLE mInterfaces;
- bool HasHighestMetric(PMIB_IPINTERFACE_ROW targetIface);
+ bool HasBestMetric(PMIB_IPINTERFACE_ROW targetIface);
public:
NetworkInterfaces(const NetworkInterfaces &) = delete;
NetworkInterfaces &operator=(const NetworkInterfaces &) = delete;
- void EnsureIfaceMetricIsHighest(NET_LUID interfaceLuid);
NetworkInterfaces();
- bool SetTopMetricForInterfacesByAlias(const wchar_t *deviceAlias);
- bool SetTopMetricForInterfacesWithLuid(NET_LUID targetIface);
+ bool SetBestMetricForInterfacesByAlias(const wchar_t *deviceAlias);
+ bool SetBestMetricForInterfacesWithLuid(NET_LUID targetIface);
~NetworkInterfaces();
static NET_LUID GetInterfaceLuid(const std::wstring &interfaceAlias);
const MIB_IPINTERFACE_ROW *GetInterface(NET_LUID interfaceLuid, ADDRESS_FAMILY interfaceFamily);
};
-const static uint32_t MAX_METRIC = 1;
+constexpr static uint32_t BEST_METRIC = 1;
diff --git a/windows/winnet/src/winnet/stdafx.h b/windows/winnet/src/winnet/stdafx.h
index 254cb49b0d..3115e02cd4 100644
--- a/windows/winnet/src/winnet/stdafx.h
+++ b/windows/winnet/src/winnet/stdafx.h
@@ -10,6 +10,7 @@
#include "targetver.h"
+#define NOMINMAX
#define WIN32_LEAN_AND_MEAN // Exclude rarely-used stuff from Windows headers
// Windows Header Files:
#include <windows.h>
diff --git a/windows/winnet/src/winnet/winnet.cpp b/windows/winnet/src/winnet/winnet.cpp
index 7a206635ed..b71484bf7c 100644
--- a/windows/winnet/src/winnet/winnet.cpp
+++ b/windows/winnet/src/winnet/winnet.cpp
@@ -33,9 +33,9 @@ std::shared_ptr<shared::logging::LogSinkAdapter> g_RouteManagerLogSink;
extern "C"
WINNET_LINKAGE
-WINNET_ETM_STATUS
+WINNET_EBM_STATUS
WINNET_API
-WinNet_EnsureTopMetric(
+WinNet_EnsureBestMetric(
const wchar_t *deviceAlias,
MullvadLogSink logSink,
void *logSinkContext
@@ -44,17 +44,17 @@ WinNet_EnsureTopMetric(
try
{
NetworkInterfaces interfaces;
- bool metrics_set = interfaces.SetTopMetricForInterfacesByAlias(deviceAlias);
- return metrics_set ? WINNET_ETM_STATUS_METRIC_SET : WINNET_ETM_STATUS_METRIC_NO_CHANGE;
+ return interfaces.SetBestMetricForInterfacesByAlias(deviceAlias) ?
+ WINNET_EBM_STATUS_METRIC_SET : WINNET_EBM_STATUS_METRIC_NO_CHANGE;
}
catch (const std::exception &err)
{
shared::logging::UnwindAndLog(logSink, logSinkContext, err);
- return WINNET_ETM_STATUS_FAILURE;
+ return WINNET_EBM_STATUS_FAILURE;
}
catch (...)
{
- return WINNET_ETM_STATUS_FAILURE;
+ return WINNET_EBM_STATUS_FAILURE;
}
};
diff --git a/windows/winnet/src/winnet/winnet.def b/windows/winnet/src/winnet/winnet.def
index b23ae6c854..ecec97959e 100644
--- a/windows/winnet/src/winnet/winnet.def
+++ b/windows/winnet/src/winnet/winnet.def
@@ -1,6 +1,6 @@
LIBRARY winnet
EXPORTS
- WinNet_EnsureTopMetric
+ WinNet_EnsureBestMetric
WinNet_GetTapInterfaceIpv6Status
WinNet_GetTapInterfaceAlias
WinNet_ReleaseString
diff --git a/windows/winnet/src/winnet/winnet.h b/windows/winnet/src/winnet/winnet.h
index 5afc4052eb..5e2a4154a4 100644
--- a/windows/winnet/src/winnet/winnet.h
+++ b/windows/winnet/src/winnet/winnet.h
@@ -16,18 +16,18 @@
#define WINNET_API __stdcall
-enum WINNET_ETM_STATUS
+enum WINNET_EBM_STATUS
{
- WINNET_ETM_STATUS_METRIC_NO_CHANGE = 0,
- WINNET_ETM_STATUS_METRIC_SET = 1,
- WINNET_ETM_STATUS_FAILURE = 2,
+ WINNET_EBM_STATUS_METRIC_NO_CHANGE = 0,
+ WINNET_EBM_STATUS_METRIC_SET = 1,
+ WINNET_EBM_STATUS_FAILURE = 2,
};
extern "C"
WINNET_LINKAGE
-WINNET_ETM_STATUS
+WINNET_EBM_STATUS
WINNET_API
-WinNet_EnsureTopMetric(
+WinNet_EnsureBestMetric(
const wchar_t *deviceAlias,
MullvadLogSink logSink,
void *logSinkContext