diff options
| author | Odd Stranne <odd@mullvad.net> | 2021-03-17 18:06:16 +0100 |
|---|---|---|
| committer | Odd Stranne <odd@mullvad.net> | 2021-07-02 16:31:31 +0200 |
| commit | 15be4405fcbe845f806d0e2a50c4e948e049d0d5 (patch) | |
| tree | 869d26d02bad41af6c6bff61d0831ce3705edabc | |
| parent | 79f52b5adc0d965e1688bb1253a6c782bf74f03f (diff) | |
| download | mullvadvpn-15be4405fcbe845f806d0e2a50c4e948e049d0d5.tar.xz mullvadvpn-15be4405fcbe845f806d0e2a50c4e948e049d0d5.zip | |
Restructure and extend driverlogic
| -rw-r--r-- | windows/driverlogic/driverlogic.vcxproj | 13 | ||||
| -rw-r--r-- | windows/driverlogic/driverlogic.vcxproj.filters | 13 | ||||
| -rw-r--r-- | windows/driverlogic/src/devenum.cpp | 75 | ||||
| -rw-r--r-- | windows/driverlogic/src/devenum.h | 42 | ||||
| -rw-r--r-- | windows/driverlogic/src/device.cpp | 470 | ||||
| -rw-r--r-- | windows/driverlogic/src/device.h | 85 | ||||
| -rw-r--r-- | windows/driverlogic/src/driverlogic.cpp | 1138 | ||||
| -rw-r--r-- | windows/driverlogic/src/log.cpp | 13 | ||||
| -rw-r--r-- | windows/driverlogic/src/log.h | 7 | ||||
| -rw-r--r-- | windows/driverlogic/src/service.cpp | 134 | ||||
| -rw-r--r-- | windows/driverlogic/src/service.h | 7 | ||||
| -rw-r--r-- | windows/driverlogic/src/util.cpp | 36 | ||||
| -rw-r--r-- | windows/driverlogic/src/util.h | 8 | ||||
| -rw-r--r-- | windows/driverlogic/src/version.cpp | 125 | ||||
| -rw-r--r-- | windows/driverlogic/src/version.h | 30 | ||||
| -rw-r--r-- | windows/driverlogic/src/wintun.h | 64 |
16 files changed, 1324 insertions, 936 deletions
diff --git a/windows/driverlogic/driverlogic.vcxproj b/windows/driverlogic/driverlogic.vcxproj index 5d78c90479..b91e97d86c 100644 --- a/windows/driverlogic/driverlogic.vcxproj +++ b/windows/driverlogic/driverlogic.vcxproj @@ -96,14 +96,27 @@ </Link> </ItemDefinitionGroup> <ItemGroup> + <ClCompile Include="src\devenum.cpp" /> + <ClCompile Include="src\device.cpp" /> <ClCompile Include="src\driverlogic.cpp" /> <ClCompile Include="src\error.cpp" /> + <ClCompile Include="src\log.cpp" /> + <ClCompile Include="src\service.cpp" /> <ClCompile Include="src\stdafx.cpp" /> + <ClCompile Include="src\util.cpp" /> + <ClCompile Include="src\version.cpp" /> </ItemGroup> <ItemGroup> + <ClInclude Include="src\devenum.h" /> + <ClInclude Include="src\device.h" /> <ClInclude Include="src\error.h" /> + <ClInclude Include="src\log.h" /> + <ClInclude Include="src\service.h" /> <ClInclude Include="src\stdafx.h" /> <ClInclude Include="src\targetver.h" /> + <ClInclude Include="src\util.h" /> + <ClInclude Include="src\version.h" /> + <ClInclude Include="src\wintun.h" /> </ItemGroup> <Import Project="$(VCTargetsPath)\Microsoft.Cpp.targets" /> <ImportGroup Label="ExtensionTargets"> diff --git a/windows/driverlogic/driverlogic.vcxproj.filters b/windows/driverlogic/driverlogic.vcxproj.filters index b9c494e241..91ed2267da 100644 --- a/windows/driverlogic/driverlogic.vcxproj.filters +++ b/windows/driverlogic/driverlogic.vcxproj.filters @@ -10,10 +10,23 @@ <ClCompile Include="src\driverlogic.cpp" /> <ClCompile Include="src\stdafx.cpp" /> <ClCompile Include="src\error.cpp" /> + <ClCompile Include="src\device.cpp" /> + <ClCompile Include="src\service.cpp" /> + <ClCompile Include="src\log.cpp" /> + <ClCompile Include="src\version.cpp" /> + <ClCompile Include="src\util.cpp" /> + <ClCompile Include="src\devenum.cpp" /> </ItemGroup> <ItemGroup> <ClInclude Include="src\stdafx.h" /> <ClInclude Include="src\targetver.h" /> <ClInclude Include="src\error.h" /> + <ClInclude Include="src\device.h" /> + <ClInclude Include="src\service.h" /> + <ClInclude Include="src\log.h" /> + <ClInclude Include="src\version.h" /> + <ClInclude Include="src\util.h" /> + <ClInclude Include="src\wintun.h" /> + <ClInclude Include="src\devenum.h" /> </ItemGroup> </Project>
\ No newline at end of file diff --git a/windows/driverlogic/src/devenum.cpp b/windows/driverlogic/src/devenum.cpp new file mode 100644 index 0000000000..229d7ee493 --- /dev/null +++ b/windows/driverlogic/src/devenum.cpp @@ -0,0 +1,75 @@ +#include "stdafx.h" +#include "devenum.h" +#include "error.h" + +DeviceEnumerator::DeviceEnumerator(const GUID &deviceClass) +{ + m_deviceInfoSet = SetupDiGetClassDevsW + ( + &deviceClass, + nullptr, + nullptr, + DIGCF_PRESENT + ); + + if (INVALID_HANDLE_VALUE == m_deviceInfoSet) + { + THROW_SETUPAPI_ERROR(GetLastError(), "SetupDiGetClassDevsW"); + } + + m_nextDeviceIndex = 0; + m_exhausted = false; +} + +//static +std::unique_ptr<DeviceEnumerator> DeviceEnumerator::Create(const GUID& deviceClass, Filter filter) +{ + auto enumerator = std::make_unique<DeviceEnumerator>(deviceClass); + + enumerator->setFilter(filter); + + return enumerator; +} + +DeviceEnumerator::~DeviceEnumerator() +{ + SetupDiDestroyDeviceInfoList(m_deviceInfoSet); +} + +bool DeviceEnumerator::next(EnumeratedDevice &device) +{ + if (m_exhausted) + { + return false; + } + + SP_DEVINFO_DATA deviceInfo { 0 }; + deviceInfo.cbSize = sizeof(deviceInfo); + + for (;;) + { + if (FALSE == SetupDiEnumDeviceInfo(m_deviceInfoSet, m_nextDeviceIndex, &deviceInfo)) + { + if (GetLastError() != ERROR_NO_MORE_ITEMS) + { + THROW_SETUPAPI_ERROR(GetLastError(), "SetupDiEnumDeviceInfo"); + } + + m_exhausted = true; + + return false; + } + + ++m_nextDeviceIndex; + + if (!m_filter || m_filter(m_deviceInfoSet, deviceInfo)) + { + break; + } + } + + device.deviceInfoSet = m_deviceInfoSet; + device.deviceInfo = deviceInfo; + + return true; +} diff --git a/windows/driverlogic/src/devenum.h b/windows/driverlogic/src/devenum.h new file mode 100644 index 0000000000..185bdb27f8 --- /dev/null +++ b/windows/driverlogic/src/devenum.h @@ -0,0 +1,42 @@ +#pragma once + +#include <windows.h> +#include <newdev.h> +#include <functional> +#include <memory> +#include "device.h" + +class DeviceEnumerator +{ +public: + + using Filter = std::function<bool(HDEVINFO, const SP_DEVINFO_DATA&)>; + + DeviceEnumerator(const GUID &deviceClass); + + static std::unique_ptr<DeviceEnumerator> Create(const GUID &deviceClass, Filter filter); + + ~DeviceEnumerator(); + + DeviceEnumerator(const DeviceEnumerator &) = delete; + DeviceEnumerator(DeviceEnumerator &&) = delete; + DeviceEnumerator &operator=(const DeviceEnumerator &) = delete; + DeviceEnumerator &operator=(DeviceEnumerator &&) = delete; + + void setFilter(Filter filter) + { + m_filter = filter; + } + + bool next(EnumeratedDevice &device); + +private: + + HDEVINFO m_deviceInfoSet; + + int m_nextDeviceIndex; + + bool m_exhausted; + + Filter m_filter; +}; diff --git a/windows/driverlogic/src/device.cpp b/windows/driverlogic/src/device.cpp new file mode 100644 index 0000000000..e38709f01d --- /dev/null +++ b/windows/driverlogic/src/device.cpp @@ -0,0 +1,470 @@ +#include "stdafx.h" +#include <winioctl.h> +#include <newdev.h> +#include <initguid.h> +#include <devpkey.h> +#include <devguid.h> +#include <libcommon/error.h> +#include <libcommon/memory.h> +#include <libcommon/registry/registry.h> +#include "log.h" +#include "device.h" +#include "error.h" +#include "devenum.h" +#include <vector> +#include <sstream> +#include <functional> + +namespace +{ + +// +// Identifiers defined by split tunneling driver. +// + +constexpr wchar_t DeviceSymbolicName[] = L"\\\\.\\MULLVADSPLITTUNNEL"; + +#define ST_DEVICE_TYPE 0x8000 + +#define IOCTL_ST_GET_STATE \ + CTL_CODE(ST_DEVICE_TYPE, 9, METHOD_BUFFERED, FILE_ANY_ACCESS) + +#define IOCTL_ST_RESET \ + CTL_CODE(ST_DEVICE_TYPE, 11, METHOD_NEITHER, FILE_ANY_ACCESS) + +constexpr SIZE_T ST_DRIVER_STATE_STARTED = 1; + +// +// Onwards. +// + +void +ThrowUpdateException +( + DWORD lastError, + const char *operation +) +{ + if (ERROR_DEVICE_INSTALLER_NOT_READY == lastError) + { + bool deviceInstallDisabled = false; + + try + { + const auto key = common::registry::Registry::OpenKey + ( + HKEY_LOCAL_MACHINE, + L"SYSTEM\\CurrentControlSet\\Services\\DeviceInstall\\Parameters" + ); + + deviceInstallDisabled = (0 != key->readUint32(L"DeviceInstallDisabled")); + } + catch (...) + { + } + + if (deviceInstallDisabled) + { + throw common::error::WindowsException + ( + "Device installs must be enabled to continue. " + "Enable them in the Local Group Policy editor, or " + "update the registry value DeviceInstallDisabled in " + "[HKEY_LOCAL_MACHINE\\SYSTEM\\CurrentControlSet\\Services\\DeviceInstall\\Parameters]", + lastError + ); + } + } + + THROW_SETUPAPI_ERROR(lastError, operation); +} + +} // anonymous namespace + +std::wstring +GetDeviceStringProperty +( + HDEVINFO deviceInfoSet, + const SP_DEVINFO_DATA &deviceInfo, + const DEVPROPKEY *property +) +{ + // + // Obtain required buffer size + // + + DWORD requiredSize = 0; + DEVPROPTYPE type; + + const auto sizeStatus = SetupDiGetDevicePropertyW + ( + deviceInfoSet, + const_cast<PSP_DEVINFO_DATA>(&deviceInfo), + property, + &type, + nullptr, + 0, + &requiredSize, + 0 + ); + + if (FALSE == sizeStatus) + { + const auto lastError = GetLastError(); + + if (ERROR_INSUFFICIENT_BUFFER != lastError) + { + THROW_SETUPAPI_ERROR(lastError, "SetupDiGetDevicePropertyW"); + } + } + + std::vector<wchar_t> buffer(requiredSize / sizeof(wchar_t)); + + // + // Read property + // + + const auto status = SetupDiGetDevicePropertyW + ( + deviceInfoSet, + const_cast<PSP_DEVINFO_DATA>(&deviceInfo), + property, + &type, + reinterpret_cast<PBYTE>(&buffer[0]), + requiredSize, + nullptr, + 0 + ); + + if (FALSE == status) + { + THROW_SETUPAPI_ERROR(GetLastError(), "Failed to read device property"); + } + + return buffer.data(); +} + +std::wstring +GetDeviceNetCfgInstanceId +( + HDEVINFO deviceInfoSet, + const SP_DEVINFO_DATA &deviceInfo +) +{ + auto registryKey = SetupDiOpenDevRegKey + ( + deviceInfoSet, + const_cast<SP_DEVINFO_DATA *>(&deviceInfo), + DICS_FLAG_GLOBAL, + 0, + DIREG_DRV, + KEY_READ + ); + + if (registryKey == INVALID_HANDLE_VALUE) + { + THROW_SETUPAPI_ERROR(GetLastError(), "SetupDiOpenDevRegKey"); + } + + std::vector<wchar_t> buffer(128, L'\0'); + + DWORD bufferByteLength = static_cast<DWORD>(buffer.size()) * sizeof(wchar_t); + + const auto status = RegGetValueW + ( + registryKey, + nullptr, + L"NetCfgInstanceId", + RRF_RT_REG_SZ, + nullptr, + &buffer[0], + &bufferByteLength + ); + + RegCloseKey(registryKey); + + if (ERROR_SUCCESS != status) + { + THROW_WINDOWS_ERROR(status, "RegGetValueW"); + } + + // + // RegGetValueW ensures the string is null-terminated. + // + + return std::wstring(&buffer[0]); +} + +void +CreateDevice +( + const GUID &classGuid, + const std::wstring &deviceName, + const std::wstring &deviceHardwareId +) +{ + Log(L"Attempting to create device"); + + const auto deviceInfoSet = SetupDiCreateDeviceInfoList(&classGuid, 0); + + if (INVALID_HANDLE_VALUE == deviceInfoSet) + { + THROW_SETUPAPI_ERROR(GetLastError(), "SetupDiCreateDeviceInfoList"); + } + + common::memory::ScopeDestructor scopeDestructor; + + scopeDestructor += [&deviceInfoSet]() + { + SetupDiDestroyDeviceInfoList(deviceInfoSet); + }; + + SP_DEVINFO_DATA devInfoData {0}; + devInfoData.cbSize = sizeof(SP_DEVINFO_DATA); + + auto status = SetupDiCreateDeviceInfoW + ( + deviceInfoSet, + deviceName.c_str(), + &classGuid, + nullptr, + 0, + DICD_GENERATE_ID, + &devInfoData + ); + + if (FALSE == status) + { + THROW_SETUPAPI_ERROR(GetLastError(), "SetupDiCreateDeviceInfoW"); + } + + status = SetupDiSetDeviceRegistryPropertyW + ( + deviceInfoSet, + &devInfoData, + SPDRP_HARDWAREID, + reinterpret_cast<const BYTE *>(deviceHardwareId.c_str()), + static_cast<DWORD>(deviceHardwareId.size() * sizeof(wchar_t)) + ); + + if (FALSE == status) + { + THROW_SETUPAPI_ERROR(GetLastError(), "SetupDiSetDeviceRegistryPropertyW"); + } + + // + // Create a devnode in the PnP HW tree + // + status = SetupDiCallClassInstaller + ( + DIF_REGISTERDEVICE, + deviceInfoSet, + &devInfoData + ); + + if (FALSE == status) + { + THROW_SETUPAPI_ERROR(GetLastError(), "SetupDiCallClassInstaller"); + } + + Log(L"Created new device successfully"); +} + +void +InstallDriverForDevice +( + const std::wstring &deviceHardwareId, + const std::wstring &infPath +) +{ + Log(L"Attempting to install new driver"); + + DWORD installFlags = 0; + BOOL rebootRequired = FALSE; + + for (;;) + { + auto result = UpdateDriverForPlugAndPlayDevicesW + ( + nullptr, + deviceHardwareId.c_str(), + infPath.c_str(), + installFlags, + &rebootRequired + ); + + if (FALSE != result) + { + break; + } + + const auto lastError = GetLastError(); + + if (ERROR_NO_MORE_ITEMS == lastError + && (0 == (installFlags & INSTALLFLAG_FORCE))) + { + Log(L"Driver installation/update failed. Attempting forced install."); + installFlags |= INSTALLFLAG_FORCE; + + continue; + } + + ThrowUpdateException(lastError, "UpdateDriverForPlugAndPlayDevicesW"); + } + + // + // Driver successfully installed or updated + // + + std::wstringstream ss; + + ss << L"Device driver update complete. Reboot required: " + << rebootRequired; + + Log(ss.str()); +} + +void +UninstallDevice +( + const EnumeratedDevice &device +) +{ + Log(L"Uninstalling device"); + + BOOL needReboot; + + auto status = DiUninstallDevice + ( + nullptr, + device.deviceInfoSet, + const_cast<PSP_DEVINFO_DATA>(&device.deviceInfo), + 0, + &needReboot + ); + + if (FALSE == status) + { + THROW_WINDOWS_ERROR(GetLastError(), "DiUninstallDevice"); + } + + std::wstringstream ss; + + ss << L"Successfully uninstalled device. Reboot required: " + << needReboot; + + Log(ss.str()); +} + +HANDLE +OpenSplitTunnelDevice +( +) +{ + auto handle = CreateFileW(DeviceSymbolicName, GENERIC_READ | GENERIC_WRITE, + 0, nullptr, OPEN_EXISTING, FILE_FLAG_OVERLAPPED, nullptr); + + if (handle == INVALID_HANDLE_VALUE) + { + THROW_WINDOWS_ERROR(GetLastError(), "Open split tunnel device"); + } + + return handle; +} + +void +CloseSplitTunnelDevice +( + HANDLE device +) +{ + CloseHandle(device); +} + +void +SendIoControl +( + HANDLE device, + DWORD code, + void *inBuffer, + DWORD inBufferSize, + void *outBuffer, + DWORD outBufferSize, + DWORD *bytesReturned +) +{ + OVERLAPPED o = { 0 }; + + // + // Event should not be created on-the-fly. + // + // Create an event for each thread that needs to send a request + // and keep the event around. + // + o.hEvent = CreateEventW(nullptr, TRUE, FALSE, nullptr); + + auto status = DeviceIoControl(device, code, + inBuffer, inBufferSize, outBuffer, outBufferSize, bytesReturned, &o); + + if (FALSE != status) + { + CloseHandle(o.hEvent); + + return; + } + + if (ERROR_IO_PENDING != GetLastError()) + { + const auto err = GetLastError(); + + CloseHandle(o.hEvent); + + THROW_WINDOWS_ERROR(err, "DeviceIoControl"); + } + + DWORD tempBytesReturned = 0; + + status = GetOverlappedResult(device, &o, &tempBytesReturned, TRUE); + + CloseHandle(o.hEvent); + + if (FALSE == status) + { + THROW_WINDOWS_ERROR(GetLastError(), "GetOverlappedResult"); + } + + *bytesReturned = tempBytesReturned; +} + +void +SendIoControlReset +( + HANDLE device +) +{ + DWORD dummy; + + SendIoControl(device, (DWORD)IOCTL_ST_RESET, nullptr, 0, nullptr, 0, &dummy); + + DWORD bytesReturned; + + SIZE_T currentState; + + SendIoControl(device, (DWORD)IOCTL_ST_GET_STATE, nullptr, 0, ¤tState, sizeof(currentState), &bytesReturned); + + if (bytesReturned != sizeof(currentState)) + { + throw std::runtime_error("Failed to send reset request to driver"); + } + + // + // If successful, state is ST_DRIVER_STATE_STARTED + // + // Otherwise, state is probably ST_DRIVER_STATE_ZOMBIE + // + + if (currentState != ST_DRIVER_STATE_STARTED) + { + throw std::runtime_error("Failed to reset driver state"); + } +} diff --git a/windows/driverlogic/src/device.h b/windows/driverlogic/src/device.h new file mode 100644 index 0000000000..b8bc0e7f41 --- /dev/null +++ b/windows/driverlogic/src/device.h @@ -0,0 +1,85 @@ +#pragma once + +#include <windows.h> +#include <string> +#include <optional> +#include <setupapi.h> + +struct EnumeratedDevice +{ + HDEVINFO deviceInfoSet; + SP_DEVINFO_DATA deviceInfo; +}; + +// +// Generic functions +// + +std::wstring +GetDeviceStringProperty +( + HDEVINFO deviceInfoSet, + const SP_DEVINFO_DATA &deviceInfo, + const DEVPROPKEY *property +); + +std::wstring +GetDeviceNetCfgInstanceId +( + HDEVINFO deviceInfoSet, + const SP_DEVINFO_DATA &deviceInfo +); + +void +CreateDevice +( + const GUID &classGuid, + const std::wstring &deviceName, + const std::wstring &deviceHardwareId +); + +void +InstallDriverForDevice +( + const std::wstring &deviceHardwareId, + const std::wstring &infPath +); + +void +UninstallDevice +( + const EnumeratedDevice &device +); + +// +// Functions that are specific to our driver/implementation +// + +HANDLE +OpenSplitTunnelDevice +( +); + +void +CloseSplitTunnelDevice +( + HANDLE device +); + +void +SendIoControl +( + HANDLE device, + DWORD code, + void *inBuffer, + DWORD inBufferSize, + void *outBuffer, + DWORD outBufferSize, + DWORD *bytesReturned +); + +void +SendIoControlReset +( + HANDLE device +); diff --git a/windows/driverlogic/src/driverlogic.cpp b/windows/driverlogic/src/driverlogic.cpp index 25a3c130b0..93af20b400 100644 --- a/windows/driverlogic/src/driverlogic.cpp +++ b/windows/driverlogic/src/driverlogic.cpp @@ -1,1103 +1,369 @@ #include "stdafx.h" #include "error.h" -#include <iostream> -#include <chrono> -#include <sstream> +#include "device.h" +#include "service.h" +#include "log.h" +#include "version.h" +#include "wintun.h" +#include "devenum.h" #include <string> -#include <optional> -#include <set> -#include <filesystem> #include <libcommon/error.h> -#include <libcommon/guid.h> #include <libcommon/memory.h> -#include <libcommon/network/nci.h> -#include <libcommon/registry/registry.h> #include <libcommon/string.h> -#include <setupapi.h> #include <initguid.h> -#include <devguid.h> #include <devpkey.h> -#include <newdev.h> -#include <cfgmgr32.h> +#include <devguid.h> #include <io.h> #include <fcntl.h> -#include <wintun/wintun.h> - namespace { -constexpr std::chrono::milliseconds REGISTRY_GET_TIMEOUT_MS{ 10000 }; +constexpr wchar_t SPLIT_TUNNEL_HARDWARE_ID[] = L"Root\\mullvad-split-tunnel"; -enum ReturnCodes +DEFINE_GUID(WFP_CALLOUTS_CLASS_ID, + 0x57465043, 0x616C, 0x6C6F, 0x75, 0x74, 0x5F, 0x63, 0x6C, 0x61, 0x73, 0x73); + +constexpr wchar_t SPLIT_TUNNEL_DEVICE_NAME[] = L"Mullvad Split Tunnel Device"; + +enum ReturnCode { GENERAL_SUCCESS = 0, - GENERAL_ERROR = -1, - ADAPTER_NOT_FOUND = -2 + GENERAL_ERROR = 1, + ST_DRIVER_NONE_INSTALLED = 2, + ST_DRIVER_SAME_VERSION_INSTALLED = 3, + ST_DRIVER_OLDER_VERSION_INSTALLED = 4, + ST_DRIVER_NEWER_VERSION_INSTALLED = 5 }; -struct NetworkAdapter +class ArgumentContext { - std::wstring guid; - std::wstring name; - std::wstring alias; - std::wstring deviceInstanceId; +public: - NetworkAdapter(std::wstring guid, std::wstring name, std::wstring alias, std::wstring deviceInstanceId) - : guid(guid) - , name(name) - , alias(alias) - , deviceInstanceId(deviceInstanceId) + ArgumentContext(const std::vector<std::wstring> &args) + : m_args(args) + , m_remaining(m_args.size()) { } - bool operator<(const NetworkAdapter &rhs) const + size_t total() const { - return _wcsicmp(deviceInstanceId.c_str(), rhs.deviceInstanceId.c_str()) < 0; + return m_args.size(); } -}; - -void LogAdapters(const std::wstring &description, const std::set<NetworkAdapter> &adapters) -{ - std::wcout << description << std::endl; - for (const auto &adapter : adapters) + void ensureExactArgumentCount(size_t count) const { - std::wcout << L" Adapter\n" - << L" Guid: " << adapter.guid << L'\n' - << L" Name: " << adapter.name << L'\n' - << L" Alias: " << adapter.alias << L'\n' - << L" Device instance ID: " << adapter.deviceInstanceId - << std::endl; - } -} - -void Log(const std::wstring &str) -{ - std::wcout << str << std::endl; -} - -void LogError(const std::wstring &str) -{ - std::wcerr << str << std::endl; -} - -std::optional<std::wstring> GetDeviceRegistryStringProperty( - HDEVINFO devInfo, - const SP_DEVINFO_DATA &devInfoData, - DWORD property -) -{ - // - // Obtain required buffer size - // - - DWORD requiredSize = 0; - - const auto sizeStatus = SetupDiGetDeviceRegistryPropertyW( - devInfo, - const_cast<SP_DEVINFO_DATA*>(&devInfoData), - property, - nullptr, - nullptr, - 0, - &requiredSize - ); - - const DWORD lastError = GetLastError(); - if (FALSE == sizeStatus && ERROR_INSUFFICIENT_BUFFER != lastError) - { - // ERROR_INVALID_DATA may mean that the property does not exist - // TODO: Check if there may be other causes. - if (ERROR_INVALID_DATA != lastError) + if (m_args.size() != count) { - THROW_SETUPAPI_ERROR(lastError, "SetupDiGetDeviceRegistryPropertyW"); + throw std::runtime_error("Invalid number of arguments"); } - - return std::nullopt; - } - - // - // Read property - // - - std::vector<wchar_t> buffer(requiredSize / sizeof(wchar_t)); - - const auto status = SetupDiGetDeviceRegistryPropertyW( - devInfo, - const_cast<SP_DEVINFO_DATA*>(&devInfoData), - property, - nullptr, - reinterpret_cast<PBYTE>(&buffer[0]), - requiredSize, - nullptr - ); - - if (FALSE == status) - { - THROW_SETUPAPI_ERROR(GetLastError(), "Failed to read device property"); } - return std::make_optional(buffer.data()); -} - -std::wstring GetDeviceStringProperty( - HDEVINFO devInfo, - const SP_DEVINFO_DATA &devInfoData, - const DEVPROPKEY *property -) -{ - // - // Obtain required buffer size - // - - DWORD requiredSize = 0; - DEVPROPTYPE type; - - const auto sizeStatus = SetupDiGetDevicePropertyW( - devInfo, - const_cast<SP_DEVINFO_DATA*>(&devInfoData), - property, - &type, - nullptr, - 0, - &requiredSize, - 0 - ); - - if (FALSE == sizeStatus) + const std::wstring &next() { - const auto lastError = GetLastError(); - - if (ERROR_INSUFFICIENT_BUFFER != lastError) + if (0 == m_remaining) { - THROW_SETUPAPI_ERROR(lastError, "SetupDiGetDevicePropertyW"); + throw std::runtime_error("Argument missing"); } - } - - std::vector<wchar_t> buffer(requiredSize / sizeof(wchar_t)); - // - // Read property - // + const auto &str = m_args.at(m_args.size() - m_remaining); - const auto status = SetupDiGetDevicePropertyW( - devInfo, - const_cast<SP_DEVINFO_DATA*>(&devInfoData), - property, - &type, - reinterpret_cast<PBYTE>(&buffer[0]), - requiredSize, - nullptr, - 0 - ); + --m_remaining; - if (FALSE == status) - { - THROW_SETUPAPI_ERROR(GetLastError(), "Failed to read device property"); + return str; } - return buffer.data(); -} - -std::wstring GetDeviceInstanceId( - HDEVINFO devInfo, - const SP_DEVINFO_DATA &devInfoData -) -{ - DWORD requiredSize = 0; - - SetupDiGetDeviceInstanceIdW( - devInfo, - const_cast<SP_DEVINFO_DATA*>(&devInfoData), - nullptr, - 0, - &requiredSize - ); - - std::vector<wchar_t> deviceInstanceId(1 + requiredSize); - - const auto status = SetupDiGetDeviceInstanceIdW( - devInfo, - const_cast<SP_DEVINFO_DATA *>(&devInfoData), - &deviceInstanceId[0], - requiredSize, - nullptr - ); - - if (FALSE == status) - { - THROW_SETUPAPI_ERROR(GetLastError(), "SetupDiGetDeviceInstanceIdW"); - } +private: - return deviceInstanceId.data(); -} + const std::vector<std::wstring> &m_args; + size_t m_remaining; +}; -bool TryGetRegistryValueTimeout( - HKEY key, - const wchar_t *subkey, - const wchar_t *value, - DWORD flags, - DWORD *type, - void *data, - DWORD *dataSize -) +void ResetDriverState() { - HANDLE changeEvent = nullptr; + auto deviceHandle = OpenSplitTunnelDevice(); - common::memory::ScopeDestructor scopeDestructor; - scopeDestructor += [changeEvent]() { - if (nullptr != changeEvent) - { - CloseHandle(changeEvent); - } - }; + common::memory::ScopeDestructor dtor; - auto initialTime = std::chrono::steady_clock::now(); - - for (;;) + dtor += [deviceHandle]() { - const auto status = RegGetValueW(key, subkey, value, flags, type, data, dataSize); - - if (ERROR_SUCCESS == status) - { - // We're done - return true; - } - - if (ERROR_FILE_NOT_FOUND != status) - { - THROW_WINDOWS_ERROR(status, "RegGetValueW"); - } - - if (nullptr == changeEvent) - { - changeEvent = CreateEventW(nullptr, FALSE, FALSE, nullptr); - - if (nullptr == changeEvent) - { - THROW_WINDOWS_ERROR(GetLastError(), "CreateEventW"); - } - } - - // - // Wait for the registry value to be created - // - - auto currentTime = std::chrono::steady_clock::now(); - auto elapsedTime = std::chrono::duration_cast<std::chrono::milliseconds>(currentTime - initialTime); - auto timeDelta = (REGISTRY_GET_TIMEOUT_MS - elapsedTime).count(); - - if (timeDelta <= 0) - { - return false; - } - - const auto notifyResult = RegNotifyChangeKeyValue( - key, - subkey != nullptr, // Watch subkeys - REG_NOTIFY_CHANGE_LAST_SET, - changeEvent, - TRUE - ); - - if (ERROR_SUCCESS != notifyResult) - { - THROW_WINDOWS_ERROR(notifyResult, "RegNotifyChangeKeyValue"); - } - - const auto waitResult = WaitForSingleObject(changeEvent, static_cast<DWORD>(timeDelta)); - if (WAIT_OBJECT_0 == waitResult) - { - // Try again - continue; - } - - if (WAIT_TIMEOUT != waitResult) - { - THROW_WINDOWS_ERROR(GetLastError(), "WaitForSingleObject"); - } - - return false; - } -} - -std::wstring GetNetCfgInstanceId(HDEVINFO devInfo, const SP_DEVINFO_DATA &devInfoData) -{ - HKEY hNet = SetupDiOpenDevRegKey( - devInfo, - const_cast<SP_DEVINFO_DATA *>(&devInfoData), - DICS_FLAG_GLOBAL, - 0, - DIREG_DRV, - KEY_READ - ); - - if (hNet == INVALID_HANDLE_VALUE) - { - THROW_SETUPAPI_ERROR(GetLastError(), "SetupDiOpenDevRegKey"); - } - - common::memory::ScopeDestructor scopeDestructor; - scopeDestructor += [hNet]() { - RegCloseKey(hNet); + CloseSplitTunnelDevice(deviceHandle); }; - std::vector<wchar_t> instanceId(MAX_PATH + 1); - DWORD strSize = static_cast<DWORD>(instanceId.size() * sizeof(wchar_t)); - - if (!TryGetRegistryValueTimeout( - hNet, - nullptr, - L"NetCfgInstanceId", - RRF_RT_REG_SZ, - nullptr, - instanceId.data(), - &strSize - )) - { - THROW_ERROR("Timed out waiting for NetCfgInstanceId."); - } - - return instanceId.data(); + SendIoControlReset(deviceHandle); } -bool DeleteDevice(HDEVINFO devInfo, const SP_DEVINFO_DATA &devInfoData) +std::unique_ptr<DeviceEnumerator> CreateSplitTunnelDeviceEnumerator() { - const auto data = const_cast<SP_DEVINFO_DATA *>(&devInfoData); - - wchar_t devId[MAX_DEVICE_ID_LEN]; - if (CR_SUCCESS != CM_Get_Device_IDW(data->DevInst, devId, sizeof(devId) / sizeof(devId[0]), 0)) + return DeviceEnumerator::Create(WFP_CALLOUTS_CLASS_ID, [](HDEVINFO deviceInfoSet, const SP_DEVINFO_DATA &deviceInfo) { - // skip - return false; - } + auto candidateDeviceName = GetDeviceStringProperty(deviceInfoSet, deviceInfo, &DEVPKEY_NAME); - SP_REMOVEDEVICE_PARAMS rmdParams = { 0 }; - rmdParams.ClassInstallHeader.cbSize = sizeof(SP_CLASSINSTALL_HEADER); - rmdParams.ClassInstallHeader.InstallFunction = DIF_REMOVE; - rmdParams.Scope = DI_REMOVEDEVICE_GLOBAL; - rmdParams.HwProfile = 0; - - auto status = SetupDiSetClassInstallParamsW(devInfo, data, &rmdParams.ClassInstallHeader, sizeof(rmdParams)); - if (FALSE == status) - { - THROW_SETUPAPI_ERROR(GetLastError(), "SetupDiSetClassInstallParamsW"); - } - - status = SetupDiCallClassInstaller(DIF_REMOVE, devInfo, data); - if (FALSE == status) - { - THROW_SETUPAPI_ERROR(GetLastError(), "SetupDiCallClassInstaller"); - } - - return true; -} - -void ForEachNetworkDevice(const std::optional<std::wstring> hwId, std::function<bool(HDEVINFO, const SP_DEVINFO_DATA &)> func) -{ - HDEVINFO devInfo = SetupDiGetClassDevsW( - &GUID_DEVCLASS_NET, - nullptr, - nullptr, - DIGCF_PRESENT - ); - - if (INVALID_HANDLE_VALUE == devInfo) - { - THROW_SETUPAPI_ERROR(GetLastError(), "SetupDiGetClassDevsW"); - } - - common::memory::ScopeDestructor cleanupDevList; - cleanupDevList += [&devInfo]() - { - SetupDiDestroyDeviceInfoList(devInfo); - }; - - for (int memberIndex = 0; ; memberIndex++) - { - SP_DEVINFO_DATA devInfoData = { 0 }; - devInfoData.cbSize = sizeof(devInfoData); - - if (FALSE == SetupDiEnumDeviceInfo(devInfo, memberIndex, &devInfoData)) - { - const auto lastError = GetLastError(); - - if (ERROR_NO_MORE_ITEMS == lastError) - { - break; - } - - THROW_SETUPAPI_ERROR(lastError, "Enumerating network adapters"); - } - - if (hwId.has_value()) - { - try - { - const auto hardwareId = GetDeviceRegistryStringProperty(devInfo, devInfoData, SPDRP_HARDWAREID); - - if (!hardwareId.has_value() || - 0 != hwId->compare(hardwareId.value())) - { - continue; - } - } - catch (const std::exception & e) - { - // - // Skip this adapter - // - - std::wstringstream ss; - ss << L"Skipping virtual adapter due to exception caught while iterating: " - << common::string::ToWide(e.what()); - LogError(ss.str()); - continue; - } - } - - if (!func(devInfo, devInfoData)) - { - break; - } - } -} - -std::set<NetworkAdapter> GetNetworkAdapters(const std::optional<std::wstring> hardwareId) -{ - std::set<NetworkAdapter> adapters; - common::network::Nci nci; - - ForEachNetworkDevice(hardwareId, [&](HDEVINFO devInfo, const SP_DEVINFO_DATA &devInfoData) { - try - { - // - // Construct NetworkAdapter - // - - const std::wstring guid = GetNetCfgInstanceId(devInfo, devInfoData); - GUID guidObj = common::Guid::FromString(guid); - - adapters.emplace(NetworkAdapter( - guid, - GetDeviceStringProperty(devInfo, devInfoData, &DEVPKEY_Device_DriverDesc), - nci.getConnectionName(guidObj), - GetDeviceInstanceId(devInfo, devInfoData) - )); - } - catch (const std::exception & e) - { - // - // Skip this adapter - // - - std::wstringstream ss; - ss << L"Skipping adapter due to exception caught while iterating: " - << common::string::ToWide(e.what()); - LogError(ss.str()); - } - return true; + return 0 == candidateDeviceName.compare(SPLIT_TUNNEL_DEVICE_NAME); }); - - return adapters; -} - -void throwUpdateException(DWORD lastError, const char *operation) -{ - if (ERROR_DEVICE_INSTALLER_NOT_READY == lastError) - { - bool deviceInstallDisabled = false; - - try - { - const auto key = common::registry::Registry::OpenKey( - HKEY_LOCAL_MACHINE, - L"SYSTEM\\CurrentControlSet\\Services\\DeviceInstall\\Parameters" - ); - deviceInstallDisabled = (0 != key->readUint32(L"DeviceInstallDisabled")); - } - catch (...) - { - } - - if (deviceInstallDisabled) - { - throw common::error::WindowsException( - "Device installs must be enabled to continue. " - "Enable them in the Local Group Policy editor, or " - "update the registry value DeviceInstallDisabled in " - "[HKEY_LOCAL_MACHINE\\SYSTEM\\CurrentControlSet\\Services\\DeviceInstall\\Parameters]", - lastError - ); - } - } - - THROW_SETUPAPI_ERROR(lastError, operation); } // -// Broken adapters may use our "Mullvad" name, so find one that is not in use. -// NOTE: Enumerating adapters first and picking the next free name is not sufficient, -// because the broken adapter may not be included. +// CommandSplitTunnelEvaluate() // -void RenameAdapter(const std::wstring &guid, const std::wstring &baseName) -{ - common::network::Nci nci; - - try - { - nci.setConnectionName(common::Guid::FromString(guid), baseName.c_str()); - return; - } - catch (...) - { - } - - for (int i = 1; i < 10; i++) - { - std::wstringstream ss; - ss << baseName << L"-" << i; - - try - { - nci.setConnectionName(common::Guid::FromString(guid), ss.str().c_str()); - return; - } - catch (...) - { - } - } - - THROW_ERROR("Unable to rename network adapter"); -} - -void CreateNetDevice(const std::wstring &hardwareId, const std::optional<std::wstring> alias, bool installDeviceDriver) +// Search for existing device. +// Evaluate if provided inf can/should be installed. +// +ReturnCode CommandSplitTunnelEvaluate(const std::vector<std::wstring> &args) { - GUID classGuid = GUID_DEVCLASS_NET; - - const auto deviceInfoSet = SetupDiCreateDeviceInfoList(&classGuid, 0); - if (INVALID_HANDLE_VALUE == deviceInfoSet) - { - THROW_SETUPAPI_ERROR(GetLastError(), "SetupDiCreateDeviceInfoList"); - } + ArgumentContext argsContext(args); - common::memory::ScopeDestructor scopeDestructor; - scopeDestructor += [&deviceInfoSet]() - { - SetupDiDestroyDeviceInfoList(deviceInfoSet); - }; + argsContext.ensureExactArgumentCount(1); - SP_DEVINFO_DATA devInfoData; - devInfoData.cbSize = sizeof(SP_DEVINFO_DATA); + const auto infPath = argsContext.next(); - auto status = SetupDiCreateDeviceInfoW( - deviceInfoSet, - L"NET", - &classGuid, - nullptr, - 0, - DICD_GENERATE_ID, - &devInfoData - ); + // + // Find first matching device + // - if (FALSE == status) - { - THROW_SETUPAPI_ERROR(GetLastError(), "SetupDiCreateDeviceInfoW"); - } + auto enumerator = CreateSplitTunnelDeviceEnumerator(); - status = SetupDiSetDeviceRegistryPropertyW( - deviceInfoSet, - &devInfoData, - SPDRP_HARDWAREID, - reinterpret_cast<const BYTE *>(hardwareId.c_str()), - static_cast<DWORD>(sizeof(wchar_t) * hardwareId.size()) - ); + EnumeratedDevice device; - if (FALSE == status) + if (!enumerator->next(device)) { - THROW_SETUPAPI_ERROR(GetLastError(), "SetupDiSetDeviceRegistryPropertyW"); + return ReturnCode::ST_DRIVER_NONE_INSTALLED; } // - // Create a devnode in the PnP HW tree + // Retrieve driver versions // - status = SetupDiCallClassInstaller( - DIF_REGISTERDEVICE, - deviceInfoSet, - &devInfoData - ); - if (FALSE == status) - { - THROW_SETUPAPI_ERROR(GetLastError(), "SetupDiCallClassInstaller"); - } + auto existingVersion = GetDriverVersion(device); + auto proposedVersion = InfGetDriverVersion(infPath); - Log(L"Created new network adapter successfully"); + // + // Compare driver versions + // - if (installDeviceDriver) + switch (EvaluateDriverUpgrade(existingVersion, proposedVersion)) { - BOOL rebootRequired = FALSE; - - if (FALSE == DiInstallDevice( - nullptr, - deviceInfoSet, - &devInfoData, - nullptr, - 0, - &rebootRequired - )) - { - throwUpdateException(GetLastError(), "DiInstallDevice"); - } - - std::wstringstream ss; - ss << L"Installed driver on device. Reboot required: " - << rebootRequired; - Log(ss.str()); + case DRIVER_UPGRADE_STATUS::WOULD_UPGRADE: + return ReturnCode::ST_DRIVER_OLDER_VERSION_INSTALLED; + case DRIVER_UPGRADE_STATUS::WOULD_DOWNGRADE: + return ReturnCode::ST_DRIVER_NEWER_VERSION_INSTALLED; + case DRIVER_UPGRADE_STATUS::WOULD_INSTALL_SAME_VERSION: + return ReturnCode::ST_DRIVER_SAME_VERSION_INSTALLED; + default: + Log(L"Unexpected return value from EvaluateDriverUpgrade()"); } - if (alias.has_value()) - { - RenameAdapter( - GetNetCfgInstanceId(deviceInfoSet, devInfoData), - alias.value() - ); - } + return ReturnCode::GENERAL_ERROR; } -std::wstring FindFreeAdapterAlias(const std::set<NetworkAdapter> &adapters, const std::wstring &baseName) +ReturnCode CommandSplitTunnelNewInstall(const std::vector<std::wstring> &args) { - if (adapters.empty()) - { - return baseName; - } - - auto findByAlias = [](const std::set<NetworkAdapter> &adapters, const std::wstring &alias) - { - const auto it = std::find_if(adapters.begin(), adapters.end(), [&alias](const NetworkAdapter &candidate) - { - return 0 == _wcsicmp(candidate.alias.c_str(), alias.c_str()); - }); - - return it; - }; - - const auto foundAdapter = findByAlias(adapters, baseName); + ArgumentContext argsContext(args); - if (adapters.end() == foundAdapter) - { - return baseName; - } + argsContext.ensureExactArgumentCount(1); - for (auto i = 1; i < 100; ++i) - { - std::wstringstream ss; + const auto infPath = argsContext.next(); - ss << baseName << L"-" << i; + CreateDevice(WFP_CALLOUTS_CLASS_ID, SPLIT_TUNNEL_DEVICE_NAME, SPLIT_TUNNEL_HARDWARE_ID); - const auto alias = ss.str(); - const auto nextAdapter = findByAlias(adapters, alias); + InstallDriverForDevice(SPLIT_TUNNEL_HARDWARE_ID, infPath); - if (adapters.end() == nextAdapter) - { - return alias; - } - } - - THROW_ERROR("Cannot find an unused adapter alias") + return ReturnCode::GENERAL_SUCCESS; } -std::optional<NetworkAdapter> FindAdapterByAlias(const std::set<NetworkAdapter> &tapAdapters, const std::wstring &baseName) +// +// CommandSplitTunnelRemove() +// +// Reset driver +// Uninstall device +// Stop service +// Delete service +// +ReturnCode CommandSplitTunnelRemove(const std::vector<std::wstring> &args) { - if (tapAdapters.empty()) - { - return std::nullopt; - } - - // - // Look for TAP adapter with aliases starting with baseName. - // - - auto findByAlias = [](const std::set<NetworkAdapter> &adapters, const std::wstring &alias) - { - const auto it = std::find_if(adapters.begin(), adapters.end(), [&alias](const NetworkAdapter &candidate) - { - return 0 == _wcsicmp(candidate.alias.c_str(), alias.c_str()); - }); - - return it; - }; + ArgumentContext argsContext(args); - const auto firstMullvadAdapter = findByAlias(tapAdapters, baseName); - - if (tapAdapters.end() != firstMullvadAdapter) - { - return { *firstMullvadAdapter }; - } + argsContext.ensureExactArgumentCount(0); // - // Look for TAP adapter with alias "Mullvad-1", "Mullvad-2", etc. + // Find first matching device // - for (auto i = 1; i < 10; ++i) - { - std::wstringstream ss; - - ss << baseName << L"-" << i; + auto enumerator = CreateSplitTunnelDeviceEnumerator(); - const auto alias = ss.str(); + EnumeratedDevice device; - const auto mullvadAdapter = findByAlias(tapAdapters, alias); + if (!enumerator->next(device)) + { + Log(L"Could not find split tunnel device"); - if (tapAdapters.end() != mullvadAdapter) - { - return { *mullvadAdapter }; - } + return ReturnCode::GENERAL_SUCCESS; } - return std::nullopt; -} + ResetDriverState(); -bool RemoveNetDevice(const std::optional<std::wstring> tapHardwareId, const std::wstring &guid) -{ - bool deletedAdapter = false; + UninstallDevice(device); - ForEachNetworkDevice(tapHardwareId, [&](HDEVINFO devInfo, const SP_DEVINFO_DATA &devInfoData) { - try - { - if (0 == GetNetCfgInstanceId(devInfo, devInfoData).compare(guid)) - { - deletedAdapter = DeleteDevice(devInfo, devInfoData); - return false; - } - } - catch (const std::exception & e) - { - // - // Skip this adapter - // + PokeService(L"mullvad-split-tunnel", true, true); - std::wstringstream ss; - ss << L"Skipping virtual adapter due to exception caught while iterating: " - << common::string::ToWide(e.what()); - LogError(ss.str()); - } - return true; - }); - - return deletedAdapter; + return ReturnCode::GENERAL_SUCCESS; } -void RemoveNetAdapterByAlias(const std::wstring &hardwareId, const std::wstring &baseName) +// +// CommandSplitTunnelForceInstall() +// +// There's an existing device that needs to be stopped and removed. +// After this, create a new device and associate the specified inf. +// +ReturnCode CommandSplitTunnelForceInstall(const std::vector<std::wstring> &args) { - auto tapAdapters = GetNetworkAdapters(hardwareId); - std::optional<NetworkAdapter> adapter = FindAdapterByAlias(tapAdapters, baseName); + auto status = CommandSplitTunnelRemove({}); - if (!adapter.has_value()) + if (ReturnCode::GENERAL_SUCCESS != status) { - return; + return status; } - const auto guid = adapter.value().guid; - - // - // Enumerate over all network devices with the hardware ID, - // and delete any adapter whose GUID matches that of the "Mullvad" adapter. - // - - if (!RemoveNetDevice(std::make_optional(hardwareId), guid)) - { - THROW_ERROR("The virtual adapter could not be removed"); - } + return CommandSplitTunnelNewInstall(args); } -std::filesystem::path GetCurrentModulePath() +ReturnCode CommandWintunDeletePool(const std::vector<std::wstring> &args) { - std::vector<wchar_t> pathBuffer; + ArgumentContext argsContext(args); + + argsContext.ensureExactArgumentCount(1); - SetLastError(ERROR_SUCCESS); + const auto poolName = argsContext.next(); - size_t nextCapacity = MAX_PATH; - DWORD writtenChars = 0; + WintunDll wintun; - do + BOOL rebootRequired; + + if (FALSE == wintun.deletePoolDriver(poolName.c_str(), &rebootRequired)) { - pathBuffer.resize(nextCapacity); - writtenChars = GetModuleFileNameW(nullptr, &pathBuffer[0], static_cast<DWORD>(pathBuffer.size())); + throw std::runtime_error("Failed to delete wintun pool"); + } - if (0 == writtenChars) - { - THROW_WINDOWS_ERROR(GetLastError(), "GetModuleFileNameW"); - } + std::wstringstream ss; - nextCapacity = 2 * pathBuffer.size(); - } while (ERROR_INSUFFICIENT_BUFFER == GetLastError()); + ss << L"Successfully deleted wintun pool. Reboot required: " << rebootRequired; - pathBuffer.resize(writtenChars); + Log(ss.str()); - return std::filesystem::path(pathBuffer.begin(), pathBuffer.end()); + return ReturnCode::GENERAL_SUCCESS; } -class WintunDll +ReturnCode CommandWintunDeleteAbandonedDevice(const std::vector<std::wstring> &args) { -public: - - WintunDll() : dllHandle(nullptr) - { - auto wintunPath = GetCurrentModulePath().replace_filename(L"wintun.dll"); - dllHandle = LoadLibraryExW(wintunPath.c_str(), nullptr, LOAD_WITH_ALTERED_SEARCH_PATH); + ArgumentContext argsContext(args); - if (nullptr == dllHandle) - { - THROW_WINDOWS_ERROR(GetLastError(), "LoadLibraryExW"); - } + argsContext.ensureExactArgumentCount(0); - try - { - createAdapter = getProcAddressOrThrow<WINTUN_CREATE_ADAPTER_FUNC>("WintunCreateAdapter"); - openAdapter = getProcAddressOrThrow<WINTUN_OPEN_ADAPTER_FUNC>("WintunOpenAdapter"); - freeAdapter = getProcAddressOrThrow<WINTUN_FREE_ADAPTER_FUNC>("WintunFreeAdapter"); - deletePoolDriver = getProcAddressOrThrow<WINTUN_DELETE_POOL_DRIVER_FUNC>("WintunDeletePoolDriver"); - } - catch (...) - { - FreeLibrary(dllHandle); - throw; - } - } - - ~WintunDll() + auto enumerator = DeviceEnumerator::Create(GUID_DEVCLASS_NET, [](HDEVINFO deviceInfoSet, const SP_DEVINFO_DATA &deviceInfo) { - if (nullptr != dllHandle) - { - FreeLibrary(dllHandle); - } - } + static wchar_t WintunMullvadAdapter[] = L"{AFE43773-E1F8-4EBB-8536-576AB86AFE9A}"; - WINTUN_CREATE_ADAPTER_FUNC createAdapter; - WINTUN_OPEN_ADAPTER_FUNC openAdapter; - WINTUN_FREE_ADAPTER_FUNC freeAdapter; - WINTUN_DELETE_POOL_DRIVER_FUNC deletePoolDriver; + auto candidateAdapterGuid = GetDeviceNetCfgInstanceId(deviceInfoSet, deviceInfo); -private: + return 0 == _wcsicmp(candidateAdapterGuid.c_str(), WintunMullvadAdapter); + }); + + EnumeratedDevice device; - template<typename T> - T getProcAddressOrThrow(const char *procName) + if (enumerator->next(device)) { - const T result = reinterpret_cast<T>(GetProcAddress(dllHandle, procName)); - if (nullptr == result) - { - THROW_WINDOWS_ERROR(GetLastError(), "GetProcAddress"); - } - return result; + UninstallDevice(device); } - HMODULE dllHandle; -}; + return GENERAL_SUCCESS; +} -int HandleWintunCommands(int argc, const wchar_t *argv[]) -{ - WintunDll wintun; +} // anonymous namespace - if (argc < 3) +int wmain(int argc, const wchar_t *argv[]) +{ + if (-1 == _setmode(_fileno(stdout), _O_U16TEXT) + || -1 == _setmode(_fileno(stderr), _O_U16TEXT)) { - goto INVALID_ARGUMENTS; + Log(L"Failed to set translation mode"); } - if (0 == _wcsicmp(argv[2], L"create-adapter")) + if (argc < 2) { - if (argc < 5) - { - goto INVALID_ARGUMENTS; - } - - const wchar_t *pool = argv[3]; - const wchar_t *adapter = argv[4]; + Log(L"Command not specified"); - GUID guidObject; - const GUID *requestGuid = nullptr; - if (argc >= 6) - { - guidObject = common::Guid::FromString(argv[5]); - requestGuid = &guidObject; - } - - const auto adapters = GetNetworkAdapters(std::nullopt); - const auto freeAdapterName = FindFreeAdapterAlias(adapters, adapter); - - const auto handle = wintun.createAdapter( - pool, - freeAdapterName.c_str(), - requestGuid, - nullptr - ); - - if (nullptr == handle) - { - const auto status = GetLastError(); - if (ERROR_FILE_NOT_FOUND == status) - { - return ADAPTER_NOT_FOUND; - } - else - { - THROW_WINDOWS_ERROR(status, "wintun.createAdapter"); - } - } - wintun.freeAdapter(handle); + return ReturnCode::GENERAL_ERROR; } - else if (0 == _wcsicmp(argv[2], L"delete-pool-driver")) - { - if (4 != argc) - { - goto INVALID_ARGUMENTS; - } - - const wchar_t *pool = argv[3]; - wintun.deletePoolDriver(pool, nullptr); - } - else if (0 == _wcsicmp(argv[2], L"adapter-exists")) - { - if (5 != argc) - { - goto INVALID_ARGUMENTS; - } + // + // Re-package command arguments + // - const wchar_t *pool = argv[3]; - const wchar_t *adapter = argv[4]; + const std::wstring command = argv[1]; - const auto handle = wintun.openAdapter(pool, adapter); + std::vector<std::wstring> arguments; - if (nullptr == handle) - { - const auto status = GetLastError(); - if (ERROR_FILE_NOT_FOUND == status) - { - return ADAPTER_NOT_FOUND; - } - else - { - THROW_WINDOWS_ERROR(status, "wintun.openAdapter"); - } - } - wintun.freeAdapter(handle); - } - else + for (size_t argumentIndex = 2; argumentIndex < argc; ++argumentIndex) { - goto INVALID_ARGUMENTS; + arguments.emplace_back(argv[argumentIndex]); } - return GENERAL_SUCCESS; - -INVALID_ARGUMENTS: - - LogError(L"Invalid arguments."); - return GENERAL_ERROR; -} - -} // anonymous namespace + // + // Declare all handlers + // -int wmain(int argc, const wchar_t * argv[], const wchar_t * []) -{ - if (-1 == _setmode(_fileno(stdout), _O_U16TEXT) - || -1 == _setmode(_fileno(stderr), _O_U16TEXT)) + struct CommandHandler { - LogError(L"Failed to set translation mode"); - } + std::wstring commandName; + std::function<ReturnCode(const std::vector<std::wstring> &)> handler; + }; - if (2 > argc) + std::vector<CommandHandler> handlers = { - goto INVALID_ARGUMENTS; - } + { L"st-evaluate", CommandSplitTunnelEvaluate }, + { L"st-new-install", CommandSplitTunnelNewInstall }, + { L"st-force-install", CommandSplitTunnelForceInstall }, + { L"st-remove", CommandSplitTunnelRemove }, + { L"wintun-delete-pool-driver", CommandWintunDeletePool }, + { L"wintun-delete-abandoned-device", CommandWintunDeleteAbandonedDevice } + }; + + // + // Find and invoke matching handler + // - try + for (const auto &candidate : handlers) { - if (0 == _wcsicmp(argv[1], L"new-device")) + if (0 != _wcsicmp(command.c_str(), candidate.commandName.c_str())) { - if (4 != argc) - { - goto INVALID_ARGUMENTS; - } - - const wchar_t *hardwareId = argv[2]; - const wchar_t *baseName = argv[3]; - - CreateNetDevice(hardwareId, baseName, true); + continue; } - else if (0 == _wcsicmp(argv[1], L"remove-device")) - { - if (4 != argc) - { - goto INVALID_ARGUMENTS; - } - - const wchar_t *hardwareId = argv[2]; - const wchar_t *baseName = argv[3]; - RemoveNetAdapterByAlias(hardwareId, baseName); - } - else if (0 == _wcsicmp(argv[1], L"remove-device-by-guid")) + try { - if (3 != argc) - { - goto INVALID_ARGUMENTS; - } - - const wchar_t *guid = argv[2]; - - if (!RemoveNetDevice(std::nullopt, guid)) - { - return ADAPTER_NOT_FOUND; - } + return candidate.handler(arguments); } - else if (0 == _wcsicmp(argv[1], L"device-exists")) + catch (const common::error::WindowsException &e) { - if (4 != argc) - { - goto INVALID_ARGUMENTS; - } - - const wchar_t *hardwareId = argv[2]; - const wchar_t *baseName = argv[3]; - - const auto virtualAdapters = GetNetworkAdapters(hardwareId); - const auto adapter = FindAdapterByAlias(virtualAdapters, baseName); - - if (!adapter.has_value()) - { - return ADAPTER_NOT_FOUND; - } + Log(common::string::ToWide(e.what())); + return e.errorCode(); } - else if (0 == _wcsicmp(argv[1], L"wintun")) + catch (const std::exception &e) { - return HandleWintunCommands(argc, argv); + Log(common::string::ToWide(e.what())); + return GENERAL_ERROR; } - else + catch (...) { - goto INVALID_ARGUMENTS; + Log(L"Unknown exception was raised/thrown"); + return GENERAL_ERROR; } } - catch (const common::error::WindowsException &e) - { - LogError(common::string::ToWide(e.what())); - return e.errorCode(); - } - catch (const std::exception &e) - { - LogError(common::string::ToWide(e.what())); - return GENERAL_ERROR; - } - catch (...) - { - LogError(L"Unhandled exception."); - return GENERAL_ERROR; - } - return GENERAL_SUCCESS; -INVALID_ARGUMENTS: + // + // Could not find matching handler + // - LogError(L"Invalid arguments."); + Log(L"Could not find handler for specified command"); return GENERAL_ERROR; } diff --git a/windows/driverlogic/src/log.cpp b/windows/driverlogic/src/log.cpp new file mode 100644 index 0000000000..0840b02cbf --- /dev/null +++ b/windows/driverlogic/src/log.cpp @@ -0,0 +1,13 @@ +#include "stdafx.h" +#include "log.h" +#include <iostream> + +void Log(const wchar_t *str) +{ + std::wcout << str << std::endl; +} + +void Log(const std::wstring &str) +{ + Log(str.c_str()); +} diff --git a/windows/driverlogic/src/log.h b/windows/driverlogic/src/log.h new file mode 100644 index 0000000000..6b1e67a154 --- /dev/null +++ b/windows/driverlogic/src/log.h @@ -0,0 +1,7 @@ +#pragma once + +#include <string> + +void Log(const wchar_t *str); + +void Log(const std::wstring &str); diff --git a/windows/driverlogic/src/service.cpp b/windows/driverlogic/src/service.cpp new file mode 100644 index 0000000000..14ed7880fa --- /dev/null +++ b/windows/driverlogic/src/service.cpp @@ -0,0 +1,134 @@ +#include "stdafx.h" +#include "service.h" +#include "log.h" +#include <libcommon/error.h> +#include <libcommon/memory.h> + +#undef min +#undef max +#include <chrono> + +template<typename TTime = std::chrono::milliseconds> +class TimeBox +{ + // `steady_clock` wraps around every ~292 years. + using Clock = std::chrono::steady_clock; + using ClockTimePoint = std::chrono::time_point<Clock>; + +public: + + TimeBox(typename TTime::rep maxWaitTime) + : m_startTime(Clock::now()) + , m_maxWaitTime(TTime(maxWaitTime)) + { + } + + bool expired() const + { + const auto now = Clock::now(); + + const auto elapsed = + ( + (now < m_startTime) + ? (ClockTimePoint::max() - m_startTime) + (now - ClockTimePoint::min()) + : now - m_startTime + ); + + return std::chrono::duration_cast<TTime>(elapsed) > m_maxWaitTime; + } + +private: + + ClockTimePoint m_startTime; + TTime m_maxWaitTime; +}; + +void WaitUntilServiceStopped(SC_HANDLE service, DWORD maxWaitMs) +{ + TimeBox timer(maxWaitMs); + + for (;;) + { + SERVICE_STATUS_PROCESS ssp; + + DWORD bytesNeeded; + + auto status = QueryServiceStatusEx + ( + service, + SC_STATUS_PROCESS_INFO, + reinterpret_cast<BYTE*>(&ssp), + sizeof(ssp), + &bytesNeeded + ); + + if (status != 0 + && ssp.dwCurrentState == SERVICE_STOPPED) + { + return; + } + + if (timer.expired()) + { + THROW_ERROR("Failed when waiting for service to stop"); + } + + Sleep(100); + } +} + +void PokeService(const std::wstring &serviceName, bool stopService, bool deleteService) +{ + auto serviceManager = OpenSCManagerW(nullptr, SERVICES_ACTIVE_DATABASE, SC_MANAGER_ALL_ACCESS); + + if (serviceManager == NULL) + { + THROW_WINDOWS_ERROR(GetLastError(), "OpenSCManagerW"); + } + + common::memory::ScopeDestructor dtor; + + dtor += [serviceManager]() + { + CloseServiceHandle(serviceManager); + }; + + auto service = OpenServiceW(serviceManager, serviceName.c_str(), SERVICE_ALL_ACCESS); + + if (service == NULL) + { + THROW_WINDOWS_ERROR(GetLastError(), "OpenServiceW"); + } + + dtor += [service]() + { + CloseServiceHandle(service); + }; + + if (stopService) + { + Log(L"Stopping service"); + + SERVICE_STATUS ss; + + ControlService(service, SERVICE_CONTROL_STOP, &ss); + + WaitUntilServiceStopped(service, 1000 * 5); + + Log(L"Successfully stopped service"); + } + + if (deleteService) + { + Log(L"Deleting service"); + + auto status = DeleteService(service); + + if (status == 0) + { + THROW_WINDOWS_ERROR(GetLastError(), "DeleteService"); + } + + Log(L"Successfully deleted service"); + } +} diff --git a/windows/driverlogic/src/service.h b/windows/driverlogic/src/service.h new file mode 100644 index 0000000000..87632faf1b --- /dev/null +++ b/windows/driverlogic/src/service.h @@ -0,0 +1,7 @@ +#pragma once + +#include <windows.h> + +void WaitUntilServiceStopped(SC_HANDLE service, DWORD maxWaitMs); + +void PokeService(const std::wstring &serviceName, bool stopService, bool deleteService); diff --git a/windows/driverlogic/src/util.cpp b/windows/driverlogic/src/util.cpp new file mode 100644 index 0000000000..ebce6d82cb --- /dev/null +++ b/windows/driverlogic/src/util.cpp @@ -0,0 +1,36 @@ +#include "stdafx.h" +#include "util.h" +#include <windows.h> +#include <vector> +#include <libcommon/error.h> + +using path = std::filesystem::path; + +path +GetProcessModulePath +( +) +{ + size_t bufferSize = MAX_PATH; + + std::vector<wchar_t> pathBuffer(bufferSize); + + for (;;) + { + const auto writtenChars = GetModuleFileNameW(nullptr, &pathBuffer[0], static_cast<DWORD>(pathBuffer.size())); + + if (0 == writtenChars) + { + THROW_WINDOWS_ERROR(GetLastError(), "GetModuleFileNameW"); + } + + if (writtenChars != pathBuffer.size()) + { + return path(pathBuffer.begin(), pathBuffer.begin() + writtenChars); + } + + bufferSize *= 2; + + pathBuffer.resize(bufferSize); + } +} diff --git a/windows/driverlogic/src/util.h b/windows/driverlogic/src/util.h new file mode 100644 index 0000000000..1bdcaf9d15 --- /dev/null +++ b/windows/driverlogic/src/util.h @@ -0,0 +1,8 @@ +#pragma once + +#include <filesystem> + +std::filesystem::path +GetProcessModulePath +( +); diff --git a/windows/driverlogic/src/version.cpp b/windows/driverlogic/src/version.cpp new file mode 100644 index 0000000000..d7394e224b --- /dev/null +++ b/windows/driverlogic/src/version.cpp @@ -0,0 +1,125 @@ +#include "stdafx.h" +#include "version.h" +#include "device.h" +#include <setupapi.h> +#include <initguid.h> +#include <devpkey.h> +#include <libcommon/string.h> +#include <libcommon/memory.h> +#include <stdexcept> + +DRIVER_UPGRADE_STATUS +EvaluateDriverUpgrade +( + const std::wstring &existingVersion, + const std::wstring &proposedVersion +) +{ + // + // "x.y.z.a" + // + + using namespace common::string; + + auto et = Tokenize(existingVersion, L"."); + auto pt = Tokenize(proposedVersion, L"."); + + auto items = min(et.size(), pt.size()); + + for (auto index = 0; index < items; ++index) + { + auto ev = wcstoul(et[index].c_str(), nullptr, 10); + auto pv = wcstoul(pt[index].c_str(), nullptr, 10); + + if (pv > ev) + { + return DRIVER_UPGRADE_STATUS::WOULD_UPGRADE; + } + + if (ev > pv) + { + return DRIVER_UPGRADE_STATUS::WOULD_DOWNGRADE; + } + } + + if (pt.size() > et.size()) + { + return DRIVER_UPGRADE_STATUS::WOULD_UPGRADE; + } + + if (et.size() > pt.size()) + { + return DRIVER_UPGRADE_STATUS::WOULD_DOWNGRADE; + } + + return DRIVER_UPGRADE_STATUS::WOULD_INSTALL_SAME_VERSION; +} + +std::wstring +InfGetDriverVersion +( + const std::wstring &filePath +) +{ + auto infHandle = SetupOpenInfFileW(filePath.c_str(), nullptr, INF_STYLE_WIN4, nullptr); + + if (infHandle == INVALID_HANDLE_VALUE) + { + throw std::runtime_error("SetupOpenInfFileW()"); + } + + common::memory::ScopeDestructor dtor; + + dtor += [infHandle]() + { + SetupCloseInfFile(infHandle); + }; + + INFCONTEXT infContext { 0 }; + + auto status = SetupFindFirstLineW(infHandle, L"Version", L"DriverVer", &infContext); + + if (status == FALSE) + { + throw std::runtime_error("SetupFindFirstLineW()"); + } + + DWORD requiredSize; + + // + // This is a multi-value key. + // 0 = key, 1 = driver date + // + const DWORD VersionFieldIndex = 2; + + status = SetupGetStringFieldW(&infContext, VersionFieldIndex, nullptr, 0, &requiredSize); + + if (status == FALSE || requiredSize < 2) + { + throw std::runtime_error("SetupGetStringFieldW()"); + } + + std::vector<wchar_t> buffer(requiredSize); + + status = SetupGetStringFieldW(&infContext, VersionFieldIndex, + &buffer[0], static_cast<DWORD>(buffer.size()), nullptr); + + if (status == FALSE) + { + throw std::runtime_error("SetupGetStringFieldW()"); + } + + // Remove null terminator. + buffer.resize(requiredSize - 1); + + return buffer.data(); +} + +std::wstring +GetDriverVersion +( + const EnumeratedDevice &device +) +{ + return GetDeviceStringProperty(device.deviceInfoSet, device.deviceInfo, &DEVPKEY_Device_DriverVersion); +} diff --git a/windows/driverlogic/src/version.h b/windows/driverlogic/src/version.h new file mode 100644 index 0000000000..fc88b3b232 --- /dev/null +++ b/windows/driverlogic/src/version.h @@ -0,0 +1,30 @@ +#pragma once + +#include <string> +#include "device.h" + +enum class DRIVER_UPGRADE_STATUS +{ + WOULD_DOWNGRADE, + WOULD_INSTALL_SAME_VERSION, + WOULD_UPGRADE +}; + +DRIVER_UPGRADE_STATUS +EvaluateDriverUpgrade +( + const std::wstring &existingVersion, + const std::wstring &proposedVersion +); + +std::wstring +InfGetDriverVersion +( + const std::wstring &filePath +); + +std::wstring +GetDriverVersion +( + const EnumeratedDevice &device +); diff --git a/windows/driverlogic/src/wintun.h b/windows/driverlogic/src/wintun.h new file mode 100644 index 0000000000..3d81be97e3 --- /dev/null +++ b/windows/driverlogic/src/wintun.h @@ -0,0 +1,64 @@ +#pragma once + +#include <wintun/wintun.h> +#include <libcommon/error.h> +#include "util.h" + +class WintunDll +{ +public: + + WintunDll() : dllHandle(nullptr) + { + auto wintunPath = GetProcessModulePath().replace_filename(L"wintun.dll"); + dllHandle = LoadLibraryExW(wintunPath.c_str(), nullptr, LOAD_WITH_ALTERED_SEARCH_PATH); + + if (nullptr == dllHandle) + { + THROW_WINDOWS_ERROR(GetLastError(), "LoadLibraryExW"); + } + + try + { + createAdapter = getProcAddressOrThrow<WINTUN_CREATE_ADAPTER_FUNC>("WintunCreateAdapter"); + openAdapter = getProcAddressOrThrow<WINTUN_OPEN_ADAPTER_FUNC>("WintunOpenAdapter"); + freeAdapter = getProcAddressOrThrow<WINTUN_FREE_ADAPTER_FUNC>("WintunFreeAdapter"); + deletePoolDriver = getProcAddressOrThrow<WINTUN_DELETE_POOL_DRIVER_FUNC>("WintunDeletePoolDriver"); + } + catch (...) + { + FreeLibrary(dllHandle); + throw; + } + } + + ~WintunDll() + { + if (nullptr != dllHandle) + { + FreeLibrary(dllHandle); + } + } + + WINTUN_CREATE_ADAPTER_FUNC createAdapter; + WINTUN_OPEN_ADAPTER_FUNC openAdapter; + WINTUN_FREE_ADAPTER_FUNC freeAdapter; + WINTUN_DELETE_POOL_DRIVER_FUNC deletePoolDriver; + +private: + + template<typename T> + T getProcAddressOrThrow(const char *procName) + { + const T result = reinterpret_cast<T>(GetProcAddress(dllHandle, procName)); + + if (nullptr == result) + { + THROW_WINDOWS_ERROR(GetLastError(), "GetProcAddress"); + } + + return result; + } + + HMODULE dllHandle; +}; |
