summaryrefslogtreecommitdiffhomepage
path: root/windows
diff options
context:
space:
mode:
Diffstat (limited to 'windows')
-rw-r--r--windows/driverlogic/driverlogic.vcxproj1
-rw-r--r--windows/driverlogic/driverlogic.vcxproj.filters1
-rw-r--r--windows/driverlogic/src/driverlogic.cpp30
-rw-r--r--windows/driverlogic/src/wireguard.h58
4 files changed, 89 insertions, 1 deletions
diff --git a/windows/driverlogic/driverlogic.vcxproj b/windows/driverlogic/driverlogic.vcxproj
index b91e97d86c..cc46c1ac72 100644
--- a/windows/driverlogic/driverlogic.vcxproj
+++ b/windows/driverlogic/driverlogic.vcxproj
@@ -117,6 +117,7 @@
<ClInclude Include="src\util.h" />
<ClInclude Include="src\version.h" />
<ClInclude Include="src\wintun.h" />
+ <ClInclude Include="src\wireguard.h" />
</ItemGroup>
<Import Project="$(VCTargetsPath)\Microsoft.Cpp.targets" />
<ImportGroup Label="ExtensionTargets">
diff --git a/windows/driverlogic/driverlogic.vcxproj.filters b/windows/driverlogic/driverlogic.vcxproj.filters
index 91ed2267da..9665231376 100644
--- a/windows/driverlogic/driverlogic.vcxproj.filters
+++ b/windows/driverlogic/driverlogic.vcxproj.filters
@@ -28,5 +28,6 @@
<ClInclude Include="src\util.h" />
<ClInclude Include="src\wintun.h" />
<ClInclude Include="src\devenum.h" />
+ <ClInclude Include="src\wireguard.h" />
</ItemGroup>
</Project> \ No newline at end of file
diff --git a/windows/driverlogic/src/driverlogic.cpp b/windows/driverlogic/src/driverlogic.cpp
index 93af20b400..3cb1739e21 100644
--- a/windows/driverlogic/src/driverlogic.cpp
+++ b/windows/driverlogic/src/driverlogic.cpp
@@ -5,6 +5,7 @@
#include "log.h"
#include "version.h"
#include "wintun.h"
+#include "wireguard.h"
#include "devenum.h"
#include <string>
#include <libcommon/error.h>
@@ -278,6 +279,32 @@ ReturnCode CommandWintunDeleteAbandonedDevice(const std::vector<std::wstring> &a
return GENERAL_SUCCESS;
}
+ReturnCode CommandWireGuardNtCleanup(const std::vector<std::wstring> &args)
+{
+ ArgumentContext argsContext(args);
+
+ argsContext.ensureExactArgumentCount(1);
+
+ const auto poolName = argsContext.next();
+
+ WireGuardNtDll wgNt;
+
+ BOOL rebootRequired;
+
+ if (FALSE == wgNt.deletePoolDriver(poolName.c_str(), &rebootRequired))
+ {
+ throw std::runtime_error("Failed to delete WireGuardNT pool");
+ }
+
+ std::wstringstream ss;
+
+ ss << L"Successfully deleted WireGuardNT pool. Reboot required: " << rebootRequired;
+
+ Log(ss.str());
+
+ return ReturnCode::GENERAL_SUCCESS;
+}
+
} // anonymous namespace
int wmain(int argc, const wchar_t *argv[])
@@ -325,7 +352,8 @@ int wmain(int argc, const wchar_t *argv[])
{ L"st-force-install", CommandSplitTunnelForceInstall },
{ L"st-remove", CommandSplitTunnelRemove },
{ L"wintun-delete-pool-driver", CommandWintunDeletePool },
- { L"wintun-delete-abandoned-device", CommandWintunDeleteAbandonedDevice }
+ { L"wintun-delete-abandoned-device", CommandWintunDeleteAbandonedDevice },
+ { L"wg-nt-cleanup", CommandWireGuardNtCleanup }
};
//
diff --git a/windows/driverlogic/src/wireguard.h b/windows/driverlogic/src/wireguard.h
new file mode 100644
index 0000000000..5892b248f1
--- /dev/null
+++ b/windows/driverlogic/src/wireguard.h
@@ -0,0 +1,58 @@
+#pragma once
+
+#include <wireguard-nt/wireguard.h>
+#include <libcommon/error.h>
+#include "util.h"
+
+class WireGuardNtDll
+{
+public:
+
+ WireGuardNtDll() : dllHandle(nullptr)
+ {
+ auto path = GetProcessModulePath().replace_filename(L"wireguard.dll");
+ dllHandle = LoadLibraryExW(path.c_str(), nullptr, LOAD_WITH_ALTERED_SEARCH_PATH);
+
+ if (nullptr == dllHandle)
+ {
+ THROW_WINDOWS_ERROR(GetLastError(), "LoadLibraryExW");
+ }
+
+ try
+ {
+ deletePoolDriver = getProcAddressOrThrow<WIREGUARD_DELETE_POOL_DRIVER_FUNC*>("WireGuardDeletePoolDriver");
+ }
+ catch (...)
+ {
+ FreeLibrary(dllHandle);
+ throw;
+ }
+ }
+
+ ~WireGuardNtDll()
+ {
+ if (nullptr != dllHandle)
+ {
+ FreeLibrary(dllHandle);
+ }
+ }
+
+ WIREGUARD_DELETE_POOL_DRIVER_FUNC *deletePoolDriver;
+
+private:
+
+ template<typename T>
+ T getProcAddressOrThrow(const char *procName)
+ {
+ const T result = reinterpret_cast<T>(GetProcAddress(dllHandle, procName));
+
+ if (nullptr == result)
+ {
+ THROW_WINDOWS_ERROR(GetLastError(), "GetProcAddress");
+ }
+
+ return result;
+ }
+
+ HMODULE dllHandle;
+};