summaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorDavid Lönnhager <david.l@mullvad.net>2020-04-22 20:37:31 +0200
committerDavid Lönnhager <david.l@mullvad.net>2020-04-24 16:51:46 +0200
commit896454f0d01c8e42a9d613299f48078694d17392 (patch)
tree5e2d6eed562e8860b21497f86e5cfa48221f060e
parent874c98fa9e8103696c3f74068441bd0f71b7997e (diff)
downloadmullvadvpn-896454f0d01c8e42a9d613299f48078694d17392.tar.xz
mullvadvpn-896454f0d01c8e42a9d613299f48078694d17392.zip
Add WinNet_EnableIpv6ForAdapter to enable IPv6 for a given adapter
-rw-r--r--windows/winnet/src/winnet/netconfig.cpp259
-rw-r--r--windows/winnet/src/winnet/netconfig.h5
-rw-r--r--windows/winnet/src/winnet/winnet.cpp29
-rw-r--r--windows/winnet/src/winnet/winnet.def1
-rw-r--r--windows/winnet/src/winnet/winnet.h10
-rw-r--r--windows/winnet/src/winnet/winnet.vcxproj2
-rw-r--r--windows/winnet/src/winnet/winnet.vcxproj.filters2
7 files changed, 308 insertions, 0 deletions
diff --git a/windows/winnet/src/winnet/netconfig.cpp b/windows/winnet/src/winnet/netconfig.cpp
new file mode 100644
index 0000000000..2bd96c907f
--- /dev/null
+++ b/windows/winnet/src/winnet/netconfig.cpp
@@ -0,0 +1,259 @@
+#include "stdafx.h"
+#include "netconfig.h"
+#include <stdexcept>
+#include <sstream>
+#include <windows.h>
+#include <netcfgx.h>
+#include <devguid.h>
+#include <libcommon/error.h>
+#include <libcommon/string.h>
+#include <libcommon/memory.h>
+#include <libshared/network/interfaceutils.h>
+
+
+namespace
+{
+
+const wchar_t NETCFG_LOCK_CLIENT_NAME[] = L"MULLVAD";
+constexpr uint16_t NETCFG_LOCK_TIMEOUT = 5000; // milliseconds
+const wchar_t NETCFG_IPV6_COMPONENT_NAME[] = L"MS_TCPIP6";
+
+void SetIpv6BindingForBindName(INetCfg *netCfg, const std::wstring &bindName, bool enable)
+{
+ INetCfgComponent *transactionComponent = nullptr;
+ HRESULT result = netCfg->FindComponent(NETCFG_IPV6_COMPONENT_NAME, &transactionComponent);
+
+ if (S_OK != result)
+ {
+ THROW_ERROR("Failed to obtain transaction component");
+ }
+
+ INetCfgComponentBindings *bindings = nullptr;
+ result = transactionComponent->QueryInterface(
+ IID_INetCfgComponentBindings,
+ reinterpret_cast<void**>(&bindings)
+ );
+
+ transactionComponent->Release();
+ transactionComponent = nullptr;
+
+ if (S_OK != result)
+ {
+ std::wstringstream ss;
+ ss << L"Failed to obtain component bindings for ";
+ ss << NETCFG_IPV6_COMPONENT_NAME;
+ THROW_ERROR(common::string::ToAnsi(ss.str()).c_str());
+ }
+
+ IEnumNetCfgBindingPath *pathsEnum = NULL;
+ result = bindings->EnumBindingPaths(EBP_BELOW, &pathsEnum);
+
+ bindings->Release();
+ bindings = nullptr;
+
+ if (S_OK != result)
+ {
+ THROW_ERROR("Failed to acquire binding path enumerator");
+ }
+
+ common::memory::ScopeDestructor pathsEnumDestructor;
+ pathsEnumDestructor += [&pathsEnum]() {
+ pathsEnum->Release();
+ pathsEnum = nullptr;
+ };
+
+ INetCfgBindingPath *bindingPath = NULL;
+
+ result = pathsEnum->Next(1, &bindingPath, nullptr);
+
+ for (; S_OK == result; result = pathsEnum->Next(1, &bindingPath, nullptr))
+ {
+ common::memory::ScopeDestructor bindingPathDestructor;
+ bindingPathDestructor += [&bindingPath]() {
+ bindingPath->Release();
+ bindingPath = nullptr;
+ };
+
+ IEnumNetCfgBindingInterface *enumInterface = nullptr;
+ HRESULT enumResult = bindingPath->EnumBindingInterfaces(&enumInterface);
+
+ if (S_OK != enumResult)
+ {
+ THROW_ERROR("Failed to acquire binding path interfaces");
+ }
+
+ common::memory::ScopeDestructor interfaceEnumDestructor;
+ interfaceEnumDestructor += [&enumInterface]() {
+ enumInterface->Release();
+ enumInterface = nullptr;
+ };
+
+ INetCfgBindingInterface *iface = nullptr;
+
+ while (S_OK == enumInterface->Next(1, &iface, nullptr))
+ {
+ INetCfgComponent *cfgComponent = nullptr;
+
+ auto status = iface->GetLowerComponent(&cfgComponent);
+
+ iface->Release();
+ iface = nullptr;
+
+ if (S_OK != status)
+ {
+ THROW_ERROR("Failed to acquire binding interface component");
+ }
+
+ wchar_t *componentBindName = 0;
+
+ status = cfgComponent->GetBindName(&componentBindName);
+
+ cfgComponent->Release();
+ cfgComponent = nullptr;
+
+ if (S_OK != status)
+ {
+ THROW_ERROR("Failed to acquire bind name");
+ }
+
+ bool matchesBindName = (0 == _wcsicmp(bindName.c_str(), componentBindName));
+ CoTaskMemFree(componentBindName);
+
+ if (matchesBindName)
+ {
+ //
+ // Apply the changes and exit the function
+ //
+
+ result = bindingPath->Enable(enable);
+ if (S_OK != result)
+ {
+ THROW_ERROR("Failed to set IPv6 status");
+ }
+ netCfg->Apply();
+
+ return;
+ }
+ }
+ }
+}
+
+std::wstring FindAdapterGuidForAlias(const std::wstring &alias)
+{
+ const auto adapters = shared::network::InterfaceUtils::GetAllAdapters(
+ AF_UNSPEC,
+ GAA_FLAG_SKIP_UNICAST | GAA_FLAG_SKIP_ANYCAST | GAA_FLAG_SKIP_MULTICAST
+ );
+ for (auto it = adapters.begin(); it != adapters.end(); ++it)
+ {
+ if (0 == it->alias().compare(alias))
+ {
+ return it->guid();
+ }
+ }
+
+ throw std::runtime_error("Cannot find GUID for given alias");
+}
+
+} // anonymous namespace
+
+
+void EnableIpv6ForAdapter(const std::wstring &alias)
+{
+ std::wstring adapterGuid = FindAdapterGuidForAlias(alias);
+
+ //
+ // Initialize COM
+ //
+
+ HRESULT result = CoInitialize(nullptr);
+
+ if (S_OK != result)
+ {
+ std::stringstream ss;
+ ss << "Failed to initialize COM: " << result;
+ THROW_ERROR(ss.str().c_str());
+ }
+
+ common::memory::ScopeDestructor scopeDest;
+ scopeDest += []() {
+ CoUninitialize();
+ };
+
+ //
+ // Initialize INetCfg
+ //
+
+ INetCfg *netCfg = nullptr;
+ result = CoCreateInstance(
+ CLSID_CNetCfg,
+ nullptr,
+ CLSCTX_INPROC_SERVER,
+ IID_INetCfg,
+ reinterpret_cast<void**>(&netCfg)
+ );
+
+ if (S_OK != result)
+ {
+ std::stringstream ss;
+ ss << "Failed to create INetCfg instance: " << result;
+ THROW_ERROR(ss.str().c_str());
+
+ }
+
+ scopeDest += [&netCfg]() { netCfg->Release(); };
+
+ INetCfgLock *netCfgLock = nullptr;
+ result = netCfg->QueryInterface(IID_INetCfgLock, reinterpret_cast<void**>(&netCfgLock));
+
+ if (S_OK != result)
+ {
+ std::stringstream ss;
+ ss << "Failed to obtain INetCfg lock interface: " << result;
+ THROW_ERROR(ss.str().c_str());
+ }
+
+ scopeDest += [&netCfgLock]() {
+ netCfgLock->Release();
+ };
+
+ wchar_t *blockingApplication = nullptr;
+
+ // NOTE: This should be done before initializing INetCfg
+ result = netCfgLock->AcquireWriteLock(
+ NETCFG_LOCK_TIMEOUT,
+ NETCFG_LOCK_CLIENT_NAME,
+ &blockingApplication
+ );
+
+ if (S_OK != result)
+ {
+ std::wstringstream ss;
+ ss << L"Failed to acquire write lock";
+ if (nullptr != blockingApplication)
+ {
+ ss << L" due to application: " << blockingApplication;
+ }
+ ss << ". (" << result << ")";
+
+ THROW_ERROR(common::string::ToAnsi(ss.str()).c_str());
+ }
+
+ scopeDest += [&]() {
+ CoTaskMemFree(blockingApplication);
+ netCfgLock->ReleaseWriteLock();
+ };
+
+ result = netCfg->Initialize(nullptr);
+
+ if (S_OK != result)
+ {
+ std::stringstream ss;
+ ss << "Failed to initialize INetCfg: " << result;
+ THROW_ERROR(ss.str().c_str());
+ }
+
+ scopeDest += [&netCfg]() { netCfg->Uninitialize(); };
+
+ SetIpv6BindingForBindName(netCfg, adapterGuid, true);
+}
diff --git a/windows/winnet/src/winnet/netconfig.h b/windows/winnet/src/winnet/netconfig.h
new file mode 100644
index 0000000000..a729d98629
--- /dev/null
+++ b/windows/winnet/src/winnet/netconfig.h
@@ -0,0 +1,5 @@
+#pragma once
+
+#include <string>
+
+void EnableIpv6ForAdapter(const std::wstring &alias);
diff --git a/windows/winnet/src/winnet/winnet.cpp b/windows/winnet/src/winnet/winnet.cpp
index 2a96d150ca..5f9d20f583 100644
--- a/windows/winnet/src/winnet/winnet.cpp
+++ b/windows/winnet/src/winnet/winnet.cpp
@@ -4,6 +4,7 @@
#include "offlinemonitor.h"
#include "routing/routemanager.h"
#include "converters.h"
+#include "netconfig.h"
#include <libshared/logging/logsinkadapter.h>
#include <libshared/logging/unwind.h>
#include <libshared/network/interfaceutils.h>
@@ -65,6 +66,34 @@ WinNet_EnsureBestMetric(
extern "C"
WINNET_LINKAGE
+bool
+WINNET_API
+WinNet_EnableIpv6ForAdapter(
+ const wchar_t *deviceAlias,
+ MullvadLogSink logSink,
+ void *logSinkContext
+)
+{
+ try
+ {
+ if (nullptr == deviceAlias)
+ {
+ THROW_ERROR("Invalid argument: deviceAlias");
+ }
+
+ EnableIpv6ForAdapter(deviceAlias);
+ return true;
+ }
+ catch (const std::exception & err)
+ {
+ shared::logging::UnwindAndLog(logSink, logSinkContext, err);
+ return false;
+ }
+ return false;
+}
+
+extern "C"
+WINNET_LINKAGE
WINNET_GTII_STATUS
WINNET_API
WinNet_GetTapInterfaceIpv6Status(
diff --git a/windows/winnet/src/winnet/winnet.def b/windows/winnet/src/winnet/winnet.def
index ecec97959e..29a19dffe5 100644
--- a/windows/winnet/src/winnet/winnet.def
+++ b/windows/winnet/src/winnet/winnet.def
@@ -1,6 +1,7 @@
LIBRARY winnet
EXPORTS
WinNet_EnsureBestMetric
+ WinNet_EnableIpv6ForAdapter
WinNet_GetTapInterfaceIpv6Status
WinNet_GetTapInterfaceAlias
WinNet_ReleaseString
diff --git a/windows/winnet/src/winnet/winnet.h b/windows/winnet/src/winnet/winnet.h
index 5e2a4154a4..885ed8ef90 100644
--- a/windows/winnet/src/winnet/winnet.h
+++ b/windows/winnet/src/winnet/winnet.h
@@ -33,6 +33,16 @@ WinNet_EnsureBestMetric(
void *logSinkContext
);
+extern "C"
+WINNET_LINKAGE
+bool
+WINNET_API
+WinNet_EnableIpv6ForAdapter(
+ const wchar_t *deviceAlias,
+ MullvadLogSink logSink,
+ void *logSinkContext
+);
+
enum WINNET_GTII_STATUS
{
WINNET_GTII_STATUS_ENABLED = 0,
diff --git a/windows/winnet/src/winnet/winnet.vcxproj b/windows/winnet/src/winnet/winnet.vcxproj
index 7b4578d4b1..6876035160 100644
--- a/windows/winnet/src/winnet/winnet.vcxproj
+++ b/windows/winnet/src/winnet/winnet.vcxproj
@@ -28,6 +28,7 @@
</ItemGroup>
<ItemGroup>
<ClCompile Include="converters.cpp" />
+ <ClCompile Include="netconfig.cpp" />
<ClCompile Include="networkadaptermonitor.cpp" />
<ClCompile Include="dllmain.cpp" />
<ClCompile Include="InterfacePair.cpp" />
@@ -42,6 +43,7 @@
</ItemGroup>
<ItemGroup>
<ClInclude Include="converters.h" />
+ <ClInclude Include="netconfig.h" />
<ClInclude Include="networkadaptermonitor.h" />
<ClInclude Include="InterfacePair.h" />
<ClInclude Include="offlinemonitor.h" />
diff --git a/windows/winnet/src/winnet/winnet.vcxproj.filters b/windows/winnet/src/winnet/winnet.vcxproj.filters
index 2d5a039c6b..27f0051bc6 100644
--- a/windows/winnet/src/winnet/winnet.vcxproj.filters
+++ b/windows/winnet/src/winnet/winnet.vcxproj.filters
@@ -21,6 +21,7 @@
<Filter>routing</Filter>
</ClCompile>
<ClCompile Include="converters.cpp" />
+ <ClCompile Include="netconfig.cpp" />
</ItemGroup>
<ItemGroup>
<ClInclude Include="stdafx.h" />
@@ -43,6 +44,7 @@
<Filter>routing</Filter>
</ClInclude>
<ClInclude Include="converters.h" />
+ <ClInclude Include="netconfig.h" />
</ItemGroup>
<ItemGroup>
<None Include="winnet.def" />