diff options
| author | David Lönnhager <david.l@mullvad.net> | 2019-11-28 09:20:26 +0100 |
|---|---|---|
| committer | David Lönnhager <david.l@mullvad.net> | 2019-12-03 09:17:30 +0100 |
| commit | 8f9f08a676d196172c76de1a568416beeb4da3d0 (patch) | |
| tree | 845051488fb48699a43653acbdc1331354b3f2df | |
| parent | c030edc99b8fa03b1090d9a8060873d6fc7ed16c (diff) | |
| download | mullvadvpn-8f9f08a676d196172c76de1a568416beeb4da3d0.tar.xz mullvadvpn-8f9f08a676d196172c76de1a568416beeb4da3d0.zip | |
driverlogic: Obtain additional information about TAP adapters using SetupAPI and nci.dll
6 files changed, 363 insertions, 127 deletions
diff --git a/windows/nsis-plugins/src/driverlogic/context.cpp b/windows/nsis-plugins/src/driverlogic/context.cpp index 42394e25e3..c7ef8c91cd 100644 --- a/windows/nsis-plugins/src/driverlogic/context.cpp +++ b/windows/nsis-plugins/src/driverlogic/context.cpp @@ -1,5 +1,6 @@ #include "stdafx.h" #include "context.h" +#include "ncicontext.h" #include <libcommon/string.h> #include <libcommon/error.h> @@ -28,59 +29,6 @@ namespace const wchar_t TAP_HARDWARE_ID[] = L"tap0901"; -std::set<Context::NetworkAdapter> GetAllAdapters() -{ - ULONG bufferSize = 0; - - const ULONG flags = GAA_FLAG_SKIP_UNICAST | GAA_FLAG_SKIP_ANYCAST | GAA_FLAG_SKIP_MULTICAST | GAA_FLAG_SKIP_DNS_SERVER; - - auto status = GetAdaptersAddresses(AF_INET, flags, nullptr, nullptr, &bufferSize); - - THROW_UNLESS(ERROR_BUFFER_OVERFLOW, status, "Probe for adapter listing buffer size"); - - // Memory is cheap, this avoids a looping construct. - bufferSize *= 2; - - std::vector<uint8_t> buffer(bufferSize); - - status = GetAdaptersAddresses(AF_INET, flags, nullptr, - reinterpret_cast<PIP_ADAPTER_ADDRESSES>(&buffer[0]), &bufferSize); - - THROW_UNLESS(ERROR_SUCCESS, status, "Retrieve adapter listing"); - - std::set<Context::NetworkAdapter> adapters; - - for (auto it = (PIP_ADAPTER_ADDRESSES)&buffer[0]; nullptr != it; it = it->Next) - { - adapters.emplace(Context::NetworkAdapter(common::string::ToWide(it->AdapterName), - it->Description, it->FriendlyName)); - } - - return adapters; -} - -std::set<Context::NetworkAdapter> GetTapAdapters(const std::set<Context::NetworkAdapter> &adapters) -{ - std::set<Context::NetworkAdapter> tapAdapters; - - for (const auto &adapter : adapters) - { - static const wchar_t name[] = L"TAP-Windows Adapter V9"; - - // - // Compare partial name, because once you start having more TAP adapters - // they're named "TAP-Windows Adapter V9 #2" and so on. - // - - if (0 == adapter.name.compare(0, _countof(name) - 1, name)) - { - tapAdapters.insert(adapter); - } - } - - return tapAdapters; -} - template<typename T> void LogAdapters(const std::wstring &description, const T &adapters) { @@ -143,6 +91,217 @@ std::wstring GetNetCfgInstanceId(HDEVINFO devInfo, const SP_DEVINFO_DATA &devInf return instanceId.data(); } +std::wstring GetDeviceInstanceId( + HDEVINFO devInfo, + SP_DEVINFO_DATA* devInfoData +) +{ + DWORD requiredSize = 0; + + SetupDiGetDeviceInstanceIdW( + devInfo, + devInfoData, + nullptr, + 0, + &requiredSize + ); + + std::vector<wchar_t> deviceInstanceId; + deviceInstanceId.resize(1 + requiredSize * sizeof(wchar_t)); + + const auto status = SetupDiGetDeviceInstanceIdW( + devInfo, + devInfoData, + &deviceInstanceId[0], + deviceInstanceId.size(), + nullptr + ); + THROW_GLE_IF(FALSE, status, "SetupDiGetDeviceInstanceIdW() failed"); + + return deviceInstanceId.data(); +} + +std::wstring GetDeviceStringProperty( + HDEVINFO devInfo, + SP_DEVINFO_DATA *devInfoData, + const DEVPROPKEY *property +) +{ + // + // Obtain required buffer size + // + + DWORD requiredSize = 0; + DEVPROPTYPE type; + + const auto sizeStatus = SetupDiGetDevicePropertyW( + devInfo, + devInfoData, + property, + &type, + nullptr, + 0, + &requiredSize, + 0 + ); + + const DWORD lastError = GetLastError(); + if (FALSE == sizeStatus && ERROR_INSUFFICIENT_BUFFER != lastError) + { + common::error::Throw( + "Error obtaining device property length", + lastError + ); + } + + std::vector<wchar_t> buffer; + buffer.resize(1 + requiredSize / sizeof(wchar_t)); + + // + // Read property + // + + const auto status = SetupDiGetDevicePropertyW( + devInfo, + devInfoData, + property, + &type, + reinterpret_cast<PBYTE>(&buffer[0]), + buffer.size() * sizeof(wchar_t), + nullptr, + 0 + ); + + THROW_GLE_IF(FALSE, status, "Failed to read device property"); + + return buffer.data(); +} + +std::optional<std::wstring> GetDeviceRegistryStringProperty( + HDEVINFO devInfo, + SP_DEVINFO_DATA *devInfoData, + DWORD property +) +{ + // + // Obtain required buffer size + // + + DWORD requiredSize = 0; + + const auto sizeStatus = SetupDiGetDeviceRegistryPropertyW( + devInfo, + devInfoData, + property, + nullptr, + nullptr, + 0, + &requiredSize + ); + + const DWORD lastError = GetLastError(); + if (FALSE == sizeStatus && ERROR_INSUFFICIENT_BUFFER != lastError) + { + if (ERROR_INVALID_DATA == lastError) + { + // ERROR_INVALID_DATA may mean that the property does not exist + // TODO: Check if there may be other causes. + return std::nullopt; + } + THROW_GLE("Error obtaining device property length"); + } + + // + // Read property + // + + std::vector<wchar_t> buffer; + buffer.resize(1 + requiredSize / sizeof(wchar_t)); + + const auto status = SetupDiGetDeviceRegistryPropertyW( + devInfo, + devInfoData, + property, + nullptr, + reinterpret_cast<PBYTE>(&buffer[0]), + buffer.size() * sizeof(wchar_t), + nullptr + ); + + THROW_GLE_IF(FALSE, status, "Failed to read device property"); + + return { buffer.data() }; +} + +std::set<Context::NetworkAdapter> GetTapAdapters() +{ + std::set<Context::NetworkAdapter> adapters; + + HDEVINFO devInfo = SetupDiGetClassDevs( + &GUID_DEVCLASS_NET, + nullptr, + nullptr, + DIGCF_PRESENT + ); + THROW_GLE_IF(INVALID_HANDLE_VALUE, devInfo, "SetupDiGetClassDevs() failed"); + + common::memory::ScopeDestructor scopeDestructor; + scopeDestructor += [devInfo]() + { + SetupDiDestroyDeviceInfoList(devInfo); + }; + + NciContext nci; + + for (int memberIndex = 0; ; memberIndex++) + { + SP_DEVINFO_DATA devInfoData = { 0 }; + devInfoData.cbSize = sizeof(devInfoData); + + if (FALSE == SetupDiEnumDeviceInfo(devInfo, memberIndex, &devInfoData)) + { + if (ERROR_NO_MORE_ITEMS == GetLastError()) + { + // Done + break; + } + THROW_GLE("SetupDiEnumDeviceInfo() failed while enumerating network adapters"); + } + + // + // Check whether this is a TAP adapter + // + + const auto hardwareId = GetDeviceRegistryStringProperty(devInfo, &devInfoData, SPDRP_HARDWAREID); + if (!hardwareId.has_value() + || wcscmp(hardwareId.value().c_str(), TAP_HARDWARE_ID) != 0) + { + continue; + } + + // + // Construct NetworkAdapter + // + + const std::wstring guid = GetNetCfgInstanceId(devInfo, devInfoData); + + IID guidObj = { 0 }; + if (S_OK != IIDFromString(&guid[0], &guidObj)) + { + throw std::runtime_error("IIDFromString() failed"); + } + + adapters.emplace(Context::NetworkAdapter( + guid, + GetDeviceStringProperty(devInfo, &devInfoData, &DEVPKEY_Device_DriverDesc), + nci.getConnectionName(guidObj), + GetDeviceInstanceId(devInfo, &devInfoData) + )); + } + + return adapters; +} + } // anonymous namespace //static @@ -201,15 +360,14 @@ std::optional<Context::NetworkAdapter> Context::FindMullvadAdapter(const std::se Context::BaselineStatus Context::establishBaseline() { - m_baseline = GetAllAdapters(); - const auto tapAdapters = GetTapAdapters(m_baseline); + m_baseline = GetTapAdapters(); - if (tapAdapters.empty()) + if (m_baseline.empty()) { return BaselineStatus::NO_TAP_ADAPTERS_PRESENT; } - if (FindMullvadAdapter(tapAdapters).has_value()) + if (FindMullvadAdapter(m_baseline).has_value()) { return BaselineStatus::MULLVAD_ADAPTER_PRESENT; } @@ -219,19 +377,16 @@ Context::BaselineStatus Context::establishBaseline() void Context::recordCurrentState() { - m_currentState = GetAllAdapters(); + m_currentState = GetTapAdapters(); } Context::NetworkAdapter Context::getNewAdapter() { std::list<NetworkAdapter> added; - const auto baselineTaps = GetTapAdapters(m_baseline); - const auto currentTaps = GetTapAdapters(m_currentState); - - for (const auto &adapter : currentTaps) + for (const auto &adapter : m_currentState) { - if (baselineTaps.end() == baselineTaps.find(adapter)) + if (m_baseline.end() == m_baseline.find(adapter)) { added.push_back(adapter); } @@ -239,8 +394,8 @@ Context::NetworkAdapter Context::getNewAdapter() if (added.size() != 1) { - LogAdapters(L"Enumerable network adapters", m_currentState); - LogAdapters(L"Added TAP adapters", added); + LogAdapters(L"Enumerable network TAP adapters", m_currentState); + LogAdapters(L"New TAP adapters:", added); throw std::runtime_error("Unable to identify recently added TAP adapter"); } @@ -251,7 +406,7 @@ Context::NetworkAdapter Context::getNewAdapter() //static Context::DeletionResult Context::DeleteMullvadAdapter() { - auto tapAdapters = GetTapAdapters(GetAllAdapters()); + auto tapAdapters = GetTapAdapters(); std::optional<NetworkAdapter> mullvadAdapter = FindMullvadAdapter(tapAdapters); if (!mullvadAdapter.has_value()) @@ -276,71 +431,28 @@ Context::DeletionResult Context::DeleteMullvadAdapter() SetupDiDestroyDeviceInfoList(devInfo); }; - SP_DEVINFO_DATA devInfoData; - - std::vector<wchar_t> buffer; - DWORD nameLen; - int numRemainingAdapters = 0; for (int memberIndex = 0; ; memberIndex++) { - devInfoData = { 0 }; + SP_DEVINFO_DATA devInfoData = { 0 }; devInfoData.cbSize = sizeof(devInfoData); if (FALSE == SetupDiEnumDeviceInfo(devInfo, memberIndex, &devInfoData)) { - if (GetLastError() == ERROR_NO_MORE_ITEMS) + if (ERROR_NO_MORE_ITEMS == GetLastError()) { - /* done */ break; } THROW_GLE("Error enumerating network adapters"); } - if (FALSE == SetupDiGetDeviceRegistryPropertyW( - devInfo, - &devInfoData, - SPDRP_HARDWAREID, - nullptr, - nullptr, - 0, - &nameLen - )) - { - const auto status = GetLastError(); - if (ERROR_INSUFFICIENT_BUFFER != status) - { - /* ERROR_INSUFFICIENT_BUFFER is expected */ - if (ERROR_INVALID_DATA == status) - { - /* ERROR_INVALID_DATA may mean that the property does not exist */ - continue; - } - THROW_GLE("Error obtaining network adapter hardware ID length"); - } - } - - buffer.resize(nameLen / sizeof(wchar_t) + 1); - buffer[nameLen / sizeof(wchar_t)] = L'\0'; + const auto hardwareId = GetDeviceRegistryStringProperty(devInfo, &devInfoData, SPDRP_HARDWAREID); - if (FALSE == SetupDiGetDeviceRegistryPropertyW( - devInfo, - &devInfoData, - SPDRP_HARDWAREID, - nullptr, - reinterpret_cast<PBYTE>(buffer.data()), - (buffer.size() - 1) * sizeof(wchar_t), - nullptr - )) + if (hardwareId.has_value() + && wcscmp(TAP_HARDWARE_ID, hardwareId.value().data()) == 0) { - THROW_GLE("Error obtaining network adapter hardware ID"); - } - - if (wcscmp(TAP_HARDWARE_ID, buffer.data()) == 0) - { - std::wstring netCfgInstanceId = GetNetCfgInstanceId(devInfo, devInfoData); - if (netCfgInstanceId.compare(mullvadGuid) != 0) + if (0 != GetNetCfgInstanceId(devInfo, devInfoData).compare(mullvadGuid)) { numRemainingAdapters++; continue; @@ -356,10 +468,7 @@ Context::DeletionResult Context::DeleteMullvadAdapter() } } - if (numRemainingAdapters > 0) - { - return DeletionResult::SOME_REMAINING_TAP_ADAPTERS; - } - - return DeletionResult::NO_REMAINING_TAP_ADAPTERS; + return (numRemainingAdapters > 0) + ? DeletionResult::SOME_REMAINING_TAP_ADAPTERS + : DeletionResult::NO_REMAINING_TAP_ADAPTERS; } diff --git a/windows/nsis-plugins/src/driverlogic/context.h b/windows/nsis-plugins/src/driverlogic/context.h index eabf884ba8..72682a3dbe 100644 --- a/windows/nsis-plugins/src/driverlogic/context.h +++ b/windows/nsis-plugins/src/driverlogic/context.h @@ -17,11 +17,13 @@ public: std::wstring guid; std::wstring name; std::wstring alias; + std::wstring deviceInstanceId; - NetworkAdapter(std::wstring _guid, std::wstring _name, std::wstring _alias) - : guid(_guid) - , name(_name) - , alias(_alias) + NetworkAdapter(std::wstring guid, std::wstring name, std::wstring alias, std::wstring deviceInstanceId) + : guid(guid) + , name(name) + , alias(alias) + , deviceInstanceId(deviceInstanceId) { } diff --git a/windows/nsis-plugins/src/driverlogic/driverlogic.vcxproj b/windows/nsis-plugins/src/driverlogic/driverlogic.vcxproj index 0bcb1a2179..ff268f19d3 100644 --- a/windows/nsis-plugins/src/driverlogic/driverlogic.vcxproj +++ b/windows/nsis-plugins/src/driverlogic/driverlogic.vcxproj @@ -70,7 +70,7 @@ <GenerateDebugInformation>true</GenerateDebugInformation> <ImageHasSafeExceptionHandlers>false</ImageHasSafeExceptionHandlers> <AdditionalLibraryDirectories>$(ProjectDir)../../../../dist-assets/binaries/x86_64-pc-windows-msvc/nsis/;$(SolutionDir)bin\$(Platform)-$(Configuration)\</AdditionalLibraryDirectories> - <AdditionalDependencies>setupapi.lib;iphlpapi.lib;log.lib;libcommon.lib;pluginapi-x86-unicode.lib;kernel32.lib;user32.lib;gdi32.lib;winspool.lib;comdlg32.lib;advapi32.lib;shell32.lib;ole32.lib;oleaut32.lib;uuid.lib;odbc32.lib;odbccp32.lib;%(AdditionalDependencies)</AdditionalDependencies> + <AdditionalDependencies>setupapi.lib;log.lib;libcommon.lib;pluginapi-x86-unicode.lib;kernel32.lib;user32.lib;gdi32.lib;winspool.lib;comdlg32.lib;advapi32.lib;shell32.lib;ole32.lib;oleaut32.lib;uuid.lib;odbc32.lib;odbccp32.lib;%(AdditionalDependencies)</AdditionalDependencies> <IgnoreSpecificDefaultLibraries>libc.lib</IgnoreSpecificDefaultLibraries> <ModuleDefinitionFile>driverlogic.def</ModuleDefinitionFile> </Link> @@ -96,13 +96,14 @@ <GenerateDebugInformation>true</GenerateDebugInformation> <ImageHasSafeExceptionHandlers>false</ImageHasSafeExceptionHandlers> <AdditionalLibraryDirectories>$(ProjectDir)../../../../dist-assets/binaries/x86_64-pc-windows-msvc/nsis/;$(SolutionDir)bin\$(Platform)-$(Configuration)\</AdditionalLibraryDirectories> - <AdditionalDependencies>setupapi.lib;iphlpapi.lib;log.lib;libcommon.lib;pluginapi-x86-unicode.lib;kernel32.lib;user32.lib;gdi32.lib;winspool.lib;comdlg32.lib;advapi32.lib;shell32.lib;ole32.lib;oleaut32.lib;uuid.lib;odbc32.lib;odbccp32.lib;%(AdditionalDependencies)</AdditionalDependencies> + <AdditionalDependencies>setupapi.lib;log.lib;libcommon.lib;pluginapi-x86-unicode.lib;kernel32.lib;user32.lib;gdi32.lib;winspool.lib;comdlg32.lib;advapi32.lib;shell32.lib;ole32.lib;oleaut32.lib;uuid.lib;odbc32.lib;odbccp32.lib;%(AdditionalDependencies)</AdditionalDependencies> <IgnoreSpecificDefaultLibraries>libc.lib</IgnoreSpecificDefaultLibraries> <ModuleDefinitionFile>driverlogic.def</ModuleDefinitionFile> </Link> </ItemDefinitionGroup> <ItemGroup> <ClInclude Include="context.h" /> + <ClInclude Include="ncicontext.h" /> <ClInclude Include="stdafx.h" /> <ClInclude Include="targetver.h" /> </ItemGroup> @@ -110,6 +111,7 @@ <ClCompile Include="dllmain.cpp" /> <ClCompile Include="driverlogic.cpp" /> <ClCompile Include="context.cpp" /> + <ClCompile Include="ncicontext.cpp" /> <ClCompile Include="stdafx.cpp"> <PrecompiledHeader Condition="'$(Configuration)|$(Platform)'=='Debug|Win32'">Create</PrecompiledHeader> <PrecompiledHeader Condition="'$(Configuration)|$(Platform)'=='Release|Win32'">Create</PrecompiledHeader> diff --git a/windows/nsis-plugins/src/driverlogic/driverlogic.vcxproj.filters b/windows/nsis-plugins/src/driverlogic/driverlogic.vcxproj.filters index 2c25d275f1..eff46cba67 100644 --- a/windows/nsis-plugins/src/driverlogic/driverlogic.vcxproj.filters +++ b/windows/nsis-plugins/src/driverlogic/driverlogic.vcxproj.filters @@ -4,12 +4,14 @@ <ClInclude Include="stdafx.h" /> <ClInclude Include="targetver.h" /> <ClInclude Include="context.h" /> + <ClInclude Include="ncicontext.h" /> </ItemGroup> <ItemGroup> <ClCompile Include="dllmain.cpp" /> <ClCompile Include="driverlogic.cpp" /> <ClCompile Include="stdafx.cpp" /> <ClCompile Include="context.cpp" /> + <ClCompile Include="ncicontext.cpp" /> </ItemGroup> <ItemGroup> <None Include="driverlogic.def" /> diff --git a/windows/nsis-plugins/src/driverlogic/ncicontext.cpp b/windows/nsis-plugins/src/driverlogic/ncicontext.cpp new file mode 100644 index 0000000000..988199f470 --- /dev/null +++ b/windows/nsis-plugins/src/driverlogic/ncicontext.cpp @@ -0,0 +1,88 @@ +#include "stdafx.h" +#include "ncicontext.h" +#include <libcommon/error.h> +#include <libcommon/filesystem.h> +#include <filesystem> +#include <stdexcept> + +NciContext::NciContext() +{ + std::wstring systemDir = common::fs::GetKnownFolderPath( + FOLDERID_System, + KF_FLAG_DEFAULT, + nullptr + ); + const auto lsassPath = std::filesystem::path(systemDir).append(L"nci.dll"); + + dllHandle = LoadLibraryW(lsassPath.c_str()); + + if (nullptr == dllHandle) + { + throw std::runtime_error("Failed to load nci.dll"); + } + + m_nciGetConnectionName = reinterpret_cast<nciGetConnectionNameFunc>( + GetProcAddress(dllHandle, "NciGetConnectionName")); + + if (nullptr == m_nciGetConnectionName) + { + FreeLibrary(dllHandle); + throw std::runtime_error("Failed to obtain pointer to nciGetConnectionName"); + } + + m_nciSetConnectionName = reinterpret_cast<nciSetConnectionNameFunc>( + GetProcAddress(dllHandle, "NciSetConnectionName")); + + if (nullptr == m_nciSetConnectionName) + { + FreeLibrary(dllHandle); + throw std::runtime_error("Failed to obtain pointer to nciSetConnectionName"); + } +} + +NciContext::~NciContext() +{ + FreeLibrary(dllHandle); +} + +std::wstring NciContext::getConnectionName(const GUID& guid) +{ + DWORD nameLen = 0; + DWORD status = m_nciGetConnectionName(&guid, nullptr, 0, &nameLen); + + if (0 != status) + { + common::error::Throw( + "NciGetConnectionName() failed", + status + ); + } + + std::wstring buffer; + buffer.resize(nameLen / sizeof(wchar_t)); + + DWORD capacity = static_cast<DWORD>(buffer.capacity() * sizeof(wchar_t)); + status = m_nciGetConnectionName(&guid, &buffer[0], capacity, nullptr); + + if (0 != status) + { + common::error::Throw( + "NciGetConnectionName() failed", + status + ); + } + + return buffer; +} + +void NciContext::setConnectionName(const GUID& guid, const wchar_t* newName) +{ + const auto status = m_nciSetConnectionName(&guid, newName); + if (0 != status) + { + common::error::Throw( + "NciSetConnectionName() failed", + status + ); + } +} diff --git a/windows/nsis-plugins/src/driverlogic/ncicontext.h b/windows/nsis-plugins/src/driverlogic/ncicontext.h new file mode 100644 index 0000000000..00dfa27f04 --- /dev/null +++ b/windows/nsis-plugins/src/driverlogic/ncicontext.h @@ -0,0 +1,33 @@ +#pragma once + +#include <windows.h> +#include <string> + +// +// Interface for nci.dll. +// + +class NciContext +{ + HMODULE dllHandle; + + using nciGetConnectionNameFunc = DWORD(__stdcall*)(const GUID*, wchar_t*, DWORD, DWORD*); + using nciSetConnectionNameFunc = DWORD(__stdcall*)(const GUID*, const wchar_t*); + + nciGetConnectionNameFunc m_nciGetConnectionName; + nciSetConnectionNameFunc m_nciSetConnectionName; + + NciContext(NciContext&) = delete; + NciContext& operator=(NciContext&) = delete; + +public: + + NciContext(); + ~NciContext(); + + NciContext(NciContext&&) = default; + NciContext& operator=(NciContext&&) = default; + + std::wstring getConnectionName(const GUID& guid); + void setConnectionName(const GUID& guid, const wchar_t* newName); +}; |
