summaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorOdd Stranne <odd@mullvad.net>2018-09-28 14:47:55 +0200
committerOdd Stranne <odd@mullvad.net>2018-10-03 23:13:55 +0200
commit3291c2c839aea203686737c12cd0bd22a0dd20f6 (patch)
treebd031f7c61d1ec92707d3621f57b32c066476c4d
parent5ddf1c021ea4dc25ad00730f5f0553150c326899 (diff)
downloadmullvadvpn-3291c2c839aea203686737c12cd0bd22a0dd20f6.tar.xz
mullvadvpn-3291c2c839aea203686737c12cd0bd22a0dd20f6.zip
Rewrite 'windns' to get rid of WMI + fix bugs and shortcomings
-rw-r--r--windows/windns/extras.sln1
-rw-r--r--windows/windns/src/extras/loader/loader.cpp48
-rw-r--r--windows/windns/src/windns/clientsinkinfo.h14
-rw-r--r--windows/windns/src/windns/configmanager.cpp133
-rw-r--r--windows/windns/src/windns/configmanager.h110
-rw-r--r--windows/windns/src/windns/dnsagent.cpp442
-rw-r--r--windows/windns/src/windns/dnsagent.h84
-rw-r--r--windows/windns/src/windns/iclientsinkproxy.h13
-rw-r--r--windows/windns/src/windns/ilogsink.h24
-rw-r--r--windows/windns/src/windns/inameserversource.h23
-rw-r--r--windows/windns/src/windns/interfaceconfig.cpp66
-rw-r--r--windows/windns/src/windns/interfaceconfig.h62
-rw-r--r--windows/windns/src/windns/interfacemonitor.cpp24
-rw-r--r--windows/windns/src/windns/interfacemonitor.h30
-rw-r--r--windows/windns/src/windns/interfacesnap.cpp151
-rw-r--r--windows/windns/src/windns/interfacesnap.h35
-rw-r--r--windows/windns/src/windns/irecoverysink.h14
-rw-r--r--windows/windns/src/windns/logsink.cpp34
-rw-r--r--windows/windns/src/windns/logsink.h22
-rw-r--r--windows/windns/src/windns/nameserversource.cpp77
-rw-r--r--windows/windns/src/windns/nameserversource.h29
-rw-r--r--windows/windns/src/windns/netconfigeventsink.cpp51
-rw-r--r--windows/windns/src/windns/netconfigeventsink.h24
-rw-r--r--windows/windns/src/windns/netconfighelpers.cpp65
-rw-r--r--windows/windns/src/windns/netconfighelpers.h24
-rw-r--r--windows/windns/src/windns/netsh.cpp274
-rw-r--r--windows/windns/src/windns/netsh.h34
-rw-r--r--windows/windns/src/windns/recoveryformatter.cpp88
-rw-r--r--windows/windns/src/windns/recoveryformatter.h24
-rw-r--r--windows/windns/src/windns/recoverylogic.cpp90
-rw-r--r--windows/windns/src/windns/recoverylogic.h14
-rw-r--r--windows/windns/src/windns/recoverysink.cpp58
-rw-r--r--windows/windns/src/windns/recoverysink.h30
-rw-r--r--windows/windns/src/windns/registrypaths.cpp20
-rw-r--r--windows/windns/src/windns/registrypaths.h14
-rw-r--r--windows/windns/src/windns/types.h10
-rw-r--r--windows/windns/src/windns/windns.cpp117
-rw-r--r--windows/windns/src/windns/windns.h34
-rw-r--r--windows/windns/src/windns/windns.vcxproj38
-rw-r--r--windows/windns/src/windns/windns.vcxproj.filters32
-rw-r--r--windows/windns/src/windns/windnscontext.cpp130
-rw-r--r--windows/windns/src/windns/windnscontext.h30
42 files changed, 1724 insertions, 913 deletions
diff --git a/windows/windns/extras.sln b/windows/windns/extras.sln
index ed1e71d4c2..ca9347638e 100644
--- a/windows/windns/extras.sln
+++ b/windows/windns/extras.sln
@@ -5,6 +5,7 @@ MinimumVisualStudioVersion = 10.0.40219.1
Project("{8BC9CEB8-8B4A-11D0-8D11-00A0C91BC942}") = "loader", "src\extras\loader\loader.vcxproj", "{1476A8B9-4A9E-4358-8744-A350CB97E152}"
ProjectSection(ProjectDependencies) = postProject
{A5344205-FC37-4572-9C63-8564ECC410AC} = {A5344205-FC37-4572-9C63-8564ECC410AC}
+ {B52E2D10-A94A-4605-914A-2DCEF6A757EF} = {B52E2D10-A94A-4605-914A-2DCEF6A757EF}
EndProjectSection
EndProject
Project("{8BC9CEB8-8B4A-11D0-8D11-00A0C91BC942}") = "windns", "src\windns\windns.vcxproj", "{A5344205-FC37-4572-9C63-8564ECC410AC}"
diff --git a/windows/windns/src/extras/loader/loader.cpp b/windows/windns/src/extras/loader/loader.cpp
index 46d8a7030c..5c6603c3dd 100644
--- a/windows/windns/src/extras/loader/loader.cpp
+++ b/windows/windns/src/extras/loader/loader.cpp
@@ -7,9 +7,19 @@
#include <vector>
#include <windows.h>
-void WINDNS_API ErrorSink(const char *errorMessage, const char **details, uint32_t numDetails, void *context)
+void WINDNS_API LogSink(WinDnsLogCategory category, const char *message, const char **details,
+ uint32_t numDetails, void *context)
{
- std::cout << "WINDNS Error: " << errorMessage << std::endl;
+ if (WINDNS_LOG_CATEGORY_ERROR == category)
+ {
+ std::cout << "WINDNS Error: ";
+ }
+ else
+ {
+ std::cout << "WINDNS Info: ";
+ }
+
+ std::cout << message << std::endl;
for (uint32_t i = 0; i < numDetails; ++i)
{
@@ -17,9 +27,9 @@ void WINDNS_API ErrorSink(const char *errorMessage, const char **details, uint32
}
}
-void WINDNS_API ConfigSink(const void *configData, uint32_t dataLength, void *context)
+void WINDNS_API RecoverySink(const void *recoveryData, uint32_t dataLength, void *context)
{
- std::wcout << L"Updated config was delivered to WINDNS client code" << std::endl;
+ std::wcout << L"Updated recovery data was delivered to WINDNS client code" << std::endl;
auto f = CreateFileW(L"windns_recovery", GENERIC_WRITE, 0, nullptr, CREATE_ALWAYS, 0, nullptr);
@@ -29,7 +39,7 @@ void WINDNS_API ConfigSink(const void *configData, uint32_t dataLength, void *co
return;
}
- if (FALSE == WriteFile(f, configData, dataLength, nullptr, nullptr))
+ if (FALSE == WriteFile(f, recoveryData, dataLength, nullptr, nullptr))
{
std::wcout << L"Failed to update recovery file" << std::endl;
}
@@ -59,7 +69,8 @@ void Recover()
return;
}
- std::wcout << L"WinDns_Recover: " << WinDns_Recover(&data[0], static_cast<uint32_t>(data.size())) << std::endl;
+ std::wcout << L"WinDns_Recover: " << std::boolalpha <<
+ WinDns_Recover(&data[0], static_cast<uint32_t>(data.size())) << std::endl;
}
bool Ask(const std::wstring &question)
@@ -88,7 +99,7 @@ int main()
{
common::trace::Trace::RegisterSink(new common::trace::ConsoleTraceSink);
- std::wcout << L"WinDns_Initialize: " << WinDns_Initialize(ErrorSink, nullptr) << std::endl;
+ std::wcout << L"WinDns_Initialize: " << std::boolalpha << WinDns_Initialize(LogSink, nullptr) << std::endl;
if (Ask(L"Perform recovery?"))
{
@@ -98,28 +109,39 @@ int main()
const wchar_t *servers[] =
{
- L"8.8.8.8"
+ L"8.8.8.8",
+ L"8.8.4.4"
};
- std::wcout << L"WinDns_Set: " << WinDns_Set(servers, _countof(servers), ConfigSink, nullptr) << std::endl;
+ const wchar_t *v6Servers[] =
+ {
+ L"2001:4860:4860::8888",
+ L"2001:4860:4860::8844"
+ };
+
+ auto status = WinDns_Set(servers, _countof(servers), v6Servers, _countof(v6Servers), RecoverySink, nullptr);
+
+ std::wcout << L"WinDns_Set: " << std::boolalpha << status << std::endl;
WaitInput(L"Press a key to abort DNS monitoring + enforcing...");
if (Ask(L"Perform WinDns_Reset() before next WinDns_Set()?"))
{
- std::wcout << L"WinDns_Reset: " << WinDns_Reset() << std::endl;
+ std::wcout << L"WinDns_Reset: " << std::boolalpha << WinDns_Reset() << std::endl;
}
- std::wcout << L"WinDns_Set: " << WinDns_Set(servers, _countof(servers), ConfigSink, nullptr) << std::endl;
+ status = WinDns_Set(servers, _countof(servers), v6Servers, _countof(v6Servers), RecoverySink, nullptr);
+
+ std::wcout << L"WinDns_Set: " << std::boolalpha << status << std::endl;
WaitInput(L"Press a key to abort DNS monitoring + enforcing...");
if (Ask(L"Perform WinDns_Reset() before WinDns_Deinitialize()?"))
{
- std::wcout << L"WinDns_Reset: " << WinDns_Reset() << std::endl;
+ std::wcout << L"WinDns_Reset: " << std::boolalpha << WinDns_Reset() << std::endl;
}
- std::wcout << L"WinDns_Deinitialize: " << WinDns_Deinitialize() << std::endl;
+ std::wcout << L"WinDns_Deinitialize: " << std::boolalpha << WinDns_Deinitialize() << std::endl;
return 0;
}
diff --git a/windows/windns/src/windns/clientsinkinfo.h b/windows/windns/src/windns/clientsinkinfo.h
index db42485d8e..39f602eeac 100644
--- a/windows/windns/src/windns/clientsinkinfo.h
+++ b/windows/windns/src/windns/clientsinkinfo.h
@@ -2,20 +2,14 @@
#include "windns.h"
-struct ErrorSinkInfo
+struct LogSinkInfo
{
- WinDnsErrorSink sink;
+ WinDnsLogSink sink;
void *context;
};
-struct ConfigSinkInfo
+struct RecoverySinkInfo
{
- WinDnsConfigSink sink;
+ WinDnsRecoverySink sink;
void *context;
};
-
-struct ClientSinkInfo
-{
- ErrorSinkInfo errorSinkInfo;
- ConfigSinkInfo configSinkInfo;
-};
diff --git a/windows/windns/src/windns/configmanager.cpp b/windows/windns/src/windns/configmanager.cpp
deleted file mode 100644
index 498e58a381..0000000000
--- a/windows/windns/src/windns/configmanager.cpp
+++ /dev/null
@@ -1,133 +0,0 @@
-#include "stdafx.h"
-#include "configmanager.h"
-#include "libcommon/serialization/serializer.h"
-#include "libcommon/trace/xtrace.h"
-#include <utility>
-#include <algorithm>
-
-ConfigManager::ConfigManager
-(
- const std::vector<std::wstring> &servers,
- IClientSinkProxy *clientSinkProxy
-)
- : m_servers(servers)
- , m_clientSinkProxy(clientSinkProxy)
-{
-}
-
-void ConfigManager::lock()
-{
- m_mutex.lock();
-}
-
-void ConfigManager::unlock()
-{
- m_mutex.unlock();
-}
-
-void ConfigManager::updateServers(const std::vector<std::wstring> &servers)
-{
- XTRACE(L"Updating DNS server list");
- m_servers = servers;
-}
-
-const std::vector<std::wstring> &ConfigManager::getServers() const
-{
- return m_servers;
-}
-
-ConfigManager::UpdateStatus ConfigManager::updateConfig(const InterfaceConfig &previous, const InterfaceConfig &target)
-{
- XTRACE(L"Interface configuration update for interface=", target.interfaceIndex());
-
- //
- // There are a few cases we need to deal with:
- //
- // 1/ An interface being offline and coming online.
- // 2/ An external application changing the interface settings.
- // 3/ Us changing the interface settings.
- // a. On an interface the ConfigManager hasn't seen before.
- // b. On an interface the ConfigManager already knows about.
- //
-
- const auto configIndex = target.configIndex();
- auto iter = m_configs.find(configIndex);
-
- if (verifyServers(target))
- {
- XTRACE(L"Update event was initiated by WINDNS or did not include DNS changes");
-
- //
- // If we haven't seen this config id before, it means the 'previous' instance
- // is the original configuration on the system, and as such must be recorded.
- //
- if (m_configs.end() == iter)
- {
- XTRACE(L"Creating new interface configuration entry");
- m_configs.insert(std::make_pair(configIndex, previous));
-
- exportConfigs();
- }
-
- return UpdateStatus::DnsApproved;
- }
-
- //
- // The update was not initiated by us so store the updated configuration.
- //
- if (m_configs.end() == iter)
- {
- XTRACE(L"Creating new interface configuration entry");
- m_configs.insert(std::make_pair(configIndex, target));
- }
- else
- {
- XTRACE(L"Updating interface configuration entry");
- iter->second.updateServers(target);
- }
-
- exportConfigs();
-
- return UpdateStatus::DnsDeviates;
-}
-
-bool ConfigManager::processConfigs(std::function<bool(const InterfaceConfig &)> configSink)
-{
- for (auto it = m_configs.begin(); it != m_configs.end(); ++it)
- {
- if (false == configSink(it->second))
- {
- return false;
- }
- }
-
- return true;
-}
-
-bool ConfigManager::verifyServers(const InterfaceConfig &config)
-{
- auto updatedServers = config.servers();
-
- if (nullptr == updatedServers)
- {
- return false;
- }
-
- return std::equal(m_servers.begin(), m_servers.end(), updatedServers->begin(), updatedServers->end());
-}
-
-void ConfigManager::exportConfigs()
-{
- common::serialization::Serializer s;
-
- s << static_cast<uint32_t>(m_configs.size());
-
- for (auto it = m_configs.begin(); it != m_configs.end(); ++it)
- {
- it->second.serialize(s);
- }
-
- auto data = s.blob();
-
- m_clientSinkProxy->config(&data[0], static_cast<uint32_t>(data.size()));
-}
diff --git a/windows/windns/src/windns/configmanager.h b/windows/windns/src/windns/configmanager.h
deleted file mode 100644
index c6f14f7ea9..0000000000
--- a/windows/windns/src/windns/configmanager.h
+++ /dev/null
@@ -1,110 +0,0 @@
-#pragma once
-
-#include "interfaceconfig.h"
-#include "iclientsinkproxy.h"
-#include <map>
-#include <string>
-#include <mutex>
-#include <memory>
-#include <functional>
-
-//
-// The ConfigManager is engineered to track the "real" DNS configuration for an adapter.
-//
-// The situation is somewhat complicated, because a given system may have several adapters, which
-// in turn may have several configurations?
-//
-// Every update for every configuration is recorded, bar the ones that correspond to us
-// overriding the DNS settings.
-//
-
-class ConfigManager
-{
-public:
-
- struct Mutex
- {
- Mutex(const Mutex &) = delete;
- Mutex &operator=(const Mutex &) = delete;
- Mutex(Mutex &&) = delete;
- Mutex &operator=(Mutex &&) = delete;
-
- Mutex(ConfigManager &manager)
- : m_manager(manager)
- {
- m_manager.lock();
- }
-
- ~Mutex()
- {
- m_manager.unlock();
- }
-
- ConfigManager &m_manager;
- };
-
- //
- // "servers" specifies the set of servers used when overriding settings.
- // This enables filtering out the corresponding event.
- //
- ConfigManager
- (
- const std::vector<std::wstring> &servers,
- IClientSinkProxy *clientSinkProxy
- );
-
- //
- // The ConfigManager is shared between threads.
- // Locking is managed externally for reasons of efficiency.
- //
- void lock();
- void unlock();
-
- //
- // Establish the set of servers to use when overriding DNS settings.
- //
- void updateServers(const std::vector<std::wstring> &servers);
-
- //
- // Get the current set of servers used for overriding DNS settings.
- //
- const std::vector<std::wstring> &getServers() const;
-
- //
- // Notify the ConfigManager that a live configuration has been updated.
- //
- enum class UpdateStatus
- {
- DnsApproved,
- DnsDeviates
- };
-
- UpdateStatus updateConfig(const InterfaceConfig &previous, const InterfaceConfig &target);
-
- //
- // Enumerate recorded configs.
- //
- bool processConfigs(std::function<bool(const InterfaceConfig &)> configSink);
-
-private:
-
- std::mutex m_mutex;
-
- std::vector<std::wstring> m_servers;
- IClientSinkProxy *m_clientSinkProxy;
-
- //
- // Organize configs based on their system assigned index.
- //
- std::map<uint32_t, InterfaceConfig> m_configs;
-
- //
- // Check DNS server list to see if it matches what we're trying to enforce.
- //
- bool verifyServers(const InterfaceConfig &config);
-
- //
- // Bundle the current config details and send them into the config sink.
- //
- void exportConfigs();
-};
diff --git a/windows/windns/src/windns/dnsagent.cpp b/windows/windns/src/windns/dnsagent.cpp
new file mode 100644
index 0000000000..6ff8ee9b39
--- /dev/null
+++ b/windows/windns/src/windns/dnsagent.cpp
@@ -0,0 +1,442 @@
+#include "stdafx.h"
+#include "dnsagent.h"
+#include "registrypaths.h"
+#include "netsh.h"
+#include <libcommon/trace/xtrace.h>
+#include <libcommon/error.h>
+#include <process.h>
+#include <algorithm>
+
+DnsAgent::DnsAgent(Protocol protocol, INameServerSource *nameServerSource, IRecoverySink *recoverySink, ILogSink *logSink)
+ : m_protocol(protocol)
+ , m_nameServerSource(nameServerSource)
+ , m_recoverySink(recoverySink)
+ , m_logSink(logSink)
+ , m_thread(nullptr)
+ , m_shutdownEvent(nullptr)
+{
+ constructNameServerUpdateEvent();
+ constructRootMonitor();
+
+ startTrackingInterfaces(discoverInterfaces());
+ updateRecoveryData();
+
+ constructThread();
+}
+
+DnsAgent::~DnsAgent()
+{
+ SetEvent(m_shutdownEvent);
+ WaitForSingleObject(m_thread, INFINITE);
+
+ CloseHandle(m_shutdownEvent);
+ CloseHandle(m_thread);
+
+ m_nameServerSource->unsubscribe(m_serverSourceEvent);
+ CloseHandle(m_serverSourceEvent);
+}
+
+void DnsAgent::constructNameServerUpdateEvent()
+{
+ m_serverSourceEvent = CreateEventW(nullptr, TRUE, FALSE, nullptr);
+
+ THROW_GLE_IF(nullptr, m_serverSourceEvent, "Create name server subscription event");
+
+ m_nameServerSource->subscribe(m_serverSourceEvent);
+}
+
+void DnsAgent::constructRootMonitor()
+{
+ m_rootMonitor = common::registry::Registry::MonitorKey(HKEY_LOCAL_MACHINE,
+ RegistryPaths::InterfaceRoot(m_protocol), { common::registry::RegistryEventFlag::SubkeyChange });
+}
+
+void DnsAgent::constructThread()
+{
+ m_shutdownEvent = CreateEventW(nullptr, TRUE, FALSE, nullptr);
+
+ THROW_GLE_IF(nullptr, m_shutdownEvent, "Create shutdown event");
+
+ auto rawThreadHandle = _beginthreadex(nullptr, 0, &DnsAgent::ThreadEntry, this, 0, nullptr);
+
+ if (0 == rawThreadHandle)
+ {
+ throw std::runtime_error("Could not create monitoring thread");
+ }
+
+ m_thread = reinterpret_cast<HANDLE>(rawThreadHandle);
+}
+
+//static
+unsigned __stdcall DnsAgent::ThreadEntry(void *parameters)
+{
+ try
+ {
+ reinterpret_cast<DnsAgent *>(parameters)->thread();
+ }
+ catch (std::exception &err)
+ {
+ const char *what = err.what();
+
+ reinterpret_cast<DnsAgent *>(parameters)->m_logSink->error
+ (
+ "Critical error in monitoring thread", &what, 1
+ );
+ }
+ catch (...)
+ {
+ reinterpret_cast<DnsAgent *>(parameters)->m_logSink->error
+ (
+ "Unspecified critical error in monitoring thread"
+ );
+ }
+
+ return 0;
+}
+
+void DnsAgent::thread()
+{
+ for (;;)
+ {
+ std::vector<HANDLE> waitHandles;
+
+ //
+ // Reserve enough space in the array to hold:
+ //
+ // Shutdown event handle
+ // Name servers source update event
+ // Monitor handle for interfaces root key
+ // Monitor handles for all interfaces
+ //
+ waitHandles.reserve(3 + m_trackedInterfaces.size());
+
+ const size_t shutdownEventIndex = 0;
+ const size_t serverSourceEventIndex = 1;
+ const size_t rootKeyEventIndex = 2;
+ const size_t firstInterfaceIndex = 3;
+
+ waitHandles.push_back(m_shutdownEvent);
+ waitHandles.push_back(m_serverSourceEvent);
+ waitHandles.push_back(m_rootMonitor->queueSingleEvent());
+
+ for (auto &interfaceData : m_trackedInterfaces)
+ {
+ waitHandles.push_back(interfaceData.monitor->queueSingleEvent());
+ }
+
+ //
+ // Wait for one or more events to become signalled.
+ //
+
+ const auto status = WaitForMultipleObjects(static_cast<DWORD>(waitHandles.size()), &waitHandles[0], FALSE, INFINITE);
+
+ if (WAIT_FAILED == status)
+ {
+ m_logSink->error("Failed to wait on events. Restarting wait in 1 minute.");
+
+ if (WAIT_OBJECT_0 == WaitForSingleObject(m_shutdownEvent, 1000 * 60))
+ {
+ break;
+ }
+
+ continue;
+ }
+
+ const size_t firstSignalledIndex = status - WAIT_OBJECT_0;
+
+ if (firstSignalledIndex == shutdownEventIndex)
+ {
+ break;
+ }
+
+ if (firstSignalledIndex >= firstInterfaceIndex)
+ {
+ XTRACE(L"Interface event is signalled");
+
+ const auto result = processInterfaceEvent(&waitHandles[firstInterfaceIndex],
+ firstSignalledIndex - firstInterfaceIndex);
+
+ if (ProcessingResult::TrackingUpdated == result)
+ {
+ updateRecoveryData();
+ }
+
+ continue;
+ }
+
+ //
+ // We can't easily tell which events have been signalled.
+ //
+
+ const auto interfaceResult = processInterfaceEvent(&waitHandles[firstInterfaceIndex], 0);
+ auto rootResult = ProcessingResult::Nop;
+
+ if (WAIT_OBJECT_0 == WaitForSingleObject(waitHandles[rootKeyEventIndex], 0))
+ {
+ XTRACE(L"Interfaces root key event is signalled");
+
+ rootResult = processRootKeyEvent();
+ }
+
+ if (ProcessingResult::TrackingUpdated == interfaceResult
+ || ProcessingResult::TrackingUpdated == rootResult)
+ {
+ updateRecoveryData();
+ }
+
+ if (WAIT_OBJECT_0 == WaitForSingleObject(waitHandles[serverSourceEventIndex], 0))
+ {
+ XTRACE(L"Server source update event is signalled");
+ ResetEvent(m_serverSourceEvent);
+
+ processServerSourceEvent();
+ }
+ }
+
+ XTRACE(L"Thread is exiting");
+}
+
+void DnsAgent::processServerSourceEvent()
+{
+ //
+ // Check actual interface settings to determine which interfaces
+ // need to have their settings overridden.
+ //
+ // Do NOT update 'preservedSettings' on the tracking entries because
+ // it would overwrite legitimate settings with the previously enforced settings.
+ //
+
+ std::vector<std::wstring> interfaces;
+ interfaces.reserve(m_trackedInterfaces.size());
+
+ std::transform(m_trackedInterfaces.begin(), m_trackedInterfaces.end(), std::back_inserter(interfaces), [](const InterfaceData &interfaceData)
+ {
+ return interfaceData.interfaceGuid;
+ });
+
+ const auto updatedSnaps = createSnaps(interfaces);
+ const auto enforcedServers = m_nameServerSource->getNameServers(m_protocol);
+
+ for (const auto snap : updatedSnaps)
+ {
+ if (snap.needsOverriding(enforcedServers))
+ {
+ setNameServers(snap.interfaceGuid(), enforcedServers);
+ }
+ }
+}
+
+DnsAgent::ProcessingResult DnsAgent::processRootKeyEvent()
+{
+ ProcessingResult result = ProcessingResult::Nop;
+
+ std::vector<std::wstring> oldInterfaces;
+ oldInterfaces.reserve(m_trackedInterfaces.size());
+
+ std::transform(m_trackedInterfaces.begin(), m_trackedInterfaces.end(), std::back_inserter(oldInterfaces), [](const InterfaceData &interfaceData)
+ {
+ return interfaceData.interfaceGuid;
+ });
+
+ auto currentInterfaces = discoverInterfaces();
+
+ std::sort(oldInterfaces.begin(), oldInterfaces.end());
+ std::sort(currentInterfaces.begin(), currentInterfaces.end());
+
+ //
+ // Stop tracking interfaces that have been removed.
+ //
+
+ std::vector<std::wstring> removedInterfaces;
+
+ std::set_difference(oldInterfaces.begin(), oldInterfaces.end(), currentInterfaces.begin(), currentInterfaces.end(),
+ std::back_inserter(removedInterfaces));
+
+ if (false == removedInterfaces.empty())
+ {
+ result = ProcessingResult::TrackingUpdated;
+ stopTrackingInterfaces(removedInterfaces);
+ }
+
+ //
+ // Start tracking new interfaces.
+ //
+
+ std::vector<std::wstring> newInterfaces;
+
+ std::set_difference(currentInterfaces.begin(), currentInterfaces.end(), oldInterfaces.begin(), oldInterfaces.end(),
+ std::back_inserter(newInterfaces));
+
+ if (false == newInterfaces.empty())
+ {
+ result = ProcessingResult::TrackingUpdated;
+ startTrackingInterfaces(newInterfaces);
+ }
+
+ return result;
+}
+
+DnsAgent::ProcessingResult DnsAgent::processInterfaceEvent(const HANDLE *interfaceEvents, size_t startIndex)
+{
+ ProcessingResult result = ProcessingResult::Nop;
+
+ //
+ // 'interfaceEvents' runs in parallel with 'm_trackedInterfaces'.
+ //
+
+ const auto enforcedNameServers = m_nameServerSource->getNameServers(m_protocol);
+
+ for (size_t i = startIndex; i < m_trackedInterfaces.size(); ++i)
+ {
+ if (WAIT_TIMEOUT == WaitForSingleObject(interfaceEvents[i], 0))
+ {
+ continue;
+ }
+
+ auto &interface = m_trackedInterfaces[i];
+
+ XTRACE(L"Processing event for interface ", interface.interfaceGuid);
+
+ try
+ {
+ InterfaceSnap updatedSnap(m_protocol, interface.interfaceGuid);
+
+ if (updatedSnap.needsOverriding(enforcedNameServers))
+ {
+ result = ProcessingResult::TrackingUpdated;
+
+ interface.preservedSettings = std::move(updatedSnap);
+ setNameServers(interface.interfaceGuid, enforcedNameServers);
+ }
+ }
+ catch (std::exception &err)
+ {
+ const char *what = err.what();
+
+ m_logSink->error("Could not fetch updated interface settings. Probably because the interface was removed.", &what, 1);
+
+ continue;
+ }
+ catch (...)
+ {
+ m_logSink->error("Could not fetch updated interface settings. Probably because the interface was removed.");
+
+ continue;
+ }
+ }
+
+ return result;
+}
+
+std::vector<std::wstring> DnsAgent::discoverInterfaces()
+{
+ auto regKey = common::registry::Registry::OpenKey(HKEY_LOCAL_MACHINE, RegistryPaths::InterfaceRoot(m_protocol));
+
+ std::vector<std::wstring> interfaces;
+
+ interfaces.reserve(20);
+
+ regKey->enumerateSubKeys([&interfaces](const std::wstring &keyName)
+ {
+ interfaces.push_back(keyName);
+ return true;
+ });
+
+ return interfaces;
+}
+
+std::vector<InterfaceSnap> DnsAgent::createSnaps(const std::vector<std::wstring> &interfaces)
+{
+ std::vector<InterfaceSnap> snaps;
+
+ snaps.reserve(interfaces.size());
+
+ for (const auto &interface : interfaces)
+ {
+ snaps.emplace_back(m_protocol, interface);
+ }
+
+ return snaps;
+}
+
+void DnsAgent::setNameServers(const std::wstring &interfaceGuid, const std::vector<std::wstring> &enforcedServers)
+{
+ XTRACE(L"Overriding name servers for interface ", interfaceGuid);
+
+ if (Protocol::IPv4 == m_protocol)
+ {
+ NetSh::Instance().SetIpv4StaticDns(NetSh::ConvertInterfaceGuidToIndex(interfaceGuid), enforcedServers);
+ }
+ else
+ {
+ NetSh::Instance().SetIpv6StaticDns(NetSh::ConvertInterfaceGuidToIndex(interfaceGuid), enforcedServers);
+ }
+}
+
+void DnsAgent::startTrackingInterfaces(const std::vector<std::wstring> &interfaces)
+{
+ const auto snaps = createSnaps(interfaces);
+
+ //
+ // Override configured name servers on all interfaces, as necessary.
+ //
+
+ const auto enforcedServers = m_nameServerSource->getNameServers(m_protocol);
+
+ for (const auto &snap : snaps)
+ {
+ if (snap.needsOverriding(enforcedServers))
+ {
+ setNameServers(snap.interfaceGuid(), enforcedServers);
+ }
+ }
+
+ //
+ // Create a tracking record for each interface.
+ //
+
+ for (const auto &snap : snaps)
+ {
+ const auto interfaceGuid = snap.interfaceGuid();
+
+ XTRACE(L"Creating tracking entry for interface ", interfaceGuid);
+
+ m_trackedInterfaces.emplace_back(interfaceGuid, snap, std::make_unique<InterfaceMonitor>(m_protocol, interfaceGuid));
+ }
+}
+
+void DnsAgent::stopTrackingInterfaces(const std::vector<std::wstring> &interfaces)
+{
+ for (const auto &interfaceGuid : interfaces)
+ {
+ auto iter = std::find_if(m_trackedInterfaces.begin(), m_trackedInterfaces.end(), [&interfaceGuid](const InterfaceData &candidate)
+ {
+ return candidate.interfaceGuid == interfaceGuid;
+ });
+
+ if (m_trackedInterfaces.end() == iter)
+ {
+ m_logSink->error("Request to stop tracking non-tracked interface, ignoring.");
+
+ continue;
+ }
+
+ XTRACE(L"Cancel tracking of interface ", interfaceGuid);
+
+ m_trackedInterfaces.erase(iter);
+ }
+}
+
+void DnsAgent::updateRecoveryData()
+{
+ std::vector<InterfaceSnap> snaps;
+
+ snaps.reserve(m_trackedInterfaces.size());
+
+ std::transform(m_trackedInterfaces.begin(), m_trackedInterfaces.end(), std::back_inserter(snaps), [](const InterfaceData &interfaceData)
+ {
+ return interfaceData.preservedSettings;
+ });
+
+ m_recoverySink->preserveSnaps(m_protocol, snaps);
+}
diff --git a/windows/windns/src/windns/dnsagent.h b/windows/windns/src/windns/dnsagent.h
new file mode 100644
index 0000000000..48dd12c52d
--- /dev/null
+++ b/windows/windns/src/windns/dnsagent.h
@@ -0,0 +1,84 @@
+#pragma once
+
+#include "interfacesnap.h"
+#include "interfacemonitor.h"
+#include "types.h"
+#include "inameserversource.h"
+#include "irecoverysink.h"
+#include "ilogsink.h"
+#include <libcommon/registry/registry.h>
+#include <string>
+#include <vector>
+#include <memory>
+#include <windows.h>
+
+//
+// DnsAgent:
+// Monitor interfaces and enforce name server settings.
+//
+class DnsAgent
+{
+public:
+
+ DnsAgent(Protocol protocol, INameServerSource *nameServerSource, IRecoverySink *recoverySink, ILogSink *logSink);
+ ~DnsAgent();
+
+private:
+
+ Protocol m_protocol;
+ INameServerSource *m_nameServerSource;
+ IRecoverySink *m_recoverySink;
+ ILogSink *m_logSink;
+
+ //
+ // InterfaceData:
+ // Tracking entry for a network interface.
+ //
+ struct InterfaceData
+ {
+ InterfaceData(const std::wstring &interfaceGuid_, const InterfaceSnap &snap_, std::unique_ptr<InterfaceMonitor> &&monitor_)
+ : interfaceGuid(interfaceGuid_), preservedSettings(snap_), monitor(std::move(monitor_))
+ {
+ }
+
+ std::wstring interfaceGuid;
+ InterfaceSnap preservedSettings;
+ std::unique_ptr<InterfaceMonitor> monitor;
+ };
+
+ std::vector<InterfaceData> m_trackedInterfaces;
+
+ std::unique_ptr<common::registry::RegistryMonitor> m_rootMonitor;
+
+ HANDLE m_serverSourceEvent;
+ HANDLE m_thread;
+ HANDLE m_shutdownEvent;
+
+ void constructNameServerUpdateEvent();
+ void constructRootMonitor();
+ void constructThread();
+
+ static unsigned __stdcall ThreadEntry(void *);
+ void thread();
+
+ void processServerSourceEvent();
+
+ enum class ProcessingResult
+ {
+ TrackingUpdated,
+ Nop
+ };
+
+ ProcessingResult processRootKeyEvent();
+ ProcessingResult processInterfaceEvent(const HANDLE *interfaceEvents, size_t startIndex);
+
+ std::vector<std::wstring> discoverInterfaces();
+ std::vector<InterfaceSnap> createSnaps(const std::vector<std::wstring> &interfaces);
+
+ void setNameServers(const std::wstring &interfaceGuid, const std::vector<std::wstring> &enforcedServers);
+
+ void startTrackingInterfaces(const std::vector<std::wstring> &interfaces);
+ void stopTrackingInterfaces(const std::vector<std::wstring> &interfaces);
+
+ void updateRecoveryData();
+};
diff --git a/windows/windns/src/windns/iclientsinkproxy.h b/windows/windns/src/windns/iclientsinkproxy.h
deleted file mode 100644
index 9270a12c50..0000000000
--- a/windows/windns/src/windns/iclientsinkproxy.h
+++ /dev/null
@@ -1,13 +0,0 @@
-#pragma once
-
-#include <cstdint>
-
-struct IClientSinkProxy
-{
- virtual ~IClientSinkProxy() = 0
- {
- }
-
- virtual void error(const char *errorMessage, const char **details, uint32_t numDetails) = 0;
- virtual void config(const void *configData, uint32_t dataLength) = 0;
-};
diff --git a/windows/windns/src/windns/ilogsink.h b/windows/windns/src/windns/ilogsink.h
new file mode 100644
index 0000000000..dec1584ca3
--- /dev/null
+++ b/windows/windns/src/windns/ilogsink.h
@@ -0,0 +1,24 @@
+#pragma once
+
+#include <cstdint>
+
+struct ILogSink
+{
+ virtual ~ILogSink() = 0
+ {
+ }
+
+ virtual void error(const char *msg, const char **details, uint32_t numDetails) = 0;
+
+ virtual void error(const char *msg)
+ {
+ error(msg, nullptr, 0);
+ }
+
+ virtual void info(const char *msg, const char **details, uint32_t numDetails) = 0;
+
+ virtual void info(const char *msg)
+ {
+ info(msg, nullptr, 0);
+ }
+};
diff --git a/windows/windns/src/windns/inameserversource.h b/windows/windns/src/windns/inameserversource.h
new file mode 100644
index 0000000000..f97209ce7d
--- /dev/null
+++ b/windows/windns/src/windns/inameserversource.h
@@ -0,0 +1,23 @@
+#pragma once
+
+#include "types.h"
+#include <vector>
+#include <string>
+
+//
+// Provide the array of name servers that we enforce on all adapters.
+//
+struct INameServerSource
+{
+ virtual ~INameServerSource() = 0
+ {
+ }
+
+ virtual std::vector<std::wstring> getNameServers(Protocol protocol) const = 0;
+
+ //
+ // Get notified if the servers array is updated.
+ //
+ virtual void subscribe(HANDLE eventHandle) = 0;
+ virtual void unsubscribe(HANDLE eventHandle) = 0;
+};
diff --git a/windows/windns/src/windns/interfaceconfig.cpp b/windows/windns/src/windns/interfaceconfig.cpp
deleted file mode 100644
index 74e912de89..0000000000
--- a/windows/windns/src/windns/interfaceconfig.cpp
+++ /dev/null
@@ -1,66 +0,0 @@
-#include "stdafx.h"
-#include "interfaceconfig.h"
-#include "netconfighelpers.h"
-#include "libcommon/com.h"
-#include "libcommon/wmi/wmi.h"
-
-using namespace common;
-
-InterfaceConfig::InterfaceConfig(CComPtr<IWbemClassObject> instance)
-{
- //
- // V_xxx macros seem to require an l-value so access the correct field directly instead.
- //
-
- m_configIndex = wmi::WmiGetPropertyAlways(instance, L"Index").ulVal;
-
- m_dhcp = wmi::WmiGetPropertyAlways(instance, L"DHCPEnabled").boolVal;
-
- m_interfaceIndex = wmi::WmiGetPropertyAlways(instance, L"InterfaceIndex").ulVal;
- m_interfaceGuid = ComConvertString(wmi::WmiGetPropertyAlways(instance, L"SettingID").bstrVal);
-
- m_servers = nchelpers::GetDnsServers(instance);
-}
-
-InterfaceConfig::InterfaceConfig(common::serialization::Deserializer &deserializer)
-{
- common::serialization::Deserializer &d = deserializer;
-
- d >> m_configIndex;
- d >> (uint8_t &)m_dhcp;
- d >> m_interfaceIndex;
- d >> m_interfaceGuid;
-
- bool serversAvailable;
-
- d >> (uint8_t &)serversAvailable;
-
- if (serversAvailable)
- {
- m_servers = std::make_shared<std::vector<std::wstring> >();
- d >> *m_servers;
- }
-}
-
-void InterfaceConfig::serialize(common::serialization::Serializer &serializer) const
-{
- common::serialization::Serializer &s = serializer;
-
- s << m_configIndex;
- s << (uint8_t)m_dhcp;
- s << m_interfaceIndex;
- s << m_interfaceGuid;
-
- //
- // TODO: Encapsulate this inside a new type.
- //
- if (nullptr == m_servers.get())
- {
- s << (uint8_t)0;
- }
- else
- {
- s << (uint8_t)1;
- s << *m_servers;
- }
-}
diff --git a/windows/windns/src/windns/interfaceconfig.h b/windows/windns/src/windns/interfaceconfig.h
deleted file mode 100644
index e95296b7fb..0000000000
--- a/windows/windns/src/windns/interfaceconfig.h
+++ /dev/null
@@ -1,62 +0,0 @@
-#pragma once
-
-#include "types.h"
-#include "libcommon/serialization/deserializer.h"
-#include "libcommon/serialization/serializer.h"
-#include <cstdint>
-#include <string>
-#include <vector>
-#include <atlbase.h>
-#include <wbemidl.h>
-
-class InterfaceConfig
-{
-public:
-
- // instance = Win32_NetworkAdapterConfiguration.
- explicit InterfaceConfig(CComPtr<IWbemClassObject> instance);
-
- explicit InterfaceConfig(common::serialization::Deserializer &deserializer);
- void serialize(common::serialization::Serializer &serializer) const;
-
- void updateServers(const InterfaceConfig &rhs)
- {
- m_servers = rhs.m_servers;
- }
-
- uint32_t configIndex() const
- {
- return m_configIndex;
- }
-
- bool dhcp() const
- {
- return m_dhcp;
- }
-
- uint32_t interfaceIndex() const
- {
- return m_interfaceIndex;
- }
-
- const std::wstring &interfaceGuid() const
- {
- return m_interfaceGuid;
- }
-
- const std::vector<std::wstring> *servers() const
- {
- return m_servers.get();
- }
-
-private:
-
- uint32_t m_configIndex;
-
- bool m_dhcp;
-
- uint32_t m_interfaceIndex;
- std::wstring m_interfaceGuid;
-
- OptionalStringList m_servers;
-};
diff --git a/windows/windns/src/windns/interfacemonitor.cpp b/windows/windns/src/windns/interfacemonitor.cpp
new file mode 100644
index 0000000000..322bb69e57
--- /dev/null
+++ b/windows/windns/src/windns/interfacemonitor.cpp
@@ -0,0 +1,24 @@
+#include "stdafx.h"
+#include "interfacemonitor.h"
+#include "registrypaths.h"
+
+using namespace common::registry;
+
+InterfaceMonitor::InterfaceMonitor(Protocol protocol, const std::wstring &interfaceGuid)
+ : m_protocol(protocol)
+ , m_interfaceGuid(interfaceGuid)
+{
+ const auto interfacePath = RegistryPaths::InterfaceKey(interfaceGuid, protocol);
+
+ m_monitor = Registry::MonitorKey(HKEY_LOCAL_MACHINE, interfacePath, { RegistryEventFlag::ValueChange });
+}
+
+HANDLE InterfaceMonitor::queueSingleEvent()
+{
+ return m_monitor->queueSingleEvent();
+}
+
+const std::wstring &InterfaceMonitor::interfaceGuid() const
+{
+ return m_interfaceGuid;
+}
diff --git a/windows/windns/src/windns/interfacemonitor.h b/windows/windns/src/windns/interfacemonitor.h
new file mode 100644
index 0000000000..f40c88dab6
--- /dev/null
+++ b/windows/windns/src/windns/interfacemonitor.h
@@ -0,0 +1,30 @@
+#pragma once
+
+#include "types.h"
+#include <libcommon/registry/registry.h>
+#include <string>
+#include <memory>
+#include <windows.h>
+
+class InterfaceMonitor
+{
+public:
+
+ explicit InterfaceMonitor(Protocol protocol, const std::wstring &interfaceGuid);
+
+ //
+ // The event becomes signalled if:
+ // 1. A value change occurs.
+ // 2. The monitored interface's registry key is deleted.
+ //
+ HANDLE queueSingleEvent();
+
+ const std::wstring &interfaceGuid() const;
+
+private:
+
+ Protocol m_protocol;
+ std::wstring m_interfaceGuid;
+
+ std::unique_ptr<common::registry::RegistryMonitor> m_monitor;
+};
diff --git a/windows/windns/src/windns/interfacesnap.cpp b/windows/windns/src/windns/interfacesnap.cpp
new file mode 100644
index 0000000000..309296897f
--- /dev/null
+++ b/windows/windns/src/windns/interfacesnap.cpp
@@ -0,0 +1,151 @@
+#include "stdafx.h"
+#include "interfacesnap.h"
+#include "registrypaths.h"
+#include <libcommon/registry/registry.h>
+#include <libcommon/string.h>
+#include <cstdint>
+
+using namespace common::registry;
+
+namespace
+{
+
+enum class NameServerType
+{
+ Static,
+ Dhcp
+};
+
+std::vector<std::wstring> GetNameServers(const std::wstring &interfaceGuid,
+ Protocol protocol, NameServerType nameServerType)
+{
+ const auto interfacePath = RegistryPaths::InterfaceKey(interfaceGuid, protocol);
+
+ const auto regKey = Registry::OpenKey(HKEY_LOCAL_MACHINE, interfacePath);
+
+ std::wstring nameservers;
+
+ try
+ {
+ //
+ // This particular value is a string array packed into a string data type.
+ // REG_MULTI_SZ would have been the correct type to use, but there
+ // are probably historical reasons for the value type currently being used.
+ //
+ nameservers = regKey->readString(NameServerType::Static == nameServerType ? L"NameServer" : L"DhcpNameServer");
+ }
+ catch (...)
+ {
+ }
+
+ if (nameservers.empty())
+ {
+ return std::vector<std::wstring>();
+ }
+
+ return common::string::Tokenize(nameservers, L",");
+}
+
+bool GetDhcpEnabled(const std::wstring &interfaceGuid, Protocol protocol)
+{
+ const auto interfacePath = RegistryPaths::InterfaceKey(interfaceGuid, protocol);
+
+ const auto regKey = Registry::OpenKey(HKEY_LOCAL_MACHINE, interfacePath);
+
+ bool enabled = false;
+
+ try
+ {
+ const auto flag = regKey->readUint32(L"EnableDHCP");
+
+ enabled = (1 == flag);
+ }
+ catch (...)
+ {
+ }
+
+ return enabled;
+}
+
+} // anonymous namespace
+
+InterfaceSnap::InterfaceSnap(Protocol protocol, const std::wstring &interfaceGuid)
+ : m_protocol(protocol)
+ , m_interfaceGuid(interfaceGuid)
+{
+ m_configuredForDhcp = GetDhcpEnabled(m_interfaceGuid, m_protocol);
+
+ // Static name servers are configured by the user.
+ m_staticNameServers = GetNameServers(m_interfaceGuid, m_protocol, NameServerType::Static);
+
+ // DHCP name servers are the servers most recently supplied by DHCP.
+ // An adapter can be configured for DHCP and static name servers at the same time.
+ // Static name servers always have precedence.
+ m_dhcpNameServers = GetNameServers(m_interfaceGuid, m_protocol, NameServerType::Dhcp);
+}
+
+InterfaceSnap::InterfaceSnap(common::serialization::Deserializer &deserializer)
+{
+ common::serialization::Deserializer &d = deserializer;
+
+ d >> (uint8_t &)m_protocol;
+
+ if (m_protocol != Protocol::IPv4
+ && m_protocol != Protocol::IPv6)
+ {
+ throw std::runtime_error("Serialized data for 'InterfaceSnap' instance is invalid (protocol)");
+ }
+
+ d >> m_interfaceGuid;
+ d >> (uint8_t &)m_configuredForDhcp;
+ d >> m_staticNameServers;
+ d >> m_dhcpNameServers;
+}
+
+void InterfaceSnap::serialize(common::serialization::Serializer &serializer) const
+{
+ common::serialization::Serializer &s = serializer;
+
+ s << (uint8_t &)m_protocol;
+ s << m_interfaceGuid;
+ s << (uint8_t &)m_configuredForDhcp;
+ s << m_staticNameServers;
+ s << m_dhcpNameServers;
+}
+
+bool InterfaceSnap::needsOverriding(const std::vector<std::wstring> &enforcedServers) const
+{
+ if (internalInterface())
+ {
+ return false;
+ }
+
+ //
+ // The interface has static DNS, or
+ // The interface has DNS provided by the DHCP server, or
+ // The interface *will get* DNS provided to it by the DHCP server
+ //
+
+ //
+ // It's not enough that m_staticNameServers has the same elements.
+ // The order defines primary and secondary name server and has to match.
+ //
+
+ return m_staticNameServers != enforcedServers;
+}
+
+const std::wstring &InterfaceSnap::interfaceGuid() const
+{
+ return m_interfaceGuid;
+}
+
+const std::vector<std::wstring> &InterfaceSnap::nameServers() const
+{
+ return m_staticNameServers;
+}
+
+bool InterfaceSnap::internalInterface() const
+{
+ return false == m_configuredForDhcp
+ && m_staticNameServers.empty();
+}
diff --git a/windows/windns/src/windns/interfacesnap.h b/windows/windns/src/windns/interfacesnap.h
new file mode 100644
index 0000000000..4d1156b615
--- /dev/null
+++ b/windows/windns/src/windns/interfacesnap.h
@@ -0,0 +1,35 @@
+#pragma once
+
+#include "types.h"
+#include <libcommon/serialization/serializer.h>
+#include <libcommon/serialization/deserializer.h>
+#include <string>
+#include <vector>
+
+class InterfaceSnap
+{
+public:
+
+ explicit InterfaceSnap(Protocol protocol, const std::wstring &interfaceGuid);
+
+ explicit InterfaceSnap(common::serialization::Deserializer &deserializer);
+ void serialize(common::serialization::Serializer &serializer) const;
+
+ bool needsOverriding(const std::vector<std::wstring> &enforcedServers) const;
+
+ const std::wstring &interfaceGuid() const;
+
+ const std::vector<std::wstring> &nameServers() const;
+
+ bool internalInterface() const;
+
+private:
+
+ Protocol m_protocol;
+ std::wstring m_interfaceGuid;
+
+ bool m_configuredForDhcp;
+
+ std::vector<std::wstring> m_staticNameServers;
+ std::vector<std::wstring> m_dhcpNameServers;
+};
diff --git a/windows/windns/src/windns/irecoverysink.h b/windows/windns/src/windns/irecoverysink.h
new file mode 100644
index 0000000000..8251625053
--- /dev/null
+++ b/windows/windns/src/windns/irecoverysink.h
@@ -0,0 +1,14 @@
+#pragma once
+
+#include "types.h"
+#include "interfacesnap.h"
+#include <vector>
+
+struct IRecoverySink
+{
+ virtual ~IRecoverySink() = 0
+ {
+ }
+
+ virtual void preserveSnaps(Protocol protocol, const std::vector<InterfaceSnap> &snaps) = 0;
+};
diff --git a/windows/windns/src/windns/logsink.cpp b/windows/windns/src/windns/logsink.cpp
new file mode 100644
index 0000000000..b5f127d25c
--- /dev/null
+++ b/windows/windns/src/windns/logsink.cpp
@@ -0,0 +1,34 @@
+#include "stdafx.h"
+#include "logsink.h"
+
+LogSink::LogSink(const LogSinkInfo &target)
+ : m_target(target)
+{
+}
+
+void LogSink::setTarget(const LogSinkInfo &target)
+{
+ std::scoped_lock<std::mutex> lock(m_targetMutex);
+
+ m_target = target;
+}
+
+void LogSink::error(const char *msg, const char **details, uint32_t numDetails)
+{
+ std::scoped_lock<std::mutex> lock(m_targetMutex);
+
+ if (nullptr != m_target.sink)
+ {
+ m_target.sink(WINDNS_LOG_CATEGORY_ERROR, msg, details, numDetails, m_target.context);
+ }
+}
+
+void LogSink::info(const char *msg, const char **details, uint32_t numDetails)
+{
+ std::scoped_lock<std::mutex> lock(m_targetMutex);
+
+ if (nullptr != m_target.sink)
+ {
+ m_target.sink(WINDNS_LOG_CATEGORY_INFO, msg, details, numDetails, m_target.context);
+ }
+}
diff --git a/windows/windns/src/windns/logsink.h b/windows/windns/src/windns/logsink.h
new file mode 100644
index 0000000000..396320a0a8
--- /dev/null
+++ b/windows/windns/src/windns/logsink.h
@@ -0,0 +1,22 @@
+#pragma once
+
+#include "clientsinkinfo.h"
+#include "ilogsink.h"
+#include <mutex>
+
+class LogSink : public ILogSink
+{
+public:
+
+ LogSink(const LogSinkInfo &target);
+
+ void setTarget(const LogSinkInfo &target);
+
+ void error(const char *msg, const char **details, uint32_t numDetails) override;
+ void info(const char *msg, const char **details, uint32_t numDetails) override;
+
+private:
+
+ std::mutex m_targetMutex;
+ LogSinkInfo m_target;
+};
diff --git a/windows/windns/src/windns/nameserversource.cpp b/windows/windns/src/windns/nameserversource.cpp
new file mode 100644
index 0000000000..b574266efb
--- /dev/null
+++ b/windows/windns/src/windns/nameserversource.cpp
@@ -0,0 +1,77 @@
+#include "stdafx.h"
+#include "nameserversource.h"
+
+NameServerSource::NameServerSource(const std::vector<std::wstring> &ipv4NameServers,
+ const std::vector<std::wstring> &ipv6NameServers)
+ : m_ipv4NameServers(ipv4NameServers)
+ , m_ipv6NameServers(ipv6NameServers)
+{
+}
+
+void NameServerSource::setNameServers(Protocol protocol, const std::vector<std::wstring> &nameServers)
+{
+ {
+ std::scoped_lock<std::mutex> lock(m_nameServersMutex);
+
+ if (Protocol::IPv4 == protocol)
+ {
+ m_ipv4NameServers = nameServers;
+ }
+ else
+ {
+ m_ipv6NameServers = nameServers;
+ }
+ }
+
+ //
+ // Notify all subscribers.
+ //
+
+ std::scoped_lock<std::mutex> lock(m_subscriptionMutex);
+
+ for (HANDLE eventHandle : m_subscriptions)
+ {
+ SetEvent(eventHandle);
+ }
+}
+
+std::vector<std::wstring> NameServerSource::getNameServers(Protocol protocol) const
+{
+ std::vector<std::wstring> copy;
+
+ std::scoped_lock<std::mutex> lock(m_nameServersMutex);
+
+ if (Protocol::IPv4 == protocol)
+ {
+ copy = m_ipv4NameServers;
+ }
+ else
+ {
+ copy = m_ipv6NameServers;
+ }
+
+ return copy;
+}
+
+void NameServerSource::subscribe(HANDLE eventHandle)
+{
+ ResetEvent(eventHandle);
+
+ std::scoped_lock<std::mutex> lock(m_subscriptionMutex);
+
+ m_subscriptions.push_back(eventHandle);
+}
+
+void NameServerSource::unsubscribe(HANDLE eventHandle)
+{
+ std::scoped_lock<std::mutex> lock(m_subscriptionMutex);
+
+ auto it = std::find(m_subscriptions.begin(), m_subscriptions.end(), eventHandle);
+
+ if (m_subscriptions.end() == it)
+ {
+ return;
+ }
+
+ m_subscriptions.erase(it);
+}
diff --git a/windows/windns/src/windns/nameserversource.h b/windows/windns/src/windns/nameserversource.h
new file mode 100644
index 0000000000..7d46fcf264
--- /dev/null
+++ b/windows/windns/src/windns/nameserversource.h
@@ -0,0 +1,29 @@
+#pragma once
+
+#include "inameserversource.h"
+#include <mutex>
+
+class NameServerSource : public INameServerSource
+{
+public:
+
+ NameServerSource(const std::vector<std::wstring> &ipv4NameServers,
+ const std::vector<std::wstring> &ipv6NameServers);
+
+ void setNameServers(Protocol protocol, const std::vector<std::wstring> &nameServers);
+
+ std::vector<std::wstring> getNameServers(Protocol protocol) const override;
+
+ void subscribe(HANDLE eventHandle) override;
+ void unsubscribe(HANDLE eventHandle) override;
+
+private:
+
+ mutable std::mutex m_nameServersMutex;
+
+ std::vector<std::wstring> m_ipv4NameServers;
+ std::vector<std::wstring> m_ipv6NameServers;
+
+ std::mutex m_subscriptionMutex;
+ std::list<HANDLE> m_subscriptions;
+};
diff --git a/windows/windns/src/windns/netconfigeventsink.cpp b/windows/windns/src/windns/netconfigeventsink.cpp
deleted file mode 100644
index 5b8d4efd07..0000000000
--- a/windows/windns/src/windns/netconfigeventsink.cpp
+++ /dev/null
@@ -1,51 +0,0 @@
-#include "stdafx.h"
-#include "netconfigeventsink.h"
-#include "netconfighelpers.h"
-#include "netsh.h"
-#include "confineoperation.h"
-#include <functional>
-
-using namespace common;
-
-NetConfigEventSink::NetConfigEventSink
-(
- std::shared_ptr<wmi::IConnection> connection,
- std::shared_ptr<ConfigManager> configManager,
- IClientSinkProxy *clientSinkProxy
-)
- : m_connection(connection)
- , m_configManager(configManager)
- , m_clientSinkProxy(clientSinkProxy)
-{
-}
-
-void NetConfigEventSink::update(CComPtr<IWbemClassObject> previous, CComPtr<IWbemClassObject> target)
-{
- auto forwardError = [this](const char *errorMessage, const char **details, uint32_t numDetails)
- {
- m_clientSinkProxy->error(errorMessage, details, numDetails);
- };
-
- ConfineOperation("Process adapter update event", forwardError, [&]()
- {
- InterfaceConfig previousConfig(previous);
- InterfaceConfig targetConfig(target);
-
- ConfigManager::Mutex mutex(*m_configManager);
-
- //
- // This is OK because the config manager will reject updates
- // that set our DNS servers.
- //
- if (ConfigManager::UpdateStatus::DnsApproved == m_configManager->updateConfig(previousConfig, targetConfig))
- {
- return;
- }
-
- //
- // The update was initiated from an external source.
- // Override current settings to enforce our selected DNS servers.
- //
- nchelpers::SetDnsServers(targetConfig.interfaceIndex(), m_configManager->getServers());
- });
-}
diff --git a/windows/windns/src/windns/netconfigeventsink.h b/windows/windns/src/windns/netconfigeventsink.h
deleted file mode 100644
index c3fd9242f7..0000000000
--- a/windows/windns/src/windns/netconfigeventsink.h
+++ /dev/null
@@ -1,24 +0,0 @@
-#pragma once
-
-#include "libcommon/wmi/ieventsink.h"
-#include "libcommon/wmi/iconnection.h"
-#include "configmanager.h"
-#include "iclientsinkproxy.h"
-#include <memory>
-
-class NetConfigEventSink : public common::wmi::IModificationEventSink
-{
-public:
-
- NetConfigEventSink(std::shared_ptr<common::wmi::IConnection> connection,
- std::shared_ptr<ConfigManager> configManager, IClientSinkProxy *clientSinkProxy);
-
- void update(CComPtr<IWbemClassObject> previous, CComPtr<IWbemClassObject> target) override;
-
-private:
-
- std::shared_ptr<common::wmi::IConnection> m_connection;
- std::shared_ptr<ConfigManager> m_configManager;
-
- IClientSinkProxy *m_clientSinkProxy;
-};
diff --git a/windows/windns/src/windns/netconfighelpers.cpp b/windows/windns/src/windns/netconfighelpers.cpp
deleted file mode 100644
index 7156038b4e..0000000000
--- a/windows/windns/src/windns/netconfighelpers.cpp
+++ /dev/null
@@ -1,65 +0,0 @@
-#include "stdafx.h"
-#include "netconfighelpers.h"
-#include "libcommon/com.h"
-#include "libcommon/wmi/wmi.h"
-#include "libcommon/trace/xtrace.h"
-#include "netsh.h"
-
-using namespace common;
-
-namespace nchelpers
-{
-
-OptionalStringList GetDnsServers(CComPtr<IWbemClassObject> instance)
-{
- OptionalStringList result;
-
- auto servers = wmi::WmiGetProperty(instance, L"DNSServerSearchOrder");
-
- if (VT_EMPTY == V_VT(&servers) || VT_NULL == V_VT(&servers))
- {
- return result;
- }
-
- result = std::make_shared<std::vector<std::wstring> >(
- ComConvertStringArray(V_ARRAY(&servers)));
-
- return result;
-}
-
-uint32_t GetInterfaceIndex(CComPtr<IWbemClassObject> instance)
-{
- return wmi::WmiGetPropertyAlways(instance, L"InterfaceIndex").ulVal;
-}
-
-void SetDnsServers(uint32_t interfaceIndex, const std::vector<std::wstring> &servers)
-{
- NetSh::SetIpv4PrimaryDns(interfaceIndex, servers[0]);
-
- if (servers.size() > 1)
- {
- NetSh::SetIpv4SecondaryDns(interfaceIndex, servers[1]);
- }
-}
-
-void RevertDnsServers(const InterfaceConfig &config, uint32_t timeout)
-{
- XTRACE("Reverting DNS configuration for interface with index=", config.interfaceIndex());
-
- auto servers = config.servers();
-
- if (config.dhcp() || nullptr == servers || 0 == servers->size())
- {
- NetSh::SetIpv4Dhcp(config.interfaceIndex(), timeout);
- return;
- }
-
- NetSh::SetIpv4PrimaryDns(config.interfaceIndex(), (*servers)[0], timeout);
-
- if (servers->size() > 1)
- {
- NetSh::SetIpv4SecondaryDns(config.interfaceIndex(), (*servers)[1], timeout);
- }
-}
-
-}
diff --git a/windows/windns/src/windns/netconfighelpers.h b/windows/windns/src/windns/netconfighelpers.h
deleted file mode 100644
index 33b6c64cc2..0000000000
--- a/windows/windns/src/windns/netconfighelpers.h
+++ /dev/null
@@ -1,24 +0,0 @@
-#pragma once
-
-#include "types.h"
-#include "interfaceconfig.h"
-#include <string>
-#include <vector>
-#include <cstdint>
-#include <atlbase.h>
-#include <wbemidl.h>
-
-namespace nchelpers
-{
-
-// instance = Win32_NetworkAdapterConfiguration
-OptionalStringList GetDnsServers(CComPtr<IWbemClassObject> instance);
-
-// instance = Win32_NetworkAdapterConfiguration
-uint32_t GetInterfaceIndex(CComPtr<IWbemClassObject> instance);
-
-void SetDnsServers(uint32_t interfaceIndex, const std::vector<std::wstring> &servers);
-
-void RevertDnsServers(const InterfaceConfig &config, uint32_t timeout = 0);
-
-}
diff --git a/windows/windns/src/windns/netsh.cpp b/windows/windns/src/windns/netsh.cpp
index fdd8fbfa03..b8c2afff68 100644
--- a/windows/windns/src/windns/netsh.cpp
+++ b/windows/windns/src/windns/netsh.cpp
@@ -1,44 +1,17 @@
#include "stdafx.h"
#include "netsh.h"
-#include "libcommon/applicationrunner.h"
-#include "libcommon/string.h"
-#include "libcommon/filesystem.h"
+#include <libcommon/string.h>
+#include <libcommon/filesystem.h>
+#include <libcommon/guid.h>
#include <sstream>
#include <stdexcept>
#include <experimental/filesystem>
+#include <iphlpapi.h>
namespace
{
-ErrorSinkInfo g_ErrorSink = { nullptr, nullptr };
-
-std::wstring g_NetShPath;
-
-void InitializePath()
-{
- if (false == g_NetShPath.empty())
- {
- return;
- }
-
- const auto system32 = common::fs::GetKnownFolderPath(FOLDERID_System, 0, nullptr);
-
- g_NetShPath = std::experimental::filesystem::path(system32).append(L"netsh.exe");
-}
-
-const std::wstring &NetShPath()
-{
- InitializePath();
-
- return g_NetShPath;
-}
-
-void InfoSink(const char *msg)
-{
- auto infoMsg = std::string("INFO: ").append(msg);
-
- g_ErrorSink.sink(infoMsg.c_str(), nullptr, 0, g_ErrorSink.context);
-}
+NetSh *g_Instance = nullptr;
std::vector<std::string> BlockToRows(const std::string &textBlock)
{
@@ -87,97 +60,89 @@ __declspec(noreturn) void ThrowWithDetails(std::string &&error, common::Applicat
throw NetShError(std::move(error), std::move(details));
}
-void ValidateShellOut(common::ApplicationRunner &netsh, uint32_t timeout)
-{
- // Use default timeout of 4 seconds.
- const uint32_t actualTimeout = (0 == timeout ? 3000 : timeout);
-
- const auto startTime = GetTickCount64();
-
- DWORD returnCode;
+} // anonymous namespace
- if (false == netsh.join(returnCode, actualTimeout))
+//static
+void NetSh::Construct(ILogSink *logSink)
+{
+ if (nullptr != g_Instance)
{
- ThrowWithDetails("'netsh' did not complete in a timely manner", netsh);
+ throw std::runtime_error("NetSh is already constructed");
}
- if (returnCode != 0)
+ if (nullptr == logSink)
{
- std::stringstream ss;
-
- ss << "'netsh' failed the requested operation. Error: " << returnCode;
-
- ThrowWithDetails(ss.str(), netsh);
+ throw std::runtime_error("Invalid logger sink");
}
- const auto elapsed = static_cast<uint32_t>(GetTickCount64() - startTime);
-
- if (elapsed > (actualTimeout / 2))
- {
- std::stringstream ss;
-
- ss << L"'netsh' completed successfully, albeit a little slowly. It consumed "
- << elapsed << " ms of "
- << actualTimeout << " ms max permitted execution time";
-
- InfoSink(ss.str().c_str());
- }
+ g_Instance = new NetSh(logSink);
}
-} // anonymous namespace
-
//static
-void NetSh::RegisterErrorSink(const ErrorSinkInfo &errorSink)
+NetSh &NetSh::Instance()
{
- g_ErrorSink = errorSink;
+ if (nullptr == g_Instance)
+ {
+ throw std::runtime_error("NetSh is being referenced prior to being constructed");
+ }
+
+ return *g_Instance;
}
-//static
-void NetSh::SetIpv4PrimaryDns(uint32_t interfaceIndex, std::wstring server, uint32_t timeout)
+void NetSh::SetIpv4StaticDns(uint32_t interfaceIndex,
+ const std::vector<std::wstring> &nameServers, uint32_t timeout)
{
//
+ // Setting primary and secondary name server requires two invokations:
+ //
// netsh interface ipv4 set dnsservers name="Ethernet 2" source=static address=8.8.8.8 validate=no
+ // netsh interface ipv4 add dnsservers name="Ethernet 2" address=8.8.4.4 index=2 validate=no
//
// Note: we're specifying the interface by index instead.
//
- std::wstringstream ss;
+ if (nameServers.empty())
+ {
+ throw std::runtime_error("Invalid list of name servers (zero length list)");
+ }
- ss << L"interface ipv4 set dnsservers name="
- << interfaceIndex
- << L" source=static address="
- << server
- << L" validate=no";
+ {
+ std::wstringstream ss;
- auto netsh = common::ApplicationRunner::StartWithoutConsole(NetShPath(), ss.str());
+ ss << L"interface ipv4 set dnsservers name="
+ << interfaceIndex
+ << L" source=static address="
+ << nameServers[0]
+ << L" validate=no";
- ValidateShellOut(*netsh, timeout);
-}
+ auto netsh = common::ApplicationRunner::StartWithoutConsole(m_netShPath, ss.str());
+
+ ValidateShellOut(*netsh, timeout);
+ }
-//static
-void NetSh::SetIpv4SecondaryDns(uint32_t interfaceIndex, std::wstring server, uint32_t timeout)
-{
- //
- // netsh interface ipv4 add dnsservers name="Ethernet 2" address=8.8.4.4 index=2 validate=no
//
- // Note: we're specifying the interface by index instead.
+ // Set additional name servers.
//
- std::wstringstream ss;
+ for (size_t i = 1; i < nameServers.size(); ++i)
+ {
+ std::wstringstream ss;
- ss << L"interface ipv4 add dnsservers name="
- << interfaceIndex
- << L" address="
- << server
- << L" index=2 validate=no";
+ ss << L"interface ipv4 add dnsservers name="
+ << interfaceIndex
+ << L" address="
+ << nameServers[i]
+ << L" index="
+ << i + 1
+ << L" validate=no";
- auto netsh = common::ApplicationRunner::StartWithoutConsole(NetShPath(), ss.str());
+ auto netsh = common::ApplicationRunner::StartWithoutConsole(m_netShPath, ss.str());
- ValidateShellOut(*netsh, timeout);
+ ValidateShellOut(*netsh, timeout);
+ }
}
-//static
-void NetSh::SetIpv4Dhcp(uint32_t interfaceIndex, uint32_t timeout)
+void NetSh::SetIpv4DhcpDns(uint32_t interfaceIndex, uint32_t timeout)
{
//
// netsh interface ipv4 set dnsservers name="Ethernet 2" source=dhcp
@@ -191,57 +156,65 @@ void NetSh::SetIpv4Dhcp(uint32_t interfaceIndex, uint32_t timeout)
<< interfaceIndex
<< L" source=dhcp";
- auto netsh = common::ApplicationRunner::StartWithoutConsole(NetShPath(), ss.str());
+ auto netsh = common::ApplicationRunner::StartWithoutConsole(m_netShPath, ss.str());
ValidateShellOut(*netsh, timeout);
}
-//static
-void NetSh::SetIpv6PrimaryDns(uint32_t interfaceIndex, std::wstring server, uint32_t timeout)
+void NetSh::SetIpv6StaticDns(uint32_t interfaceIndex,
+ const std::vector<std::wstring> &nameServers, uint32_t timeout)
{
//
+ // Setting primary and secondary name server requires two invokations:
+ //
// netsh interface ipv6 set dnsservers name="Ethernet 2" source=static address=2001:4860:4860::8888 validate=no
+ // netsh interface ipv6 add dnsservers name="Ethernet 2" address=2001:4860:4860::8844 index=2 validate=no
//
// Note: we're specifying the interface by index instead.
//
- std::wstringstream ss;
+ if (nameServers.empty())
+ {
+ throw std::runtime_error("Invalid list of name servers (zero length list)");
+ }
- ss << L"interface ipv6 set dnsservers name="
- << interfaceIndex
- << L" source=static address="
- << server
- << L" validate=no";
+ {
+ std::wstringstream ss;
- auto netsh = common::ApplicationRunner::StartWithoutConsole(NetShPath(), ss.str());
+ ss << L"interface ipv6 set dnsservers name="
+ << interfaceIndex
+ << L" source=static address="
+ << nameServers[0]
+ << L" validate=no";
- ValidateShellOut(*netsh, timeout);
-}
+ auto netsh = common::ApplicationRunner::StartWithoutConsole(m_netShPath, ss.str());
+
+ ValidateShellOut(*netsh, timeout);
+ }
-//static
-void NetSh::SetIpv6SecondaryDns(uint32_t interfaceIndex, std::wstring server, uint32_t timeout)
-{
- //
- // netsh interface ipv6 add dnsservers name="Ethernet 2" address=2001:4860:4860::8844 index=2 validate=no
//
- // Note: we're specifying the interface by index instead.
+ // Set additional name servers.
//
- std::wstringstream ss;
+ for (size_t i = 1; i < nameServers.size(); ++i)
+ {
+ std::wstringstream ss;
- ss << L"interface ipv6 add dnsservers name="
- << interfaceIndex
- << L"address ="
- << server
- << L" index=2 validate=no";
+ ss << L"interface ipv6 add dnsservers name="
+ << interfaceIndex
+ << L" address="
+ << nameServers[i]
+ << L" index="
+ << i + 1
+ << L" validate=no";
- auto netsh = common::ApplicationRunner::StartWithoutConsole(NetShPath(), ss.str());
+ auto netsh = common::ApplicationRunner::StartWithoutConsole(m_netShPath, ss.str());
- ValidateShellOut(*netsh, timeout);
+ ValidateShellOut(*netsh, timeout);
+ }
}
-//static
-void NetSh::SetIpv6Dhcp(uint32_t interfaceIndex, uint32_t timeout)
+void NetSh::SetIpv6DhcpDns(uint32_t interfaceIndex, uint32_t timeout)
{
//
// netsh interface ipv6 set dnsservers name="Ethernet 2" source=dhcp
@@ -255,7 +228,68 @@ void NetSh::SetIpv6Dhcp(uint32_t interfaceIndex, uint32_t timeout)
<< interfaceIndex
<< L" source=dhcp";
- auto netsh = common::ApplicationRunner::StartWithoutConsole(NetShPath(), ss.str());
+ auto netsh = common::ApplicationRunner::StartWithoutConsole(m_netShPath, ss.str());
ValidateShellOut(*netsh, timeout);
}
+
+//static
+uint32_t NetSh::ConvertInterfaceGuidToIndex(const std::wstring &interfaceGuid)
+{
+ auto rawGuid = common::Guid::FromString(interfaceGuid);
+
+ NET_LUID luid;
+ NET_IFINDEX index;
+
+ if (NO_ERROR != ConvertInterfaceGuidToLuid(&rawGuid, &luid)
+ || NO_ERROR != ConvertInterfaceLuidToIndex(&luid, &index))
+ {
+ throw std::runtime_error("Invalid interface GUID");
+ }
+
+ return index;
+}
+
+NetSh::NetSh(ILogSink *logSink)
+ : m_logSink(logSink)
+{
+ const auto system32 = common::fs::GetKnownFolderPath(FOLDERID_System, 0, nullptr);
+
+ m_netShPath = std::experimental::filesystem::path(system32).append(L"netsh.exe");
+}
+
+void NetSh::ValidateShellOut(common::ApplicationRunner &netsh, uint32_t timeout)
+{
+ const uint32_t actualTimeout = (0 == timeout ? 3000 : timeout);
+
+ const auto startTime = GetTickCount64();
+
+ DWORD returnCode;
+
+ if (false == netsh.join(returnCode, actualTimeout))
+ {
+ ThrowWithDetails("'netsh' did not complete in a timely manner", netsh);
+ }
+
+ if (returnCode != 0)
+ {
+ std::stringstream ss;
+
+ ss << "'netsh' failed the requested operation. Error: " << returnCode;
+
+ ThrowWithDetails(ss.str(), netsh);
+ }
+
+ const auto elapsed = static_cast<uint32_t>(GetTickCount64() - startTime);
+
+ if (elapsed > (actualTimeout / 2))
+ {
+ std::stringstream ss;
+
+ ss << L"'netsh' completed successfully, albeit a little slowly. It consumed "
+ << elapsed << " ms of "
+ << actualTimeout << " ms max permitted execution time";
+
+ m_logSink->info(ss.str().c_str(), nullptr, 0);
+ }
+}
diff --git a/windows/windns/src/windns/netsh.h b/windows/windns/src/windns/netsh.h
index 5ba43fba66..f47408c9e2 100644
--- a/windows/windns/src/windns/netsh.h
+++ b/windows/windns/src/windns/netsh.h
@@ -1,7 +1,9 @@
#pragma once
-#include "clientsinkinfo.h"
+#include "ilogsink.h"
+#include <libcommon/applicationrunner.h>
#include <string>
+#include <vector>
#include <cstdint>
#include <stdexcept>
@@ -9,24 +11,30 @@ class NetSh
{
public:
- static void RegisterErrorSink(const ErrorSinkInfo &errorSink);
+ static void Construct(ILogSink *logSink);
- static void SetIpv4PrimaryDns(uint32_t interfaceIndex, std::wstring server, uint32_t timeout = 0);
-
- //
- // Caveat: This sets the primary DNS server if there isn't already one.
- //
- static void SetIpv4SecondaryDns(uint32_t interfaceIndex, std::wstring server, uint32_t timeout = 0);
+ static NetSh &Instance();
- static void SetIpv4Dhcp(uint32_t interfaceIndex, uint32_t timeout = 0);
+ void SetIpv4StaticDns(uint32_t interfaceIndex,
+ const std::vector<std::wstring> &nameServers, uint32_t timeout = 0);
- static void SetIpv6PrimaryDns(uint32_t interfaceIndex, std::wstring server, uint32_t timeout = 0);
- static void SetIpv6SecondaryDns(uint32_t interfaceIndex, std::wstring server, uint32_t timeout = 0);
- static void SetIpv6Dhcp(uint32_t interfaceIndex, uint32_t timeout = 0);
+ void SetIpv4DhcpDns(uint32_t interfaceIndex, uint32_t timeout = 0);
+
+ void SetIpv6StaticDns(uint32_t interfaceIndex,
+ const std::vector<std::wstring> &nameServers, uint32_t timeout = 0);
+
+ void SetIpv6DhcpDns(uint32_t interfaceIndex, uint32_t timeout = 0);
+
+ static uint32_t ConvertInterfaceGuidToIndex(const std::wstring &interfaceGuid);
private:
- NetSh();
+ ILogSink *m_logSink;
+ std::wstring m_netShPath;
+
+ NetSh(ILogSink *logSink);
+
+ void ValidateShellOut(common::ApplicationRunner &netsh, uint32_t timeout);
};
class NetShError : public std::exception
diff --git a/windows/windns/src/windns/recoveryformatter.cpp b/windows/windns/src/windns/recoveryformatter.cpp
new file mode 100644
index 0000000000..e6ed3d75fd
--- /dev/null
+++ b/windows/windns/src/windns/recoveryformatter.cpp
@@ -0,0 +1,88 @@
+#include "stdafx.h"
+#include <libcommon/serialization/serializer.h>
+#include <libcommon/serialization/deserializer.h>
+#include "recoveryformatter.h"
+#include <stdexcept>
+
+namespace
+{
+
+uint32_t RF_MAGIC = 0x21534E44; // stores as 'DNS!'
+uint32_t RF_VERSION = 0x01;
+
+} // anonymous namespace
+
+//static
+std::vector<uint8_t> RecoveryFormatter::Pack(const std::vector<InterfaceSnap> &v4Snaps,
+ const std::vector<InterfaceSnap> &v6Snaps)
+{
+ common::serialization::Serializer s;
+
+ //
+ // Format of binary blob
+ //
+ // u32 tag
+ // u32 version
+ // u32 number of ipv4 snaps
+ // [] ipv4 snaps
+ // u32 number of ipv6 snaps
+ // [] ipv6 snaps
+ //
+
+ s << RF_MAGIC;
+ s << RF_VERSION;
+
+ s << static_cast<uint32_t>(v4Snaps.size());
+
+ for (const auto &snap : v4Snaps)
+ {
+ snap.serialize(s);
+ }
+
+ s << static_cast<uint32_t>(v6Snaps.size());
+
+ for (const auto &snap : v6Snaps)
+ {
+ snap.serialize(s);
+ }
+
+ return s.blob();
+}
+
+//static
+RecoveryFormatter::Unpacked RecoveryFormatter::Unpack(const uint8_t *data, uint32_t dataSize)
+{
+ common::serialization::Deserializer d(data, dataSize);
+
+ if (RF_MAGIC != d.decode<uint32_t>()
+ || RF_VERSION != d.decode<uint32_t>())
+ {
+ throw std::runtime_error("Invalid header in recovery data");
+ }
+
+ Unpacked unpacked;
+
+ auto numV4Snaps = d.decode<uint32_t>();
+
+ for (; 0 != numV4Snaps; --numV4Snaps)
+ {
+ // Invoke deserializing ctor on InterfaceSnap.
+ unpacked.v4Snaps.emplace_back(d);
+ }
+
+ auto numV6Snaps = d.decode<uint32_t>();
+
+ for (; 0 != numV6Snaps; --numV6Snaps)
+ {
+ // Invoke deserializing ctor on InterfaceSnap.
+ unpacked.v6Snaps.emplace_back(d);
+ }
+
+ return unpacked;
+}
+
+//static
+RecoveryFormatter::Unpacked RecoveryFormatter::Unpack(const std::vector<uint8_t> &data)
+{
+ return RecoveryFormatter::Unpack(&data[0], static_cast<uint32_t>(data.size()));
+}
diff --git a/windows/windns/src/windns/recoveryformatter.h b/windows/windns/src/windns/recoveryformatter.h
new file mode 100644
index 0000000000..c3d715cd9e
--- /dev/null
+++ b/windows/windns/src/windns/recoveryformatter.h
@@ -0,0 +1,24 @@
+#pragma once
+
+#include "interfacesnap.h"
+#include <vector>
+#include <cstdint>
+
+class RecoveryFormatter
+{
+public:
+
+ RecoveryFormatter() = delete;
+
+ static std::vector<uint8_t> Pack(const std::vector<InterfaceSnap> &v4Snaps,
+ const std::vector<InterfaceSnap> &v6Snaps);
+
+ struct Unpacked
+ {
+ std::vector<InterfaceSnap> v4Snaps;
+ std::vector<InterfaceSnap> v6Snaps;
+ };
+
+ static Unpacked Unpack(const uint8_t *data, uint32_t dataSize);
+ static Unpacked Unpack(const std::vector<uint8_t> &data);
+};
diff --git a/windows/windns/src/windns/recoverylogic.cpp b/windows/windns/src/windns/recoverylogic.cpp
new file mode 100644
index 0000000000..0c630a89b2
--- /dev/null
+++ b/windows/windns/src/windns/recoverylogic.cpp
@@ -0,0 +1,90 @@
+#include "stdafx.h"
+#include "recoverylogic.h"
+#include "netsh.h"
+#include "confineoperation.h"
+#include <libcommon/trace/xtrace.h>
+#include <stdexcept>
+
+//static
+void RecoveryLogic::RestoreInterfaces(const RecoveryFormatter::Unpacked &data,
+ ILogSink *logSink, uint32_t timeout)
+{
+ if (nullptr == logSink)
+ {
+ throw std::runtime_error("Invalid logger sink");
+ }
+
+ auto forwardError = [logSink](const char *msg, const char **details, uint32_t numDetails)
+ {
+ logSink->error(msg, details, numDetails);
+ };
+
+ bool success = true;
+
+ for (const auto &snap : data.v4Snaps)
+ {
+ const auto status = ConfineOperation("Reset interface DNS settings", forwardError, [&snap, &timeout]()
+ {
+ if (snap.internalInterface())
+ {
+ //
+ // This is an interface used for internal communication.
+ // We haven't changed any settings on it and therefore should not restore it.
+ //
+ return;
+ }
+
+ XTRACE("Resetting name server configuration for interface ", snap.interfaceGuid());
+
+ if (snap.nameServers().empty())
+ {
+ NetSh::Instance().SetIpv4DhcpDns(NetSh::ConvertInterfaceGuidToIndex(snap.interfaceGuid()), timeout);
+ }
+ else
+ {
+ NetSh::Instance().SetIpv4StaticDns(NetSh::ConvertInterfaceGuidToIndex(snap.interfaceGuid()), snap.nameServers(), timeout);
+ }
+ });
+
+ if (false == status)
+ {
+ success = false;
+ }
+ }
+
+ for (const auto &snap : data.v6Snaps)
+ {
+ const auto status = ConfineOperation("Reset interface DNS settings", forwardError, [&snap, &timeout]()
+ {
+ if (snap.internalInterface())
+ {
+ //
+ // This is an interface used for internal communication.
+ // We haven't changed any settings on it and therefore should not restore it.
+ //
+ return;
+ }
+
+ XTRACE("Resetting name server configuration for interface ", snap.interfaceGuid());
+
+ if (snap.nameServers().empty())
+ {
+ NetSh::Instance().SetIpv6DhcpDns(NetSh::ConvertInterfaceGuidToIndex(snap.interfaceGuid()), timeout);
+ }
+ else
+ {
+ NetSh::Instance().SetIpv6StaticDns(NetSh::ConvertInterfaceGuidToIndex(snap.interfaceGuid()), snap.nameServers(), timeout);
+ }
+ });
+
+ if (false == status)
+ {
+ success = false;
+ }
+ }
+
+ if (false == success)
+ {
+ throw std::runtime_error("Could not reset DNS settings for one of more interfaces");
+ }
+}
diff --git a/windows/windns/src/windns/recoverylogic.h b/windows/windns/src/windns/recoverylogic.h
new file mode 100644
index 0000000000..5937b0635c
--- /dev/null
+++ b/windows/windns/src/windns/recoverylogic.h
@@ -0,0 +1,14 @@
+#pragma once
+
+#include "recoveryformatter.h"
+#include "ilogsink.h"
+
+class RecoveryLogic
+{
+public:
+
+ RecoveryLogic() = delete;
+
+ static void RestoreInterfaces(const RecoveryFormatter::Unpacked &data,
+ ILogSink *logSink, uint32_t timeout = 0);
+};
diff --git a/windows/windns/src/windns/recoverysink.cpp b/windows/windns/src/windns/recoverysink.cpp
new file mode 100644
index 0000000000..55a6620c7d
--- /dev/null
+++ b/windows/windns/src/windns/recoverysink.cpp
@@ -0,0 +1,58 @@
+#include "stdafx.h"
+#include "recoverysink.h"
+#include "recoveryformatter.h"
+#include <stdexcept>
+
+RecoverySink::RecoverySink(const RecoverySinkInfo &target)
+ : m_target(target)
+{
+}
+
+void RecoverySink::setTarget(const RecoverySinkInfo &target)
+{
+ std::scoped_lock<std::mutex> lock(m_targetMutex);
+
+ m_target = target;
+}
+
+void RecoverySink::preserveSnaps(Protocol protocol, const std::vector<InterfaceSnap> &snaps)
+{
+ std::scoped_lock<std::mutex> dataLock(m_dataMutex);
+
+ switch (protocol)
+ {
+ case Protocol::IPv4:
+ {
+ m_v4Snaps = snaps;
+ break;
+ }
+ case Protocol::IPv6:
+ {
+ m_v6Snaps = snaps;
+ break;
+ }
+ default:
+ {
+ throw std::runtime_error("Missing case handler");
+ }
+ }
+
+ m_recoveryData = RecoveryFormatter::Pack(m_v4Snaps, m_v6Snaps);
+
+ std::scoped_lock<std::mutex> lock(m_targetMutex);
+
+ m_target.sink(&m_recoveryData[0], static_cast<uint32_t>(m_recoveryData.size()), m_target.context);
+}
+
+std::vector<uint8_t> RecoverySink::recoveryData() const
+{
+ std::vector<uint8_t> copy;
+
+ {
+ std::scoped_lock<std::mutex> dataLock(m_dataMutex);
+
+ copy = m_recoveryData;
+ }
+
+ return copy;
+}
diff --git a/windows/windns/src/windns/recoverysink.h b/windows/windns/src/windns/recoverysink.h
new file mode 100644
index 0000000000..b685ccd12e
--- /dev/null
+++ b/windows/windns/src/windns/recoverysink.h
@@ -0,0 +1,30 @@
+#pragma once
+
+#include "irecoverysink.h"
+#include "clientsinkinfo.h"
+#include <vector>
+#include <cstdint>
+#include <mutex>
+
+class RecoverySink : public IRecoverySink
+{
+public:
+
+ RecoverySink(const RecoverySinkInfo &target);
+
+ void setTarget(const RecoverySinkInfo &target);
+
+ void preserveSnaps(Protocol protocol, const std::vector<InterfaceSnap> &snaps) override;
+
+ std::vector<uint8_t> recoveryData() const;
+
+private:
+
+ std::mutex m_targetMutex;
+ RecoverySinkInfo m_target;
+
+ mutable std::mutex m_dataMutex;
+ std::vector<InterfaceSnap> m_v4Snaps;
+ std::vector<InterfaceSnap> m_v6Snaps;
+ std::vector<uint8_t> m_recoveryData;
+};
diff --git a/windows/windns/src/windns/registrypaths.cpp b/windows/windns/src/windns/registrypaths.cpp
new file mode 100644
index 0000000000..65117ee4fe
--- /dev/null
+++ b/windows/windns/src/windns/registrypaths.cpp
@@ -0,0 +1,20 @@
+#include "stdafx.h"
+#include "registrypaths.h"
+
+//static
+std::wstring RegistryPaths::InterfaceRoot(Protocol protocol)
+{
+ return
+ std::wstring(L"SYSTEM\\CurrentControlSet\\Services\\")
+ .append(Protocol::IPv4 == protocol ? L"Tcpip" : L"Tcpip6")
+ .append(L"\\Parameters\\Interfaces");
+}
+
+//static
+std::wstring RegistryPaths::InterfaceKey(const std::wstring &interfaceGuid, Protocol protocol)
+{
+ return
+ InterfaceRoot(protocol)
+ .append(L"\\")
+ .append(interfaceGuid);
+}
diff --git a/windows/windns/src/windns/registrypaths.h b/windows/windns/src/windns/registrypaths.h
new file mode 100644
index 0000000000..34e8022649
--- /dev/null
+++ b/windows/windns/src/windns/registrypaths.h
@@ -0,0 +1,14 @@
+#pragma once
+
+#include "types.h"
+#include <string>
+
+class RegistryPaths
+{
+public:
+
+ RegistryPaths() = delete;
+
+ static std::wstring InterfaceRoot(Protocol protocol);
+ static std::wstring InterfaceKey(const std::wstring &interfaceGuid, Protocol protocol);
+};
diff --git a/windows/windns/src/windns/types.h b/windows/windns/src/windns/types.h
index d30d556c5d..4782621cd2 100644
--- a/windows/windns/src/windns/types.h
+++ b/windows/windns/src/windns/types.h
@@ -1,7 +1,9 @@
#pragma once
-#include <string>
-#include <vector>
-#include <memory>
+#include <cstdint>
-using OptionalStringList = std::shared_ptr<std::vector<std::wstring> >;
+enum class Protocol : uint8_t
+{
+ IPv4,
+ IPv6
+};
diff --git a/windows/windns/src/windns/windns.cpp b/windows/windns/src/windns/windns.cpp
index 8bf9a1b649..fe674723e2 100644
--- a/windows/windns/src/windns/windns.cpp
+++ b/windows/windns/src/windns/windns.cpp
@@ -2,20 +2,19 @@
#include "windns.h"
#include "windnscontext.h"
#include "clientsinkinfo.h"
-#include "interfaceconfig.h"
-#include "netconfighelpers.h"
#include "confineoperation.h"
+#include "recoveryformatter.h"
+#include "recoverylogic.h"
#include "netsh.h"
-#include "libcommon/serialization/deserializer.h"
+#include "logsink.h"
+#include <memory>
#include <vector>
#include <string>
namespace
{
-WinDnsErrorSink g_ErrorSink = nullptr;
-void *g_ErrorContext = nullptr;
-
+LogSink *g_LogSink = nullptr;
WinDnsContext *g_Context = nullptr;
std::vector<std::wstring> MakeStringArray(const wchar_t **strings, uint32_t numStrings)
@@ -30,11 +29,11 @@ std::vector<std::wstring> MakeStringArray(const wchar_t **strings, uint32_t numS
return v;
}
-void ForwardError(const char *errorMessage, const char **details, uint32_t numDetails)
+void ForwardError(const char *message, const char **details, uint32_t numDetails)
{
- if (nullptr != g_ErrorSink)
+ if (nullptr != g_LogSink)
{
- g_ErrorSink(errorMessage, details, numDetails, g_ErrorContext);
+ g_LogSink->error(message, details, numDetails);
}
}
@@ -44,8 +43,8 @@ WINDNS_LINKAGE
bool
WINDNS_API
WinDns_Initialize(
- WinDnsErrorSink errorSink,
- void *errorContext
+ WinDnsLogSink logSink,
+ void *logContext
)
{
if (nullptr != g_Context)
@@ -53,14 +52,19 @@ WinDns_Initialize(
return false;
}
- g_ErrorSink = errorSink;
- g_ErrorContext = errorContext;
-
- return ConfineOperation("Initialize", ForwardError, []()
+ return ConfineOperation("Initialize", ForwardError, [&]()
{
- NetSh::RegisterErrorSink(ErrorSinkInfo{ g_ErrorSink, g_ErrorContext });
+ if (nullptr == g_LogSink)
+ {
+ g_LogSink = new LogSink(LogSinkInfo{ logSink, logContext });
+ NetSh::Construct(g_LogSink);
+ }
+ else
+ {
+ g_LogSink->setTarget(LogSinkInfo{ logSink, logContext });
+ }
- g_Context = new WinDnsContext;
+ g_Context = new WinDnsContext(g_LogSink);
});
}
@@ -78,6 +82,10 @@ WinDns_Deinitialize(
delete g_Context;
g_Context = nullptr;
+ // Maintain a single instance forever and invoke setTarget() on it.
+ //delete g_LogSink;
+ //g_LogSink = nullptr;
+
return true;
}
@@ -85,27 +93,28 @@ WINDNS_LINKAGE
bool
WINDNS_API
WinDns_Set(
- const wchar_t **servers,
- uint32_t numServers,
- WinDnsConfigSink configSink,
- void *configContext
+ const wchar_t **ipv4Servers,
+ uint32_t numIpv4Servers,
+ const wchar_t **ipv6Servers,
+ uint32_t numIpv6Servers,
+ WinDnsRecoverySink recoverySink,
+ void *recoveryContext
)
{
if (nullptr == g_Context
- || 0 == numServers
- || nullptr == configSink)
+ || nullptr == ipv4Servers
+ || 0 == numIpv4Servers
+ || nullptr == ipv6Servers
+ || 0 == numIpv6Servers
+ || nullptr == recoverySink)
{
return false;
}
return ConfineOperation("Enforce DNS settings", ForwardError, [&]()
{
- ClientSinkInfo sinkInfo;
-
- sinkInfo.errorSinkInfo = ErrorSinkInfo{ g_ErrorSink, g_ErrorContext };
- sinkInfo.configSinkInfo = ConfigSinkInfo{ configSink, configContext };
-
- g_Context->set(MakeStringArray(servers, numServers), sinkInfo);
+ g_Context->set(MakeStringArray(ipv4Servers, numIpv4Servers), MakeStringArray(ipv6Servers, \
+ numIpv6Servers), RecoverySinkInfo{ recoverySink, recoveryContext });
});
}
@@ -130,56 +139,16 @@ WINDNS_LINKAGE
bool
WINDNS_API
WinDns_Recover(
- const void *configData,
+ const void *recoveryData,
uint32_t dataLength
)
{
- std::vector<InterfaceConfig> configs;
-
- const auto status = ConfineOperation("Deserialize recovery data", ForwardError, [&]()
+ return ConfineOperation("Recover DNS settings", ForwardError, [&]()
{
- common::serialization::Deserializer d(reinterpret_cast<const uint8_t *>(configData), dataLength);
-
- auto numConfigs = d.decode<uint32_t>();
-
- if (numConfigs > 50)
- {
- throw std::runtime_error("Too many configuration entries");
- }
+ auto unpacked = RecoveryFormatter::Unpack(reinterpret_cast<const uint8_t *>(recoveryData), dataLength);
- configs.reserve(numConfigs);
+ static const uint32_t TIMEOUT_TEN_SECONDS = 1000 * 10;
- for (; numConfigs != 0; --numConfigs)
- {
- configs.emplace_back(InterfaceConfig(d));
- }
+ RecoveryLogic::RestoreInterfaces(unpacked, g_LogSink, TIMEOUT_TEN_SECONDS);
});
-
- if (false == status)
- {
- return false;
- }
-
- //
- // Try to restore each config and update 'success' if any update fails.
- //
-
- static const uint32_t TIMEOUT_10_SECONDS = 10 * 1000;
-
- bool success = true;
-
- for (const auto &config : configs)
- {
- const auto adapterStatus = ConfineOperation("Restore adapter DNS settings", ForwardError, [&config]()
- {
- nchelpers::RevertDnsServers(config, TIMEOUT_10_SECONDS);
- });
-
- if (false == adapterStatus)
- {
- success = false;
- }
- }
-
- return success;
}
diff --git a/windows/windns/src/windns/windns.h b/windows/windns/src/windns/windns.h
index ab4ec9a539..a7ebfd8f95 100644
--- a/windows/windns/src/windns/windns.h
+++ b/windows/windns/src/windns/windns.h
@@ -17,8 +17,16 @@
// Functions
///////////////////////////////////////////////////////////////////////////////
-typedef void (WINDNS_API *WinDnsErrorSink)(const char *errorMessage, const char **details, uint32_t numDetails, void *context);
-typedef void (WINDNS_API *WinDnsConfigSink)(const void *configData, uint32_t dataLength, void *context);
+enum WinDnsLogCategory
+{
+ WINDNS_LOG_CATEGORY_ERROR = 0x01,
+ WINDNS_LOG_CATEGORY_INFO = 0x02
+};
+
+typedef void (WINDNS_API *WinDnsLogSink)(WinDnsLogCategory category, const char *message,
+ const char **details, uint32_t numDetails, void *context);
+
+typedef void (WINDNS_API *WinDnsRecoverySink)(const void *recoveryData, uint32_t dataLength, void *context);
//
// WinDns_Initialize:
@@ -35,8 +43,8 @@ WINDNS_LINKAGE
bool
WINDNS_API
WinDns_Initialize(
- WinDnsErrorSink errorSink,
- void *errorContext
+ WinDnsLogSink logSink,
+ void *logContext
);
//
@@ -56,20 +64,22 @@ WinDns_Deinitialize(
//
// Configure which DNS servers should be used and start enforcing these settings.
//
-// The 'configSink' will receive periodic callbacks with updated config data
+// The 'recoverySink' will receive periodic callbacks with updated recovery data
// until you call WinDns_Reset.
//
-// You should persist the config data in preparation for an eventual recovery.
+// You should persist the recovery data in preparation for an eventual recovery.
//
extern "C"
WINDNS_LINKAGE
bool
WINDNS_API
WinDns_Set(
- const wchar_t **servers,
- uint32_t numServers,
- WinDnsConfigSink configSink,
- void *configContext
+ const wchar_t **ipv4Servers,
+ uint32_t numIpv4Servers,
+ const wchar_t **ipv6Servers,
+ uint32_t numIpv6Servers,
+ WinDnsRecoverySink recoverySink,
+ void *recoveryContext
);
//
@@ -80,7 +90,7 @@ WinDns_Set(
// (Also taking into account external changes to DNS settings that have occurred
// during the period of enforcing specific settings.)
//
-// It's safe to discard persisted config data once WinDns_Reset returns 'true'.
+// It's safe to discard persisted recovery data once WinDns_Reset returns 'true'.
//
extern "C"
WINDNS_LINKAGE
@@ -102,6 +112,6 @@ WINDNS_LINKAGE
bool
WINDNS_API
WinDns_Recover(
- const void *configData,
+ const void *recoveryData,
uint32_t dataLength
);
diff --git a/windows/windns/src/windns/windns.vcxproj b/windows/windns/src/windns/windns.vcxproj
index 8b8bbb124a..bf84c01690 100644
--- a/windows/windns/src/windns/windns.vcxproj
+++ b/windows/windns/src/windns/windns.vcxproj
@@ -106,7 +106,7 @@
<SubSystem>Windows</SubSystem>
<GenerateDebugInformation>true</GenerateDebugInformation>
<AdditionalLibraryDirectories>$(SolutionDir)bin\$(Platform)-$(Configuration)\</AdditionalLibraryDirectories>
- <AdditionalDependencies>libcommon.lib;wbemuuid.lib;comsuppw.lib;kernel32.lib;user32.lib;gdi32.lib;winspool.lib;comdlg32.lib;advapi32.lib;shell32.lib;ole32.lib;oleaut32.lib;uuid.lib;odbc32.lib;odbccp32.lib;%(AdditionalDependencies)</AdditionalDependencies>
+ <AdditionalDependencies>libcommon.lib;Iphlpapi.lib;wbemuuid.lib;comsuppw.lib;kernel32.lib;user32.lib;gdi32.lib;winspool.lib;comdlg32.lib;advapi32.lib;shell32.lib;ole32.lib;oleaut32.lib;uuid.lib;odbc32.lib;odbccp32.lib;%(AdditionalDependencies)</AdditionalDependencies>
</Link>
</ItemDefinitionGroup>
<ItemDefinitionGroup Condition="'$(Configuration)|$(Platform)'=='Debug|Win32'">
@@ -125,7 +125,7 @@
<SubSystem>Windows</SubSystem>
<GenerateDebugInformation>true</GenerateDebugInformation>
<AdditionalLibraryDirectories>$(SolutionDir)bin\$(Platform)-$(Configuration)\</AdditionalLibraryDirectories>
- <AdditionalDependencies>libcommon.lib;wbemuuid.lib;comsuppw.lib;kernel32.lib;user32.lib;gdi32.lib;winspool.lib;comdlg32.lib;advapi32.lib;shell32.lib;ole32.lib;oleaut32.lib;uuid.lib;odbc32.lib;odbccp32.lib;%(AdditionalDependencies)</AdditionalDependencies>
+ <AdditionalDependencies>libcommon.lib;Iphlpapi.lib;wbemuuid.lib;comsuppw.lib;kernel32.lib;user32.lib;gdi32.lib;winspool.lib;comdlg32.lib;advapi32.lib;shell32.lib;ole32.lib;oleaut32.lib;uuid.lib;odbc32.lib;odbccp32.lib;%(AdditionalDependencies)</AdditionalDependencies>
</Link>
</ItemDefinitionGroup>
<ItemDefinitionGroup Condition="'$(Configuration)|$(Platform)'=='Release|Win32'">
@@ -148,7 +148,7 @@
<OptimizeReferences>true</OptimizeReferences>
<GenerateDebugInformation>true</GenerateDebugInformation>
<AdditionalLibraryDirectories>$(SolutionDir)bin\$(Platform)-$(Configuration)\</AdditionalLibraryDirectories>
- <AdditionalDependencies>libcommon.lib;wbemuuid.lib;comsuppw.lib;kernel32.lib;user32.lib;gdi32.lib;winspool.lib;comdlg32.lib;advapi32.lib;shell32.lib;ole32.lib;oleaut32.lib;uuid.lib;odbc32.lib;odbccp32.lib;%(AdditionalDependencies)</AdditionalDependencies>
+ <AdditionalDependencies>libcommon.lib;Iphlpapi.lib;wbemuuid.lib;comsuppw.lib;kernel32.lib;user32.lib;gdi32.lib;winspool.lib;comdlg32.lib;advapi32.lib;shell32.lib;ole32.lib;oleaut32.lib;uuid.lib;odbc32.lib;odbccp32.lib;%(AdditionalDependencies)</AdditionalDependencies>
</Link>
</ItemDefinitionGroup>
<ItemDefinitionGroup Condition="'$(Configuration)|$(Platform)'=='Release|x64'">
@@ -171,18 +171,25 @@
<OptimizeReferences>true</OptimizeReferences>
<GenerateDebugInformation>true</GenerateDebugInformation>
<AdditionalLibraryDirectories>$(SolutionDir)bin\$(Platform)-$(Configuration)\</AdditionalLibraryDirectories>
- <AdditionalDependencies>libcommon.lib;wbemuuid.lib;comsuppw.lib;kernel32.lib;user32.lib;gdi32.lib;winspool.lib;comdlg32.lib;advapi32.lib;shell32.lib;ole32.lib;oleaut32.lib;uuid.lib;odbc32.lib;odbccp32.lib;%(AdditionalDependencies)</AdditionalDependencies>
+ <AdditionalDependencies>libcommon.lib;Iphlpapi.lib;wbemuuid.lib;comsuppw.lib;kernel32.lib;user32.lib;gdi32.lib;winspool.lib;comdlg32.lib;advapi32.lib;shell32.lib;ole32.lib;oleaut32.lib;uuid.lib;odbc32.lib;odbccp32.lib;%(AdditionalDependencies)</AdditionalDependencies>
</Link>
</ItemDefinitionGroup>
<ItemGroup>
<ClInclude Include="clientsinkinfo.h" />
- <ClInclude Include="configmanager.h" />
<ClInclude Include="confineoperation.h" />
- <ClInclude Include="iclientsinkproxy.h" />
- <ClInclude Include="interfaceconfig.h" />
- <ClInclude Include="netconfigeventsink.h" />
- <ClInclude Include="netconfighelpers.h" />
+ <ClInclude Include="logsink.h" />
+ <ClInclude Include="ilogsink.h" />
+ <ClInclude Include="inameserversource.h" />
+ <ClInclude Include="interfacemonitor.h" />
+ <ClInclude Include="interfacesnap.h" />
+ <ClInclude Include="irecoverysink.h" />
+ <ClInclude Include="dnsagent.h" />
+ <ClInclude Include="nameserversource.h" />
<ClInclude Include="netsh.h" />
+ <ClInclude Include="recoveryformatter.h" />
+ <ClInclude Include="recoverylogic.h" />
+ <ClInclude Include="recoverysink.h" />
+ <ClInclude Include="registrypaths.h" />
<ClInclude Include="stdafx.h" />
<ClInclude Include="targetver.h" />
<ClInclude Include="types.h" />
@@ -190,13 +197,18 @@
<ClInclude Include="windnscontext.h" />
</ItemGroup>
<ItemGroup>
- <ClCompile Include="configmanager.cpp" />
<ClCompile Include="confineoperation.cpp" />
<ClCompile Include="dllmain.cpp" />
- <ClCompile Include="interfaceconfig.cpp" />
- <ClCompile Include="netconfigeventsink.cpp" />
- <ClCompile Include="netconfighelpers.cpp" />
+ <ClCompile Include="logsink.cpp" />
+ <ClCompile Include="interfacemonitor.cpp" />
+ <ClCompile Include="interfacesnap.cpp" />
+ <ClCompile Include="dnsagent.cpp" />
+ <ClCompile Include="nameserversource.cpp" />
<ClCompile Include="netsh.cpp" />
+ <ClCompile Include="recoveryformatter.cpp" />
+ <ClCompile Include="recoverylogic.cpp" />
+ <ClCompile Include="recoverysink.cpp" />
+ <ClCompile Include="registrypaths.cpp" />
<ClCompile Include="stdafx.cpp">
<PrecompiledHeader Condition="'$(Configuration)|$(Platform)'=='Debug|x64'">Create</PrecompiledHeader>
<PrecompiledHeader Condition="'$(Configuration)|$(Platform)'=='Debug|Win32'">Create</PrecompiledHeader>
diff --git a/windows/windns/src/windns/windns.vcxproj.filters b/windows/windns/src/windns/windns.vcxproj.filters
index 57e9eeaef0..3c846bb6ba 100644
--- a/windows/windns/src/windns/windns.vcxproj.filters
+++ b/windows/windns/src/windns/windns.vcxproj.filters
@@ -5,27 +5,39 @@
<ClInclude Include="targetver.h" />
<ClInclude Include="windns.h" />
<ClInclude Include="windnscontext.h" />
- <ClInclude Include="configmanager.h" />
- <ClInclude Include="netconfigeventsink.h" />
- <ClInclude Include="netconfighelpers.h" />
<ClInclude Include="clientsinkinfo.h" />
<ClInclude Include="netsh.h" />
- <ClInclude Include="interfaceconfig.h" />
- <ClInclude Include="types.h" />
- <ClInclude Include="iclientsinkproxy.h" />
<ClInclude Include="confineoperation.h" />
+ <ClInclude Include="interfacesnap.h" />
+ <ClInclude Include="interfacemonitor.h" />
+ <ClInclude Include="registrypaths.h" />
+ <ClInclude Include="types.h" />
+ <ClInclude Include="inameserversource.h" />
+ <ClInclude Include="irecoverysink.h" />
+ <ClInclude Include="dnsagent.h" />
+ <ClInclude Include="recoverysink.h" />
+ <ClInclude Include="recoveryformatter.h" />
+ <ClInclude Include="nameserversource.h" />
+ <ClInclude Include="ilogsink.h" />
+ <ClInclude Include="logsink.h" />
+ <ClInclude Include="recoverylogic.h" />
</ItemGroup>
<ItemGroup>
<ClCompile Include="dllmain.cpp" />
<ClCompile Include="stdafx.cpp" />
<ClCompile Include="windns.cpp" />
<ClCompile Include="windnscontext.cpp" />
- <ClCompile Include="configmanager.cpp" />
- <ClCompile Include="netconfigeventsink.cpp" />
- <ClCompile Include="netconfighelpers.cpp" />
<ClCompile Include="netsh.cpp" />
- <ClCompile Include="interfaceconfig.cpp" />
<ClCompile Include="confineoperation.cpp" />
+ <ClCompile Include="interfacesnap.cpp" />
+ <ClCompile Include="interfacemonitor.cpp" />
+ <ClCompile Include="registrypaths.cpp" />
+ <ClCompile Include="dnsagent.cpp" />
+ <ClCompile Include="recoverysink.cpp" />
+ <ClCompile Include="recoveryformatter.cpp" />
+ <ClCompile Include="nameserversource.cpp" />
+ <ClCompile Include="logsink.cpp" />
+ <ClCompile Include="recoverylogic.cpp" />
</ItemGroup>
<ItemGroup>
<ResourceCompile Include="windns.rc" />
diff --git a/windows/windns/src/windns/windnscontext.cpp b/windows/windns/src/windns/windnscontext.cpp
index 29cb034566..cf80c0e619 100644
--- a/windows/windns/src/windns/windnscontext.cpp
+++ b/windows/windns/src/windns/windnscontext.cpp
@@ -1,131 +1,97 @@
#include "stdafx.h"
#include "windnscontext.h"
-#include "libcommon/wmi/connection.h"
-#include "netconfigeventsink.h"
-#include "netconfighelpers.h"
#include "confineoperation.h"
+#include "recoveryformatter.h"
+#include "recoverylogic.h"
#include <functional>
using namespace common;
-WinDnsContext::WinDnsContext()
+WinDnsContext::WinDnsContext(ILogSink *logSink)
+ : m_logSink(logSink)
{
- m_connection = std::make_shared<wmi::Connection>(wmi::Connection::Namespace::Cimv2);
+ if (nullptr == logSink)
+ {
+ throw std::runtime_error("Invalid logger sink");
+ }
}
WinDnsContext::~WinDnsContext()
{
- try
+ auto forwardError = [this](const char *msg, const char **details, uint32_t numDetails)
{
- reset();
- }
- catch (...)
+ m_logSink->error(msg, details, numDetails);
+ };
+
+ ConfineOperation("Reset DNS settings", forwardError, [this]()
{
- }
+ this->reset();
+ });
}
-void WinDnsContext::set(const std::vector<std::wstring> &servers, const ClientSinkInfo &sinkInfo)
+void WinDnsContext::set(const std::vector<std::wstring> &ipv4NameServers,
+ const std::vector<std::wstring> &ipv6NameServers, const RecoverySinkInfo &recoverySinkInfo)
{
- m_sinkInfo = sinkInfo;
+ //
+ // The 'sink' and 'source' instances must be kept alive for the lifetime of the agents.
+ //
- if (nullptr == m_notification)
+ if (!m_recoverySink)
{
- m_configManager = std::make_shared<ConfigManager>(servers, this);
-
- //
- // Register interface configuration monitoring.
- //
-
- auto eventSink = std::make_shared<NetConfigEventSink>(m_connection, m_configManager, this);
- auto eventDispatcher = CComPtr<wmi::IEventDispatcher>(new wmi::ModificationEventDispatcher(eventSink));
-
- m_notification = std::make_unique<wmi::Notification>(m_connection, eventDispatcher);
-
- m_notification->activate
- (
- L"SELECT * "
- L"FROM __InstanceModificationEvent "
- L"WITHIN 1 "
- L"WHERE TargetInstance ISA 'Win32_NetworkAdapterConfiguration'"
- L"AND TargetInstance.IPEnabled = True"
- );
+ m_recoverySink = std::make_unique<RecoverySink>(recoverySinkInfo);
}
else
{
- ConfigManager::Mutex mutex(*m_configManager);
+ m_recoverySink->setTarget(recoverySinkInfo);
+ }
- m_configManager->updateServers(servers);
+ if (!m_nameServerSource)
+ {
+ m_nameServerSource = std::make_unique<NameServerSource>(ipv4NameServers, ipv6NameServers);
+ }
+ else
+ {
+ m_nameServerSource->setNameServers(Protocol::IPv4, ipv4NameServers);
+ m_nameServerSource->setNameServers(Protocol::IPv6, ipv6NameServers);
}
//
- // Discover all active interfaces and apply our DNS settings.
+ // Instantiate agents unless they're already set up.
//
- auto resultSet = m_connection->query(L"SELECT * from Win32_NetworkAdapterConfiguration WHERE IPEnabled = True");
+ if (!m_ipv4Agent)
+ {
+ m_ipv4Agent = std::make_unique<DnsAgent>(Protocol::IPv4, m_nameServerSource.get(), m_recoverySink.get(), m_logSink);
+ }
- while (resultSet.advance())
+ if (!m_ipv6Agent)
{
- nchelpers::SetDnsServers(nchelpers::GetInterfaceIndex(resultSet.result()), servers);
+ m_ipv6Agent = std::make_unique<DnsAgent>(Protocol::IPv6, m_nameServerSource.get(), m_recoverySink.get(), m_logSink);
}
}
void WinDnsContext::reset()
{
- if (nullptr == m_notification)
+ if (!m_ipv4Agent && !m_ipv6Agent)
{
return;
}
- m_notification->deactivate();
- m_notification = nullptr;
-
- //
- // Reset adapter configs.
//
- // Safe to do without a mutex guarding the config manager.
+ // Destructing the agents will abort all monitoring + enforcing.
//
- // Try to reset as many adapters as possible, even if one or more fails to reset.
- //
-
- bool success = true;
-
- auto forwardError = std::bind(&WinDnsContext::error, this, std::placeholders::_1, std::placeholders::_2, std::placeholders::_3);
-
- m_configManager->processConfigs([&success, &forwardError](const InterfaceConfig &config)
- {
- const auto adapterStatus = ConfineOperation("Reset adapter DNS configuration", forwardError, [&config]()
- {
- nchelpers::RevertDnsServers(config);
- });
-
- if (false == adapterStatus)
- {
- success = false;
- }
- return true;
- });
-
- if (false == success)
+ if (m_ipv4Agent)
{
- throw std::runtime_error("Resetting DNS failed for one or more adapters");
+ m_ipv4Agent.reset(nullptr);
}
-}
-// IClientSinkProxy
-void WinDnsContext::error(const char *errorMessage, const char **details, uint32_t numDetails)
-{
- if (nullptr != m_sinkInfo.errorSinkInfo.sink)
+ if (m_ipv6Agent)
{
- m_sinkInfo.errorSinkInfo.sink(errorMessage, details, numDetails, m_sinkInfo.errorSinkInfo.context);
+ m_ipv6Agent.reset(nullptr);
}
-}
-// IClientSinkProxy
-void WinDnsContext::config(const void *configData, uint32_t dataLength)
-{
- if (nullptr != m_sinkInfo.configSinkInfo.sink)
- {
- m_sinkInfo.configSinkInfo.sink(configData, dataLength, m_sinkInfo.configSinkInfo.context);
- }
+ auto recoveryData = RecoveryFormatter::Unpack(m_recoverySink->recoveryData());
+
+ RecoveryLogic::RestoreInterfaces(recoveryData, m_logSink);
}
diff --git a/windows/windns/src/windns/windnscontext.h b/windows/windns/src/windns/windnscontext.h
index c83ec6e482..c64ab96f31 100644
--- a/windows/windns/src/windns/windnscontext.h
+++ b/windows/windns/src/windns/windnscontext.h
@@ -1,32 +1,34 @@
#pragma once
#include "windns.h"
-#include "libcommon/wmi/connection.h"
-#include "libcommon/wmi/notification.h"
-#include "configmanager.h"
#include "clientsinkinfo.h"
-#include "iclientsinkproxy.h"
+#include "ilogsink.h"
+#include "recoverysink.h"
+#include "nameserversource.h"
+#include "dnsagent.h"
#include <vector>
#include <string>
#include <memory>
-class WinDnsContext : public IClientSinkProxy
+class WinDnsContext
{
public:
- WinDnsContext();
+ WinDnsContext(ILogSink *logSink);
~WinDnsContext();
- void set(const std::vector<std::wstring> &servers, const ClientSinkInfo &sinkInfo);
- void reset();
+ void set(const std::vector<std::wstring> &ipv4NameServers, const std::vector<std::wstring> &ipv6NameServers,
+ const RecoverySinkInfo &recoverySinkInfo);
- void IClientSinkProxy::error(const char *errorMessage, const char **details, uint32_t numDetails) override;
- void IClientSinkProxy::config(const void *configData, uint32_t dataLength) override;
+ void reset();
private:
- std::shared_ptr<common::wmi::Connection> m_connection;
- std::shared_ptr<ConfigManager> m_configManager;
- std::unique_ptr<common::wmi::Notification> m_notification;
- ClientSinkInfo m_sinkInfo;
+ ILogSink *m_logSink;
+
+ std::unique_ptr<RecoverySink> m_recoverySink;
+ std::unique_ptr<NameServerSource> m_nameServerSource;
+
+ std::unique_ptr<DnsAgent> m_ipv4Agent;
+ std::unique_ptr<DnsAgent> m_ipv6Agent;
};