diff options
| author | Odd Stranne <odd@mullvad.net> | 2018-09-12 13:43:14 +0200 |
|---|---|---|
| committer | Odd Stranne <odd@mullvad.net> | 2018-09-12 13:43:14 +0200 |
| commit | d309f3af318255faa7e37b91791d6f222992ee42 (patch) | |
| tree | 12b6e4eb342636ddfed281d7a28805ea8052cada /windows | |
| parent | 4cea1a683f2abfb4ee6b0549bdcf9db2517bb400 (diff) | |
| parent | 0152d49cc6bc54e717db7186d8162508776c3135 (diff) | |
| download | mullvadvpn-d309f3af318255faa7e37b91791d6f222992ee42.tar.xz mullvadvpn-d309f3af318255faa7e37b91791d6f222992ee42.zip | |
Merge branch 'windns-improve-logging'
Diffstat (limited to 'windows')
| -rw-r--r-- | windows/windns/src/extras/loader/loader.cpp | 13 | ||||
| -rw-r--r-- | windows/windns/src/windns/configmanager.cpp | 12 | ||||
| -rw-r--r-- | windows/windns/src/windns/configmanager.h | 13 | ||||
| -rw-r--r-- | windows/windns/src/windns/confineoperation.cpp | 98 | ||||
| -rw-r--r-- | windows/windns/src/windns/confineoperation.h | 25 | ||||
| -rw-r--r-- | windows/windns/src/windns/iclientsinkproxy.h | 13 | ||||
| -rw-r--r-- | windows/windns/src/windns/netconfigeventsink.cpp | 53 | ||||
| -rw-r--r-- | windows/windns/src/windns/netconfigeventsink.h | 8 | ||||
| -rw-r--r-- | windows/windns/src/windns/netsh.cpp | 91 | ||||
| -rw-r--r-- | windows/windns/src/windns/netsh.h | 23 | ||||
| -rw-r--r-- | windows/windns/src/windns/windns.cpp | 109 | ||||
| -rw-r--r-- | windows/windns/src/windns/windns.h | 2 | ||||
| -rw-r--r-- | windows/windns/src/windns/windns.vcxproj | 3 | ||||
| -rw-r--r-- | windows/windns/src/windns/windns.vcxproj.filters | 3 | ||||
| -rw-r--r-- | windows/windns/src/windns/windnscontext.cpp | 61 | ||||
| -rw-r--r-- | windows/windns/src/windns/windnscontext.h | 6 |
16 files changed, 382 insertions, 151 deletions
diff --git a/windows/windns/src/extras/loader/loader.cpp b/windows/windns/src/extras/loader/loader.cpp index 1f7b5a6f97..46d8a7030c 100644 --- a/windows/windns/src/extras/loader/loader.cpp +++ b/windows/windns/src/extras/loader/loader.cpp @@ -7,9 +7,14 @@ #include <vector> #include <windows.h> -void WINDNS_API ErrorSink(const char *errorMessage, void *context) +void WINDNS_API ErrorSink(const char *errorMessage, const char **details, uint32_t numDetails, void *context) { std::cout << "WINDNS Error: " << errorMessage << std::endl; + + for (uint32_t i = 0; i < numDetails; ++i) + { + std::cout << " " << details[i] << std::endl; + } } void WINDNS_API ConfigSink(const void *configData, uint32_t dataLength, void *context) @@ -83,14 +88,14 @@ int main() { common::trace::Trace::RegisterSink(new common::trace::ConsoleTraceSink); + std::wcout << L"WinDns_Initialize: " << WinDns_Initialize(ErrorSink, nullptr) << std::endl; + if (Ask(L"Perform recovery?")) { Recover(); return 0; } - std::wcout << L"WinDns_Initialize: " << WinDns_Initialize(ErrorSink, nullptr) << std::endl; - const wchar_t *servers[] = { L"8.8.8.8" @@ -117,4 +122,4 @@ int main() std::wcout << L"WinDns_Deinitialize: " << WinDns_Deinitialize() << std::endl; return 0; -}
\ No newline at end of file +} diff --git a/windows/windns/src/windns/configmanager.cpp b/windows/windns/src/windns/configmanager.cpp index c9fdde1142..498e58a381 100644 --- a/windows/windns/src/windns/configmanager.cpp +++ b/windows/windns/src/windns/configmanager.cpp @@ -8,10 +8,10 @@ ConfigManager::ConfigManager ( const std::vector<std::wstring> &servers, - const ConfigSinkInfo &configSinkInfo + IClientSinkProxy *clientSinkProxy ) : m_servers(servers) - , m_configSinkInfo(configSinkInfo) + , m_clientSinkProxy(clientSinkProxy) { } @@ -31,12 +31,6 @@ void ConfigManager::updateServers(const std::vector<std::wstring> &servers) m_servers = servers; } -void ConfigManager::updateConfigSink(const ConfigSinkInfo &configSinkInfo) -{ - XTRACE(L"Updating config sink"); - m_configSinkInfo = configSinkInfo; -} - const std::vector<std::wstring> &ConfigManager::getServers() const { return m_servers; @@ -135,5 +129,5 @@ void ConfigManager::exportConfigs() auto data = s.blob(); - m_configSinkInfo.sink(&data[0], static_cast<uint32_t>(data.size()), m_configSinkInfo.context); + m_clientSinkProxy->config(&data[0], static_cast<uint32_t>(data.size())); } diff --git a/windows/windns/src/windns/configmanager.h b/windows/windns/src/windns/configmanager.h index 157f148901..c6f14f7ea9 100644 --- a/windows/windns/src/windns/configmanager.h +++ b/windows/windns/src/windns/configmanager.h @@ -1,7 +1,7 @@ #pragma once #include "interfaceconfig.h" -#include "clientsinkinfo.h" +#include "iclientsinkproxy.h" #include <map> #include <string> #include <mutex> @@ -50,7 +50,7 @@ public: ConfigManager ( const std::vector<std::wstring> &servers, - const ConfigSinkInfo &configSinkInfo + IClientSinkProxy *clientSinkProxy ); // @@ -61,16 +61,11 @@ public: void unlock(); // - // Notify the ConfigManager that servers used when overriding DNS settings have changed. + // Establish the set of servers to use when overriding DNS settings. // void updateServers(const std::vector<std::wstring> &servers); // - // Update the callback used for persisting settings. - // - void updateConfigSink(const ConfigSinkInfo &configSinkInfo); - - // // Get the current set of servers used for overriding DNS settings. // const std::vector<std::wstring> &getServers() const; @@ -96,7 +91,7 @@ private: std::mutex m_mutex; std::vector<std::wstring> m_servers; - ConfigSinkInfo m_configSinkInfo; + IClientSinkProxy *m_clientSinkProxy; // // Organize configs based on their system assigned index. diff --git a/windows/windns/src/windns/confineoperation.cpp b/windows/windns/src/windns/confineoperation.cpp new file mode 100644 index 0000000000..5c16de71bf --- /dev/null +++ b/windows/windns/src/windns/confineoperation.cpp @@ -0,0 +1,98 @@ +#include "stdafx.h" +#include "confineoperation.h" +#include "netsh.h" + +bool ConfineOperation +( + const char *literalOperation, + std::function<void(const char *, const char **, uint32_t)> errorCallback, + std::function<void()> operation +) +{ + try + { + operation(); + return true; + } + catch (NetShError &err) + { + auto raw = CreateRawStringArray(err.details()); + + const char **details = reinterpret_cast<const char **>(&raw[0]); + uint32_t numDetails = static_cast<uint32_t>(err.details().size()); + + if (0 == numDetails) + { + details = nullptr; + } + + const auto what = std::string(literalOperation).append(": ").append(err.what()); + + errorCallback(what.c_str(), details, numDetails); + + return false; + } + catch (std::exception &err) + { + const auto what = std::string(literalOperation).append(": ").append(err.what()); + + errorCallback(what.c_str(), nullptr, 0); + + return false; + } + catch (...) + { + const auto what = std::string(literalOperation).append(": Unspecified failure"); + + errorCallback(what.c_str(), nullptr, 0); + + return false; + } +} + +std::vector<uint8_t> CreateRawStringArray(const std::vector<std::string> &arr) +{ + // + // Return a buffer containing a nullptr if there are no items in the array. + // This enables clients of this function to address the pointer table. + // + + if (arr.empty()) + { + return std::vector<uint8_t>(sizeof(char *), 0); + } + + // + // Determine total size needed. + // + + size_t bufferSize = 0; + + for (const auto &str : arr) + { + bufferSize += sizeof(char *); + bufferSize += (str.size() + 1); + } + + // + // Copy strings and populate pointer table. + // + + std::vector<uint8_t> buffer(bufferSize, 0); + + char **pointerTable = reinterpret_cast<char**>(&buffer[0]); + char *data = reinterpret_cast<char*>(&buffer[0] + (sizeof(char*) * arr.size())); + + for (const auto &str : arr) + { + const auto fullStringSize = str.size() + 1; + + *pointerTable = data; + memcpy(data, str.c_str(), fullStringSize); + + ++pointerTable; + data += fullStringSize; + } + + return buffer; +} diff --git a/windows/windns/src/windns/confineoperation.h b/windows/windns/src/windns/confineoperation.h new file mode 100644 index 0000000000..21fe7dd996 --- /dev/null +++ b/windows/windns/src/windns/confineoperation.h @@ -0,0 +1,25 @@ +#pragma once + +#include <functional> +#include <vector> +#include <string> +#include <cstdint> + +bool ConfineOperation +( + const char *literalOperation, + std::function<void(const char *, const char **, uint32_t)> errorCallback, + std::function<void()> operation +); + +// +// The returned buffer looks like this: +// +// string pointer 1 +// string pointer 2 +// string pointer n +// string 1 +// string 2 +// string n +// +std::vector<uint8_t> CreateRawStringArray(const std::vector<std::string> &arr); diff --git a/windows/windns/src/windns/iclientsinkproxy.h b/windows/windns/src/windns/iclientsinkproxy.h new file mode 100644 index 0000000000..9270a12c50 --- /dev/null +++ b/windows/windns/src/windns/iclientsinkproxy.h @@ -0,0 +1,13 @@ +#pragma once + +#include <cstdint> + +struct IClientSinkProxy +{ + virtual ~IClientSinkProxy() = 0 + { + } + + virtual void error(const char *errorMessage, const char **details, uint32_t numDetails) = 0; + virtual void config(const void *configData, uint32_t dataLength) = 0; +}; diff --git a/windows/windns/src/windns/netconfigeventsink.cpp b/windows/windns/src/windns/netconfigeventsink.cpp index e4ef7ae1ce..5b8d4efd07 100644 --- a/windows/windns/src/windns/netconfigeventsink.cpp +++ b/windows/windns/src/windns/netconfigeventsink.cpp @@ -1,34 +1,51 @@ #include "stdafx.h" #include "netconfigeventsink.h" -#include "windns/netconfighelpers.h" +#include "netconfighelpers.h" +#include "netsh.h" +#include "confineoperation.h" +#include <functional> using namespace common; -NetConfigEventSink::NetConfigEventSink(std::shared_ptr<wmi::IConnection> connection, std::shared_ptr<ConfigManager> configManager) +NetConfigEventSink::NetConfigEventSink +( + std::shared_ptr<wmi::IConnection> connection, + std::shared_ptr<ConfigManager> configManager, + IClientSinkProxy *clientSinkProxy +) : m_connection(connection) , m_configManager(configManager) + , m_clientSinkProxy(clientSinkProxy) { } void NetConfigEventSink::update(CComPtr<IWbemClassObject> previous, CComPtr<IWbemClassObject> target) { - InterfaceConfig previousConfig(previous); - InterfaceConfig targetConfig(target); - - ConfigManager::Mutex mutex(*m_configManager); + auto forwardError = [this](const char *errorMessage, const char **details, uint32_t numDetails) + { + m_clientSinkProxy->error(errorMessage, details, numDetails); + }; - // - // This is OK because the config manager will reject updates - // that set our DNS servers. - // - if (ConfigManager::UpdateStatus::DnsApproved == m_configManager->updateConfig(previousConfig, targetConfig)) + ConfineOperation("Process adapter update event", forwardError, [&]() { - return; - } + InterfaceConfig previousConfig(previous); + InterfaceConfig targetConfig(target); + + ConfigManager::Mutex mutex(*m_configManager); + + // + // This is OK because the config manager will reject updates + // that set our DNS servers. + // + if (ConfigManager::UpdateStatus::DnsApproved == m_configManager->updateConfig(previousConfig, targetConfig)) + { + return; + } - // - // The update was initiated from an external source. - // Override current settings to enforce our selected DNS servers. - // - nchelpers::SetDnsServers(targetConfig.interfaceIndex(), m_configManager->getServers()); + // + // The update was initiated from an external source. + // Override current settings to enforce our selected DNS servers. + // + nchelpers::SetDnsServers(targetConfig.interfaceIndex(), m_configManager->getServers()); + }); } diff --git a/windows/windns/src/windns/netconfigeventsink.h b/windows/windns/src/windns/netconfigeventsink.h index 342278c08c..c3fd9242f7 100644 --- a/windows/windns/src/windns/netconfigeventsink.h +++ b/windows/windns/src/windns/netconfigeventsink.h @@ -2,14 +2,16 @@ #include "libcommon/wmi/ieventsink.h" #include "libcommon/wmi/iconnection.h" -#include "windns/configmanager.h" +#include "configmanager.h" +#include "iclientsinkproxy.h" #include <memory> class NetConfigEventSink : public common::wmi::IModificationEventSink { public: - NetConfigEventSink(std::shared_ptr<common::wmi::IConnection> connection, std::shared_ptr<ConfigManager> configManager); + NetConfigEventSink(std::shared_ptr<common::wmi::IConnection> connection, + std::shared_ptr<ConfigManager> configManager, IClientSinkProxy *clientSinkProxy); void update(CComPtr<IWbemClassObject> previous, CComPtr<IWbemClassObject> target) override; @@ -17,4 +19,6 @@ private: std::shared_ptr<common::wmi::IConnection> m_connection; std::shared_ptr<ConfigManager> m_configManager; + + IClientSinkProxy *m_clientSinkProxy; }; diff --git a/windows/windns/src/windns/netsh.cpp b/windows/windns/src/windns/netsh.cpp index af65ace2a9..7c031db87e 100644 --- a/windows/windns/src/windns/netsh.cpp +++ b/windows/windns/src/windns/netsh.cpp @@ -1,21 +1,92 @@ #include "stdafx.h" #include "netsh.h" #include "libcommon/applicationrunner.h" +#include "libcommon/string.h" +#include "libcommon/filesystem.h" #include <sstream> #include <stdexcept> +#include <experimental/filesystem> namespace { +std::wstring g_NetShPath; + +void InitializePath() +{ + if (false == g_NetShPath.empty()) + { + return; + } + + const auto system32 = common::fs::GetKnownFolderPath(FOLDERID_System, 0, nullptr); + + g_NetShPath = std::experimental::filesystem::path(system32).append(L"netsh.exe"); +} + +const std::wstring &NetShPath() +{ + InitializePath(); + + return g_NetShPath; +} + +std::vector<std::string> BlockToRows(const std::string &textBlock) +{ + // + // TODO: Formalize and move to libcommon. + // There is a recurring need to split a text block into lines, ignoring blank lines. + // + // Also, changing the encoding back and forth is terribly wasteful. + // Should look into replacing all of this with Boost some day. + // + + const auto wideTextBlock = common::string::ToWide(textBlock); + const auto wideRows = common::string::Tokenize(wideTextBlock, L"\r\n"); + + std::vector<std::string> result; + + result.reserve(wideRows.size()); + + std::transform(wideRows.begin(), wideRows.end(), std::back_inserter(result), [](const std::wstring &str) + { + return common::string::ToAnsi(str); + }); + + return result; +} + +__declspec(noreturn) void ThrowWithDetails(std::string &&error, common::ApplicationRunner &netsh) +{ + std::vector<std::string> details { "Failed to capture output from 'netsh'" }; + + std::string output; + + static const size_t MAX_CHARS = 2048; + static const size_t TIMEOUT_MILLISECONDS = 2000; + + if (netsh.read(output, MAX_CHARS, TIMEOUT_MILLISECONDS)) + { + auto outputRows = BlockToRows(output); + + if (false == outputRows.empty()) + { + details = std::move(outputRows); + } + } + + throw NetShError(std::move(error), std::move(details)); +} + void ValidateShellOut(common::ApplicationRunner &netsh) { - static const uint32_t TIMEOUT_TWO_SECONDS = 2000; + static const size_t TIMEOUT_MILLISECONDS = 2000; DWORD returnCode; - if (false == netsh.join(returnCode, TIMEOUT_TWO_SECONDS)) + if (false == netsh.join(returnCode, TIMEOUT_MILLISECONDS)) { - throw std::runtime_error("'netsh' did not complete in a timely manner"); + ThrowWithDetails("'netsh' did not complete in a timely manner", netsh); } if (returnCode != 0) @@ -24,7 +95,7 @@ void ValidateShellOut(common::ApplicationRunner &netsh) ss << "'netsh' failed the requested operation. Error: " << returnCode; - throw std::runtime_error(ss.str()); + ThrowWithDetails(ss.str(), netsh); } } @@ -47,7 +118,7 @@ void NetSh::SetIpv4PrimaryDns(uint32_t interfaceIndex, std::wstring server) << server << L" validate=no"; - auto netsh = common::ApplicationRunner::StartWithoutConsole(L"netsh.exe", ss.str()); + auto netsh = common::ApplicationRunner::StartWithoutConsole(NetShPath(), ss.str()); ValidateShellOut(*netsh); } @@ -69,7 +140,7 @@ void NetSh::SetIpv4SecondaryDns(uint32_t interfaceIndex, std::wstring server) << server << L" index=2 validate=no"; - auto netsh = common::ApplicationRunner::StartWithoutConsole(L"netsh.exe", ss.str()); + auto netsh = common::ApplicationRunner::StartWithoutConsole(NetShPath(), ss.str()); ValidateShellOut(*netsh); } @@ -89,7 +160,7 @@ void NetSh::SetIpv4Dhcp(uint32_t interfaceIndex) << interfaceIndex << L" source=dhcp"; - auto netsh = common::ApplicationRunner::StartWithoutConsole(L"netsh.exe", ss.str()); + auto netsh = common::ApplicationRunner::StartWithoutConsole(NetShPath(), ss.str()); ValidateShellOut(*netsh); } @@ -111,7 +182,7 @@ void NetSh::SetIpv6PrimaryDns(uint32_t interfaceIndex, std::wstring server) << server << L" validate=no"; - auto netsh = common::ApplicationRunner::StartWithoutConsole(L"netsh.exe", ss.str()); + auto netsh = common::ApplicationRunner::StartWithoutConsole(NetShPath(), ss.str()); ValidateShellOut(*netsh); } @@ -133,7 +204,7 @@ void NetSh::SetIpv6SecondaryDns(uint32_t interfaceIndex, std::wstring server) << server << L" index=2 validate=no"; - auto netsh = common::ApplicationRunner::StartWithoutConsole(L"netsh.exe", ss.str()); + auto netsh = common::ApplicationRunner::StartWithoutConsole(NetShPath(), ss.str()); ValidateShellOut(*netsh); } @@ -153,7 +224,7 @@ void NetSh::SetIpv6Dhcp(uint32_t interfaceIndex) << interfaceIndex << L" source=dhcp"; - auto netsh = common::ApplicationRunner::StartWithoutConsole(L"netsh.exe", ss.str()); + auto netsh = common::ApplicationRunner::StartWithoutConsole(NetShPath(), ss.str()); ValidateShellOut(*netsh); } diff --git a/windows/windns/src/windns/netsh.h b/windows/windns/src/windns/netsh.h index d299cf6f87..7aa5b800f0 100644 --- a/windows/windns/src/windns/netsh.h +++ b/windows/windns/src/windns/netsh.h @@ -2,6 +2,7 @@ #include <string> #include <cstdint> +#include <stdexcept> class NetSh { @@ -24,3 +25,25 @@ private: NetSh(); }; + +class NetShError : public std::exception +{ +public: + + NetShError(std::string &&error, std::vector<std::string> &&details) + : std::exception(error.c_str()) + , m_error(std::move(error)) + , m_details(std::move(details)) + { + } + + const std::vector<std::string> &details() + { + return m_details; + } + +private: + + const std::string m_error; + const std::vector<std::string> m_details; +}; diff --git a/windows/windns/src/windns/windns.cpp b/windows/windns/src/windns/windns.cpp index 48ad863dfa..662432495a 100644 --- a/windows/windns/src/windns/windns.cpp +++ b/windows/windns/src/windns/windns.cpp @@ -2,9 +2,10 @@ #include "windns.h" #include "windnscontext.h" #include "clientsinkinfo.h" -#include "libcommon/serialization/deserializer.h" #include "interfaceconfig.h" #include "netconfighelpers.h" +#include "confineoperation.h" +#include "libcommon/serialization/deserializer.h" #include <vector> #include <string> @@ -28,6 +29,14 @@ std::vector<std::wstring> MakeStringArray(const wchar_t **strings, uint32_t numS return v; } +void ForwardError(const char *errorMessage, const char **details, uint32_t numDetails) +{ + if (nullptr != g_ErrorSink) + { + g_ErrorSink(errorMessage, details, numDetails, g_ErrorContext); + } +} + } // anonymous namespace WINDNS_LINKAGE @@ -46,25 +55,10 @@ WinDns_Initialize( g_ErrorSink = errorSink; g_ErrorContext = errorContext; - try + return ConfineOperation("Initialize", ForwardError, []() { g_Context = new WinDnsContext; - } - catch (std::exception &err) - { - if (nullptr != g_ErrorSink) - { - g_ErrorSink(err.what(), g_ErrorContext); - } - - return false; - } - catch (...) - { - return false; - } - - return true; + }); } WINDNS_LINKAGE @@ -101,7 +95,7 @@ WinDns_Set( return false; } - try + return ConfineOperation("Enforce DNS settings", ForwardError, [&]() { ClientSinkInfo sinkInfo; @@ -109,22 +103,7 @@ WinDns_Set( sinkInfo.configSinkInfo = ConfigSinkInfo{ configSink, configContext }; g_Context->set(MakeStringArray(servers, numServers), sinkInfo); - } - catch (std::exception &err) - { - if (nullptr != g_ErrorSink) - { - g_ErrorSink(err.what(), g_ErrorContext); - } - - return false; - } - catch (...) - { - return false; - } - - return true; + }); } WINDNS_LINKAGE @@ -138,25 +117,10 @@ WinDns_Reset( return true; } - try + return ConfineOperation("Reset DNS settings", ForwardError, []() { g_Context->reset(); - } - catch (std::exception &err) - { - if (nullptr != g_ErrorSink) - { - g_ErrorSink(err.what(), g_ErrorContext); - } - - return false; - } - catch (...) - { - return false; - } - - return true; + }); } WINDNS_LINKAGE @@ -169,7 +133,7 @@ WinDns_Recover( { std::vector<InterfaceConfig> configs; - try + const auto status = ConfineOperation("Deserialize recovery data", ForwardError, [&]() { common::serialization::Deserializer d(reinterpret_cast<const uint8_t *>(configData), dataLength); @@ -177,7 +141,7 @@ WinDns_Recover( if (numConfigs > 50) { - return false; + throw std::runtime_error("Too many configuration entries"); } configs.reserve(numConfigs); @@ -186,48 +150,27 @@ WinDns_Recover( { configs.emplace_back(InterfaceConfig(d)); } - } - catch (std::exception &err) - { - if (nullptr != g_ErrorSink) - { - auto msg = std::string("Failed to deserialize recovery data: ").append(err.what()); - - g_ErrorSink(msg.c_str(), g_ErrorContext); - } + }); - return false; - } - catch (...) + if (false == status) { return false; } - if (configs.empty()) - { - return true; - } + // + // Try to restore each config and update 'success' if any update fails. + // bool success = true; for (const auto &config : configs) { - try + const auto adapterStatus = ConfineOperation("Restore adapter DNS settings", ForwardError, [&config]() { nchelpers::RevertDnsServers(config); - } - catch (std::exception &err) - { - if (nullptr != g_ErrorSink) - { - auto msg = std::string("Failed to restore interface settings: ").append(err.what()); + }); - g_ErrorSink(msg.c_str(), g_ErrorContext); - } - - success = false; - } - catch (...) + if (false == adapterStatus) { success = false; } diff --git a/windows/windns/src/windns/windns.h b/windows/windns/src/windns/windns.h index 734719ca80..ab4ec9a539 100644 --- a/windows/windns/src/windns/windns.h +++ b/windows/windns/src/windns/windns.h @@ -17,7 +17,7 @@ // Functions /////////////////////////////////////////////////////////////////////////////// -typedef void (WINDNS_API *WinDnsErrorSink)(const char *errorMessage, void *context); +typedef void (WINDNS_API *WinDnsErrorSink)(const char *errorMessage, const char **details, uint32_t numDetails, void *context); typedef void (WINDNS_API *WinDnsConfigSink)(const void *configData, uint32_t dataLength, void *context); // diff --git a/windows/windns/src/windns/windns.vcxproj b/windows/windns/src/windns/windns.vcxproj index a5fb31bb55..8b8bbb124a 100644 --- a/windows/windns/src/windns/windns.vcxproj +++ b/windows/windns/src/windns/windns.vcxproj @@ -177,6 +177,8 @@ <ItemGroup> <ClInclude Include="clientsinkinfo.h" /> <ClInclude Include="configmanager.h" /> + <ClInclude Include="confineoperation.h" /> + <ClInclude Include="iclientsinkproxy.h" /> <ClInclude Include="interfaceconfig.h" /> <ClInclude Include="netconfigeventsink.h" /> <ClInclude Include="netconfighelpers.h" /> @@ -189,6 +191,7 @@ </ItemGroup> <ItemGroup> <ClCompile Include="configmanager.cpp" /> + <ClCompile Include="confineoperation.cpp" /> <ClCompile Include="dllmain.cpp" /> <ClCompile Include="interfaceconfig.cpp" /> <ClCompile Include="netconfigeventsink.cpp" /> diff --git a/windows/windns/src/windns/windns.vcxproj.filters b/windows/windns/src/windns/windns.vcxproj.filters index 3c10de37d9..57e9eeaef0 100644 --- a/windows/windns/src/windns/windns.vcxproj.filters +++ b/windows/windns/src/windns/windns.vcxproj.filters @@ -12,6 +12,8 @@ <ClInclude Include="netsh.h" /> <ClInclude Include="interfaceconfig.h" /> <ClInclude Include="types.h" /> + <ClInclude Include="iclientsinkproxy.h" /> + <ClInclude Include="confineoperation.h" /> </ItemGroup> <ItemGroup> <ClCompile Include="dllmain.cpp" /> @@ -23,6 +25,7 @@ <ClCompile Include="netconfighelpers.cpp" /> <ClCompile Include="netsh.cpp" /> <ClCompile Include="interfaceconfig.cpp" /> + <ClCompile Include="confineoperation.cpp" /> </ItemGroup> <ItemGroup> <ResourceCompile Include="windns.rc" /> diff --git a/windows/windns/src/windns/windnscontext.cpp b/windows/windns/src/windns/windnscontext.cpp index 17265f0b65..29cb034566 100644 --- a/windows/windns/src/windns/windnscontext.cpp +++ b/windows/windns/src/windns/windnscontext.cpp @@ -3,6 +3,8 @@ #include "libcommon/wmi/connection.h" #include "netconfigeventsink.h" #include "netconfighelpers.h" +#include "confineoperation.h" +#include <functional> using namespace common; @@ -17,13 +19,6 @@ WinDnsContext::~WinDnsContext() { reset(); } - catch (std::exception &err) - { - if (nullptr != m_sinkInfo.errorSinkInfo.sink) - { - m_sinkInfo.errorSinkInfo.sink(err.what(), m_sinkInfo.errorSinkInfo.context); - } - } catch (...) { } @@ -35,13 +30,13 @@ void WinDnsContext::set(const std::vector<std::wstring> &servers, const ClientSi if (nullptr == m_notification) { - m_configManager = std::make_shared<ConfigManager>(servers, m_sinkInfo.configSinkInfo); + m_configManager = std::make_shared<ConfigManager>(servers, this); // // Register interface configuration monitoring. // - auto eventSink = std::make_shared<NetConfigEventSink>(m_connection, m_configManager); + auto eventSink = std::make_shared<NetConfigEventSink>(m_connection, m_configManager, this); auto eventDispatcher = CComPtr<wmi::IEventDispatcher>(new wmi::ModificationEventDispatcher(eventSink)); m_notification = std::make_unique<wmi::Notification>(m_connection, eventDispatcher); @@ -60,7 +55,6 @@ void WinDnsContext::set(const std::vector<std::wstring> &servers, const ClientSi ConfigManager::Mutex mutex(*m_configManager); m_configManager->updateServers(servers); - m_configManager->updateConfigSink(m_sinkInfo.configSinkInfo); } // @@ -86,13 +80,52 @@ void WinDnsContext::reset() m_notification = nullptr; // - // Revert configs - // Safe to do without a mutex guarding the config manager + // Reset adapter configs. // + // Safe to do without a mutex guarding the config manager. + // + // Try to reset as many adapters as possible, even if one or more fails to reset. + // + + bool success = true; - m_configManager->processConfigs([&](const InterfaceConfig &config) + auto forwardError = std::bind(&WinDnsContext::error, this, std::placeholders::_1, std::placeholders::_2, std::placeholders::_3); + + m_configManager->processConfigs([&success, &forwardError](const InterfaceConfig &config) { - nchelpers::RevertDnsServers(config); + const auto adapterStatus = ConfineOperation("Reset adapter DNS configuration", forwardError, [&config]() + { + nchelpers::RevertDnsServers(config); + }); + + if (false == adapterStatus) + { + success = false; + } + return true; }); + + if (false == success) + { + throw std::runtime_error("Resetting DNS failed for one or more adapters"); + } +} + +// IClientSinkProxy +void WinDnsContext::error(const char *errorMessage, const char **details, uint32_t numDetails) +{ + if (nullptr != m_sinkInfo.errorSinkInfo.sink) + { + m_sinkInfo.errorSinkInfo.sink(errorMessage, details, numDetails, m_sinkInfo.errorSinkInfo.context); + } +} + +// IClientSinkProxy +void WinDnsContext::config(const void *configData, uint32_t dataLength) +{ + if (nullptr != m_sinkInfo.configSinkInfo.sink) + { + m_sinkInfo.configSinkInfo.sink(configData, dataLength, m_sinkInfo.configSinkInfo.context); + } } diff --git a/windows/windns/src/windns/windnscontext.h b/windows/windns/src/windns/windnscontext.h index c514153edd..c83ec6e482 100644 --- a/windows/windns/src/windns/windnscontext.h +++ b/windows/windns/src/windns/windnscontext.h @@ -5,11 +5,12 @@ #include "libcommon/wmi/notification.h" #include "configmanager.h" #include "clientsinkinfo.h" +#include "iclientsinkproxy.h" #include <vector> #include <string> #include <memory> -class WinDnsContext +class WinDnsContext : public IClientSinkProxy { public: @@ -19,6 +20,9 @@ public: void set(const std::vector<std::wstring> &servers, const ClientSinkInfo &sinkInfo); void reset(); + void IClientSinkProxy::error(const char *errorMessage, const char **details, uint32_t numDetails) override; + void IClientSinkProxy::config(const void *configData, uint32_t dataLength) override; + private: std::shared_ptr<common::wmi::Connection> m_connection; |
