diff options
| author | David Lönnhager <david.l@mullvad.net> | 2020-06-12 11:30:01 +0200 |
|---|---|---|
| committer | David Lönnhager <david.l@mullvad.net> | 2020-06-15 09:34:24 +0200 |
| commit | 9bd9c061022163a291d3dc7ae01f7589bfa60848 (patch) | |
| tree | 4404cba34d30681865fc5ed51cb300282cc5ecec /windows/driverlogic/src/driverlogic.cpp | |
| parent | fa4637b71df806f499b0a8d38332027d778318f2 (diff) | |
| download | mullvadvpn-9bd9c061022163a291d3dc7ae01f7589bfa60848.tar.xz mullvadvpn-9bd9c061022163a291d3dc7ae01f7589bfa60848.zip | |
Allow some time for NetCfgInstanceId to be created for the new TAP adapter device
Diffstat (limited to 'windows/driverlogic/src/driverlogic.cpp')
| -rw-r--r-- | windows/driverlogic/src/driverlogic.cpp | 99 |
1 files changed, 92 insertions, 7 deletions
diff --git a/windows/driverlogic/src/driverlogic.cpp b/windows/driverlogic/src/driverlogic.cpp index 3e482c196f..450e08f821 100644 --- a/windows/driverlogic/src/driverlogic.cpp +++ b/windows/driverlogic/src/driverlogic.cpp @@ -1,6 +1,7 @@ #include "stdafx.h" #include "error.h" #include <iostream> +#include <chrono> #include <sstream> #include <string> #include <optional> @@ -28,6 +29,8 @@ constexpr wchar_t DEPRECATED_TAP_HARDWARE_ID[] = L"tap0901"; constexpr wchar_t TAP_HARDWARE_ID[] = L"tapmullvad0901"; constexpr wchar_t TAP_BASE_ALIAS[] = L"Mullvad"; +constexpr std::chrono::milliseconds REGISTRY_GET_TIMEOUT_MS{ 10000 }; + enum ReturnCodes { GENERAL_SUCCESS = 0, @@ -231,6 +234,87 @@ std::wstring GetDeviceInstanceId( return deviceInstanceId.data(); } +bool TryGetRegistryValueTimeout( + HKEY key, + const wchar_t *subkey, + const wchar_t *value, + DWORD flags, + DWORD *type, + void *data, + DWORD *dataSize +) +{ + HANDLE changeEvent = CreateEventW(nullptr, FALSE, FALSE, nullptr); + + if (nullptr == changeEvent) + { + THROW_WINDOWS_ERROR(GetLastError(), "CreateEventW"); + } + + common::memory::ScopeDestructor scopeDestructor; + scopeDestructor += [changeEvent]() { + CloseHandle(changeEvent); + }; + + auto initialTime = std::chrono::steady_clock::now(); + + for (;;) + { + const auto status = RegGetValueW(key, subkey, value, flags, type, data, dataSize); + + if (ERROR_SUCCESS == status) + { + // We're done + return true; + } + + if (ERROR_FILE_NOT_FOUND != status) + { + THROW_WINDOWS_ERROR(status, "RegGetValueW"); + } + + // + // Wait for the registry value to be created + // + + auto currentTime = std::chrono::steady_clock::now(); + auto elapsedTime = std::chrono::duration_cast<std::chrono::milliseconds>(currentTime - initialTime); + auto timeDelta = (REGISTRY_GET_TIMEOUT_MS - elapsedTime).count(); + + if (timeDelta <= 0) + { + return false; + } + + const auto notifyResult = RegNotifyChangeKeyValue( + key, + subkey != nullptr, // Watch subkeys + REG_NOTIFY_CHANGE_LAST_SET, + changeEvent, + TRUE + ); + + if (ERROR_SUCCESS != notifyResult) + { + THROW_WINDOWS_ERROR(notifyResult, "RegNotifyChangeKeyValue"); + } + + const auto waitResult = WaitForSingleObject(changeEvent, static_cast<DWORD>(timeDelta)); + if (WAIT_OBJECT_0 == waitResult) + { + // Try again + continue; + } + + if (WAIT_TIMEOUT != waitResult) + { + THROW_WINDOWS_ERROR(GetLastError(), "WaitForSingleObject"); + } + + return false; + } +} + std::wstring GetNetCfgInstanceId(HDEVINFO devInfo, const SP_DEVINFO_DATA &devInfoData) { HKEY hNet = SetupDiOpenDevRegKey( @@ -247,10 +331,15 @@ std::wstring GetNetCfgInstanceId(HDEVINFO devInfo, const SP_DEVINFO_DATA &devInf THROW_SETUPAPI_ERROR(GetLastError(), "SetupDiOpenDevRegKey"); } + common::memory::ScopeDestructor scopeDestructor; + scopeDestructor += [hNet]() { + RegCloseKey(hNet); + }; + std::vector<wchar_t> instanceId(MAX_PATH + 1); DWORD strSize = static_cast<DWORD>(instanceId.size() * sizeof(wchar_t)); - const auto status = RegGetValueW( + if (!TryGetRegistryValueTimeout( hNet, nullptr, L"NetCfgInstanceId", @@ -258,13 +347,9 @@ std::wstring GetNetCfgInstanceId(HDEVINFO devInfo, const SP_DEVINFO_DATA &devInf nullptr, instanceId.data(), &strSize - ); - - RegCloseKey(hNet); - - if (ERROR_SUCCESS != status) + )) { - THROW_WINDOWS_ERROR(status, "RegGetValueW"); + THROW_ERROR("Timed out waiting for NetCfgInstanceId."); } return instanceId.data(); |
