diff options
| author | Odd Stranne <odd@mullvad.net> | 2018-09-28 14:47:55 +0200 |
|---|---|---|
| committer | Odd Stranne <odd@mullvad.net> | 2018-10-03 23:13:55 +0200 |
| commit | 3291c2c839aea203686737c12cd0bd22a0dd20f6 (patch) | |
| tree | bd031f7c61d1ec92707d3621f57b32c066476c4d | |
| parent | 5ddf1c021ea4dc25ad00730f5f0553150c326899 (diff) | |
| download | mullvadvpn-3291c2c839aea203686737c12cd0bd22a0dd20f6.tar.xz mullvadvpn-3291c2c839aea203686737c12cd0bd22a0dd20f6.zip | |
Rewrite 'windns' to get rid of WMI + fix bugs and shortcomings
42 files changed, 1724 insertions, 913 deletions
diff --git a/windows/windns/extras.sln b/windows/windns/extras.sln index ed1e71d4c2..ca9347638e 100644 --- a/windows/windns/extras.sln +++ b/windows/windns/extras.sln @@ -5,6 +5,7 @@ MinimumVisualStudioVersion = 10.0.40219.1 Project("{8BC9CEB8-8B4A-11D0-8D11-00A0C91BC942}") = "loader", "src\extras\loader\loader.vcxproj", "{1476A8B9-4A9E-4358-8744-A350CB97E152}" ProjectSection(ProjectDependencies) = postProject {A5344205-FC37-4572-9C63-8564ECC410AC} = {A5344205-FC37-4572-9C63-8564ECC410AC} + {B52E2D10-A94A-4605-914A-2DCEF6A757EF} = {B52E2D10-A94A-4605-914A-2DCEF6A757EF} EndProjectSection EndProject Project("{8BC9CEB8-8B4A-11D0-8D11-00A0C91BC942}") = "windns", "src\windns\windns.vcxproj", "{A5344205-FC37-4572-9C63-8564ECC410AC}" diff --git a/windows/windns/src/extras/loader/loader.cpp b/windows/windns/src/extras/loader/loader.cpp index 46d8a7030c..5c6603c3dd 100644 --- a/windows/windns/src/extras/loader/loader.cpp +++ b/windows/windns/src/extras/loader/loader.cpp @@ -7,9 +7,19 @@ #include <vector> #include <windows.h> -void WINDNS_API ErrorSink(const char *errorMessage, const char **details, uint32_t numDetails, void *context) +void WINDNS_API LogSink(WinDnsLogCategory category, const char *message, const char **details, + uint32_t numDetails, void *context) { - std::cout << "WINDNS Error: " << errorMessage << std::endl; + if (WINDNS_LOG_CATEGORY_ERROR == category) + { + std::cout << "WINDNS Error: "; + } + else + { + std::cout << "WINDNS Info: "; + } + + std::cout << message << std::endl; for (uint32_t i = 0; i < numDetails; ++i) { @@ -17,9 +27,9 @@ void WINDNS_API ErrorSink(const char *errorMessage, const char **details, uint32 } } -void WINDNS_API ConfigSink(const void *configData, uint32_t dataLength, void *context) +void WINDNS_API RecoverySink(const void *recoveryData, uint32_t dataLength, void *context) { - std::wcout << L"Updated config was delivered to WINDNS client code" << std::endl; + std::wcout << L"Updated recovery data was delivered to WINDNS client code" << std::endl; auto f = CreateFileW(L"windns_recovery", GENERIC_WRITE, 0, nullptr, CREATE_ALWAYS, 0, nullptr); @@ -29,7 +39,7 @@ void WINDNS_API ConfigSink(const void *configData, uint32_t dataLength, void *co return; } - if (FALSE == WriteFile(f, configData, dataLength, nullptr, nullptr)) + if (FALSE == WriteFile(f, recoveryData, dataLength, nullptr, nullptr)) { std::wcout << L"Failed to update recovery file" << std::endl; } @@ -59,7 +69,8 @@ void Recover() return; } - std::wcout << L"WinDns_Recover: " << WinDns_Recover(&data[0], static_cast<uint32_t>(data.size())) << std::endl; + std::wcout << L"WinDns_Recover: " << std::boolalpha << + WinDns_Recover(&data[0], static_cast<uint32_t>(data.size())) << std::endl; } bool Ask(const std::wstring &question) @@ -88,7 +99,7 @@ int main() { common::trace::Trace::RegisterSink(new common::trace::ConsoleTraceSink); - std::wcout << L"WinDns_Initialize: " << WinDns_Initialize(ErrorSink, nullptr) << std::endl; + std::wcout << L"WinDns_Initialize: " << std::boolalpha << WinDns_Initialize(LogSink, nullptr) << std::endl; if (Ask(L"Perform recovery?")) { @@ -98,28 +109,39 @@ int main() const wchar_t *servers[] = { - L"8.8.8.8" + L"8.8.8.8", + L"8.8.4.4" }; - std::wcout << L"WinDns_Set: " << WinDns_Set(servers, _countof(servers), ConfigSink, nullptr) << std::endl; + const wchar_t *v6Servers[] = + { + L"2001:4860:4860::8888", + L"2001:4860:4860::8844" + }; + + auto status = WinDns_Set(servers, _countof(servers), v6Servers, _countof(v6Servers), RecoverySink, nullptr); + + std::wcout << L"WinDns_Set: " << std::boolalpha << status << std::endl; WaitInput(L"Press a key to abort DNS monitoring + enforcing..."); if (Ask(L"Perform WinDns_Reset() before next WinDns_Set()?")) { - std::wcout << L"WinDns_Reset: " << WinDns_Reset() << std::endl; + std::wcout << L"WinDns_Reset: " << std::boolalpha << WinDns_Reset() << std::endl; } - std::wcout << L"WinDns_Set: " << WinDns_Set(servers, _countof(servers), ConfigSink, nullptr) << std::endl; + status = WinDns_Set(servers, _countof(servers), v6Servers, _countof(v6Servers), RecoverySink, nullptr); + + std::wcout << L"WinDns_Set: " << std::boolalpha << status << std::endl; WaitInput(L"Press a key to abort DNS monitoring + enforcing..."); if (Ask(L"Perform WinDns_Reset() before WinDns_Deinitialize()?")) { - std::wcout << L"WinDns_Reset: " << WinDns_Reset() << std::endl; + std::wcout << L"WinDns_Reset: " << std::boolalpha << WinDns_Reset() << std::endl; } - std::wcout << L"WinDns_Deinitialize: " << WinDns_Deinitialize() << std::endl; + std::wcout << L"WinDns_Deinitialize: " << std::boolalpha << WinDns_Deinitialize() << std::endl; return 0; } diff --git a/windows/windns/src/windns/clientsinkinfo.h b/windows/windns/src/windns/clientsinkinfo.h index db42485d8e..39f602eeac 100644 --- a/windows/windns/src/windns/clientsinkinfo.h +++ b/windows/windns/src/windns/clientsinkinfo.h @@ -2,20 +2,14 @@ #include "windns.h" -struct ErrorSinkInfo +struct LogSinkInfo { - WinDnsErrorSink sink; + WinDnsLogSink sink; void *context; }; -struct ConfigSinkInfo +struct RecoverySinkInfo { - WinDnsConfigSink sink; + WinDnsRecoverySink sink; void *context; }; - -struct ClientSinkInfo -{ - ErrorSinkInfo errorSinkInfo; - ConfigSinkInfo configSinkInfo; -}; diff --git a/windows/windns/src/windns/configmanager.cpp b/windows/windns/src/windns/configmanager.cpp deleted file mode 100644 index 498e58a381..0000000000 --- a/windows/windns/src/windns/configmanager.cpp +++ /dev/null @@ -1,133 +0,0 @@ -#include "stdafx.h" -#include "configmanager.h" -#include "libcommon/serialization/serializer.h" -#include "libcommon/trace/xtrace.h" -#include <utility> -#include <algorithm> - -ConfigManager::ConfigManager -( - const std::vector<std::wstring> &servers, - IClientSinkProxy *clientSinkProxy -) - : m_servers(servers) - , m_clientSinkProxy(clientSinkProxy) -{ -} - -void ConfigManager::lock() -{ - m_mutex.lock(); -} - -void ConfigManager::unlock() -{ - m_mutex.unlock(); -} - -void ConfigManager::updateServers(const std::vector<std::wstring> &servers) -{ - XTRACE(L"Updating DNS server list"); - m_servers = servers; -} - -const std::vector<std::wstring> &ConfigManager::getServers() const -{ - return m_servers; -} - -ConfigManager::UpdateStatus ConfigManager::updateConfig(const InterfaceConfig &previous, const InterfaceConfig &target) -{ - XTRACE(L"Interface configuration update for interface=", target.interfaceIndex()); - - // - // There are a few cases we need to deal with: - // - // 1/ An interface being offline and coming online. - // 2/ An external application changing the interface settings. - // 3/ Us changing the interface settings. - // a. On an interface the ConfigManager hasn't seen before. - // b. On an interface the ConfigManager already knows about. - // - - const auto configIndex = target.configIndex(); - auto iter = m_configs.find(configIndex); - - if (verifyServers(target)) - { - XTRACE(L"Update event was initiated by WINDNS or did not include DNS changes"); - - // - // If we haven't seen this config id before, it means the 'previous' instance - // is the original configuration on the system, and as such must be recorded. - // - if (m_configs.end() == iter) - { - XTRACE(L"Creating new interface configuration entry"); - m_configs.insert(std::make_pair(configIndex, previous)); - - exportConfigs(); - } - - return UpdateStatus::DnsApproved; - } - - // - // The update was not initiated by us so store the updated configuration. - // - if (m_configs.end() == iter) - { - XTRACE(L"Creating new interface configuration entry"); - m_configs.insert(std::make_pair(configIndex, target)); - } - else - { - XTRACE(L"Updating interface configuration entry"); - iter->second.updateServers(target); - } - - exportConfigs(); - - return UpdateStatus::DnsDeviates; -} - -bool ConfigManager::processConfigs(std::function<bool(const InterfaceConfig &)> configSink) -{ - for (auto it = m_configs.begin(); it != m_configs.end(); ++it) - { - if (false == configSink(it->second)) - { - return false; - } - } - - return true; -} - -bool ConfigManager::verifyServers(const InterfaceConfig &config) -{ - auto updatedServers = config.servers(); - - if (nullptr == updatedServers) - { - return false; - } - - return std::equal(m_servers.begin(), m_servers.end(), updatedServers->begin(), updatedServers->end()); -} - -void ConfigManager::exportConfigs() -{ - common::serialization::Serializer s; - - s << static_cast<uint32_t>(m_configs.size()); - - for (auto it = m_configs.begin(); it != m_configs.end(); ++it) - { - it->second.serialize(s); - } - - auto data = s.blob(); - - m_clientSinkProxy->config(&data[0], static_cast<uint32_t>(data.size())); -} diff --git a/windows/windns/src/windns/configmanager.h b/windows/windns/src/windns/configmanager.h deleted file mode 100644 index c6f14f7ea9..0000000000 --- a/windows/windns/src/windns/configmanager.h +++ /dev/null @@ -1,110 +0,0 @@ -#pragma once - -#include "interfaceconfig.h" -#include "iclientsinkproxy.h" -#include <map> -#include <string> -#include <mutex> -#include <memory> -#include <functional> - -// -// The ConfigManager is engineered to track the "real" DNS configuration for an adapter. -// -// The situation is somewhat complicated, because a given system may have several adapters, which -// in turn may have several configurations? -// -// Every update for every configuration is recorded, bar the ones that correspond to us -// overriding the DNS settings. -// - -class ConfigManager -{ -public: - - struct Mutex - { - Mutex(const Mutex &) = delete; - Mutex &operator=(const Mutex &) = delete; - Mutex(Mutex &&) = delete; - Mutex &operator=(Mutex &&) = delete; - - Mutex(ConfigManager &manager) - : m_manager(manager) - { - m_manager.lock(); - } - - ~Mutex() - { - m_manager.unlock(); - } - - ConfigManager &m_manager; - }; - - // - // "servers" specifies the set of servers used when overriding settings. - // This enables filtering out the corresponding event. - // - ConfigManager - ( - const std::vector<std::wstring> &servers, - IClientSinkProxy *clientSinkProxy - ); - - // - // The ConfigManager is shared between threads. - // Locking is managed externally for reasons of efficiency. - // - void lock(); - void unlock(); - - // - // Establish the set of servers to use when overriding DNS settings. - // - void updateServers(const std::vector<std::wstring> &servers); - - // - // Get the current set of servers used for overriding DNS settings. - // - const std::vector<std::wstring> &getServers() const; - - // - // Notify the ConfigManager that a live configuration has been updated. - // - enum class UpdateStatus - { - DnsApproved, - DnsDeviates - }; - - UpdateStatus updateConfig(const InterfaceConfig &previous, const InterfaceConfig &target); - - // - // Enumerate recorded configs. - // - bool processConfigs(std::function<bool(const InterfaceConfig &)> configSink); - -private: - - std::mutex m_mutex; - - std::vector<std::wstring> m_servers; - IClientSinkProxy *m_clientSinkProxy; - - // - // Organize configs based on their system assigned index. - // - std::map<uint32_t, InterfaceConfig> m_configs; - - // - // Check DNS server list to see if it matches what we're trying to enforce. - // - bool verifyServers(const InterfaceConfig &config); - - // - // Bundle the current config details and send them into the config sink. - // - void exportConfigs(); -}; diff --git a/windows/windns/src/windns/dnsagent.cpp b/windows/windns/src/windns/dnsagent.cpp new file mode 100644 index 0000000000..6ff8ee9b39 --- /dev/null +++ b/windows/windns/src/windns/dnsagent.cpp @@ -0,0 +1,442 @@ +#include "stdafx.h" +#include "dnsagent.h" +#include "registrypaths.h" +#include "netsh.h" +#include <libcommon/trace/xtrace.h> +#include <libcommon/error.h> +#include <process.h> +#include <algorithm> + +DnsAgent::DnsAgent(Protocol protocol, INameServerSource *nameServerSource, IRecoverySink *recoverySink, ILogSink *logSink) + : m_protocol(protocol) + , m_nameServerSource(nameServerSource) + , m_recoverySink(recoverySink) + , m_logSink(logSink) + , m_thread(nullptr) + , m_shutdownEvent(nullptr) +{ + constructNameServerUpdateEvent(); + constructRootMonitor(); + + startTrackingInterfaces(discoverInterfaces()); + updateRecoveryData(); + + constructThread(); +} + +DnsAgent::~DnsAgent() +{ + SetEvent(m_shutdownEvent); + WaitForSingleObject(m_thread, INFINITE); + + CloseHandle(m_shutdownEvent); + CloseHandle(m_thread); + + m_nameServerSource->unsubscribe(m_serverSourceEvent); + CloseHandle(m_serverSourceEvent); +} + +void DnsAgent::constructNameServerUpdateEvent() +{ + m_serverSourceEvent = CreateEventW(nullptr, TRUE, FALSE, nullptr); + + THROW_GLE_IF(nullptr, m_serverSourceEvent, "Create name server subscription event"); + + m_nameServerSource->subscribe(m_serverSourceEvent); +} + +void DnsAgent::constructRootMonitor() +{ + m_rootMonitor = common::registry::Registry::MonitorKey(HKEY_LOCAL_MACHINE, + RegistryPaths::InterfaceRoot(m_protocol), { common::registry::RegistryEventFlag::SubkeyChange }); +} + +void DnsAgent::constructThread() +{ + m_shutdownEvent = CreateEventW(nullptr, TRUE, FALSE, nullptr); + + THROW_GLE_IF(nullptr, m_shutdownEvent, "Create shutdown event"); + + auto rawThreadHandle = _beginthreadex(nullptr, 0, &DnsAgent::ThreadEntry, this, 0, nullptr); + + if (0 == rawThreadHandle) + { + throw std::runtime_error("Could not create monitoring thread"); + } + + m_thread = reinterpret_cast<HANDLE>(rawThreadHandle); +} + +//static +unsigned __stdcall DnsAgent::ThreadEntry(void *parameters) +{ + try + { + reinterpret_cast<DnsAgent *>(parameters)->thread(); + } + catch (std::exception &err) + { + const char *what = err.what(); + + reinterpret_cast<DnsAgent *>(parameters)->m_logSink->error + ( + "Critical error in monitoring thread", &what, 1 + ); + } + catch (...) + { + reinterpret_cast<DnsAgent *>(parameters)->m_logSink->error + ( + "Unspecified critical error in monitoring thread" + ); + } + + return 0; +} + +void DnsAgent::thread() +{ + for (;;) + { + std::vector<HANDLE> waitHandles; + + // + // Reserve enough space in the array to hold: + // + // Shutdown event handle + // Name servers source update event + // Monitor handle for interfaces root key + // Monitor handles for all interfaces + // + waitHandles.reserve(3 + m_trackedInterfaces.size()); + + const size_t shutdownEventIndex = 0; + const size_t serverSourceEventIndex = 1; + const size_t rootKeyEventIndex = 2; + const size_t firstInterfaceIndex = 3; + + waitHandles.push_back(m_shutdownEvent); + waitHandles.push_back(m_serverSourceEvent); + waitHandles.push_back(m_rootMonitor->queueSingleEvent()); + + for (auto &interfaceData : m_trackedInterfaces) + { + waitHandles.push_back(interfaceData.monitor->queueSingleEvent()); + } + + // + // Wait for one or more events to become signalled. + // + + const auto status = WaitForMultipleObjects(static_cast<DWORD>(waitHandles.size()), &waitHandles[0], FALSE, INFINITE); + + if (WAIT_FAILED == status) + { + m_logSink->error("Failed to wait on events. Restarting wait in 1 minute."); + + if (WAIT_OBJECT_0 == WaitForSingleObject(m_shutdownEvent, 1000 * 60)) + { + break; + } + + continue; + } + + const size_t firstSignalledIndex = status - WAIT_OBJECT_0; + + if (firstSignalledIndex == shutdownEventIndex) + { + break; + } + + if (firstSignalledIndex >= firstInterfaceIndex) + { + XTRACE(L"Interface event is signalled"); + + const auto result = processInterfaceEvent(&waitHandles[firstInterfaceIndex], + firstSignalledIndex - firstInterfaceIndex); + + if (ProcessingResult::TrackingUpdated == result) + { + updateRecoveryData(); + } + + continue; + } + + // + // We can't easily tell which events have been signalled. + // + + const auto interfaceResult = processInterfaceEvent(&waitHandles[firstInterfaceIndex], 0); + auto rootResult = ProcessingResult::Nop; + + if (WAIT_OBJECT_0 == WaitForSingleObject(waitHandles[rootKeyEventIndex], 0)) + { + XTRACE(L"Interfaces root key event is signalled"); + + rootResult = processRootKeyEvent(); + } + + if (ProcessingResult::TrackingUpdated == interfaceResult + || ProcessingResult::TrackingUpdated == rootResult) + { + updateRecoveryData(); + } + + if (WAIT_OBJECT_0 == WaitForSingleObject(waitHandles[serverSourceEventIndex], 0)) + { + XTRACE(L"Server source update event is signalled"); + ResetEvent(m_serverSourceEvent); + + processServerSourceEvent(); + } + } + + XTRACE(L"Thread is exiting"); +} + +void DnsAgent::processServerSourceEvent() +{ + // + // Check actual interface settings to determine which interfaces + // need to have their settings overridden. + // + // Do NOT update 'preservedSettings' on the tracking entries because + // it would overwrite legitimate settings with the previously enforced settings. + // + + std::vector<std::wstring> interfaces; + interfaces.reserve(m_trackedInterfaces.size()); + + std::transform(m_trackedInterfaces.begin(), m_trackedInterfaces.end(), std::back_inserter(interfaces), [](const InterfaceData &interfaceData) + { + return interfaceData.interfaceGuid; + }); + + const auto updatedSnaps = createSnaps(interfaces); + const auto enforcedServers = m_nameServerSource->getNameServers(m_protocol); + + for (const auto snap : updatedSnaps) + { + if (snap.needsOverriding(enforcedServers)) + { + setNameServers(snap.interfaceGuid(), enforcedServers); + } + } +} + +DnsAgent::ProcessingResult DnsAgent::processRootKeyEvent() +{ + ProcessingResult result = ProcessingResult::Nop; + + std::vector<std::wstring> oldInterfaces; + oldInterfaces.reserve(m_trackedInterfaces.size()); + + std::transform(m_trackedInterfaces.begin(), m_trackedInterfaces.end(), std::back_inserter(oldInterfaces), [](const InterfaceData &interfaceData) + { + return interfaceData.interfaceGuid; + }); + + auto currentInterfaces = discoverInterfaces(); + + std::sort(oldInterfaces.begin(), oldInterfaces.end()); + std::sort(currentInterfaces.begin(), currentInterfaces.end()); + + // + // Stop tracking interfaces that have been removed. + // + + std::vector<std::wstring> removedInterfaces; + + std::set_difference(oldInterfaces.begin(), oldInterfaces.end(), currentInterfaces.begin(), currentInterfaces.end(), + std::back_inserter(removedInterfaces)); + + if (false == removedInterfaces.empty()) + { + result = ProcessingResult::TrackingUpdated; + stopTrackingInterfaces(removedInterfaces); + } + + // + // Start tracking new interfaces. + // + + std::vector<std::wstring> newInterfaces; + + std::set_difference(currentInterfaces.begin(), currentInterfaces.end(), oldInterfaces.begin(), oldInterfaces.end(), + std::back_inserter(newInterfaces)); + + if (false == newInterfaces.empty()) + { + result = ProcessingResult::TrackingUpdated; + startTrackingInterfaces(newInterfaces); + } + + return result; +} + +DnsAgent::ProcessingResult DnsAgent::processInterfaceEvent(const HANDLE *interfaceEvents, size_t startIndex) +{ + ProcessingResult result = ProcessingResult::Nop; + + // + // 'interfaceEvents' runs in parallel with 'm_trackedInterfaces'. + // + + const auto enforcedNameServers = m_nameServerSource->getNameServers(m_protocol); + + for (size_t i = startIndex; i < m_trackedInterfaces.size(); ++i) + { + if (WAIT_TIMEOUT == WaitForSingleObject(interfaceEvents[i], 0)) + { + continue; + } + + auto &interface = m_trackedInterfaces[i]; + + XTRACE(L"Processing event for interface ", interface.interfaceGuid); + + try + { + InterfaceSnap updatedSnap(m_protocol, interface.interfaceGuid); + + if (updatedSnap.needsOverriding(enforcedNameServers)) + { + result = ProcessingResult::TrackingUpdated; + + interface.preservedSettings = std::move(updatedSnap); + setNameServers(interface.interfaceGuid, enforcedNameServers); + } + } + catch (std::exception &err) + { + const char *what = err.what(); + + m_logSink->error("Could not fetch updated interface settings. Probably because the interface was removed.", &what, 1); + + continue; + } + catch (...) + { + m_logSink->error("Could not fetch updated interface settings. Probably because the interface was removed."); + + continue; + } + } + + return result; +} + +std::vector<std::wstring> DnsAgent::discoverInterfaces() +{ + auto regKey = common::registry::Registry::OpenKey(HKEY_LOCAL_MACHINE, RegistryPaths::InterfaceRoot(m_protocol)); + + std::vector<std::wstring> interfaces; + + interfaces.reserve(20); + + regKey->enumerateSubKeys([&interfaces](const std::wstring &keyName) + { + interfaces.push_back(keyName); + return true; + }); + + return interfaces; +} + +std::vector<InterfaceSnap> DnsAgent::createSnaps(const std::vector<std::wstring> &interfaces) +{ + std::vector<InterfaceSnap> snaps; + + snaps.reserve(interfaces.size()); + + for (const auto &interface : interfaces) + { + snaps.emplace_back(m_protocol, interface); + } + + return snaps; +} + +void DnsAgent::setNameServers(const std::wstring &interfaceGuid, const std::vector<std::wstring> &enforcedServers) +{ + XTRACE(L"Overriding name servers for interface ", interfaceGuid); + + if (Protocol::IPv4 == m_protocol) + { + NetSh::Instance().SetIpv4StaticDns(NetSh::ConvertInterfaceGuidToIndex(interfaceGuid), enforcedServers); + } + else + { + NetSh::Instance().SetIpv6StaticDns(NetSh::ConvertInterfaceGuidToIndex(interfaceGuid), enforcedServers); + } +} + +void DnsAgent::startTrackingInterfaces(const std::vector<std::wstring> &interfaces) +{ + const auto snaps = createSnaps(interfaces); + + // + // Override configured name servers on all interfaces, as necessary. + // + + const auto enforcedServers = m_nameServerSource->getNameServers(m_protocol); + + for (const auto &snap : snaps) + { + if (snap.needsOverriding(enforcedServers)) + { + setNameServers(snap.interfaceGuid(), enforcedServers); + } + } + + // + // Create a tracking record for each interface. + // + + for (const auto &snap : snaps) + { + const auto interfaceGuid = snap.interfaceGuid(); + + XTRACE(L"Creating tracking entry for interface ", interfaceGuid); + + m_trackedInterfaces.emplace_back(interfaceGuid, snap, std::make_unique<InterfaceMonitor>(m_protocol, interfaceGuid)); + } +} + +void DnsAgent::stopTrackingInterfaces(const std::vector<std::wstring> &interfaces) +{ + for (const auto &interfaceGuid : interfaces) + { + auto iter = std::find_if(m_trackedInterfaces.begin(), m_trackedInterfaces.end(), [&interfaceGuid](const InterfaceData &candidate) + { + return candidate.interfaceGuid == interfaceGuid; + }); + + if (m_trackedInterfaces.end() == iter) + { + m_logSink->error("Request to stop tracking non-tracked interface, ignoring."); + + continue; + } + + XTRACE(L"Cancel tracking of interface ", interfaceGuid); + + m_trackedInterfaces.erase(iter); + } +} + +void DnsAgent::updateRecoveryData() +{ + std::vector<InterfaceSnap> snaps; + + snaps.reserve(m_trackedInterfaces.size()); + + std::transform(m_trackedInterfaces.begin(), m_trackedInterfaces.end(), std::back_inserter(snaps), [](const InterfaceData &interfaceData) + { + return interfaceData.preservedSettings; + }); + + m_recoverySink->preserveSnaps(m_protocol, snaps); +} diff --git a/windows/windns/src/windns/dnsagent.h b/windows/windns/src/windns/dnsagent.h new file mode 100644 index 0000000000..48dd12c52d --- /dev/null +++ b/windows/windns/src/windns/dnsagent.h @@ -0,0 +1,84 @@ +#pragma once + +#include "interfacesnap.h" +#include "interfacemonitor.h" +#include "types.h" +#include "inameserversource.h" +#include "irecoverysink.h" +#include "ilogsink.h" +#include <libcommon/registry/registry.h> +#include <string> +#include <vector> +#include <memory> +#include <windows.h> + +// +// DnsAgent: +// Monitor interfaces and enforce name server settings. +// +class DnsAgent +{ +public: + + DnsAgent(Protocol protocol, INameServerSource *nameServerSource, IRecoverySink *recoverySink, ILogSink *logSink); + ~DnsAgent(); + +private: + + Protocol m_protocol; + INameServerSource *m_nameServerSource; + IRecoverySink *m_recoverySink; + ILogSink *m_logSink; + + // + // InterfaceData: + // Tracking entry for a network interface. + // + struct InterfaceData + { + InterfaceData(const std::wstring &interfaceGuid_, const InterfaceSnap &snap_, std::unique_ptr<InterfaceMonitor> &&monitor_) + : interfaceGuid(interfaceGuid_), preservedSettings(snap_), monitor(std::move(monitor_)) + { + } + + std::wstring interfaceGuid; + InterfaceSnap preservedSettings; + std::unique_ptr<InterfaceMonitor> monitor; + }; + + std::vector<InterfaceData> m_trackedInterfaces; + + std::unique_ptr<common::registry::RegistryMonitor> m_rootMonitor; + + HANDLE m_serverSourceEvent; + HANDLE m_thread; + HANDLE m_shutdownEvent; + + void constructNameServerUpdateEvent(); + void constructRootMonitor(); + void constructThread(); + + static unsigned __stdcall ThreadEntry(void *); + void thread(); + + void processServerSourceEvent(); + + enum class ProcessingResult + { + TrackingUpdated, + Nop + }; + + ProcessingResult processRootKeyEvent(); + ProcessingResult processInterfaceEvent(const HANDLE *interfaceEvents, size_t startIndex); + + std::vector<std::wstring> discoverInterfaces(); + std::vector<InterfaceSnap> createSnaps(const std::vector<std::wstring> &interfaces); + + void setNameServers(const std::wstring &interfaceGuid, const std::vector<std::wstring> &enforcedServers); + + void startTrackingInterfaces(const std::vector<std::wstring> &interfaces); + void stopTrackingInterfaces(const std::vector<std::wstring> &interfaces); + + void updateRecoveryData(); +}; diff --git a/windows/windns/src/windns/iclientsinkproxy.h b/windows/windns/src/windns/iclientsinkproxy.h deleted file mode 100644 index 9270a12c50..0000000000 --- a/windows/windns/src/windns/iclientsinkproxy.h +++ /dev/null @@ -1,13 +0,0 @@ -#pragma once - -#include <cstdint> - -struct IClientSinkProxy -{ - virtual ~IClientSinkProxy() = 0 - { - } - - virtual void error(const char *errorMessage, const char **details, uint32_t numDetails) = 0; - virtual void config(const void *configData, uint32_t dataLength) = 0; -}; diff --git a/windows/windns/src/windns/ilogsink.h b/windows/windns/src/windns/ilogsink.h new file mode 100644 index 0000000000..dec1584ca3 --- /dev/null +++ b/windows/windns/src/windns/ilogsink.h @@ -0,0 +1,24 @@ +#pragma once + +#include <cstdint> + +struct ILogSink +{ + virtual ~ILogSink() = 0 + { + } + + virtual void error(const char *msg, const char **details, uint32_t numDetails) = 0; + + virtual void error(const char *msg) + { + error(msg, nullptr, 0); + } + + virtual void info(const char *msg, const char **details, uint32_t numDetails) = 0; + + virtual void info(const char *msg) + { + info(msg, nullptr, 0); + } +}; diff --git a/windows/windns/src/windns/inameserversource.h b/windows/windns/src/windns/inameserversource.h new file mode 100644 index 0000000000..f97209ce7d --- /dev/null +++ b/windows/windns/src/windns/inameserversource.h @@ -0,0 +1,23 @@ +#pragma once + +#include "types.h" +#include <vector> +#include <string> + +// +// Provide the array of name servers that we enforce on all adapters. +// +struct INameServerSource +{ + virtual ~INameServerSource() = 0 + { + } + + virtual std::vector<std::wstring> getNameServers(Protocol protocol) const = 0; + + // + // Get notified if the servers array is updated. + // + virtual void subscribe(HANDLE eventHandle) = 0; + virtual void unsubscribe(HANDLE eventHandle) = 0; +}; diff --git a/windows/windns/src/windns/interfaceconfig.cpp b/windows/windns/src/windns/interfaceconfig.cpp deleted file mode 100644 index 74e912de89..0000000000 --- a/windows/windns/src/windns/interfaceconfig.cpp +++ /dev/null @@ -1,66 +0,0 @@ -#include "stdafx.h" -#include "interfaceconfig.h" -#include "netconfighelpers.h" -#include "libcommon/com.h" -#include "libcommon/wmi/wmi.h" - -using namespace common; - -InterfaceConfig::InterfaceConfig(CComPtr<IWbemClassObject> instance) -{ - // - // V_xxx macros seem to require an l-value so access the correct field directly instead. - // - - m_configIndex = wmi::WmiGetPropertyAlways(instance, L"Index").ulVal; - - m_dhcp = wmi::WmiGetPropertyAlways(instance, L"DHCPEnabled").boolVal; - - m_interfaceIndex = wmi::WmiGetPropertyAlways(instance, L"InterfaceIndex").ulVal; - m_interfaceGuid = ComConvertString(wmi::WmiGetPropertyAlways(instance, L"SettingID").bstrVal); - - m_servers = nchelpers::GetDnsServers(instance); -} - -InterfaceConfig::InterfaceConfig(common::serialization::Deserializer &deserializer) -{ - common::serialization::Deserializer &d = deserializer; - - d >> m_configIndex; - d >> (uint8_t &)m_dhcp; - d >> m_interfaceIndex; - d >> m_interfaceGuid; - - bool serversAvailable; - - d >> (uint8_t &)serversAvailable; - - if (serversAvailable) - { - m_servers = std::make_shared<std::vector<std::wstring> >(); - d >> *m_servers; - } -} - -void InterfaceConfig::serialize(common::serialization::Serializer &serializer) const -{ - common::serialization::Serializer &s = serializer; - - s << m_configIndex; - s << (uint8_t)m_dhcp; - s << m_interfaceIndex; - s << m_interfaceGuid; - - // - // TODO: Encapsulate this inside a new type. - // - if (nullptr == m_servers.get()) - { - s << (uint8_t)0; - } - else - { - s << (uint8_t)1; - s << *m_servers; - } -} diff --git a/windows/windns/src/windns/interfaceconfig.h b/windows/windns/src/windns/interfaceconfig.h deleted file mode 100644 index e95296b7fb..0000000000 --- a/windows/windns/src/windns/interfaceconfig.h +++ /dev/null @@ -1,62 +0,0 @@ -#pragma once - -#include "types.h" -#include "libcommon/serialization/deserializer.h" -#include "libcommon/serialization/serializer.h" -#include <cstdint> -#include <string> -#include <vector> -#include <atlbase.h> -#include <wbemidl.h> - -class InterfaceConfig -{ -public: - - // instance = Win32_NetworkAdapterConfiguration. - explicit InterfaceConfig(CComPtr<IWbemClassObject> instance); - - explicit InterfaceConfig(common::serialization::Deserializer &deserializer); - void serialize(common::serialization::Serializer &serializer) const; - - void updateServers(const InterfaceConfig &rhs) - { - m_servers = rhs.m_servers; - } - - uint32_t configIndex() const - { - return m_configIndex; - } - - bool dhcp() const - { - return m_dhcp; - } - - uint32_t interfaceIndex() const - { - return m_interfaceIndex; - } - - const std::wstring &interfaceGuid() const - { - return m_interfaceGuid; - } - - const std::vector<std::wstring> *servers() const - { - return m_servers.get(); - } - -private: - - uint32_t m_configIndex; - - bool m_dhcp; - - uint32_t m_interfaceIndex; - std::wstring m_interfaceGuid; - - OptionalStringList m_servers; -}; diff --git a/windows/windns/src/windns/interfacemonitor.cpp b/windows/windns/src/windns/interfacemonitor.cpp new file mode 100644 index 0000000000..322bb69e57 --- /dev/null +++ b/windows/windns/src/windns/interfacemonitor.cpp @@ -0,0 +1,24 @@ +#include "stdafx.h" +#include "interfacemonitor.h" +#include "registrypaths.h" + +using namespace common::registry; + +InterfaceMonitor::InterfaceMonitor(Protocol protocol, const std::wstring &interfaceGuid) + : m_protocol(protocol) + , m_interfaceGuid(interfaceGuid) +{ + const auto interfacePath = RegistryPaths::InterfaceKey(interfaceGuid, protocol); + + m_monitor = Registry::MonitorKey(HKEY_LOCAL_MACHINE, interfacePath, { RegistryEventFlag::ValueChange }); +} + +HANDLE InterfaceMonitor::queueSingleEvent() +{ + return m_monitor->queueSingleEvent(); +} + +const std::wstring &InterfaceMonitor::interfaceGuid() const +{ + return m_interfaceGuid; +} diff --git a/windows/windns/src/windns/interfacemonitor.h b/windows/windns/src/windns/interfacemonitor.h new file mode 100644 index 0000000000..f40c88dab6 --- /dev/null +++ b/windows/windns/src/windns/interfacemonitor.h @@ -0,0 +1,30 @@ +#pragma once + +#include "types.h" +#include <libcommon/registry/registry.h> +#include <string> +#include <memory> +#include <windows.h> + +class InterfaceMonitor +{ +public: + + explicit InterfaceMonitor(Protocol protocol, const std::wstring &interfaceGuid); + + // + // The event becomes signalled if: + // 1. A value change occurs. + // 2. The monitored interface's registry key is deleted. + // + HANDLE queueSingleEvent(); + + const std::wstring &interfaceGuid() const; + +private: + + Protocol m_protocol; + std::wstring m_interfaceGuid; + + std::unique_ptr<common::registry::RegistryMonitor> m_monitor; +}; diff --git a/windows/windns/src/windns/interfacesnap.cpp b/windows/windns/src/windns/interfacesnap.cpp new file mode 100644 index 0000000000..309296897f --- /dev/null +++ b/windows/windns/src/windns/interfacesnap.cpp @@ -0,0 +1,151 @@ +#include "stdafx.h" +#include "interfacesnap.h" +#include "registrypaths.h" +#include <libcommon/registry/registry.h> +#include <libcommon/string.h> +#include <cstdint> + +using namespace common::registry; + +namespace +{ + +enum class NameServerType +{ + Static, + Dhcp +}; + +std::vector<std::wstring> GetNameServers(const std::wstring &interfaceGuid, + Protocol protocol, NameServerType nameServerType) +{ + const auto interfacePath = RegistryPaths::InterfaceKey(interfaceGuid, protocol); + + const auto regKey = Registry::OpenKey(HKEY_LOCAL_MACHINE, interfacePath); + + std::wstring nameservers; + + try + { + // + // This particular value is a string array packed into a string data type. + // REG_MULTI_SZ would have been the correct type to use, but there + // are probably historical reasons for the value type currently being used. + // + nameservers = regKey->readString(NameServerType::Static == nameServerType ? L"NameServer" : L"DhcpNameServer"); + } + catch (...) + { + } + + if (nameservers.empty()) + { + return std::vector<std::wstring>(); + } + + return common::string::Tokenize(nameservers, L","); +} + +bool GetDhcpEnabled(const std::wstring &interfaceGuid, Protocol protocol) +{ + const auto interfacePath = RegistryPaths::InterfaceKey(interfaceGuid, protocol); + + const auto regKey = Registry::OpenKey(HKEY_LOCAL_MACHINE, interfacePath); + + bool enabled = false; + + try + { + const auto flag = regKey->readUint32(L"EnableDHCP"); + + enabled = (1 == flag); + } + catch (...) + { + } + + return enabled; +} + +} // anonymous namespace + +InterfaceSnap::InterfaceSnap(Protocol protocol, const std::wstring &interfaceGuid) + : m_protocol(protocol) + , m_interfaceGuid(interfaceGuid) +{ + m_configuredForDhcp = GetDhcpEnabled(m_interfaceGuid, m_protocol); + + // Static name servers are configured by the user. + m_staticNameServers = GetNameServers(m_interfaceGuid, m_protocol, NameServerType::Static); + + // DHCP name servers are the servers most recently supplied by DHCP. + // An adapter can be configured for DHCP and static name servers at the same time. + // Static name servers always have precedence. + m_dhcpNameServers = GetNameServers(m_interfaceGuid, m_protocol, NameServerType::Dhcp); +} + +InterfaceSnap::InterfaceSnap(common::serialization::Deserializer &deserializer) +{ + common::serialization::Deserializer &d = deserializer; + + d >> (uint8_t &)m_protocol; + + if (m_protocol != Protocol::IPv4 + && m_protocol != Protocol::IPv6) + { + throw std::runtime_error("Serialized data for 'InterfaceSnap' instance is invalid (protocol)"); + } + + d >> m_interfaceGuid; + d >> (uint8_t &)m_configuredForDhcp; + d >> m_staticNameServers; + d >> m_dhcpNameServers; +} + +void InterfaceSnap::serialize(common::serialization::Serializer &serializer) const +{ + common::serialization::Serializer &s = serializer; + + s << (uint8_t &)m_protocol; + s << m_interfaceGuid; + s << (uint8_t &)m_configuredForDhcp; + s << m_staticNameServers; + s << m_dhcpNameServers; +} + +bool InterfaceSnap::needsOverriding(const std::vector<std::wstring> &enforcedServers) const +{ + if (internalInterface()) + { + return false; + } + + // + // The interface has static DNS, or + // The interface has DNS provided by the DHCP server, or + // The interface *will get* DNS provided to it by the DHCP server + // + + // + // It's not enough that m_staticNameServers has the same elements. + // The order defines primary and secondary name server and has to match. + // + + return m_staticNameServers != enforcedServers; +} + +const std::wstring &InterfaceSnap::interfaceGuid() const +{ + return m_interfaceGuid; +} + +const std::vector<std::wstring> &InterfaceSnap::nameServers() const +{ + return m_staticNameServers; +} + +bool InterfaceSnap::internalInterface() const +{ + return false == m_configuredForDhcp + && m_staticNameServers.empty(); +} diff --git a/windows/windns/src/windns/interfacesnap.h b/windows/windns/src/windns/interfacesnap.h new file mode 100644 index 0000000000..4d1156b615 --- /dev/null +++ b/windows/windns/src/windns/interfacesnap.h @@ -0,0 +1,35 @@ +#pragma once + +#include "types.h" +#include <libcommon/serialization/serializer.h> +#include <libcommon/serialization/deserializer.h> +#include <string> +#include <vector> + +class InterfaceSnap +{ +public: + + explicit InterfaceSnap(Protocol protocol, const std::wstring &interfaceGuid); + + explicit InterfaceSnap(common::serialization::Deserializer &deserializer); + void serialize(common::serialization::Serializer &serializer) const; + + bool needsOverriding(const std::vector<std::wstring> &enforcedServers) const; + + const std::wstring &interfaceGuid() const; + + const std::vector<std::wstring> &nameServers() const; + + bool internalInterface() const; + +private: + + Protocol m_protocol; + std::wstring m_interfaceGuid; + + bool m_configuredForDhcp; + + std::vector<std::wstring> m_staticNameServers; + std::vector<std::wstring> m_dhcpNameServers; +}; diff --git a/windows/windns/src/windns/irecoverysink.h b/windows/windns/src/windns/irecoverysink.h new file mode 100644 index 0000000000..8251625053 --- /dev/null +++ b/windows/windns/src/windns/irecoverysink.h @@ -0,0 +1,14 @@ +#pragma once + +#include "types.h" +#include "interfacesnap.h" +#include <vector> + +struct IRecoverySink +{ + virtual ~IRecoverySink() = 0 + { + } + + virtual void preserveSnaps(Protocol protocol, const std::vector<InterfaceSnap> &snaps) = 0; +}; diff --git a/windows/windns/src/windns/logsink.cpp b/windows/windns/src/windns/logsink.cpp new file mode 100644 index 0000000000..b5f127d25c --- /dev/null +++ b/windows/windns/src/windns/logsink.cpp @@ -0,0 +1,34 @@ +#include "stdafx.h" +#include "logsink.h" + +LogSink::LogSink(const LogSinkInfo &target) + : m_target(target) +{ +} + +void LogSink::setTarget(const LogSinkInfo &target) +{ + std::scoped_lock<std::mutex> lock(m_targetMutex); + + m_target = target; +} + +void LogSink::error(const char *msg, const char **details, uint32_t numDetails) +{ + std::scoped_lock<std::mutex> lock(m_targetMutex); + + if (nullptr != m_target.sink) + { + m_target.sink(WINDNS_LOG_CATEGORY_ERROR, msg, details, numDetails, m_target.context); + } +} + +void LogSink::info(const char *msg, const char **details, uint32_t numDetails) +{ + std::scoped_lock<std::mutex> lock(m_targetMutex); + + if (nullptr != m_target.sink) + { + m_target.sink(WINDNS_LOG_CATEGORY_INFO, msg, details, numDetails, m_target.context); + } +} diff --git a/windows/windns/src/windns/logsink.h b/windows/windns/src/windns/logsink.h new file mode 100644 index 0000000000..396320a0a8 --- /dev/null +++ b/windows/windns/src/windns/logsink.h @@ -0,0 +1,22 @@ +#pragma once + +#include "clientsinkinfo.h" +#include "ilogsink.h" +#include <mutex> + +class LogSink : public ILogSink +{ +public: + + LogSink(const LogSinkInfo &target); + + void setTarget(const LogSinkInfo &target); + + void error(const char *msg, const char **details, uint32_t numDetails) override; + void info(const char *msg, const char **details, uint32_t numDetails) override; + +private: + + std::mutex m_targetMutex; + LogSinkInfo m_target; +}; diff --git a/windows/windns/src/windns/nameserversource.cpp b/windows/windns/src/windns/nameserversource.cpp new file mode 100644 index 0000000000..b574266efb --- /dev/null +++ b/windows/windns/src/windns/nameserversource.cpp @@ -0,0 +1,77 @@ +#include "stdafx.h" +#include "nameserversource.h" + +NameServerSource::NameServerSource(const std::vector<std::wstring> &ipv4NameServers, + const std::vector<std::wstring> &ipv6NameServers) + : m_ipv4NameServers(ipv4NameServers) + , m_ipv6NameServers(ipv6NameServers) +{ +} + +void NameServerSource::setNameServers(Protocol protocol, const std::vector<std::wstring> &nameServers) +{ + { + std::scoped_lock<std::mutex> lock(m_nameServersMutex); + + if (Protocol::IPv4 == protocol) + { + m_ipv4NameServers = nameServers; + } + else + { + m_ipv6NameServers = nameServers; + } + } + + // + // Notify all subscribers. + // + + std::scoped_lock<std::mutex> lock(m_subscriptionMutex); + + for (HANDLE eventHandle : m_subscriptions) + { + SetEvent(eventHandle); + } +} + +std::vector<std::wstring> NameServerSource::getNameServers(Protocol protocol) const +{ + std::vector<std::wstring> copy; + + std::scoped_lock<std::mutex> lock(m_nameServersMutex); + + if (Protocol::IPv4 == protocol) + { + copy = m_ipv4NameServers; + } + else + { + copy = m_ipv6NameServers; + } + + return copy; +} + +void NameServerSource::subscribe(HANDLE eventHandle) +{ + ResetEvent(eventHandle); + + std::scoped_lock<std::mutex> lock(m_subscriptionMutex); + + m_subscriptions.push_back(eventHandle); +} + +void NameServerSource::unsubscribe(HANDLE eventHandle) +{ + std::scoped_lock<std::mutex> lock(m_subscriptionMutex); + + auto it = std::find(m_subscriptions.begin(), m_subscriptions.end(), eventHandle); + + if (m_subscriptions.end() == it) + { + return; + } + + m_subscriptions.erase(it); +} diff --git a/windows/windns/src/windns/nameserversource.h b/windows/windns/src/windns/nameserversource.h new file mode 100644 index 0000000000..7d46fcf264 --- /dev/null +++ b/windows/windns/src/windns/nameserversource.h @@ -0,0 +1,29 @@ +#pragma once + +#include "inameserversource.h" +#include <mutex> + +class NameServerSource : public INameServerSource +{ +public: + + NameServerSource(const std::vector<std::wstring> &ipv4NameServers, + const std::vector<std::wstring> &ipv6NameServers); + + void setNameServers(Protocol protocol, const std::vector<std::wstring> &nameServers); + + std::vector<std::wstring> getNameServers(Protocol protocol) const override; + + void subscribe(HANDLE eventHandle) override; + void unsubscribe(HANDLE eventHandle) override; + +private: + + mutable std::mutex m_nameServersMutex; + + std::vector<std::wstring> m_ipv4NameServers; + std::vector<std::wstring> m_ipv6NameServers; + + std::mutex m_subscriptionMutex; + std::list<HANDLE> m_subscriptions; +}; diff --git a/windows/windns/src/windns/netconfigeventsink.cpp b/windows/windns/src/windns/netconfigeventsink.cpp deleted file mode 100644 index 5b8d4efd07..0000000000 --- a/windows/windns/src/windns/netconfigeventsink.cpp +++ /dev/null @@ -1,51 +0,0 @@ -#include "stdafx.h" -#include "netconfigeventsink.h" -#include "netconfighelpers.h" -#include "netsh.h" -#include "confineoperation.h" -#include <functional> - -using namespace common; - -NetConfigEventSink::NetConfigEventSink -( - std::shared_ptr<wmi::IConnection> connection, - std::shared_ptr<ConfigManager> configManager, - IClientSinkProxy *clientSinkProxy -) - : m_connection(connection) - , m_configManager(configManager) - , m_clientSinkProxy(clientSinkProxy) -{ -} - -void NetConfigEventSink::update(CComPtr<IWbemClassObject> previous, CComPtr<IWbemClassObject> target) -{ - auto forwardError = [this](const char *errorMessage, const char **details, uint32_t numDetails) - { - m_clientSinkProxy->error(errorMessage, details, numDetails); - }; - - ConfineOperation("Process adapter update event", forwardError, [&]() - { - InterfaceConfig previousConfig(previous); - InterfaceConfig targetConfig(target); - - ConfigManager::Mutex mutex(*m_configManager); - - // - // This is OK because the config manager will reject updates - // that set our DNS servers. - // - if (ConfigManager::UpdateStatus::DnsApproved == m_configManager->updateConfig(previousConfig, targetConfig)) - { - return; - } - - // - // The update was initiated from an external source. - // Override current settings to enforce our selected DNS servers. - // - nchelpers::SetDnsServers(targetConfig.interfaceIndex(), m_configManager->getServers()); - }); -} diff --git a/windows/windns/src/windns/netconfigeventsink.h b/windows/windns/src/windns/netconfigeventsink.h deleted file mode 100644 index c3fd9242f7..0000000000 --- a/windows/windns/src/windns/netconfigeventsink.h +++ /dev/null @@ -1,24 +0,0 @@ -#pragma once - -#include "libcommon/wmi/ieventsink.h" -#include "libcommon/wmi/iconnection.h" -#include "configmanager.h" -#include "iclientsinkproxy.h" -#include <memory> - -class NetConfigEventSink : public common::wmi::IModificationEventSink -{ -public: - - NetConfigEventSink(std::shared_ptr<common::wmi::IConnection> connection, - std::shared_ptr<ConfigManager> configManager, IClientSinkProxy *clientSinkProxy); - - void update(CComPtr<IWbemClassObject> previous, CComPtr<IWbemClassObject> target) override; - -private: - - std::shared_ptr<common::wmi::IConnection> m_connection; - std::shared_ptr<ConfigManager> m_configManager; - - IClientSinkProxy *m_clientSinkProxy; -}; diff --git a/windows/windns/src/windns/netconfighelpers.cpp b/windows/windns/src/windns/netconfighelpers.cpp deleted file mode 100644 index 7156038b4e..0000000000 --- a/windows/windns/src/windns/netconfighelpers.cpp +++ /dev/null @@ -1,65 +0,0 @@ -#include "stdafx.h" -#include "netconfighelpers.h" -#include "libcommon/com.h" -#include "libcommon/wmi/wmi.h" -#include "libcommon/trace/xtrace.h" -#include "netsh.h" - -using namespace common; - -namespace nchelpers -{ - -OptionalStringList GetDnsServers(CComPtr<IWbemClassObject> instance) -{ - OptionalStringList result; - - auto servers = wmi::WmiGetProperty(instance, L"DNSServerSearchOrder"); - - if (VT_EMPTY == V_VT(&servers) || VT_NULL == V_VT(&servers)) - { - return result; - } - - result = std::make_shared<std::vector<std::wstring> >( - ComConvertStringArray(V_ARRAY(&servers))); - - return result; -} - -uint32_t GetInterfaceIndex(CComPtr<IWbemClassObject> instance) -{ - return wmi::WmiGetPropertyAlways(instance, L"InterfaceIndex").ulVal; -} - -void SetDnsServers(uint32_t interfaceIndex, const std::vector<std::wstring> &servers) -{ - NetSh::SetIpv4PrimaryDns(interfaceIndex, servers[0]); - - if (servers.size() > 1) - { - NetSh::SetIpv4SecondaryDns(interfaceIndex, servers[1]); - } -} - -void RevertDnsServers(const InterfaceConfig &config, uint32_t timeout) -{ - XTRACE("Reverting DNS configuration for interface with index=", config.interfaceIndex()); - - auto servers = config.servers(); - - if (config.dhcp() || nullptr == servers || 0 == servers->size()) - { - NetSh::SetIpv4Dhcp(config.interfaceIndex(), timeout); - return; - } - - NetSh::SetIpv4PrimaryDns(config.interfaceIndex(), (*servers)[0], timeout); - - if (servers->size() > 1) - { - NetSh::SetIpv4SecondaryDns(config.interfaceIndex(), (*servers)[1], timeout); - } -} - -} diff --git a/windows/windns/src/windns/netconfighelpers.h b/windows/windns/src/windns/netconfighelpers.h deleted file mode 100644 index 33b6c64cc2..0000000000 --- a/windows/windns/src/windns/netconfighelpers.h +++ /dev/null @@ -1,24 +0,0 @@ -#pragma once - -#include "types.h" -#include "interfaceconfig.h" -#include <string> -#include <vector> -#include <cstdint> -#include <atlbase.h> -#include <wbemidl.h> - -namespace nchelpers -{ - -// instance = Win32_NetworkAdapterConfiguration -OptionalStringList GetDnsServers(CComPtr<IWbemClassObject> instance); - -// instance = Win32_NetworkAdapterConfiguration -uint32_t GetInterfaceIndex(CComPtr<IWbemClassObject> instance); - -void SetDnsServers(uint32_t interfaceIndex, const std::vector<std::wstring> &servers); - -void RevertDnsServers(const InterfaceConfig &config, uint32_t timeout = 0); - -} diff --git a/windows/windns/src/windns/netsh.cpp b/windows/windns/src/windns/netsh.cpp index fdd8fbfa03..b8c2afff68 100644 --- a/windows/windns/src/windns/netsh.cpp +++ b/windows/windns/src/windns/netsh.cpp @@ -1,44 +1,17 @@ #include "stdafx.h" #include "netsh.h" -#include "libcommon/applicationrunner.h" -#include "libcommon/string.h" -#include "libcommon/filesystem.h" +#include <libcommon/string.h> +#include <libcommon/filesystem.h> +#include <libcommon/guid.h> #include <sstream> #include <stdexcept> #include <experimental/filesystem> +#include <iphlpapi.h> namespace { -ErrorSinkInfo g_ErrorSink = { nullptr, nullptr }; - -std::wstring g_NetShPath; - -void InitializePath() -{ - if (false == g_NetShPath.empty()) - { - return; - } - - const auto system32 = common::fs::GetKnownFolderPath(FOLDERID_System, 0, nullptr); - - g_NetShPath = std::experimental::filesystem::path(system32).append(L"netsh.exe"); -} - -const std::wstring &NetShPath() -{ - InitializePath(); - - return g_NetShPath; -} - -void InfoSink(const char *msg) -{ - auto infoMsg = std::string("INFO: ").append(msg); - - g_ErrorSink.sink(infoMsg.c_str(), nullptr, 0, g_ErrorSink.context); -} +NetSh *g_Instance = nullptr; std::vector<std::string> BlockToRows(const std::string &textBlock) { @@ -87,97 +60,89 @@ __declspec(noreturn) void ThrowWithDetails(std::string &&error, common::Applicat throw NetShError(std::move(error), std::move(details)); } -void ValidateShellOut(common::ApplicationRunner &netsh, uint32_t timeout) -{ - // Use default timeout of 4 seconds. - const uint32_t actualTimeout = (0 == timeout ? 3000 : timeout); - - const auto startTime = GetTickCount64(); - - DWORD returnCode; +} // anonymous namespace - if (false == netsh.join(returnCode, actualTimeout)) +//static +void NetSh::Construct(ILogSink *logSink) +{ + if (nullptr != g_Instance) { - ThrowWithDetails("'netsh' did not complete in a timely manner", netsh); + throw std::runtime_error("NetSh is already constructed"); } - if (returnCode != 0) + if (nullptr == logSink) { - std::stringstream ss; - - ss << "'netsh' failed the requested operation. Error: " << returnCode; - - ThrowWithDetails(ss.str(), netsh); + throw std::runtime_error("Invalid logger sink"); } - const auto elapsed = static_cast<uint32_t>(GetTickCount64() - startTime); - - if (elapsed > (actualTimeout / 2)) - { - std::stringstream ss; - - ss << L"'netsh' completed successfully, albeit a little slowly. It consumed " - << elapsed << " ms of " - << actualTimeout << " ms max permitted execution time"; - - InfoSink(ss.str().c_str()); - } + g_Instance = new NetSh(logSink); } -} // anonymous namespace - //static -void NetSh::RegisterErrorSink(const ErrorSinkInfo &errorSink) +NetSh &NetSh::Instance() { - g_ErrorSink = errorSink; + if (nullptr == g_Instance) + { + throw std::runtime_error("NetSh is being referenced prior to being constructed"); + } + + return *g_Instance; } -//static -void NetSh::SetIpv4PrimaryDns(uint32_t interfaceIndex, std::wstring server, uint32_t timeout) +void NetSh::SetIpv4StaticDns(uint32_t interfaceIndex, + const std::vector<std::wstring> &nameServers, uint32_t timeout) { // + // Setting primary and secondary name server requires two invokations: + // // netsh interface ipv4 set dnsservers name="Ethernet 2" source=static address=8.8.8.8 validate=no + // netsh interface ipv4 add dnsservers name="Ethernet 2" address=8.8.4.4 index=2 validate=no // // Note: we're specifying the interface by index instead. // - std::wstringstream ss; + if (nameServers.empty()) + { + throw std::runtime_error("Invalid list of name servers (zero length list)"); + } - ss << L"interface ipv4 set dnsservers name=" - << interfaceIndex - << L" source=static address=" - << server - << L" validate=no"; + { + std::wstringstream ss; - auto netsh = common::ApplicationRunner::StartWithoutConsole(NetShPath(), ss.str()); + ss << L"interface ipv4 set dnsservers name=" + << interfaceIndex + << L" source=static address=" + << nameServers[0] + << L" validate=no"; - ValidateShellOut(*netsh, timeout); -} + auto netsh = common::ApplicationRunner::StartWithoutConsole(m_netShPath, ss.str()); + + ValidateShellOut(*netsh, timeout); + } -//static -void NetSh::SetIpv4SecondaryDns(uint32_t interfaceIndex, std::wstring server, uint32_t timeout) -{ - // - // netsh interface ipv4 add dnsservers name="Ethernet 2" address=8.8.4.4 index=2 validate=no // - // Note: we're specifying the interface by index instead. + // Set additional name servers. // - std::wstringstream ss; + for (size_t i = 1; i < nameServers.size(); ++i) + { + std::wstringstream ss; - ss << L"interface ipv4 add dnsservers name=" - << interfaceIndex - << L" address=" - << server - << L" index=2 validate=no"; + ss << L"interface ipv4 add dnsservers name=" + << interfaceIndex + << L" address=" + << nameServers[i] + << L" index=" + << i + 1 + << L" validate=no"; - auto netsh = common::ApplicationRunner::StartWithoutConsole(NetShPath(), ss.str()); + auto netsh = common::ApplicationRunner::StartWithoutConsole(m_netShPath, ss.str()); - ValidateShellOut(*netsh, timeout); + ValidateShellOut(*netsh, timeout); + } } -//static -void NetSh::SetIpv4Dhcp(uint32_t interfaceIndex, uint32_t timeout) +void NetSh::SetIpv4DhcpDns(uint32_t interfaceIndex, uint32_t timeout) { // // netsh interface ipv4 set dnsservers name="Ethernet 2" source=dhcp @@ -191,57 +156,65 @@ void NetSh::SetIpv4Dhcp(uint32_t interfaceIndex, uint32_t timeout) << interfaceIndex << L" source=dhcp"; - auto netsh = common::ApplicationRunner::StartWithoutConsole(NetShPath(), ss.str()); + auto netsh = common::ApplicationRunner::StartWithoutConsole(m_netShPath, ss.str()); ValidateShellOut(*netsh, timeout); } -//static -void NetSh::SetIpv6PrimaryDns(uint32_t interfaceIndex, std::wstring server, uint32_t timeout) +void NetSh::SetIpv6StaticDns(uint32_t interfaceIndex, + const std::vector<std::wstring> &nameServers, uint32_t timeout) { // + // Setting primary and secondary name server requires two invokations: + // // netsh interface ipv6 set dnsservers name="Ethernet 2" source=static address=2001:4860:4860::8888 validate=no + // netsh interface ipv6 add dnsservers name="Ethernet 2" address=2001:4860:4860::8844 index=2 validate=no // // Note: we're specifying the interface by index instead. // - std::wstringstream ss; + if (nameServers.empty()) + { + throw std::runtime_error("Invalid list of name servers (zero length list)"); + } - ss << L"interface ipv6 set dnsservers name=" - << interfaceIndex - << L" source=static address=" - << server - << L" validate=no"; + { + std::wstringstream ss; - auto netsh = common::ApplicationRunner::StartWithoutConsole(NetShPath(), ss.str()); + ss << L"interface ipv6 set dnsservers name=" + << interfaceIndex + << L" source=static address=" + << nameServers[0] + << L" validate=no"; - ValidateShellOut(*netsh, timeout); -} + auto netsh = common::ApplicationRunner::StartWithoutConsole(m_netShPath, ss.str()); + + ValidateShellOut(*netsh, timeout); + } -//static -void NetSh::SetIpv6SecondaryDns(uint32_t interfaceIndex, std::wstring server, uint32_t timeout) -{ - // - // netsh interface ipv6 add dnsservers name="Ethernet 2" address=2001:4860:4860::8844 index=2 validate=no // - // Note: we're specifying the interface by index instead. + // Set additional name servers. // - std::wstringstream ss; + for (size_t i = 1; i < nameServers.size(); ++i) + { + std::wstringstream ss; - ss << L"interface ipv6 add dnsservers name=" - << interfaceIndex - << L"address =" - << server - << L" index=2 validate=no"; + ss << L"interface ipv6 add dnsservers name=" + << interfaceIndex + << L" address=" + << nameServers[i] + << L" index=" + << i + 1 + << L" validate=no"; - auto netsh = common::ApplicationRunner::StartWithoutConsole(NetShPath(), ss.str()); + auto netsh = common::ApplicationRunner::StartWithoutConsole(m_netShPath, ss.str()); - ValidateShellOut(*netsh, timeout); + ValidateShellOut(*netsh, timeout); + } } -//static -void NetSh::SetIpv6Dhcp(uint32_t interfaceIndex, uint32_t timeout) +void NetSh::SetIpv6DhcpDns(uint32_t interfaceIndex, uint32_t timeout) { // // netsh interface ipv6 set dnsservers name="Ethernet 2" source=dhcp @@ -255,7 +228,68 @@ void NetSh::SetIpv6Dhcp(uint32_t interfaceIndex, uint32_t timeout) << interfaceIndex << L" source=dhcp"; - auto netsh = common::ApplicationRunner::StartWithoutConsole(NetShPath(), ss.str()); + auto netsh = common::ApplicationRunner::StartWithoutConsole(m_netShPath, ss.str()); ValidateShellOut(*netsh, timeout); } + +//static +uint32_t NetSh::ConvertInterfaceGuidToIndex(const std::wstring &interfaceGuid) +{ + auto rawGuid = common::Guid::FromString(interfaceGuid); + + NET_LUID luid; + NET_IFINDEX index; + + if (NO_ERROR != ConvertInterfaceGuidToLuid(&rawGuid, &luid) + || NO_ERROR != ConvertInterfaceLuidToIndex(&luid, &index)) + { + throw std::runtime_error("Invalid interface GUID"); + } + + return index; +} + +NetSh::NetSh(ILogSink *logSink) + : m_logSink(logSink) +{ + const auto system32 = common::fs::GetKnownFolderPath(FOLDERID_System, 0, nullptr); + + m_netShPath = std::experimental::filesystem::path(system32).append(L"netsh.exe"); +} + +void NetSh::ValidateShellOut(common::ApplicationRunner &netsh, uint32_t timeout) +{ + const uint32_t actualTimeout = (0 == timeout ? 3000 : timeout); + + const auto startTime = GetTickCount64(); + + DWORD returnCode; + + if (false == netsh.join(returnCode, actualTimeout)) + { + ThrowWithDetails("'netsh' did not complete in a timely manner", netsh); + } + + if (returnCode != 0) + { + std::stringstream ss; + + ss << "'netsh' failed the requested operation. Error: " << returnCode; + + ThrowWithDetails(ss.str(), netsh); + } + + const auto elapsed = static_cast<uint32_t>(GetTickCount64() - startTime); + + if (elapsed > (actualTimeout / 2)) + { + std::stringstream ss; + + ss << L"'netsh' completed successfully, albeit a little slowly. It consumed " + << elapsed << " ms of " + << actualTimeout << " ms max permitted execution time"; + + m_logSink->info(ss.str().c_str(), nullptr, 0); + } +} diff --git a/windows/windns/src/windns/netsh.h b/windows/windns/src/windns/netsh.h index 5ba43fba66..f47408c9e2 100644 --- a/windows/windns/src/windns/netsh.h +++ b/windows/windns/src/windns/netsh.h @@ -1,7 +1,9 @@ #pragma once -#include "clientsinkinfo.h" +#include "ilogsink.h" +#include <libcommon/applicationrunner.h> #include <string> +#include <vector> #include <cstdint> #include <stdexcept> @@ -9,24 +11,30 @@ class NetSh { public: - static void RegisterErrorSink(const ErrorSinkInfo &errorSink); + static void Construct(ILogSink *logSink); - static void SetIpv4PrimaryDns(uint32_t interfaceIndex, std::wstring server, uint32_t timeout = 0); - - // - // Caveat: This sets the primary DNS server if there isn't already one. - // - static void SetIpv4SecondaryDns(uint32_t interfaceIndex, std::wstring server, uint32_t timeout = 0); + static NetSh &Instance(); - static void SetIpv4Dhcp(uint32_t interfaceIndex, uint32_t timeout = 0); + void SetIpv4StaticDns(uint32_t interfaceIndex, + const std::vector<std::wstring> &nameServers, uint32_t timeout = 0); - static void SetIpv6PrimaryDns(uint32_t interfaceIndex, std::wstring server, uint32_t timeout = 0); - static void SetIpv6SecondaryDns(uint32_t interfaceIndex, std::wstring server, uint32_t timeout = 0); - static void SetIpv6Dhcp(uint32_t interfaceIndex, uint32_t timeout = 0); + void SetIpv4DhcpDns(uint32_t interfaceIndex, uint32_t timeout = 0); + + void SetIpv6StaticDns(uint32_t interfaceIndex, + const std::vector<std::wstring> &nameServers, uint32_t timeout = 0); + + void SetIpv6DhcpDns(uint32_t interfaceIndex, uint32_t timeout = 0); + + static uint32_t ConvertInterfaceGuidToIndex(const std::wstring &interfaceGuid); private: - NetSh(); + ILogSink *m_logSink; + std::wstring m_netShPath; + + NetSh(ILogSink *logSink); + + void ValidateShellOut(common::ApplicationRunner &netsh, uint32_t timeout); }; class NetShError : public std::exception diff --git a/windows/windns/src/windns/recoveryformatter.cpp b/windows/windns/src/windns/recoveryformatter.cpp new file mode 100644 index 0000000000..e6ed3d75fd --- /dev/null +++ b/windows/windns/src/windns/recoveryformatter.cpp @@ -0,0 +1,88 @@ +#include "stdafx.h" +#include <libcommon/serialization/serializer.h> +#include <libcommon/serialization/deserializer.h> +#include "recoveryformatter.h" +#include <stdexcept> + +namespace +{ + +uint32_t RF_MAGIC = 0x21534E44; // stores as 'DNS!' +uint32_t RF_VERSION = 0x01; + +} // anonymous namespace + +//static +std::vector<uint8_t> RecoveryFormatter::Pack(const std::vector<InterfaceSnap> &v4Snaps, + const std::vector<InterfaceSnap> &v6Snaps) +{ + common::serialization::Serializer s; + + // + // Format of binary blob + // + // u32 tag + // u32 version + // u32 number of ipv4 snaps + // [] ipv4 snaps + // u32 number of ipv6 snaps + // [] ipv6 snaps + // + + s << RF_MAGIC; + s << RF_VERSION; + + s << static_cast<uint32_t>(v4Snaps.size()); + + for (const auto &snap : v4Snaps) + { + snap.serialize(s); + } + + s << static_cast<uint32_t>(v6Snaps.size()); + + for (const auto &snap : v6Snaps) + { + snap.serialize(s); + } + + return s.blob(); +} + +//static +RecoveryFormatter::Unpacked RecoveryFormatter::Unpack(const uint8_t *data, uint32_t dataSize) +{ + common::serialization::Deserializer d(data, dataSize); + + if (RF_MAGIC != d.decode<uint32_t>() + || RF_VERSION != d.decode<uint32_t>()) + { + throw std::runtime_error("Invalid header in recovery data"); + } + + Unpacked unpacked; + + auto numV4Snaps = d.decode<uint32_t>(); + + for (; 0 != numV4Snaps; --numV4Snaps) + { + // Invoke deserializing ctor on InterfaceSnap. + unpacked.v4Snaps.emplace_back(d); + } + + auto numV6Snaps = d.decode<uint32_t>(); + + for (; 0 != numV6Snaps; --numV6Snaps) + { + // Invoke deserializing ctor on InterfaceSnap. + unpacked.v6Snaps.emplace_back(d); + } + + return unpacked; +} + +//static +RecoveryFormatter::Unpacked RecoveryFormatter::Unpack(const std::vector<uint8_t> &data) +{ + return RecoveryFormatter::Unpack(&data[0], static_cast<uint32_t>(data.size())); +} diff --git a/windows/windns/src/windns/recoveryformatter.h b/windows/windns/src/windns/recoveryformatter.h new file mode 100644 index 0000000000..c3d715cd9e --- /dev/null +++ b/windows/windns/src/windns/recoveryformatter.h @@ -0,0 +1,24 @@ +#pragma once + +#include "interfacesnap.h" +#include <vector> +#include <cstdint> + +class RecoveryFormatter +{ +public: + + RecoveryFormatter() = delete; + + static std::vector<uint8_t> Pack(const std::vector<InterfaceSnap> &v4Snaps, + const std::vector<InterfaceSnap> &v6Snaps); + + struct Unpacked + { + std::vector<InterfaceSnap> v4Snaps; + std::vector<InterfaceSnap> v6Snaps; + }; + + static Unpacked Unpack(const uint8_t *data, uint32_t dataSize); + static Unpacked Unpack(const std::vector<uint8_t> &data); +}; diff --git a/windows/windns/src/windns/recoverylogic.cpp b/windows/windns/src/windns/recoverylogic.cpp new file mode 100644 index 0000000000..0c630a89b2 --- /dev/null +++ b/windows/windns/src/windns/recoverylogic.cpp @@ -0,0 +1,90 @@ +#include "stdafx.h" +#include "recoverylogic.h" +#include "netsh.h" +#include "confineoperation.h" +#include <libcommon/trace/xtrace.h> +#include <stdexcept> + +//static +void RecoveryLogic::RestoreInterfaces(const RecoveryFormatter::Unpacked &data, + ILogSink *logSink, uint32_t timeout) +{ + if (nullptr == logSink) + { + throw std::runtime_error("Invalid logger sink"); + } + + auto forwardError = [logSink](const char *msg, const char **details, uint32_t numDetails) + { + logSink->error(msg, details, numDetails); + }; + + bool success = true; + + for (const auto &snap : data.v4Snaps) + { + const auto status = ConfineOperation("Reset interface DNS settings", forwardError, [&snap, &timeout]() + { + if (snap.internalInterface()) + { + // + // This is an interface used for internal communication. + // We haven't changed any settings on it and therefore should not restore it. + // + return; + } + + XTRACE("Resetting name server configuration for interface ", snap.interfaceGuid()); + + if (snap.nameServers().empty()) + { + NetSh::Instance().SetIpv4DhcpDns(NetSh::ConvertInterfaceGuidToIndex(snap.interfaceGuid()), timeout); + } + else + { + NetSh::Instance().SetIpv4StaticDns(NetSh::ConvertInterfaceGuidToIndex(snap.interfaceGuid()), snap.nameServers(), timeout); + } + }); + + if (false == status) + { + success = false; + } + } + + for (const auto &snap : data.v6Snaps) + { + const auto status = ConfineOperation("Reset interface DNS settings", forwardError, [&snap, &timeout]() + { + if (snap.internalInterface()) + { + // + // This is an interface used for internal communication. + // We haven't changed any settings on it and therefore should not restore it. + // + return; + } + + XTRACE("Resetting name server configuration for interface ", snap.interfaceGuid()); + + if (snap.nameServers().empty()) + { + NetSh::Instance().SetIpv6DhcpDns(NetSh::ConvertInterfaceGuidToIndex(snap.interfaceGuid()), timeout); + } + else + { + NetSh::Instance().SetIpv6StaticDns(NetSh::ConvertInterfaceGuidToIndex(snap.interfaceGuid()), snap.nameServers(), timeout); + } + }); + + if (false == status) + { + success = false; + } + } + + if (false == success) + { + throw std::runtime_error("Could not reset DNS settings for one of more interfaces"); + } +} diff --git a/windows/windns/src/windns/recoverylogic.h b/windows/windns/src/windns/recoverylogic.h new file mode 100644 index 0000000000..5937b0635c --- /dev/null +++ b/windows/windns/src/windns/recoverylogic.h @@ -0,0 +1,14 @@ +#pragma once + +#include "recoveryformatter.h" +#include "ilogsink.h" + +class RecoveryLogic +{ +public: + + RecoveryLogic() = delete; + + static void RestoreInterfaces(const RecoveryFormatter::Unpacked &data, + ILogSink *logSink, uint32_t timeout = 0); +}; diff --git a/windows/windns/src/windns/recoverysink.cpp b/windows/windns/src/windns/recoverysink.cpp new file mode 100644 index 0000000000..55a6620c7d --- /dev/null +++ b/windows/windns/src/windns/recoverysink.cpp @@ -0,0 +1,58 @@ +#include "stdafx.h" +#include "recoverysink.h" +#include "recoveryformatter.h" +#include <stdexcept> + +RecoverySink::RecoverySink(const RecoverySinkInfo &target) + : m_target(target) +{ +} + +void RecoverySink::setTarget(const RecoverySinkInfo &target) +{ + std::scoped_lock<std::mutex> lock(m_targetMutex); + + m_target = target; +} + +void RecoverySink::preserveSnaps(Protocol protocol, const std::vector<InterfaceSnap> &snaps) +{ + std::scoped_lock<std::mutex> dataLock(m_dataMutex); + + switch (protocol) + { + case Protocol::IPv4: + { + m_v4Snaps = snaps; + break; + } + case Protocol::IPv6: + { + m_v6Snaps = snaps; + break; + } + default: + { + throw std::runtime_error("Missing case handler"); + } + } + + m_recoveryData = RecoveryFormatter::Pack(m_v4Snaps, m_v6Snaps); + + std::scoped_lock<std::mutex> lock(m_targetMutex); + + m_target.sink(&m_recoveryData[0], static_cast<uint32_t>(m_recoveryData.size()), m_target.context); +} + +std::vector<uint8_t> RecoverySink::recoveryData() const +{ + std::vector<uint8_t> copy; + + { + std::scoped_lock<std::mutex> dataLock(m_dataMutex); + + copy = m_recoveryData; + } + + return copy; +} diff --git a/windows/windns/src/windns/recoverysink.h b/windows/windns/src/windns/recoverysink.h new file mode 100644 index 0000000000..b685ccd12e --- /dev/null +++ b/windows/windns/src/windns/recoverysink.h @@ -0,0 +1,30 @@ +#pragma once + +#include "irecoverysink.h" +#include "clientsinkinfo.h" +#include <vector> +#include <cstdint> +#include <mutex> + +class RecoverySink : public IRecoverySink +{ +public: + + RecoverySink(const RecoverySinkInfo &target); + + void setTarget(const RecoverySinkInfo &target); + + void preserveSnaps(Protocol protocol, const std::vector<InterfaceSnap> &snaps) override; + + std::vector<uint8_t> recoveryData() const; + +private: + + std::mutex m_targetMutex; + RecoverySinkInfo m_target; + + mutable std::mutex m_dataMutex; + std::vector<InterfaceSnap> m_v4Snaps; + std::vector<InterfaceSnap> m_v6Snaps; + std::vector<uint8_t> m_recoveryData; +}; diff --git a/windows/windns/src/windns/registrypaths.cpp b/windows/windns/src/windns/registrypaths.cpp new file mode 100644 index 0000000000..65117ee4fe --- /dev/null +++ b/windows/windns/src/windns/registrypaths.cpp @@ -0,0 +1,20 @@ +#include "stdafx.h" +#include "registrypaths.h" + +//static +std::wstring RegistryPaths::InterfaceRoot(Protocol protocol) +{ + return + std::wstring(L"SYSTEM\\CurrentControlSet\\Services\\") + .append(Protocol::IPv4 == protocol ? L"Tcpip" : L"Tcpip6") + .append(L"\\Parameters\\Interfaces"); +} + +//static +std::wstring RegistryPaths::InterfaceKey(const std::wstring &interfaceGuid, Protocol protocol) +{ + return + InterfaceRoot(protocol) + .append(L"\\") + .append(interfaceGuid); +} diff --git a/windows/windns/src/windns/registrypaths.h b/windows/windns/src/windns/registrypaths.h new file mode 100644 index 0000000000..34e8022649 --- /dev/null +++ b/windows/windns/src/windns/registrypaths.h @@ -0,0 +1,14 @@ +#pragma once + +#include "types.h" +#include <string> + +class RegistryPaths +{ +public: + + RegistryPaths() = delete; + + static std::wstring InterfaceRoot(Protocol protocol); + static std::wstring InterfaceKey(const std::wstring &interfaceGuid, Protocol protocol); +}; diff --git a/windows/windns/src/windns/types.h b/windows/windns/src/windns/types.h index d30d556c5d..4782621cd2 100644 --- a/windows/windns/src/windns/types.h +++ b/windows/windns/src/windns/types.h @@ -1,7 +1,9 @@ #pragma once -#include <string> -#include <vector> -#include <memory> +#include <cstdint> -using OptionalStringList = std::shared_ptr<std::vector<std::wstring> >; +enum class Protocol : uint8_t +{ + IPv4, + IPv6 +}; diff --git a/windows/windns/src/windns/windns.cpp b/windows/windns/src/windns/windns.cpp index 8bf9a1b649..fe674723e2 100644 --- a/windows/windns/src/windns/windns.cpp +++ b/windows/windns/src/windns/windns.cpp @@ -2,20 +2,19 @@ #include "windns.h" #include "windnscontext.h" #include "clientsinkinfo.h" -#include "interfaceconfig.h" -#include "netconfighelpers.h" #include "confineoperation.h" +#include "recoveryformatter.h" +#include "recoverylogic.h" #include "netsh.h" -#include "libcommon/serialization/deserializer.h" +#include "logsink.h" +#include <memory> #include <vector> #include <string> namespace { -WinDnsErrorSink g_ErrorSink = nullptr; -void *g_ErrorContext = nullptr; - +LogSink *g_LogSink = nullptr; WinDnsContext *g_Context = nullptr; std::vector<std::wstring> MakeStringArray(const wchar_t **strings, uint32_t numStrings) @@ -30,11 +29,11 @@ std::vector<std::wstring> MakeStringArray(const wchar_t **strings, uint32_t numS return v; } -void ForwardError(const char *errorMessage, const char **details, uint32_t numDetails) +void ForwardError(const char *message, const char **details, uint32_t numDetails) { - if (nullptr != g_ErrorSink) + if (nullptr != g_LogSink) { - g_ErrorSink(errorMessage, details, numDetails, g_ErrorContext); + g_LogSink->error(message, details, numDetails); } } @@ -44,8 +43,8 @@ WINDNS_LINKAGE bool WINDNS_API WinDns_Initialize( - WinDnsErrorSink errorSink, - void *errorContext + WinDnsLogSink logSink, + void *logContext ) { if (nullptr != g_Context) @@ -53,14 +52,19 @@ WinDns_Initialize( return false; } - g_ErrorSink = errorSink; - g_ErrorContext = errorContext; - - return ConfineOperation("Initialize", ForwardError, []() + return ConfineOperation("Initialize", ForwardError, [&]() { - NetSh::RegisterErrorSink(ErrorSinkInfo{ g_ErrorSink, g_ErrorContext }); + if (nullptr == g_LogSink) + { + g_LogSink = new LogSink(LogSinkInfo{ logSink, logContext }); + NetSh::Construct(g_LogSink); + } + else + { + g_LogSink->setTarget(LogSinkInfo{ logSink, logContext }); + } - g_Context = new WinDnsContext; + g_Context = new WinDnsContext(g_LogSink); }); } @@ -78,6 +82,10 @@ WinDns_Deinitialize( delete g_Context; g_Context = nullptr; + // Maintain a single instance forever and invoke setTarget() on it. + //delete g_LogSink; + //g_LogSink = nullptr; + return true; } @@ -85,27 +93,28 @@ WINDNS_LINKAGE bool WINDNS_API WinDns_Set( - const wchar_t **servers, - uint32_t numServers, - WinDnsConfigSink configSink, - void *configContext + const wchar_t **ipv4Servers, + uint32_t numIpv4Servers, + const wchar_t **ipv6Servers, + uint32_t numIpv6Servers, + WinDnsRecoverySink recoverySink, + void *recoveryContext ) { if (nullptr == g_Context - || 0 == numServers - || nullptr == configSink) + || nullptr == ipv4Servers + || 0 == numIpv4Servers + || nullptr == ipv6Servers + || 0 == numIpv6Servers + || nullptr == recoverySink) { return false; } return ConfineOperation("Enforce DNS settings", ForwardError, [&]() { - ClientSinkInfo sinkInfo; - - sinkInfo.errorSinkInfo = ErrorSinkInfo{ g_ErrorSink, g_ErrorContext }; - sinkInfo.configSinkInfo = ConfigSinkInfo{ configSink, configContext }; - - g_Context->set(MakeStringArray(servers, numServers), sinkInfo); + g_Context->set(MakeStringArray(ipv4Servers, numIpv4Servers), MakeStringArray(ipv6Servers, \ + numIpv6Servers), RecoverySinkInfo{ recoverySink, recoveryContext }); }); } @@ -130,56 +139,16 @@ WINDNS_LINKAGE bool WINDNS_API WinDns_Recover( - const void *configData, + const void *recoveryData, uint32_t dataLength ) { - std::vector<InterfaceConfig> configs; - - const auto status = ConfineOperation("Deserialize recovery data", ForwardError, [&]() + return ConfineOperation("Recover DNS settings", ForwardError, [&]() { - common::serialization::Deserializer d(reinterpret_cast<const uint8_t *>(configData), dataLength); - - auto numConfigs = d.decode<uint32_t>(); - - if (numConfigs > 50) - { - throw std::runtime_error("Too many configuration entries"); - } + auto unpacked = RecoveryFormatter::Unpack(reinterpret_cast<const uint8_t *>(recoveryData), dataLength); - configs.reserve(numConfigs); + static const uint32_t TIMEOUT_TEN_SECONDS = 1000 * 10; - for (; numConfigs != 0; --numConfigs) - { - configs.emplace_back(InterfaceConfig(d)); - } + RecoveryLogic::RestoreInterfaces(unpacked, g_LogSink, TIMEOUT_TEN_SECONDS); }); - - if (false == status) - { - return false; - } - - // - // Try to restore each config and update 'success' if any update fails. - // - - static const uint32_t TIMEOUT_10_SECONDS = 10 * 1000; - - bool success = true; - - for (const auto &config : configs) - { - const auto adapterStatus = ConfineOperation("Restore adapter DNS settings", ForwardError, [&config]() - { - nchelpers::RevertDnsServers(config, TIMEOUT_10_SECONDS); - }); - - if (false == adapterStatus) - { - success = false; - } - } - - return success; } diff --git a/windows/windns/src/windns/windns.h b/windows/windns/src/windns/windns.h index ab4ec9a539..a7ebfd8f95 100644 --- a/windows/windns/src/windns/windns.h +++ b/windows/windns/src/windns/windns.h @@ -17,8 +17,16 @@ // Functions /////////////////////////////////////////////////////////////////////////////// -typedef void (WINDNS_API *WinDnsErrorSink)(const char *errorMessage, const char **details, uint32_t numDetails, void *context); -typedef void (WINDNS_API *WinDnsConfigSink)(const void *configData, uint32_t dataLength, void *context); +enum WinDnsLogCategory +{ + WINDNS_LOG_CATEGORY_ERROR = 0x01, + WINDNS_LOG_CATEGORY_INFO = 0x02 +}; + +typedef void (WINDNS_API *WinDnsLogSink)(WinDnsLogCategory category, const char *message, + const char **details, uint32_t numDetails, void *context); + +typedef void (WINDNS_API *WinDnsRecoverySink)(const void *recoveryData, uint32_t dataLength, void *context); // // WinDns_Initialize: @@ -35,8 +43,8 @@ WINDNS_LINKAGE bool WINDNS_API WinDns_Initialize( - WinDnsErrorSink errorSink, - void *errorContext + WinDnsLogSink logSink, + void *logContext ); // @@ -56,20 +64,22 @@ WinDns_Deinitialize( // // Configure which DNS servers should be used and start enforcing these settings. // -// The 'configSink' will receive periodic callbacks with updated config data +// The 'recoverySink' will receive periodic callbacks with updated recovery data // until you call WinDns_Reset. // -// You should persist the config data in preparation for an eventual recovery. +// You should persist the recovery data in preparation for an eventual recovery. // extern "C" WINDNS_LINKAGE bool WINDNS_API WinDns_Set( - const wchar_t **servers, - uint32_t numServers, - WinDnsConfigSink configSink, - void *configContext + const wchar_t **ipv4Servers, + uint32_t numIpv4Servers, + const wchar_t **ipv6Servers, + uint32_t numIpv6Servers, + WinDnsRecoverySink recoverySink, + void *recoveryContext ); // @@ -80,7 +90,7 @@ WinDns_Set( // (Also taking into account external changes to DNS settings that have occurred // during the period of enforcing specific settings.) // -// It's safe to discard persisted config data once WinDns_Reset returns 'true'. +// It's safe to discard persisted recovery data once WinDns_Reset returns 'true'. // extern "C" WINDNS_LINKAGE @@ -102,6 +112,6 @@ WINDNS_LINKAGE bool WINDNS_API WinDns_Recover( - const void *configData, + const void *recoveryData, uint32_t dataLength ); diff --git a/windows/windns/src/windns/windns.vcxproj b/windows/windns/src/windns/windns.vcxproj index 8b8bbb124a..bf84c01690 100644 --- a/windows/windns/src/windns/windns.vcxproj +++ b/windows/windns/src/windns/windns.vcxproj @@ -106,7 +106,7 @@ <SubSystem>Windows</SubSystem> <GenerateDebugInformation>true</GenerateDebugInformation> <AdditionalLibraryDirectories>$(SolutionDir)bin\$(Platform)-$(Configuration)\</AdditionalLibraryDirectories> - <AdditionalDependencies>libcommon.lib;wbemuuid.lib;comsuppw.lib;kernel32.lib;user32.lib;gdi32.lib;winspool.lib;comdlg32.lib;advapi32.lib;shell32.lib;ole32.lib;oleaut32.lib;uuid.lib;odbc32.lib;odbccp32.lib;%(AdditionalDependencies)</AdditionalDependencies> + <AdditionalDependencies>libcommon.lib;Iphlpapi.lib;wbemuuid.lib;comsuppw.lib;kernel32.lib;user32.lib;gdi32.lib;winspool.lib;comdlg32.lib;advapi32.lib;shell32.lib;ole32.lib;oleaut32.lib;uuid.lib;odbc32.lib;odbccp32.lib;%(AdditionalDependencies)</AdditionalDependencies> </Link> </ItemDefinitionGroup> <ItemDefinitionGroup Condition="'$(Configuration)|$(Platform)'=='Debug|Win32'"> @@ -125,7 +125,7 @@ <SubSystem>Windows</SubSystem> <GenerateDebugInformation>true</GenerateDebugInformation> <AdditionalLibraryDirectories>$(SolutionDir)bin\$(Platform)-$(Configuration)\</AdditionalLibraryDirectories> - <AdditionalDependencies>libcommon.lib;wbemuuid.lib;comsuppw.lib;kernel32.lib;user32.lib;gdi32.lib;winspool.lib;comdlg32.lib;advapi32.lib;shell32.lib;ole32.lib;oleaut32.lib;uuid.lib;odbc32.lib;odbccp32.lib;%(AdditionalDependencies)</AdditionalDependencies> + <AdditionalDependencies>libcommon.lib;Iphlpapi.lib;wbemuuid.lib;comsuppw.lib;kernel32.lib;user32.lib;gdi32.lib;winspool.lib;comdlg32.lib;advapi32.lib;shell32.lib;ole32.lib;oleaut32.lib;uuid.lib;odbc32.lib;odbccp32.lib;%(AdditionalDependencies)</AdditionalDependencies> </Link> </ItemDefinitionGroup> <ItemDefinitionGroup Condition="'$(Configuration)|$(Platform)'=='Release|Win32'"> @@ -148,7 +148,7 @@ <OptimizeReferences>true</OptimizeReferences> <GenerateDebugInformation>true</GenerateDebugInformation> <AdditionalLibraryDirectories>$(SolutionDir)bin\$(Platform)-$(Configuration)\</AdditionalLibraryDirectories> - <AdditionalDependencies>libcommon.lib;wbemuuid.lib;comsuppw.lib;kernel32.lib;user32.lib;gdi32.lib;winspool.lib;comdlg32.lib;advapi32.lib;shell32.lib;ole32.lib;oleaut32.lib;uuid.lib;odbc32.lib;odbccp32.lib;%(AdditionalDependencies)</AdditionalDependencies> + <AdditionalDependencies>libcommon.lib;Iphlpapi.lib;wbemuuid.lib;comsuppw.lib;kernel32.lib;user32.lib;gdi32.lib;winspool.lib;comdlg32.lib;advapi32.lib;shell32.lib;ole32.lib;oleaut32.lib;uuid.lib;odbc32.lib;odbccp32.lib;%(AdditionalDependencies)</AdditionalDependencies> </Link> </ItemDefinitionGroup> <ItemDefinitionGroup Condition="'$(Configuration)|$(Platform)'=='Release|x64'"> @@ -171,18 +171,25 @@ <OptimizeReferences>true</OptimizeReferences> <GenerateDebugInformation>true</GenerateDebugInformation> <AdditionalLibraryDirectories>$(SolutionDir)bin\$(Platform)-$(Configuration)\</AdditionalLibraryDirectories> - <AdditionalDependencies>libcommon.lib;wbemuuid.lib;comsuppw.lib;kernel32.lib;user32.lib;gdi32.lib;winspool.lib;comdlg32.lib;advapi32.lib;shell32.lib;ole32.lib;oleaut32.lib;uuid.lib;odbc32.lib;odbccp32.lib;%(AdditionalDependencies)</AdditionalDependencies> + <AdditionalDependencies>libcommon.lib;Iphlpapi.lib;wbemuuid.lib;comsuppw.lib;kernel32.lib;user32.lib;gdi32.lib;winspool.lib;comdlg32.lib;advapi32.lib;shell32.lib;ole32.lib;oleaut32.lib;uuid.lib;odbc32.lib;odbccp32.lib;%(AdditionalDependencies)</AdditionalDependencies> </Link> </ItemDefinitionGroup> <ItemGroup> <ClInclude Include="clientsinkinfo.h" /> - <ClInclude Include="configmanager.h" /> <ClInclude Include="confineoperation.h" /> - <ClInclude Include="iclientsinkproxy.h" /> - <ClInclude Include="interfaceconfig.h" /> - <ClInclude Include="netconfigeventsink.h" /> - <ClInclude Include="netconfighelpers.h" /> + <ClInclude Include="logsink.h" /> + <ClInclude Include="ilogsink.h" /> + <ClInclude Include="inameserversource.h" /> + <ClInclude Include="interfacemonitor.h" /> + <ClInclude Include="interfacesnap.h" /> + <ClInclude Include="irecoverysink.h" /> + <ClInclude Include="dnsagent.h" /> + <ClInclude Include="nameserversource.h" /> <ClInclude Include="netsh.h" /> + <ClInclude Include="recoveryformatter.h" /> + <ClInclude Include="recoverylogic.h" /> + <ClInclude Include="recoverysink.h" /> + <ClInclude Include="registrypaths.h" /> <ClInclude Include="stdafx.h" /> <ClInclude Include="targetver.h" /> <ClInclude Include="types.h" /> @@ -190,13 +197,18 @@ <ClInclude Include="windnscontext.h" /> </ItemGroup> <ItemGroup> - <ClCompile Include="configmanager.cpp" /> <ClCompile Include="confineoperation.cpp" /> <ClCompile Include="dllmain.cpp" /> - <ClCompile Include="interfaceconfig.cpp" /> - <ClCompile Include="netconfigeventsink.cpp" /> - <ClCompile Include="netconfighelpers.cpp" /> + <ClCompile Include="logsink.cpp" /> + <ClCompile Include="interfacemonitor.cpp" /> + <ClCompile Include="interfacesnap.cpp" /> + <ClCompile Include="dnsagent.cpp" /> + <ClCompile Include="nameserversource.cpp" /> <ClCompile Include="netsh.cpp" /> + <ClCompile Include="recoveryformatter.cpp" /> + <ClCompile Include="recoverylogic.cpp" /> + <ClCompile Include="recoverysink.cpp" /> + <ClCompile Include="registrypaths.cpp" /> <ClCompile Include="stdafx.cpp"> <PrecompiledHeader Condition="'$(Configuration)|$(Platform)'=='Debug|x64'">Create</PrecompiledHeader> <PrecompiledHeader Condition="'$(Configuration)|$(Platform)'=='Debug|Win32'">Create</PrecompiledHeader> diff --git a/windows/windns/src/windns/windns.vcxproj.filters b/windows/windns/src/windns/windns.vcxproj.filters index 57e9eeaef0..3c846bb6ba 100644 --- a/windows/windns/src/windns/windns.vcxproj.filters +++ b/windows/windns/src/windns/windns.vcxproj.filters @@ -5,27 +5,39 @@ <ClInclude Include="targetver.h" /> <ClInclude Include="windns.h" /> <ClInclude Include="windnscontext.h" /> - <ClInclude Include="configmanager.h" /> - <ClInclude Include="netconfigeventsink.h" /> - <ClInclude Include="netconfighelpers.h" /> <ClInclude Include="clientsinkinfo.h" /> <ClInclude Include="netsh.h" /> - <ClInclude Include="interfaceconfig.h" /> - <ClInclude Include="types.h" /> - <ClInclude Include="iclientsinkproxy.h" /> <ClInclude Include="confineoperation.h" /> + <ClInclude Include="interfacesnap.h" /> + <ClInclude Include="interfacemonitor.h" /> + <ClInclude Include="registrypaths.h" /> + <ClInclude Include="types.h" /> + <ClInclude Include="inameserversource.h" /> + <ClInclude Include="irecoverysink.h" /> + <ClInclude Include="dnsagent.h" /> + <ClInclude Include="recoverysink.h" /> + <ClInclude Include="recoveryformatter.h" /> + <ClInclude Include="nameserversource.h" /> + <ClInclude Include="ilogsink.h" /> + <ClInclude Include="logsink.h" /> + <ClInclude Include="recoverylogic.h" /> </ItemGroup> <ItemGroup> <ClCompile Include="dllmain.cpp" /> <ClCompile Include="stdafx.cpp" /> <ClCompile Include="windns.cpp" /> <ClCompile Include="windnscontext.cpp" /> - <ClCompile Include="configmanager.cpp" /> - <ClCompile Include="netconfigeventsink.cpp" /> - <ClCompile Include="netconfighelpers.cpp" /> <ClCompile Include="netsh.cpp" /> - <ClCompile Include="interfaceconfig.cpp" /> <ClCompile Include="confineoperation.cpp" /> + <ClCompile Include="interfacesnap.cpp" /> + <ClCompile Include="interfacemonitor.cpp" /> + <ClCompile Include="registrypaths.cpp" /> + <ClCompile Include="dnsagent.cpp" /> + <ClCompile Include="recoverysink.cpp" /> + <ClCompile Include="recoveryformatter.cpp" /> + <ClCompile Include="nameserversource.cpp" /> + <ClCompile Include="logsink.cpp" /> + <ClCompile Include="recoverylogic.cpp" /> </ItemGroup> <ItemGroup> <ResourceCompile Include="windns.rc" /> diff --git a/windows/windns/src/windns/windnscontext.cpp b/windows/windns/src/windns/windnscontext.cpp index 29cb034566..cf80c0e619 100644 --- a/windows/windns/src/windns/windnscontext.cpp +++ b/windows/windns/src/windns/windnscontext.cpp @@ -1,131 +1,97 @@ #include "stdafx.h" #include "windnscontext.h" -#include "libcommon/wmi/connection.h" -#include "netconfigeventsink.h" -#include "netconfighelpers.h" #include "confineoperation.h" +#include "recoveryformatter.h" +#include "recoverylogic.h" #include <functional> using namespace common; -WinDnsContext::WinDnsContext() +WinDnsContext::WinDnsContext(ILogSink *logSink) + : m_logSink(logSink) { - m_connection = std::make_shared<wmi::Connection>(wmi::Connection::Namespace::Cimv2); + if (nullptr == logSink) + { + throw std::runtime_error("Invalid logger sink"); + } } WinDnsContext::~WinDnsContext() { - try + auto forwardError = [this](const char *msg, const char **details, uint32_t numDetails) { - reset(); - } - catch (...) + m_logSink->error(msg, details, numDetails); + }; + + ConfineOperation("Reset DNS settings", forwardError, [this]() { - } + this->reset(); + }); } -void WinDnsContext::set(const std::vector<std::wstring> &servers, const ClientSinkInfo &sinkInfo) +void WinDnsContext::set(const std::vector<std::wstring> &ipv4NameServers, + const std::vector<std::wstring> &ipv6NameServers, const RecoverySinkInfo &recoverySinkInfo) { - m_sinkInfo = sinkInfo; + // + // The 'sink' and 'source' instances must be kept alive for the lifetime of the agents. + // - if (nullptr == m_notification) + if (!m_recoverySink) { - m_configManager = std::make_shared<ConfigManager>(servers, this); - - // - // Register interface configuration monitoring. - // - - auto eventSink = std::make_shared<NetConfigEventSink>(m_connection, m_configManager, this); - auto eventDispatcher = CComPtr<wmi::IEventDispatcher>(new wmi::ModificationEventDispatcher(eventSink)); - - m_notification = std::make_unique<wmi::Notification>(m_connection, eventDispatcher); - - m_notification->activate - ( - L"SELECT * " - L"FROM __InstanceModificationEvent " - L"WITHIN 1 " - L"WHERE TargetInstance ISA 'Win32_NetworkAdapterConfiguration'" - L"AND TargetInstance.IPEnabled = True" - ); + m_recoverySink = std::make_unique<RecoverySink>(recoverySinkInfo); } else { - ConfigManager::Mutex mutex(*m_configManager); + m_recoverySink->setTarget(recoverySinkInfo); + } - m_configManager->updateServers(servers); + if (!m_nameServerSource) + { + m_nameServerSource = std::make_unique<NameServerSource>(ipv4NameServers, ipv6NameServers); + } + else + { + m_nameServerSource->setNameServers(Protocol::IPv4, ipv4NameServers); + m_nameServerSource->setNameServers(Protocol::IPv6, ipv6NameServers); } // - // Discover all active interfaces and apply our DNS settings. + // Instantiate agents unless they're already set up. // - auto resultSet = m_connection->query(L"SELECT * from Win32_NetworkAdapterConfiguration WHERE IPEnabled = True"); + if (!m_ipv4Agent) + { + m_ipv4Agent = std::make_unique<DnsAgent>(Protocol::IPv4, m_nameServerSource.get(), m_recoverySink.get(), m_logSink); + } - while (resultSet.advance()) + if (!m_ipv6Agent) { - nchelpers::SetDnsServers(nchelpers::GetInterfaceIndex(resultSet.result()), servers); + m_ipv6Agent = std::make_unique<DnsAgent>(Protocol::IPv6, m_nameServerSource.get(), m_recoverySink.get(), m_logSink); } } void WinDnsContext::reset() { - if (nullptr == m_notification) + if (!m_ipv4Agent && !m_ipv6Agent) { return; } - m_notification->deactivate(); - m_notification = nullptr; - - // - // Reset adapter configs. // - // Safe to do without a mutex guarding the config manager. + // Destructing the agents will abort all monitoring + enforcing. // - // Try to reset as many adapters as possible, even if one or more fails to reset. - // - - bool success = true; - - auto forwardError = std::bind(&WinDnsContext::error, this, std::placeholders::_1, std::placeholders::_2, std::placeholders::_3); - - m_configManager->processConfigs([&success, &forwardError](const InterfaceConfig &config) - { - const auto adapterStatus = ConfineOperation("Reset adapter DNS configuration", forwardError, [&config]() - { - nchelpers::RevertDnsServers(config); - }); - - if (false == adapterStatus) - { - success = false; - } - return true; - }); - - if (false == success) + if (m_ipv4Agent) { - throw std::runtime_error("Resetting DNS failed for one or more adapters"); + m_ipv4Agent.reset(nullptr); } -} -// IClientSinkProxy -void WinDnsContext::error(const char *errorMessage, const char **details, uint32_t numDetails) -{ - if (nullptr != m_sinkInfo.errorSinkInfo.sink) + if (m_ipv6Agent) { - m_sinkInfo.errorSinkInfo.sink(errorMessage, details, numDetails, m_sinkInfo.errorSinkInfo.context); + m_ipv6Agent.reset(nullptr); } -} -// IClientSinkProxy -void WinDnsContext::config(const void *configData, uint32_t dataLength) -{ - if (nullptr != m_sinkInfo.configSinkInfo.sink) - { - m_sinkInfo.configSinkInfo.sink(configData, dataLength, m_sinkInfo.configSinkInfo.context); - } + auto recoveryData = RecoveryFormatter::Unpack(m_recoverySink->recoveryData()); + + RecoveryLogic::RestoreInterfaces(recoveryData, m_logSink); } diff --git a/windows/windns/src/windns/windnscontext.h b/windows/windns/src/windns/windnscontext.h index c83ec6e482..c64ab96f31 100644 --- a/windows/windns/src/windns/windnscontext.h +++ b/windows/windns/src/windns/windnscontext.h @@ -1,32 +1,34 @@ #pragma once #include "windns.h" -#include "libcommon/wmi/connection.h" -#include "libcommon/wmi/notification.h" -#include "configmanager.h" #include "clientsinkinfo.h" -#include "iclientsinkproxy.h" +#include "ilogsink.h" +#include "recoverysink.h" +#include "nameserversource.h" +#include "dnsagent.h" #include <vector> #include <string> #include <memory> -class WinDnsContext : public IClientSinkProxy +class WinDnsContext { public: - WinDnsContext(); + WinDnsContext(ILogSink *logSink); ~WinDnsContext(); - void set(const std::vector<std::wstring> &servers, const ClientSinkInfo &sinkInfo); - void reset(); + void set(const std::vector<std::wstring> &ipv4NameServers, const std::vector<std::wstring> &ipv6NameServers, + const RecoverySinkInfo &recoverySinkInfo); - void IClientSinkProxy::error(const char *errorMessage, const char **details, uint32_t numDetails) override; - void IClientSinkProxy::config(const void *configData, uint32_t dataLength) override; + void reset(); private: - std::shared_ptr<common::wmi::Connection> m_connection; - std::shared_ptr<ConfigManager> m_configManager; - std::unique_ptr<common::wmi::Notification> m_notification; - ClientSinkInfo m_sinkInfo; + ILogSink *m_logSink; + + std::unique_ptr<RecoverySink> m_recoverySink; + std::unique_ptr<NameServerSource> m_nameServerSource; + + std::unique_ptr<DnsAgent> m_ipv4Agent; + std::unique_ptr<DnsAgent> m_ipv6Agent; }; |
