summaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorOdd Stranne <odd@mullvad.net>2021-03-17 18:06:16 +0100
committerOdd Stranne <odd@mullvad.net>2021-07-02 16:31:31 +0200
commit15be4405fcbe845f806d0e2a50c4e948e049d0d5 (patch)
tree869d26d02bad41af6c6bff61d0831ce3705edabc
parent79f52b5adc0d965e1688bb1253a6c782bf74f03f (diff)
downloadmullvadvpn-15be4405fcbe845f806d0e2a50c4e948e049d0d5.tar.xz
mullvadvpn-15be4405fcbe845f806d0e2a50c4e948e049d0d5.zip
Restructure and extend driverlogic
-rw-r--r--windows/driverlogic/driverlogic.vcxproj13
-rw-r--r--windows/driverlogic/driverlogic.vcxproj.filters13
-rw-r--r--windows/driverlogic/src/devenum.cpp75
-rw-r--r--windows/driverlogic/src/devenum.h42
-rw-r--r--windows/driverlogic/src/device.cpp470
-rw-r--r--windows/driverlogic/src/device.h85
-rw-r--r--windows/driverlogic/src/driverlogic.cpp1138
-rw-r--r--windows/driverlogic/src/log.cpp13
-rw-r--r--windows/driverlogic/src/log.h7
-rw-r--r--windows/driverlogic/src/service.cpp134
-rw-r--r--windows/driverlogic/src/service.h7
-rw-r--r--windows/driverlogic/src/util.cpp36
-rw-r--r--windows/driverlogic/src/util.h8
-rw-r--r--windows/driverlogic/src/version.cpp125
-rw-r--r--windows/driverlogic/src/version.h30
-rw-r--r--windows/driverlogic/src/wintun.h64
16 files changed, 1324 insertions, 936 deletions
diff --git a/windows/driverlogic/driverlogic.vcxproj b/windows/driverlogic/driverlogic.vcxproj
index 5d78c90479..b91e97d86c 100644
--- a/windows/driverlogic/driverlogic.vcxproj
+++ b/windows/driverlogic/driverlogic.vcxproj
@@ -96,14 +96,27 @@
</Link>
</ItemDefinitionGroup>
<ItemGroup>
+ <ClCompile Include="src\devenum.cpp" />
+ <ClCompile Include="src\device.cpp" />
<ClCompile Include="src\driverlogic.cpp" />
<ClCompile Include="src\error.cpp" />
+ <ClCompile Include="src\log.cpp" />
+ <ClCompile Include="src\service.cpp" />
<ClCompile Include="src\stdafx.cpp" />
+ <ClCompile Include="src\util.cpp" />
+ <ClCompile Include="src\version.cpp" />
</ItemGroup>
<ItemGroup>
+ <ClInclude Include="src\devenum.h" />
+ <ClInclude Include="src\device.h" />
<ClInclude Include="src\error.h" />
+ <ClInclude Include="src\log.h" />
+ <ClInclude Include="src\service.h" />
<ClInclude Include="src\stdafx.h" />
<ClInclude Include="src\targetver.h" />
+ <ClInclude Include="src\util.h" />
+ <ClInclude Include="src\version.h" />
+ <ClInclude Include="src\wintun.h" />
</ItemGroup>
<Import Project="$(VCTargetsPath)\Microsoft.Cpp.targets" />
<ImportGroup Label="ExtensionTargets">
diff --git a/windows/driverlogic/driverlogic.vcxproj.filters b/windows/driverlogic/driverlogic.vcxproj.filters
index b9c494e241..91ed2267da 100644
--- a/windows/driverlogic/driverlogic.vcxproj.filters
+++ b/windows/driverlogic/driverlogic.vcxproj.filters
@@ -10,10 +10,23 @@
<ClCompile Include="src\driverlogic.cpp" />
<ClCompile Include="src\stdafx.cpp" />
<ClCompile Include="src\error.cpp" />
+ <ClCompile Include="src\device.cpp" />
+ <ClCompile Include="src\service.cpp" />
+ <ClCompile Include="src\log.cpp" />
+ <ClCompile Include="src\version.cpp" />
+ <ClCompile Include="src\util.cpp" />
+ <ClCompile Include="src\devenum.cpp" />
</ItemGroup>
<ItemGroup>
<ClInclude Include="src\stdafx.h" />
<ClInclude Include="src\targetver.h" />
<ClInclude Include="src\error.h" />
+ <ClInclude Include="src\device.h" />
+ <ClInclude Include="src\service.h" />
+ <ClInclude Include="src\log.h" />
+ <ClInclude Include="src\version.h" />
+ <ClInclude Include="src\util.h" />
+ <ClInclude Include="src\wintun.h" />
+ <ClInclude Include="src\devenum.h" />
</ItemGroup>
</Project> \ No newline at end of file
diff --git a/windows/driverlogic/src/devenum.cpp b/windows/driverlogic/src/devenum.cpp
new file mode 100644
index 0000000000..229d7ee493
--- /dev/null
+++ b/windows/driverlogic/src/devenum.cpp
@@ -0,0 +1,75 @@
+#include "stdafx.h"
+#include "devenum.h"
+#include "error.h"
+
+DeviceEnumerator::DeviceEnumerator(const GUID &deviceClass)
+{
+ m_deviceInfoSet = SetupDiGetClassDevsW
+ (
+ &deviceClass,
+ nullptr,
+ nullptr,
+ DIGCF_PRESENT
+ );
+
+ if (INVALID_HANDLE_VALUE == m_deviceInfoSet)
+ {
+ THROW_SETUPAPI_ERROR(GetLastError(), "SetupDiGetClassDevsW");
+ }
+
+ m_nextDeviceIndex = 0;
+ m_exhausted = false;
+}
+
+//static
+std::unique_ptr<DeviceEnumerator> DeviceEnumerator::Create(const GUID& deviceClass, Filter filter)
+{
+ auto enumerator = std::make_unique<DeviceEnumerator>(deviceClass);
+
+ enumerator->setFilter(filter);
+
+ return enumerator;
+}
+
+DeviceEnumerator::~DeviceEnumerator()
+{
+ SetupDiDestroyDeviceInfoList(m_deviceInfoSet);
+}
+
+bool DeviceEnumerator::next(EnumeratedDevice &device)
+{
+ if (m_exhausted)
+ {
+ return false;
+ }
+
+ SP_DEVINFO_DATA deviceInfo { 0 };
+ deviceInfo.cbSize = sizeof(deviceInfo);
+
+ for (;;)
+ {
+ if (FALSE == SetupDiEnumDeviceInfo(m_deviceInfoSet, m_nextDeviceIndex, &deviceInfo))
+ {
+ if (GetLastError() != ERROR_NO_MORE_ITEMS)
+ {
+ THROW_SETUPAPI_ERROR(GetLastError(), "SetupDiEnumDeviceInfo");
+ }
+
+ m_exhausted = true;
+
+ return false;
+ }
+
+ ++m_nextDeviceIndex;
+
+ if (!m_filter || m_filter(m_deviceInfoSet, deviceInfo))
+ {
+ break;
+ }
+ }
+
+ device.deviceInfoSet = m_deviceInfoSet;
+ device.deviceInfo = deviceInfo;
+
+ return true;
+}
diff --git a/windows/driverlogic/src/devenum.h b/windows/driverlogic/src/devenum.h
new file mode 100644
index 0000000000..185bdb27f8
--- /dev/null
+++ b/windows/driverlogic/src/devenum.h
@@ -0,0 +1,42 @@
+#pragma once
+
+#include <windows.h>
+#include <newdev.h>
+#include <functional>
+#include <memory>
+#include "device.h"
+
+class DeviceEnumerator
+{
+public:
+
+ using Filter = std::function<bool(HDEVINFO, const SP_DEVINFO_DATA&)>;
+
+ DeviceEnumerator(const GUID &deviceClass);
+
+ static std::unique_ptr<DeviceEnumerator> Create(const GUID &deviceClass, Filter filter);
+
+ ~DeviceEnumerator();
+
+ DeviceEnumerator(const DeviceEnumerator &) = delete;
+ DeviceEnumerator(DeviceEnumerator &&) = delete;
+ DeviceEnumerator &operator=(const DeviceEnumerator &) = delete;
+ DeviceEnumerator &operator=(DeviceEnumerator &&) = delete;
+
+ void setFilter(Filter filter)
+ {
+ m_filter = filter;
+ }
+
+ bool next(EnumeratedDevice &device);
+
+private:
+
+ HDEVINFO m_deviceInfoSet;
+
+ int m_nextDeviceIndex;
+
+ bool m_exhausted;
+
+ Filter m_filter;
+};
diff --git a/windows/driverlogic/src/device.cpp b/windows/driverlogic/src/device.cpp
new file mode 100644
index 0000000000..e38709f01d
--- /dev/null
+++ b/windows/driverlogic/src/device.cpp
@@ -0,0 +1,470 @@
+#include "stdafx.h"
+#include <winioctl.h>
+#include <newdev.h>
+#include <initguid.h>
+#include <devpkey.h>
+#include <devguid.h>
+#include <libcommon/error.h>
+#include <libcommon/memory.h>
+#include <libcommon/registry/registry.h>
+#include "log.h"
+#include "device.h"
+#include "error.h"
+#include "devenum.h"
+#include <vector>
+#include <sstream>
+#include <functional>
+
+namespace
+{
+
+//
+// Identifiers defined by split tunneling driver.
+//
+
+constexpr wchar_t DeviceSymbolicName[] = L"\\\\.\\MULLVADSPLITTUNNEL";
+
+#define ST_DEVICE_TYPE 0x8000
+
+#define IOCTL_ST_GET_STATE \
+ CTL_CODE(ST_DEVICE_TYPE, 9, METHOD_BUFFERED, FILE_ANY_ACCESS)
+
+#define IOCTL_ST_RESET \
+ CTL_CODE(ST_DEVICE_TYPE, 11, METHOD_NEITHER, FILE_ANY_ACCESS)
+
+constexpr SIZE_T ST_DRIVER_STATE_STARTED = 1;
+
+//
+// Onwards.
+//
+
+void
+ThrowUpdateException
+(
+ DWORD lastError,
+ const char *operation
+)
+{
+ if (ERROR_DEVICE_INSTALLER_NOT_READY == lastError)
+ {
+ bool deviceInstallDisabled = false;
+
+ try
+ {
+ const auto key = common::registry::Registry::OpenKey
+ (
+ HKEY_LOCAL_MACHINE,
+ L"SYSTEM\\CurrentControlSet\\Services\\DeviceInstall\\Parameters"
+ );
+
+ deviceInstallDisabled = (0 != key->readUint32(L"DeviceInstallDisabled"));
+ }
+ catch (...)
+ {
+ }
+
+ if (deviceInstallDisabled)
+ {
+ throw common::error::WindowsException
+ (
+ "Device installs must be enabled to continue. "
+ "Enable them in the Local Group Policy editor, or "
+ "update the registry value DeviceInstallDisabled in "
+ "[HKEY_LOCAL_MACHINE\\SYSTEM\\CurrentControlSet\\Services\\DeviceInstall\\Parameters]",
+ lastError
+ );
+ }
+ }
+
+ THROW_SETUPAPI_ERROR(lastError, operation);
+}
+
+} // anonymous namespace
+
+std::wstring
+GetDeviceStringProperty
+(
+ HDEVINFO deviceInfoSet,
+ const SP_DEVINFO_DATA &deviceInfo,
+ const DEVPROPKEY *property
+)
+{
+ //
+ // Obtain required buffer size
+ //
+
+ DWORD requiredSize = 0;
+ DEVPROPTYPE type;
+
+ const auto sizeStatus = SetupDiGetDevicePropertyW
+ (
+ deviceInfoSet,
+ const_cast<PSP_DEVINFO_DATA>(&deviceInfo),
+ property,
+ &type,
+ nullptr,
+ 0,
+ &requiredSize,
+ 0
+ );
+
+ if (FALSE == sizeStatus)
+ {
+ const auto lastError = GetLastError();
+
+ if (ERROR_INSUFFICIENT_BUFFER != lastError)
+ {
+ THROW_SETUPAPI_ERROR(lastError, "SetupDiGetDevicePropertyW");
+ }
+ }
+
+ std::vector<wchar_t> buffer(requiredSize / sizeof(wchar_t));
+
+ //
+ // Read property
+ //
+
+ const auto status = SetupDiGetDevicePropertyW
+ (
+ deviceInfoSet,
+ const_cast<PSP_DEVINFO_DATA>(&deviceInfo),
+ property,
+ &type,
+ reinterpret_cast<PBYTE>(&buffer[0]),
+ requiredSize,
+ nullptr,
+ 0
+ );
+
+ if (FALSE == status)
+ {
+ THROW_SETUPAPI_ERROR(GetLastError(), "Failed to read device property");
+ }
+
+ return buffer.data();
+}
+
+std::wstring
+GetDeviceNetCfgInstanceId
+(
+ HDEVINFO deviceInfoSet,
+ const SP_DEVINFO_DATA &deviceInfo
+)
+{
+ auto registryKey = SetupDiOpenDevRegKey
+ (
+ deviceInfoSet,
+ const_cast<SP_DEVINFO_DATA *>(&deviceInfo),
+ DICS_FLAG_GLOBAL,
+ 0,
+ DIREG_DRV,
+ KEY_READ
+ );
+
+ if (registryKey == INVALID_HANDLE_VALUE)
+ {
+ THROW_SETUPAPI_ERROR(GetLastError(), "SetupDiOpenDevRegKey");
+ }
+
+ std::vector<wchar_t> buffer(128, L'\0');
+
+ DWORD bufferByteLength = static_cast<DWORD>(buffer.size()) * sizeof(wchar_t);
+
+ const auto status = RegGetValueW
+ (
+ registryKey,
+ nullptr,
+ L"NetCfgInstanceId",
+ RRF_RT_REG_SZ,
+ nullptr,
+ &buffer[0],
+ &bufferByteLength
+ );
+
+ RegCloseKey(registryKey);
+
+ if (ERROR_SUCCESS != status)
+ {
+ THROW_WINDOWS_ERROR(status, "RegGetValueW");
+ }
+
+ //
+ // RegGetValueW ensures the string is null-terminated.
+ //
+
+ return std::wstring(&buffer[0]);
+}
+
+void
+CreateDevice
+(
+ const GUID &classGuid,
+ const std::wstring &deviceName,
+ const std::wstring &deviceHardwareId
+)
+{
+ Log(L"Attempting to create device");
+
+ const auto deviceInfoSet = SetupDiCreateDeviceInfoList(&classGuid, 0);
+
+ if (INVALID_HANDLE_VALUE == deviceInfoSet)
+ {
+ THROW_SETUPAPI_ERROR(GetLastError(), "SetupDiCreateDeviceInfoList");
+ }
+
+ common::memory::ScopeDestructor scopeDestructor;
+
+ scopeDestructor += [&deviceInfoSet]()
+ {
+ SetupDiDestroyDeviceInfoList(deviceInfoSet);
+ };
+
+ SP_DEVINFO_DATA devInfoData {0};
+ devInfoData.cbSize = sizeof(SP_DEVINFO_DATA);
+
+ auto status = SetupDiCreateDeviceInfoW
+ (
+ deviceInfoSet,
+ deviceName.c_str(),
+ &classGuid,
+ nullptr,
+ 0,
+ DICD_GENERATE_ID,
+ &devInfoData
+ );
+
+ if (FALSE == status)
+ {
+ THROW_SETUPAPI_ERROR(GetLastError(), "SetupDiCreateDeviceInfoW");
+ }
+
+ status = SetupDiSetDeviceRegistryPropertyW
+ (
+ deviceInfoSet,
+ &devInfoData,
+ SPDRP_HARDWAREID,
+ reinterpret_cast<const BYTE *>(deviceHardwareId.c_str()),
+ static_cast<DWORD>(deviceHardwareId.size() * sizeof(wchar_t))
+ );
+
+ if (FALSE == status)
+ {
+ THROW_SETUPAPI_ERROR(GetLastError(), "SetupDiSetDeviceRegistryPropertyW");
+ }
+
+ //
+ // Create a devnode in the PnP HW tree
+ //
+ status = SetupDiCallClassInstaller
+ (
+ DIF_REGISTERDEVICE,
+ deviceInfoSet,
+ &devInfoData
+ );
+
+ if (FALSE == status)
+ {
+ THROW_SETUPAPI_ERROR(GetLastError(), "SetupDiCallClassInstaller");
+ }
+
+ Log(L"Created new device successfully");
+}
+
+void
+InstallDriverForDevice
+(
+ const std::wstring &deviceHardwareId,
+ const std::wstring &infPath
+)
+{
+ Log(L"Attempting to install new driver");
+
+ DWORD installFlags = 0;
+ BOOL rebootRequired = FALSE;
+
+ for (;;)
+ {
+ auto result = UpdateDriverForPlugAndPlayDevicesW
+ (
+ nullptr,
+ deviceHardwareId.c_str(),
+ infPath.c_str(),
+ installFlags,
+ &rebootRequired
+ );
+
+ if (FALSE != result)
+ {
+ break;
+ }
+
+ const auto lastError = GetLastError();
+
+ if (ERROR_NO_MORE_ITEMS == lastError
+ && (0 == (installFlags & INSTALLFLAG_FORCE)))
+ {
+ Log(L"Driver installation/update failed. Attempting forced install.");
+ installFlags |= INSTALLFLAG_FORCE;
+
+ continue;
+ }
+
+ ThrowUpdateException(lastError, "UpdateDriverForPlugAndPlayDevicesW");
+ }
+
+ //
+ // Driver successfully installed or updated
+ //
+
+ std::wstringstream ss;
+
+ ss << L"Device driver update complete. Reboot required: "
+ << rebootRequired;
+
+ Log(ss.str());
+}
+
+void
+UninstallDevice
+(
+ const EnumeratedDevice &device
+)
+{
+ Log(L"Uninstalling device");
+
+ BOOL needReboot;
+
+ auto status = DiUninstallDevice
+ (
+ nullptr,
+ device.deviceInfoSet,
+ const_cast<PSP_DEVINFO_DATA>(&device.deviceInfo),
+ 0,
+ &needReboot
+ );
+
+ if (FALSE == status)
+ {
+ THROW_WINDOWS_ERROR(GetLastError(), "DiUninstallDevice");
+ }
+
+ std::wstringstream ss;
+
+ ss << L"Successfully uninstalled device. Reboot required: "
+ << needReboot;
+
+ Log(ss.str());
+}
+
+HANDLE
+OpenSplitTunnelDevice
+(
+)
+{
+ auto handle = CreateFileW(DeviceSymbolicName, GENERIC_READ | GENERIC_WRITE,
+ 0, nullptr, OPEN_EXISTING, FILE_FLAG_OVERLAPPED, nullptr);
+
+ if (handle == INVALID_HANDLE_VALUE)
+ {
+ THROW_WINDOWS_ERROR(GetLastError(), "Open split tunnel device");
+ }
+
+ return handle;
+}
+
+void
+CloseSplitTunnelDevice
+(
+ HANDLE device
+)
+{
+ CloseHandle(device);
+}
+
+void
+SendIoControl
+(
+ HANDLE device,
+ DWORD code,
+ void *inBuffer,
+ DWORD inBufferSize,
+ void *outBuffer,
+ DWORD outBufferSize,
+ DWORD *bytesReturned
+)
+{
+ OVERLAPPED o = { 0 };
+
+ //
+ // Event should not be created on-the-fly.
+ //
+ // Create an event for each thread that needs to send a request
+ // and keep the event around.
+ //
+ o.hEvent = CreateEventW(nullptr, TRUE, FALSE, nullptr);
+
+ auto status = DeviceIoControl(device, code,
+ inBuffer, inBufferSize, outBuffer, outBufferSize, bytesReturned, &o);
+
+ if (FALSE != status)
+ {
+ CloseHandle(o.hEvent);
+
+ return;
+ }
+
+ if (ERROR_IO_PENDING != GetLastError())
+ {
+ const auto err = GetLastError();
+
+ CloseHandle(o.hEvent);
+
+ THROW_WINDOWS_ERROR(err, "DeviceIoControl");
+ }
+
+ DWORD tempBytesReturned = 0;
+
+ status = GetOverlappedResult(device, &o, &tempBytesReturned, TRUE);
+
+ CloseHandle(o.hEvent);
+
+ if (FALSE == status)
+ {
+ THROW_WINDOWS_ERROR(GetLastError(), "GetOverlappedResult");
+ }
+
+ *bytesReturned = tempBytesReturned;
+}
+
+void
+SendIoControlReset
+(
+ HANDLE device
+)
+{
+ DWORD dummy;
+
+ SendIoControl(device, (DWORD)IOCTL_ST_RESET, nullptr, 0, nullptr, 0, &dummy);
+
+ DWORD bytesReturned;
+
+ SIZE_T currentState;
+
+ SendIoControl(device, (DWORD)IOCTL_ST_GET_STATE, nullptr, 0, &currentState, sizeof(currentState), &bytesReturned);
+
+ if (bytesReturned != sizeof(currentState))
+ {
+ throw std::runtime_error("Failed to send reset request to driver");
+ }
+
+ //
+ // If successful, state is ST_DRIVER_STATE_STARTED
+ //
+ // Otherwise, state is probably ST_DRIVER_STATE_ZOMBIE
+ //
+
+ if (currentState != ST_DRIVER_STATE_STARTED)
+ {
+ throw std::runtime_error("Failed to reset driver state");
+ }
+}
diff --git a/windows/driverlogic/src/device.h b/windows/driverlogic/src/device.h
new file mode 100644
index 0000000000..b8bc0e7f41
--- /dev/null
+++ b/windows/driverlogic/src/device.h
@@ -0,0 +1,85 @@
+#pragma once
+
+#include <windows.h>
+#include <string>
+#include <optional>
+#include <setupapi.h>
+
+struct EnumeratedDevice
+{
+ HDEVINFO deviceInfoSet;
+ SP_DEVINFO_DATA deviceInfo;
+};
+
+//
+// Generic functions
+//
+
+std::wstring
+GetDeviceStringProperty
+(
+ HDEVINFO deviceInfoSet,
+ const SP_DEVINFO_DATA &deviceInfo,
+ const DEVPROPKEY *property
+);
+
+std::wstring
+GetDeviceNetCfgInstanceId
+(
+ HDEVINFO deviceInfoSet,
+ const SP_DEVINFO_DATA &deviceInfo
+);
+
+void
+CreateDevice
+(
+ const GUID &classGuid,
+ const std::wstring &deviceName,
+ const std::wstring &deviceHardwareId
+);
+
+void
+InstallDriverForDevice
+(
+ const std::wstring &deviceHardwareId,
+ const std::wstring &infPath
+);
+
+void
+UninstallDevice
+(
+ const EnumeratedDevice &device
+);
+
+//
+// Functions that are specific to our driver/implementation
+//
+
+HANDLE
+OpenSplitTunnelDevice
+(
+);
+
+void
+CloseSplitTunnelDevice
+(
+ HANDLE device
+);
+
+void
+SendIoControl
+(
+ HANDLE device,
+ DWORD code,
+ void *inBuffer,
+ DWORD inBufferSize,
+ void *outBuffer,
+ DWORD outBufferSize,
+ DWORD *bytesReturned
+);
+
+void
+SendIoControlReset
+(
+ HANDLE device
+);
diff --git a/windows/driverlogic/src/driverlogic.cpp b/windows/driverlogic/src/driverlogic.cpp
index 25a3c130b0..93af20b400 100644
--- a/windows/driverlogic/src/driverlogic.cpp
+++ b/windows/driverlogic/src/driverlogic.cpp
@@ -1,1103 +1,369 @@
#include "stdafx.h"
#include "error.h"
-#include <iostream>
-#include <chrono>
-#include <sstream>
+#include "device.h"
+#include "service.h"
+#include "log.h"
+#include "version.h"
+#include "wintun.h"
+#include "devenum.h"
#include <string>
-#include <optional>
-#include <set>
-#include <filesystem>
#include <libcommon/error.h>
-#include <libcommon/guid.h>
#include <libcommon/memory.h>
-#include <libcommon/network/nci.h>
-#include <libcommon/registry/registry.h>
#include <libcommon/string.h>
-#include <setupapi.h>
#include <initguid.h>
-#include <devguid.h>
#include <devpkey.h>
-#include <newdev.h>
-#include <cfgmgr32.h>
+#include <devguid.h>
#include <io.h>
#include <fcntl.h>
-#include <wintun/wintun.h>
-
namespace
{
-constexpr std::chrono::milliseconds REGISTRY_GET_TIMEOUT_MS{ 10000 };
+constexpr wchar_t SPLIT_TUNNEL_HARDWARE_ID[] = L"Root\\mullvad-split-tunnel";
-enum ReturnCodes
+DEFINE_GUID(WFP_CALLOUTS_CLASS_ID,
+ 0x57465043, 0x616C, 0x6C6F, 0x75, 0x74, 0x5F, 0x63, 0x6C, 0x61, 0x73, 0x73);
+
+constexpr wchar_t SPLIT_TUNNEL_DEVICE_NAME[] = L"Mullvad Split Tunnel Device";
+
+enum ReturnCode
{
GENERAL_SUCCESS = 0,
- GENERAL_ERROR = -1,
- ADAPTER_NOT_FOUND = -2
+ GENERAL_ERROR = 1,
+ ST_DRIVER_NONE_INSTALLED = 2,
+ ST_DRIVER_SAME_VERSION_INSTALLED = 3,
+ ST_DRIVER_OLDER_VERSION_INSTALLED = 4,
+ ST_DRIVER_NEWER_VERSION_INSTALLED = 5
};
-struct NetworkAdapter
+class ArgumentContext
{
- std::wstring guid;
- std::wstring name;
- std::wstring alias;
- std::wstring deviceInstanceId;
+public:
- NetworkAdapter(std::wstring guid, std::wstring name, std::wstring alias, std::wstring deviceInstanceId)
- : guid(guid)
- , name(name)
- , alias(alias)
- , deviceInstanceId(deviceInstanceId)
+ ArgumentContext(const std::vector<std::wstring> &args)
+ : m_args(args)
+ , m_remaining(m_args.size())
{
}
- bool operator<(const NetworkAdapter &rhs) const
+ size_t total() const
{
- return _wcsicmp(deviceInstanceId.c_str(), rhs.deviceInstanceId.c_str()) < 0;
+ return m_args.size();
}
-};
-
-void LogAdapters(const std::wstring &description, const std::set<NetworkAdapter> &adapters)
-{
- std::wcout << description << std::endl;
- for (const auto &adapter : adapters)
+ void ensureExactArgumentCount(size_t count) const
{
- std::wcout << L" Adapter\n"
- << L" Guid: " << adapter.guid << L'\n'
- << L" Name: " << adapter.name << L'\n'
- << L" Alias: " << adapter.alias << L'\n'
- << L" Device instance ID: " << adapter.deviceInstanceId
- << std::endl;
- }
-}
-
-void Log(const std::wstring &str)
-{
- std::wcout << str << std::endl;
-}
-
-void LogError(const std::wstring &str)
-{
- std::wcerr << str << std::endl;
-}
-
-std::optional<std::wstring> GetDeviceRegistryStringProperty(
- HDEVINFO devInfo,
- const SP_DEVINFO_DATA &devInfoData,
- DWORD property
-)
-{
- //
- // Obtain required buffer size
- //
-
- DWORD requiredSize = 0;
-
- const auto sizeStatus = SetupDiGetDeviceRegistryPropertyW(
- devInfo,
- const_cast<SP_DEVINFO_DATA*>(&devInfoData),
- property,
- nullptr,
- nullptr,
- 0,
- &requiredSize
- );
-
- const DWORD lastError = GetLastError();
- if (FALSE == sizeStatus && ERROR_INSUFFICIENT_BUFFER != lastError)
- {
- // ERROR_INVALID_DATA may mean that the property does not exist
- // TODO: Check if there may be other causes.
- if (ERROR_INVALID_DATA != lastError)
+ if (m_args.size() != count)
{
- THROW_SETUPAPI_ERROR(lastError, "SetupDiGetDeviceRegistryPropertyW");
+ throw std::runtime_error("Invalid number of arguments");
}
-
- return std::nullopt;
- }
-
- //
- // Read property
- //
-
- std::vector<wchar_t> buffer(requiredSize / sizeof(wchar_t));
-
- const auto status = SetupDiGetDeviceRegistryPropertyW(
- devInfo,
- const_cast<SP_DEVINFO_DATA*>(&devInfoData),
- property,
- nullptr,
- reinterpret_cast<PBYTE>(&buffer[0]),
- requiredSize,
- nullptr
- );
-
- if (FALSE == status)
- {
- THROW_SETUPAPI_ERROR(GetLastError(), "Failed to read device property");
}
- return std::make_optional(buffer.data());
-}
-
-std::wstring GetDeviceStringProperty(
- HDEVINFO devInfo,
- const SP_DEVINFO_DATA &devInfoData,
- const DEVPROPKEY *property
-)
-{
- //
- // Obtain required buffer size
- //
-
- DWORD requiredSize = 0;
- DEVPROPTYPE type;
-
- const auto sizeStatus = SetupDiGetDevicePropertyW(
- devInfo,
- const_cast<SP_DEVINFO_DATA*>(&devInfoData),
- property,
- &type,
- nullptr,
- 0,
- &requiredSize,
- 0
- );
-
- if (FALSE == sizeStatus)
+ const std::wstring &next()
{
- const auto lastError = GetLastError();
-
- if (ERROR_INSUFFICIENT_BUFFER != lastError)
+ if (0 == m_remaining)
{
- THROW_SETUPAPI_ERROR(lastError, "SetupDiGetDevicePropertyW");
+ throw std::runtime_error("Argument missing");
}
- }
-
- std::vector<wchar_t> buffer(requiredSize / sizeof(wchar_t));
- //
- // Read property
- //
+ const auto &str = m_args.at(m_args.size() - m_remaining);
- const auto status = SetupDiGetDevicePropertyW(
- devInfo,
- const_cast<SP_DEVINFO_DATA*>(&devInfoData),
- property,
- &type,
- reinterpret_cast<PBYTE>(&buffer[0]),
- requiredSize,
- nullptr,
- 0
- );
+ --m_remaining;
- if (FALSE == status)
- {
- THROW_SETUPAPI_ERROR(GetLastError(), "Failed to read device property");
+ return str;
}
- return buffer.data();
-}
-
-std::wstring GetDeviceInstanceId(
- HDEVINFO devInfo,
- const SP_DEVINFO_DATA &devInfoData
-)
-{
- DWORD requiredSize = 0;
-
- SetupDiGetDeviceInstanceIdW(
- devInfo,
- const_cast<SP_DEVINFO_DATA*>(&devInfoData),
- nullptr,
- 0,
- &requiredSize
- );
-
- std::vector<wchar_t> deviceInstanceId(1 + requiredSize);
-
- const auto status = SetupDiGetDeviceInstanceIdW(
- devInfo,
- const_cast<SP_DEVINFO_DATA *>(&devInfoData),
- &deviceInstanceId[0],
- requiredSize,
- nullptr
- );
-
- if (FALSE == status)
- {
- THROW_SETUPAPI_ERROR(GetLastError(), "SetupDiGetDeviceInstanceIdW");
- }
+private:
- return deviceInstanceId.data();
-}
+ const std::vector<std::wstring> &m_args;
+ size_t m_remaining;
+};
-bool TryGetRegistryValueTimeout(
- HKEY key,
- const wchar_t *subkey,
- const wchar_t *value,
- DWORD flags,
- DWORD *type,
- void *data,
- DWORD *dataSize
-)
+void ResetDriverState()
{
- HANDLE changeEvent = nullptr;
+ auto deviceHandle = OpenSplitTunnelDevice();
- common::memory::ScopeDestructor scopeDestructor;
- scopeDestructor += [changeEvent]() {
- if (nullptr != changeEvent)
- {
- CloseHandle(changeEvent);
- }
- };
+ common::memory::ScopeDestructor dtor;
- auto initialTime = std::chrono::steady_clock::now();
-
- for (;;)
+ dtor += [deviceHandle]()
{
- const auto status = RegGetValueW(key, subkey, value, flags, type, data, dataSize);
-
- if (ERROR_SUCCESS == status)
- {
- // We're done
- return true;
- }
-
- if (ERROR_FILE_NOT_FOUND != status)
- {
- THROW_WINDOWS_ERROR(status, "RegGetValueW");
- }
-
- if (nullptr == changeEvent)
- {
- changeEvent = CreateEventW(nullptr, FALSE, FALSE, nullptr);
-
- if (nullptr == changeEvent)
- {
- THROW_WINDOWS_ERROR(GetLastError(), "CreateEventW");
- }
- }
-
- //
- // Wait for the registry value to be created
- //
-
- auto currentTime = std::chrono::steady_clock::now();
- auto elapsedTime = std::chrono::duration_cast<std::chrono::milliseconds>(currentTime - initialTime);
- auto timeDelta = (REGISTRY_GET_TIMEOUT_MS - elapsedTime).count();
-
- if (timeDelta <= 0)
- {
- return false;
- }
-
- const auto notifyResult = RegNotifyChangeKeyValue(
- key,
- subkey != nullptr, // Watch subkeys
- REG_NOTIFY_CHANGE_LAST_SET,
- changeEvent,
- TRUE
- );
-
- if (ERROR_SUCCESS != notifyResult)
- {
- THROW_WINDOWS_ERROR(notifyResult, "RegNotifyChangeKeyValue");
- }
-
- const auto waitResult = WaitForSingleObject(changeEvent, static_cast<DWORD>(timeDelta));
- if (WAIT_OBJECT_0 == waitResult)
- {
- // Try again
- continue;
- }
-
- if (WAIT_TIMEOUT != waitResult)
- {
- THROW_WINDOWS_ERROR(GetLastError(), "WaitForSingleObject");
- }
-
- return false;
- }
-}
-
-std::wstring GetNetCfgInstanceId(HDEVINFO devInfo, const SP_DEVINFO_DATA &devInfoData)
-{
- HKEY hNet = SetupDiOpenDevRegKey(
- devInfo,
- const_cast<SP_DEVINFO_DATA *>(&devInfoData),
- DICS_FLAG_GLOBAL,
- 0,
- DIREG_DRV,
- KEY_READ
- );
-
- if (hNet == INVALID_HANDLE_VALUE)
- {
- THROW_SETUPAPI_ERROR(GetLastError(), "SetupDiOpenDevRegKey");
- }
-
- common::memory::ScopeDestructor scopeDestructor;
- scopeDestructor += [hNet]() {
- RegCloseKey(hNet);
+ CloseSplitTunnelDevice(deviceHandle);
};
- std::vector<wchar_t> instanceId(MAX_PATH + 1);
- DWORD strSize = static_cast<DWORD>(instanceId.size() * sizeof(wchar_t));
-
- if (!TryGetRegistryValueTimeout(
- hNet,
- nullptr,
- L"NetCfgInstanceId",
- RRF_RT_REG_SZ,
- nullptr,
- instanceId.data(),
- &strSize
- ))
- {
- THROW_ERROR("Timed out waiting for NetCfgInstanceId.");
- }
-
- return instanceId.data();
+ SendIoControlReset(deviceHandle);
}
-bool DeleteDevice(HDEVINFO devInfo, const SP_DEVINFO_DATA &devInfoData)
+std::unique_ptr<DeviceEnumerator> CreateSplitTunnelDeviceEnumerator()
{
- const auto data = const_cast<SP_DEVINFO_DATA *>(&devInfoData);
-
- wchar_t devId[MAX_DEVICE_ID_LEN];
- if (CR_SUCCESS != CM_Get_Device_IDW(data->DevInst, devId, sizeof(devId) / sizeof(devId[0]), 0))
+ return DeviceEnumerator::Create(WFP_CALLOUTS_CLASS_ID, [](HDEVINFO deviceInfoSet, const SP_DEVINFO_DATA &deviceInfo)
{
- // skip
- return false;
- }
+ auto candidateDeviceName = GetDeviceStringProperty(deviceInfoSet, deviceInfo, &DEVPKEY_NAME);
- SP_REMOVEDEVICE_PARAMS rmdParams = { 0 };
- rmdParams.ClassInstallHeader.cbSize = sizeof(SP_CLASSINSTALL_HEADER);
- rmdParams.ClassInstallHeader.InstallFunction = DIF_REMOVE;
- rmdParams.Scope = DI_REMOVEDEVICE_GLOBAL;
- rmdParams.HwProfile = 0;
-
- auto status = SetupDiSetClassInstallParamsW(devInfo, data, &rmdParams.ClassInstallHeader, sizeof(rmdParams));
- if (FALSE == status)
- {
- THROW_SETUPAPI_ERROR(GetLastError(), "SetupDiSetClassInstallParamsW");
- }
-
- status = SetupDiCallClassInstaller(DIF_REMOVE, devInfo, data);
- if (FALSE == status)
- {
- THROW_SETUPAPI_ERROR(GetLastError(), "SetupDiCallClassInstaller");
- }
-
- return true;
-}
-
-void ForEachNetworkDevice(const std::optional<std::wstring> hwId, std::function<bool(HDEVINFO, const SP_DEVINFO_DATA &)> func)
-{
- HDEVINFO devInfo = SetupDiGetClassDevsW(
- &GUID_DEVCLASS_NET,
- nullptr,
- nullptr,
- DIGCF_PRESENT
- );
-
- if (INVALID_HANDLE_VALUE == devInfo)
- {
- THROW_SETUPAPI_ERROR(GetLastError(), "SetupDiGetClassDevsW");
- }
-
- common::memory::ScopeDestructor cleanupDevList;
- cleanupDevList += [&devInfo]()
- {
- SetupDiDestroyDeviceInfoList(devInfo);
- };
-
- for (int memberIndex = 0; ; memberIndex++)
- {
- SP_DEVINFO_DATA devInfoData = { 0 };
- devInfoData.cbSize = sizeof(devInfoData);
-
- if (FALSE == SetupDiEnumDeviceInfo(devInfo, memberIndex, &devInfoData))
- {
- const auto lastError = GetLastError();
-
- if (ERROR_NO_MORE_ITEMS == lastError)
- {
- break;
- }
-
- THROW_SETUPAPI_ERROR(lastError, "Enumerating network adapters");
- }
-
- if (hwId.has_value())
- {
- try
- {
- const auto hardwareId = GetDeviceRegistryStringProperty(devInfo, devInfoData, SPDRP_HARDWAREID);
-
- if (!hardwareId.has_value() ||
- 0 != hwId->compare(hardwareId.value()))
- {
- continue;
- }
- }
- catch (const std::exception & e)
- {
- //
- // Skip this adapter
- //
-
- std::wstringstream ss;
- ss << L"Skipping virtual adapter due to exception caught while iterating: "
- << common::string::ToWide(e.what());
- LogError(ss.str());
- continue;
- }
- }
-
- if (!func(devInfo, devInfoData))
- {
- break;
- }
- }
-}
-
-std::set<NetworkAdapter> GetNetworkAdapters(const std::optional<std::wstring> hardwareId)
-{
- std::set<NetworkAdapter> adapters;
- common::network::Nci nci;
-
- ForEachNetworkDevice(hardwareId, [&](HDEVINFO devInfo, const SP_DEVINFO_DATA &devInfoData) {
- try
- {
- //
- // Construct NetworkAdapter
- //
-
- const std::wstring guid = GetNetCfgInstanceId(devInfo, devInfoData);
- GUID guidObj = common::Guid::FromString(guid);
-
- adapters.emplace(NetworkAdapter(
- guid,
- GetDeviceStringProperty(devInfo, devInfoData, &DEVPKEY_Device_DriverDesc),
- nci.getConnectionName(guidObj),
- GetDeviceInstanceId(devInfo, devInfoData)
- ));
- }
- catch (const std::exception & e)
- {
- //
- // Skip this adapter
- //
-
- std::wstringstream ss;
- ss << L"Skipping adapter due to exception caught while iterating: "
- << common::string::ToWide(e.what());
- LogError(ss.str());
- }
- return true;
+ return 0 == candidateDeviceName.compare(SPLIT_TUNNEL_DEVICE_NAME);
});
-
- return adapters;
-}
-
-void throwUpdateException(DWORD lastError, const char *operation)
-{
- if (ERROR_DEVICE_INSTALLER_NOT_READY == lastError)
- {
- bool deviceInstallDisabled = false;
-
- try
- {
- const auto key = common::registry::Registry::OpenKey(
- HKEY_LOCAL_MACHINE,
- L"SYSTEM\\CurrentControlSet\\Services\\DeviceInstall\\Parameters"
- );
- deviceInstallDisabled = (0 != key->readUint32(L"DeviceInstallDisabled"));
- }
- catch (...)
- {
- }
-
- if (deviceInstallDisabled)
- {
- throw common::error::WindowsException(
- "Device installs must be enabled to continue. "
- "Enable them in the Local Group Policy editor, or "
- "update the registry value DeviceInstallDisabled in "
- "[HKEY_LOCAL_MACHINE\\SYSTEM\\CurrentControlSet\\Services\\DeviceInstall\\Parameters]",
- lastError
- );
- }
- }
-
- THROW_SETUPAPI_ERROR(lastError, operation);
}
//
-// Broken adapters may use our "Mullvad" name, so find one that is not in use.
-// NOTE: Enumerating adapters first and picking the next free name is not sufficient,
-// because the broken adapter may not be included.
+// CommandSplitTunnelEvaluate()
//
-void RenameAdapter(const std::wstring &guid, const std::wstring &baseName)
-{
- common::network::Nci nci;
-
- try
- {
- nci.setConnectionName(common::Guid::FromString(guid), baseName.c_str());
- return;
- }
- catch (...)
- {
- }
-
- for (int i = 1; i < 10; i++)
- {
- std::wstringstream ss;
- ss << baseName << L"-" << i;
-
- try
- {
- nci.setConnectionName(common::Guid::FromString(guid), ss.str().c_str());
- return;
- }
- catch (...)
- {
- }
- }
-
- THROW_ERROR("Unable to rename network adapter");
-}
-
-void CreateNetDevice(const std::wstring &hardwareId, const std::optional<std::wstring> alias, bool installDeviceDriver)
+// Search for existing device.
+// Evaluate if provided inf can/should be installed.
+//
+ReturnCode CommandSplitTunnelEvaluate(const std::vector<std::wstring> &args)
{
- GUID classGuid = GUID_DEVCLASS_NET;
-
- const auto deviceInfoSet = SetupDiCreateDeviceInfoList(&classGuid, 0);
- if (INVALID_HANDLE_VALUE == deviceInfoSet)
- {
- THROW_SETUPAPI_ERROR(GetLastError(), "SetupDiCreateDeviceInfoList");
- }
+ ArgumentContext argsContext(args);
- common::memory::ScopeDestructor scopeDestructor;
- scopeDestructor += [&deviceInfoSet]()
- {
- SetupDiDestroyDeviceInfoList(deviceInfoSet);
- };
+ argsContext.ensureExactArgumentCount(1);
- SP_DEVINFO_DATA devInfoData;
- devInfoData.cbSize = sizeof(SP_DEVINFO_DATA);
+ const auto infPath = argsContext.next();
- auto status = SetupDiCreateDeviceInfoW(
- deviceInfoSet,
- L"NET",
- &classGuid,
- nullptr,
- 0,
- DICD_GENERATE_ID,
- &devInfoData
- );
+ //
+ // Find first matching device
+ //
- if (FALSE == status)
- {
- THROW_SETUPAPI_ERROR(GetLastError(), "SetupDiCreateDeviceInfoW");
- }
+ auto enumerator = CreateSplitTunnelDeviceEnumerator();
- status = SetupDiSetDeviceRegistryPropertyW(
- deviceInfoSet,
- &devInfoData,
- SPDRP_HARDWAREID,
- reinterpret_cast<const BYTE *>(hardwareId.c_str()),
- static_cast<DWORD>(sizeof(wchar_t) * hardwareId.size())
- );
+ EnumeratedDevice device;
- if (FALSE == status)
+ if (!enumerator->next(device))
{
- THROW_SETUPAPI_ERROR(GetLastError(), "SetupDiSetDeviceRegistryPropertyW");
+ return ReturnCode::ST_DRIVER_NONE_INSTALLED;
}
//
- // Create a devnode in the PnP HW tree
+ // Retrieve driver versions
//
- status = SetupDiCallClassInstaller(
- DIF_REGISTERDEVICE,
- deviceInfoSet,
- &devInfoData
- );
- if (FALSE == status)
- {
- THROW_SETUPAPI_ERROR(GetLastError(), "SetupDiCallClassInstaller");
- }
+ auto existingVersion = GetDriverVersion(device);
+ auto proposedVersion = InfGetDriverVersion(infPath);
- Log(L"Created new network adapter successfully");
+ //
+ // Compare driver versions
+ //
- if (installDeviceDriver)
+ switch (EvaluateDriverUpgrade(existingVersion, proposedVersion))
{
- BOOL rebootRequired = FALSE;
-
- if (FALSE == DiInstallDevice(
- nullptr,
- deviceInfoSet,
- &devInfoData,
- nullptr,
- 0,
- &rebootRequired
- ))
- {
- throwUpdateException(GetLastError(), "DiInstallDevice");
- }
-
- std::wstringstream ss;
- ss << L"Installed driver on device. Reboot required: "
- << rebootRequired;
- Log(ss.str());
+ case DRIVER_UPGRADE_STATUS::WOULD_UPGRADE:
+ return ReturnCode::ST_DRIVER_OLDER_VERSION_INSTALLED;
+ case DRIVER_UPGRADE_STATUS::WOULD_DOWNGRADE:
+ return ReturnCode::ST_DRIVER_NEWER_VERSION_INSTALLED;
+ case DRIVER_UPGRADE_STATUS::WOULD_INSTALL_SAME_VERSION:
+ return ReturnCode::ST_DRIVER_SAME_VERSION_INSTALLED;
+ default:
+ Log(L"Unexpected return value from EvaluateDriverUpgrade()");
}
- if (alias.has_value())
- {
- RenameAdapter(
- GetNetCfgInstanceId(deviceInfoSet, devInfoData),
- alias.value()
- );
- }
+ return ReturnCode::GENERAL_ERROR;
}
-std::wstring FindFreeAdapterAlias(const std::set<NetworkAdapter> &adapters, const std::wstring &baseName)
+ReturnCode CommandSplitTunnelNewInstall(const std::vector<std::wstring> &args)
{
- if (adapters.empty())
- {
- return baseName;
- }
-
- auto findByAlias = [](const std::set<NetworkAdapter> &adapters, const std::wstring &alias)
- {
- const auto it = std::find_if(adapters.begin(), adapters.end(), [&alias](const NetworkAdapter &candidate)
- {
- return 0 == _wcsicmp(candidate.alias.c_str(), alias.c_str());
- });
-
- return it;
- };
-
- const auto foundAdapter = findByAlias(adapters, baseName);
+ ArgumentContext argsContext(args);
- if (adapters.end() == foundAdapter)
- {
- return baseName;
- }
+ argsContext.ensureExactArgumentCount(1);
- for (auto i = 1; i < 100; ++i)
- {
- std::wstringstream ss;
+ const auto infPath = argsContext.next();
- ss << baseName << L"-" << i;
+ CreateDevice(WFP_CALLOUTS_CLASS_ID, SPLIT_TUNNEL_DEVICE_NAME, SPLIT_TUNNEL_HARDWARE_ID);
- const auto alias = ss.str();
- const auto nextAdapter = findByAlias(adapters, alias);
+ InstallDriverForDevice(SPLIT_TUNNEL_HARDWARE_ID, infPath);
- if (adapters.end() == nextAdapter)
- {
- return alias;
- }
- }
-
- THROW_ERROR("Cannot find an unused adapter alias")
+ return ReturnCode::GENERAL_SUCCESS;
}
-std::optional<NetworkAdapter> FindAdapterByAlias(const std::set<NetworkAdapter> &tapAdapters, const std::wstring &baseName)
+//
+// CommandSplitTunnelRemove()
+//
+// Reset driver
+// Uninstall device
+// Stop service
+// Delete service
+//
+ReturnCode CommandSplitTunnelRemove(const std::vector<std::wstring> &args)
{
- if (tapAdapters.empty())
- {
- return std::nullopt;
- }
-
- //
- // Look for TAP adapter with aliases starting with baseName.
- //
-
- auto findByAlias = [](const std::set<NetworkAdapter> &adapters, const std::wstring &alias)
- {
- const auto it = std::find_if(adapters.begin(), adapters.end(), [&alias](const NetworkAdapter &candidate)
- {
- return 0 == _wcsicmp(candidate.alias.c_str(), alias.c_str());
- });
-
- return it;
- };
+ ArgumentContext argsContext(args);
- const auto firstMullvadAdapter = findByAlias(tapAdapters, baseName);
-
- if (tapAdapters.end() != firstMullvadAdapter)
- {
- return { *firstMullvadAdapter };
- }
+ argsContext.ensureExactArgumentCount(0);
//
- // Look for TAP adapter with alias "Mullvad-1", "Mullvad-2", etc.
+ // Find first matching device
//
- for (auto i = 1; i < 10; ++i)
- {
- std::wstringstream ss;
-
- ss << baseName << L"-" << i;
+ auto enumerator = CreateSplitTunnelDeviceEnumerator();
- const auto alias = ss.str();
+ EnumeratedDevice device;
- const auto mullvadAdapter = findByAlias(tapAdapters, alias);
+ if (!enumerator->next(device))
+ {
+ Log(L"Could not find split tunnel device");
- if (tapAdapters.end() != mullvadAdapter)
- {
- return { *mullvadAdapter };
- }
+ return ReturnCode::GENERAL_SUCCESS;
}
- return std::nullopt;
-}
+ ResetDriverState();
-bool RemoveNetDevice(const std::optional<std::wstring> tapHardwareId, const std::wstring &guid)
-{
- bool deletedAdapter = false;
+ UninstallDevice(device);
- ForEachNetworkDevice(tapHardwareId, [&](HDEVINFO devInfo, const SP_DEVINFO_DATA &devInfoData) {
- try
- {
- if (0 == GetNetCfgInstanceId(devInfo, devInfoData).compare(guid))
- {
- deletedAdapter = DeleteDevice(devInfo, devInfoData);
- return false;
- }
- }
- catch (const std::exception & e)
- {
- //
- // Skip this adapter
- //
+ PokeService(L"mullvad-split-tunnel", true, true);
- std::wstringstream ss;
- ss << L"Skipping virtual adapter due to exception caught while iterating: "
- << common::string::ToWide(e.what());
- LogError(ss.str());
- }
- return true;
- });
-
- return deletedAdapter;
+ return ReturnCode::GENERAL_SUCCESS;
}
-void RemoveNetAdapterByAlias(const std::wstring &hardwareId, const std::wstring &baseName)
+//
+// CommandSplitTunnelForceInstall()
+//
+// There's an existing device that needs to be stopped and removed.
+// After this, create a new device and associate the specified inf.
+//
+ReturnCode CommandSplitTunnelForceInstall(const std::vector<std::wstring> &args)
{
- auto tapAdapters = GetNetworkAdapters(hardwareId);
- std::optional<NetworkAdapter> adapter = FindAdapterByAlias(tapAdapters, baseName);
+ auto status = CommandSplitTunnelRemove({});
- if (!adapter.has_value())
+ if (ReturnCode::GENERAL_SUCCESS != status)
{
- return;
+ return status;
}
- const auto guid = adapter.value().guid;
-
- //
- // Enumerate over all network devices with the hardware ID,
- // and delete any adapter whose GUID matches that of the "Mullvad" adapter.
- //
-
- if (!RemoveNetDevice(std::make_optional(hardwareId), guid))
- {
- THROW_ERROR("The virtual adapter could not be removed");
- }
+ return CommandSplitTunnelNewInstall(args);
}
-std::filesystem::path GetCurrentModulePath()
+ReturnCode CommandWintunDeletePool(const std::vector<std::wstring> &args)
{
- std::vector<wchar_t> pathBuffer;
+ ArgumentContext argsContext(args);
+
+ argsContext.ensureExactArgumentCount(1);
- SetLastError(ERROR_SUCCESS);
+ const auto poolName = argsContext.next();
- size_t nextCapacity = MAX_PATH;
- DWORD writtenChars = 0;
+ WintunDll wintun;
- do
+ BOOL rebootRequired;
+
+ if (FALSE == wintun.deletePoolDriver(poolName.c_str(), &rebootRequired))
{
- pathBuffer.resize(nextCapacity);
- writtenChars = GetModuleFileNameW(nullptr, &pathBuffer[0], static_cast<DWORD>(pathBuffer.size()));
+ throw std::runtime_error("Failed to delete wintun pool");
+ }
- if (0 == writtenChars)
- {
- THROW_WINDOWS_ERROR(GetLastError(), "GetModuleFileNameW");
- }
+ std::wstringstream ss;
- nextCapacity = 2 * pathBuffer.size();
- } while (ERROR_INSUFFICIENT_BUFFER == GetLastError());
+ ss << L"Successfully deleted wintun pool. Reboot required: " << rebootRequired;
- pathBuffer.resize(writtenChars);
+ Log(ss.str());
- return std::filesystem::path(pathBuffer.begin(), pathBuffer.end());
+ return ReturnCode::GENERAL_SUCCESS;
}
-class WintunDll
+ReturnCode CommandWintunDeleteAbandonedDevice(const std::vector<std::wstring> &args)
{
-public:
-
- WintunDll() : dllHandle(nullptr)
- {
- auto wintunPath = GetCurrentModulePath().replace_filename(L"wintun.dll");
- dllHandle = LoadLibraryExW(wintunPath.c_str(), nullptr, LOAD_WITH_ALTERED_SEARCH_PATH);
+ ArgumentContext argsContext(args);
- if (nullptr == dllHandle)
- {
- THROW_WINDOWS_ERROR(GetLastError(), "LoadLibraryExW");
- }
+ argsContext.ensureExactArgumentCount(0);
- try
- {
- createAdapter = getProcAddressOrThrow<WINTUN_CREATE_ADAPTER_FUNC>("WintunCreateAdapter");
- openAdapter = getProcAddressOrThrow<WINTUN_OPEN_ADAPTER_FUNC>("WintunOpenAdapter");
- freeAdapter = getProcAddressOrThrow<WINTUN_FREE_ADAPTER_FUNC>("WintunFreeAdapter");
- deletePoolDriver = getProcAddressOrThrow<WINTUN_DELETE_POOL_DRIVER_FUNC>("WintunDeletePoolDriver");
- }
- catch (...)
- {
- FreeLibrary(dllHandle);
- throw;
- }
- }
-
- ~WintunDll()
+ auto enumerator = DeviceEnumerator::Create(GUID_DEVCLASS_NET, [](HDEVINFO deviceInfoSet, const SP_DEVINFO_DATA &deviceInfo)
{
- if (nullptr != dllHandle)
- {
- FreeLibrary(dllHandle);
- }
- }
+ static wchar_t WintunMullvadAdapter[] = L"{AFE43773-E1F8-4EBB-8536-576AB86AFE9A}";
- WINTUN_CREATE_ADAPTER_FUNC createAdapter;
- WINTUN_OPEN_ADAPTER_FUNC openAdapter;
- WINTUN_FREE_ADAPTER_FUNC freeAdapter;
- WINTUN_DELETE_POOL_DRIVER_FUNC deletePoolDriver;
+ auto candidateAdapterGuid = GetDeviceNetCfgInstanceId(deviceInfoSet, deviceInfo);
-private:
+ return 0 == _wcsicmp(candidateAdapterGuid.c_str(), WintunMullvadAdapter);
+ });
+
+ EnumeratedDevice device;
- template<typename T>
- T getProcAddressOrThrow(const char *procName)
+ if (enumerator->next(device))
{
- const T result = reinterpret_cast<T>(GetProcAddress(dllHandle, procName));
- if (nullptr == result)
- {
- THROW_WINDOWS_ERROR(GetLastError(), "GetProcAddress");
- }
- return result;
+ UninstallDevice(device);
}
- HMODULE dllHandle;
-};
+ return GENERAL_SUCCESS;
+}
-int HandleWintunCommands(int argc, const wchar_t *argv[])
-{
- WintunDll wintun;
+} // anonymous namespace
- if (argc < 3)
+int wmain(int argc, const wchar_t *argv[])
+{
+ if (-1 == _setmode(_fileno(stdout), _O_U16TEXT)
+ || -1 == _setmode(_fileno(stderr), _O_U16TEXT))
{
- goto INVALID_ARGUMENTS;
+ Log(L"Failed to set translation mode");
}
- if (0 == _wcsicmp(argv[2], L"create-adapter"))
+ if (argc < 2)
{
- if (argc < 5)
- {
- goto INVALID_ARGUMENTS;
- }
-
- const wchar_t *pool = argv[3];
- const wchar_t *adapter = argv[4];
+ Log(L"Command not specified");
- GUID guidObject;
- const GUID *requestGuid = nullptr;
- if (argc >= 6)
- {
- guidObject = common::Guid::FromString(argv[5]);
- requestGuid = &guidObject;
- }
-
- const auto adapters = GetNetworkAdapters(std::nullopt);
- const auto freeAdapterName = FindFreeAdapterAlias(adapters, adapter);
-
- const auto handle = wintun.createAdapter(
- pool,
- freeAdapterName.c_str(),
- requestGuid,
- nullptr
- );
-
- if (nullptr == handle)
- {
- const auto status = GetLastError();
- if (ERROR_FILE_NOT_FOUND == status)
- {
- return ADAPTER_NOT_FOUND;
- }
- else
- {
- THROW_WINDOWS_ERROR(status, "wintun.createAdapter");
- }
- }
- wintun.freeAdapter(handle);
+ return ReturnCode::GENERAL_ERROR;
}
- else if (0 == _wcsicmp(argv[2], L"delete-pool-driver"))
- {
- if (4 != argc)
- {
- goto INVALID_ARGUMENTS;
- }
-
- const wchar_t *pool = argv[3];
- wintun.deletePoolDriver(pool, nullptr);
- }
- else if (0 == _wcsicmp(argv[2], L"adapter-exists"))
- {
- if (5 != argc)
- {
- goto INVALID_ARGUMENTS;
- }
+ //
+ // Re-package command arguments
+ //
- const wchar_t *pool = argv[3];
- const wchar_t *adapter = argv[4];
+ const std::wstring command = argv[1];
- const auto handle = wintun.openAdapter(pool, adapter);
+ std::vector<std::wstring> arguments;
- if (nullptr == handle)
- {
- const auto status = GetLastError();
- if (ERROR_FILE_NOT_FOUND == status)
- {
- return ADAPTER_NOT_FOUND;
- }
- else
- {
- THROW_WINDOWS_ERROR(status, "wintun.openAdapter");
- }
- }
- wintun.freeAdapter(handle);
- }
- else
+ for (size_t argumentIndex = 2; argumentIndex < argc; ++argumentIndex)
{
- goto INVALID_ARGUMENTS;
+ arguments.emplace_back(argv[argumentIndex]);
}
- return GENERAL_SUCCESS;
-
-INVALID_ARGUMENTS:
-
- LogError(L"Invalid arguments.");
- return GENERAL_ERROR;
-}
-
-} // anonymous namespace
+ //
+ // Declare all handlers
+ //
-int wmain(int argc, const wchar_t * argv[], const wchar_t * [])
-{
- if (-1 == _setmode(_fileno(stdout), _O_U16TEXT)
- || -1 == _setmode(_fileno(stderr), _O_U16TEXT))
+ struct CommandHandler
{
- LogError(L"Failed to set translation mode");
- }
+ std::wstring commandName;
+ std::function<ReturnCode(const std::vector<std::wstring> &)> handler;
+ };
- if (2 > argc)
+ std::vector<CommandHandler> handlers =
{
- goto INVALID_ARGUMENTS;
- }
+ { L"st-evaluate", CommandSplitTunnelEvaluate },
+ { L"st-new-install", CommandSplitTunnelNewInstall },
+ { L"st-force-install", CommandSplitTunnelForceInstall },
+ { L"st-remove", CommandSplitTunnelRemove },
+ { L"wintun-delete-pool-driver", CommandWintunDeletePool },
+ { L"wintun-delete-abandoned-device", CommandWintunDeleteAbandonedDevice }
+ };
+
+ //
+ // Find and invoke matching handler
+ //
- try
+ for (const auto &candidate : handlers)
{
- if (0 == _wcsicmp(argv[1], L"new-device"))
+ if (0 != _wcsicmp(command.c_str(), candidate.commandName.c_str()))
{
- if (4 != argc)
- {
- goto INVALID_ARGUMENTS;
- }
-
- const wchar_t *hardwareId = argv[2];
- const wchar_t *baseName = argv[3];
-
- CreateNetDevice(hardwareId, baseName, true);
+ continue;
}
- else if (0 == _wcsicmp(argv[1], L"remove-device"))
- {
- if (4 != argc)
- {
- goto INVALID_ARGUMENTS;
- }
-
- const wchar_t *hardwareId = argv[2];
- const wchar_t *baseName = argv[3];
- RemoveNetAdapterByAlias(hardwareId, baseName);
- }
- else if (0 == _wcsicmp(argv[1], L"remove-device-by-guid"))
+ try
{
- if (3 != argc)
- {
- goto INVALID_ARGUMENTS;
- }
-
- const wchar_t *guid = argv[2];
-
- if (!RemoveNetDevice(std::nullopt, guid))
- {
- return ADAPTER_NOT_FOUND;
- }
+ return candidate.handler(arguments);
}
- else if (0 == _wcsicmp(argv[1], L"device-exists"))
+ catch (const common::error::WindowsException &e)
{
- if (4 != argc)
- {
- goto INVALID_ARGUMENTS;
- }
-
- const wchar_t *hardwareId = argv[2];
- const wchar_t *baseName = argv[3];
-
- const auto virtualAdapters = GetNetworkAdapters(hardwareId);
- const auto adapter = FindAdapterByAlias(virtualAdapters, baseName);
-
- if (!adapter.has_value())
- {
- return ADAPTER_NOT_FOUND;
- }
+ Log(common::string::ToWide(e.what()));
+ return e.errorCode();
}
- else if (0 == _wcsicmp(argv[1], L"wintun"))
+ catch (const std::exception &e)
{
- return HandleWintunCommands(argc, argv);
+ Log(common::string::ToWide(e.what()));
+ return GENERAL_ERROR;
}
- else
+ catch (...)
{
- goto INVALID_ARGUMENTS;
+ Log(L"Unknown exception was raised/thrown");
+ return GENERAL_ERROR;
}
}
- catch (const common::error::WindowsException &e)
- {
- LogError(common::string::ToWide(e.what()));
- return e.errorCode();
- }
- catch (const std::exception &e)
- {
- LogError(common::string::ToWide(e.what()));
- return GENERAL_ERROR;
- }
- catch (...)
- {
- LogError(L"Unhandled exception.");
- return GENERAL_ERROR;
- }
- return GENERAL_SUCCESS;
-INVALID_ARGUMENTS:
+ //
+ // Could not find matching handler
+ //
- LogError(L"Invalid arguments.");
+ Log(L"Could not find handler for specified command");
return GENERAL_ERROR;
}
diff --git a/windows/driverlogic/src/log.cpp b/windows/driverlogic/src/log.cpp
new file mode 100644
index 0000000000..0840b02cbf
--- /dev/null
+++ b/windows/driverlogic/src/log.cpp
@@ -0,0 +1,13 @@
+#include "stdafx.h"
+#include "log.h"
+#include <iostream>
+
+void Log(const wchar_t *str)
+{
+ std::wcout << str << std::endl;
+}
+
+void Log(const std::wstring &str)
+{
+ Log(str.c_str());
+}
diff --git a/windows/driverlogic/src/log.h b/windows/driverlogic/src/log.h
new file mode 100644
index 0000000000..6b1e67a154
--- /dev/null
+++ b/windows/driverlogic/src/log.h
@@ -0,0 +1,7 @@
+#pragma once
+
+#include <string>
+
+void Log(const wchar_t *str);
+
+void Log(const std::wstring &str);
diff --git a/windows/driverlogic/src/service.cpp b/windows/driverlogic/src/service.cpp
new file mode 100644
index 0000000000..14ed7880fa
--- /dev/null
+++ b/windows/driverlogic/src/service.cpp
@@ -0,0 +1,134 @@
+#include "stdafx.h"
+#include "service.h"
+#include "log.h"
+#include <libcommon/error.h>
+#include <libcommon/memory.h>
+
+#undef min
+#undef max
+#include <chrono>
+
+template<typename TTime = std::chrono::milliseconds>
+class TimeBox
+{
+ // `steady_clock` wraps around every ~292 years.
+ using Clock = std::chrono::steady_clock;
+ using ClockTimePoint = std::chrono::time_point<Clock>;
+
+public:
+
+ TimeBox(typename TTime::rep maxWaitTime)
+ : m_startTime(Clock::now())
+ , m_maxWaitTime(TTime(maxWaitTime))
+ {
+ }
+
+ bool expired() const
+ {
+ const auto now = Clock::now();
+
+ const auto elapsed =
+ (
+ (now < m_startTime)
+ ? (ClockTimePoint::max() - m_startTime) + (now - ClockTimePoint::min())
+ : now - m_startTime
+ );
+
+ return std::chrono::duration_cast<TTime>(elapsed) > m_maxWaitTime;
+ }
+
+private:
+
+ ClockTimePoint m_startTime;
+ TTime m_maxWaitTime;
+};
+
+void WaitUntilServiceStopped(SC_HANDLE service, DWORD maxWaitMs)
+{
+ TimeBox timer(maxWaitMs);
+
+ for (;;)
+ {
+ SERVICE_STATUS_PROCESS ssp;
+
+ DWORD bytesNeeded;
+
+ auto status = QueryServiceStatusEx
+ (
+ service,
+ SC_STATUS_PROCESS_INFO,
+ reinterpret_cast<BYTE*>(&ssp),
+ sizeof(ssp),
+ &bytesNeeded
+ );
+
+ if (status != 0
+ && ssp.dwCurrentState == SERVICE_STOPPED)
+ {
+ return;
+ }
+
+ if (timer.expired())
+ {
+ THROW_ERROR("Failed when waiting for service to stop");
+ }
+
+ Sleep(100);
+ }
+}
+
+void PokeService(const std::wstring &serviceName, bool stopService, bool deleteService)
+{
+ auto serviceManager = OpenSCManagerW(nullptr, SERVICES_ACTIVE_DATABASE, SC_MANAGER_ALL_ACCESS);
+
+ if (serviceManager == NULL)
+ {
+ THROW_WINDOWS_ERROR(GetLastError(), "OpenSCManagerW");
+ }
+
+ common::memory::ScopeDestructor dtor;
+
+ dtor += [serviceManager]()
+ {
+ CloseServiceHandle(serviceManager);
+ };
+
+ auto service = OpenServiceW(serviceManager, serviceName.c_str(), SERVICE_ALL_ACCESS);
+
+ if (service == NULL)
+ {
+ THROW_WINDOWS_ERROR(GetLastError(), "OpenServiceW");
+ }
+
+ dtor += [service]()
+ {
+ CloseServiceHandle(service);
+ };
+
+ if (stopService)
+ {
+ Log(L"Stopping service");
+
+ SERVICE_STATUS ss;
+
+ ControlService(service, SERVICE_CONTROL_STOP, &ss);
+
+ WaitUntilServiceStopped(service, 1000 * 5);
+
+ Log(L"Successfully stopped service");
+ }
+
+ if (deleteService)
+ {
+ Log(L"Deleting service");
+
+ auto status = DeleteService(service);
+
+ if (status == 0)
+ {
+ THROW_WINDOWS_ERROR(GetLastError(), "DeleteService");
+ }
+
+ Log(L"Successfully deleted service");
+ }
+}
diff --git a/windows/driverlogic/src/service.h b/windows/driverlogic/src/service.h
new file mode 100644
index 0000000000..87632faf1b
--- /dev/null
+++ b/windows/driverlogic/src/service.h
@@ -0,0 +1,7 @@
+#pragma once
+
+#include <windows.h>
+
+void WaitUntilServiceStopped(SC_HANDLE service, DWORD maxWaitMs);
+
+void PokeService(const std::wstring &serviceName, bool stopService, bool deleteService);
diff --git a/windows/driverlogic/src/util.cpp b/windows/driverlogic/src/util.cpp
new file mode 100644
index 0000000000..ebce6d82cb
--- /dev/null
+++ b/windows/driverlogic/src/util.cpp
@@ -0,0 +1,36 @@
+#include "stdafx.h"
+#include "util.h"
+#include <windows.h>
+#include <vector>
+#include <libcommon/error.h>
+
+using path = std::filesystem::path;
+
+path
+GetProcessModulePath
+(
+)
+{
+ size_t bufferSize = MAX_PATH;
+
+ std::vector<wchar_t> pathBuffer(bufferSize);
+
+ for (;;)
+ {
+ const auto writtenChars = GetModuleFileNameW(nullptr, &pathBuffer[0], static_cast<DWORD>(pathBuffer.size()));
+
+ if (0 == writtenChars)
+ {
+ THROW_WINDOWS_ERROR(GetLastError(), "GetModuleFileNameW");
+ }
+
+ if (writtenChars != pathBuffer.size())
+ {
+ return path(pathBuffer.begin(), pathBuffer.begin() + writtenChars);
+ }
+
+ bufferSize *= 2;
+
+ pathBuffer.resize(bufferSize);
+ }
+}
diff --git a/windows/driverlogic/src/util.h b/windows/driverlogic/src/util.h
new file mode 100644
index 0000000000..1bdcaf9d15
--- /dev/null
+++ b/windows/driverlogic/src/util.h
@@ -0,0 +1,8 @@
+#pragma once
+
+#include <filesystem>
+
+std::filesystem::path
+GetProcessModulePath
+(
+);
diff --git a/windows/driverlogic/src/version.cpp b/windows/driverlogic/src/version.cpp
new file mode 100644
index 0000000000..d7394e224b
--- /dev/null
+++ b/windows/driverlogic/src/version.cpp
@@ -0,0 +1,125 @@
+#include "stdafx.h"
+#include "version.h"
+#include "device.h"
+#include <setupapi.h>
+#include <initguid.h>
+#include <devpkey.h>
+#include <libcommon/string.h>
+#include <libcommon/memory.h>
+#include <stdexcept>
+
+DRIVER_UPGRADE_STATUS
+EvaluateDriverUpgrade
+(
+ const std::wstring &existingVersion,
+ const std::wstring &proposedVersion
+)
+{
+ //
+ // "x.y.z.a"
+ //
+
+ using namespace common::string;
+
+ auto et = Tokenize(existingVersion, L".");
+ auto pt = Tokenize(proposedVersion, L".");
+
+ auto items = min(et.size(), pt.size());
+
+ for (auto index = 0; index < items; ++index)
+ {
+ auto ev = wcstoul(et[index].c_str(), nullptr, 10);
+ auto pv = wcstoul(pt[index].c_str(), nullptr, 10);
+
+ if (pv > ev)
+ {
+ return DRIVER_UPGRADE_STATUS::WOULD_UPGRADE;
+ }
+
+ if (ev > pv)
+ {
+ return DRIVER_UPGRADE_STATUS::WOULD_DOWNGRADE;
+ }
+ }
+
+ if (pt.size() > et.size())
+ {
+ return DRIVER_UPGRADE_STATUS::WOULD_UPGRADE;
+ }
+
+ if (et.size() > pt.size())
+ {
+ return DRIVER_UPGRADE_STATUS::WOULD_DOWNGRADE;
+ }
+
+ return DRIVER_UPGRADE_STATUS::WOULD_INSTALL_SAME_VERSION;
+}
+
+std::wstring
+InfGetDriverVersion
+(
+ const std::wstring &filePath
+)
+{
+ auto infHandle = SetupOpenInfFileW(filePath.c_str(), nullptr, INF_STYLE_WIN4, nullptr);
+
+ if (infHandle == INVALID_HANDLE_VALUE)
+ {
+ throw std::runtime_error("SetupOpenInfFileW()");
+ }
+
+ common::memory::ScopeDestructor dtor;
+
+ dtor += [infHandle]()
+ {
+ SetupCloseInfFile(infHandle);
+ };
+
+ INFCONTEXT infContext { 0 };
+
+ auto status = SetupFindFirstLineW(infHandle, L"Version", L"DriverVer", &infContext);
+
+ if (status == FALSE)
+ {
+ throw std::runtime_error("SetupFindFirstLineW()");
+ }
+
+ DWORD requiredSize;
+
+ //
+ // This is a multi-value key.
+ // 0 = key, 1 = driver date
+ //
+ const DWORD VersionFieldIndex = 2;
+
+ status = SetupGetStringFieldW(&infContext, VersionFieldIndex, nullptr, 0, &requiredSize);
+
+ if (status == FALSE || requiredSize < 2)
+ {
+ throw std::runtime_error("SetupGetStringFieldW()");
+ }
+
+ std::vector<wchar_t> buffer(requiredSize);
+
+ status = SetupGetStringFieldW(&infContext, VersionFieldIndex,
+ &buffer[0], static_cast<DWORD>(buffer.size()), nullptr);
+
+ if (status == FALSE)
+ {
+ throw std::runtime_error("SetupGetStringFieldW()");
+ }
+
+ // Remove null terminator.
+ buffer.resize(requiredSize - 1);
+
+ return buffer.data();
+}
+
+std::wstring
+GetDriverVersion
+(
+ const EnumeratedDevice &device
+)
+{
+ return GetDeviceStringProperty(device.deviceInfoSet, device.deviceInfo, &DEVPKEY_Device_DriverVersion);
+}
diff --git a/windows/driverlogic/src/version.h b/windows/driverlogic/src/version.h
new file mode 100644
index 0000000000..fc88b3b232
--- /dev/null
+++ b/windows/driverlogic/src/version.h
@@ -0,0 +1,30 @@
+#pragma once
+
+#include <string>
+#include "device.h"
+
+enum class DRIVER_UPGRADE_STATUS
+{
+ WOULD_DOWNGRADE,
+ WOULD_INSTALL_SAME_VERSION,
+ WOULD_UPGRADE
+};
+
+DRIVER_UPGRADE_STATUS
+EvaluateDriverUpgrade
+(
+ const std::wstring &existingVersion,
+ const std::wstring &proposedVersion
+);
+
+std::wstring
+InfGetDriverVersion
+(
+ const std::wstring &filePath
+);
+
+std::wstring
+GetDriverVersion
+(
+ const EnumeratedDevice &device
+);
diff --git a/windows/driverlogic/src/wintun.h b/windows/driverlogic/src/wintun.h
new file mode 100644
index 0000000000..3d81be97e3
--- /dev/null
+++ b/windows/driverlogic/src/wintun.h
@@ -0,0 +1,64 @@
+#pragma once
+
+#include <wintun/wintun.h>
+#include <libcommon/error.h>
+#include "util.h"
+
+class WintunDll
+{
+public:
+
+ WintunDll() : dllHandle(nullptr)
+ {
+ auto wintunPath = GetProcessModulePath().replace_filename(L"wintun.dll");
+ dllHandle = LoadLibraryExW(wintunPath.c_str(), nullptr, LOAD_WITH_ALTERED_SEARCH_PATH);
+
+ if (nullptr == dllHandle)
+ {
+ THROW_WINDOWS_ERROR(GetLastError(), "LoadLibraryExW");
+ }
+
+ try
+ {
+ createAdapter = getProcAddressOrThrow<WINTUN_CREATE_ADAPTER_FUNC>("WintunCreateAdapter");
+ openAdapter = getProcAddressOrThrow<WINTUN_OPEN_ADAPTER_FUNC>("WintunOpenAdapter");
+ freeAdapter = getProcAddressOrThrow<WINTUN_FREE_ADAPTER_FUNC>("WintunFreeAdapter");
+ deletePoolDriver = getProcAddressOrThrow<WINTUN_DELETE_POOL_DRIVER_FUNC>("WintunDeletePoolDriver");
+ }
+ catch (...)
+ {
+ FreeLibrary(dllHandle);
+ throw;
+ }
+ }
+
+ ~WintunDll()
+ {
+ if (nullptr != dllHandle)
+ {
+ FreeLibrary(dllHandle);
+ }
+ }
+
+ WINTUN_CREATE_ADAPTER_FUNC createAdapter;
+ WINTUN_OPEN_ADAPTER_FUNC openAdapter;
+ WINTUN_FREE_ADAPTER_FUNC freeAdapter;
+ WINTUN_DELETE_POOL_DRIVER_FUNC deletePoolDriver;
+
+private:
+
+ template<typename T>
+ T getProcAddressOrThrow(const char *procName)
+ {
+ const T result = reinterpret_cast<T>(GetProcAddress(dllHandle, procName));
+
+ if (nullptr == result)
+ {
+ THROW_WINDOWS_ERROR(GetLastError(), "GetProcAddress");
+ }
+
+ return result;
+ }
+
+ HMODULE dllHandle;
+};