diff options
35 files changed, 438 insertions, 988 deletions
diff --git a/Cargo.lock b/Cargo.lock index ba67067ff2..fc683a8d41 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1198,7 +1198,7 @@ dependencies = [ "libc", "option-ext", "redox_users", - "windows-sys 0.59.0", + "windows-sys 0.61.1", ] [[package]] @@ -2402,7 +2402,7 @@ dependencies = [ "serde", "talpid-platform-metadata", "tokio", - "windows-sys 0.52.0", + "windows-sys 0.61.1", "winres", ] @@ -3047,7 +3047,7 @@ dependencies = [ "talpid-types", "thiserror 2.0.9", "tokio", - "windows-sys 0.52.0", + "windows-sys 0.61.1", "winres", ] @@ -3105,7 +3105,7 @@ dependencies = [ "tokio-stream", "winapi", "windows-service", - "windows-sys 0.52.0", + "windows-sys 0.61.1", "winres", ] @@ -3202,7 +3202,7 @@ dependencies = [ "socket2 0.5.8", "talpid-windows", "tokio", - "windows-sys 0.52.0", + "windows-sys 0.61.1", ] [[package]] @@ -3252,7 +3252,7 @@ dependencies = [ "thiserror 2.0.9", "tokio", "typed-builder 0.21.0", - "windows-sys 0.52.0", + "windows-sys 0.61.1", ] [[package]] @@ -3272,7 +3272,7 @@ dependencies = [ "once_cell", "thiserror 2.0.9", "widestring", - "windows-sys 0.52.0", + "windows-sys 0.61.1", ] [[package]] @@ -3293,7 +3293,7 @@ dependencies = [ "thiserror 2.0.9", "tokio", "uuid", - "windows-sys 0.52.0", + "windows-sys 0.61.1", "winres", ] @@ -3352,7 +3352,7 @@ dependencies = [ "thiserror 2.0.9", "tokio", "windows-service", - "windows-sys 0.52.0", + "windows-sys 0.61.1", ] [[package]] @@ -5574,8 +5574,8 @@ dependencies = [ "windows 0.58.0", "windows-core 0.58.0", "windows-service", - "windows-sys 0.52.0", - "winreg 0.51.0", + "windows-sys 0.61.1", + "winreg 0.55.0", "wmi", ] @@ -5646,8 +5646,8 @@ dependencies = [ "uuid", "widestring", "winapi", - "windows-sys 0.52.0", - "winreg 0.51.0", + "windows-sys 0.61.1", + "winreg 0.55.0", ] [[package]] @@ -5669,7 +5669,7 @@ dependencies = [ "tonic-build", "tower", "winapi", - "windows-sys 0.52.0", + "windows-sys 0.61.1", "winres", ] @@ -5680,7 +5680,7 @@ dependencies = [ "rs-release", "talpid-dbus", "talpid-windows", - "windows-sys 0.52.0", + "windows-sys 0.61.1", ] [[package]] @@ -5704,7 +5704,7 @@ dependencies = [ "thiserror 2.0.9", "tokio", "widestring", - "windows-sys 0.52.0", + "windows-sys 0.61.1", ] [[package]] @@ -5731,7 +5731,7 @@ dependencies = [ "tokio", "tun 0.5.5", "tun 0.7.13", - "windows-sys 0.52.0", + "windows-sys 0.61.1", ] [[package]] @@ -5753,7 +5753,6 @@ dependencies = [ "tonic", "tonic-build", "tower", - "windows-sys 0.52.0", "zeroize", ] @@ -5780,7 +5779,7 @@ dependencies = [ "socket2 0.5.8", "talpid-types", "thiserror 2.0.9", - "windows-sys 0.52.0", + "windows-sys 0.61.1", ] [[package]] @@ -5799,7 +5798,6 @@ dependencies = [ "ipnetwork", "libc", "log", - "maybenot", "netlink-packet-core", "netlink-packet-route", "netlink-proto", @@ -5807,7 +5805,6 @@ dependencies = [ "once_cell", "parking_lot", "proptest", - "rand 0.8.5", "rand 0.9.2", "rand_chacha 0.3.1", "rtnetlink", @@ -5826,7 +5823,7 @@ dependencies = [ "tun 0.7.13", "tunnel-obfuscation", "widestring", - "windows-sys 0.52.0", + "windows-sys 0.61.1", "wireguard-go-rs", "zeroize", ] @@ -6749,12 +6746,24 @@ dependencies = [ [[package]] name = "windows" -version = "0.59.0" +version = "0.61.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9babd3a767a4c1aef6900409f85f5d53ce2544ccdfaa86dad48c91782c6d6893" +dependencies = [ + "windows-collections", + "windows-core 0.61.2", + "windows-future", + "windows-link 0.1.3", + "windows-numerics", +] + +[[package]] +name = "windows-collections" +version = "0.2.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "7f919aee0a93304be7f62e8e5027811bbba96bcb1de84d6618be56e43f8a32a1" +checksum = "3beeceb5e5cfd9eb1d76b381630e82c4241ccd0d27f1a39ed41b2760b255c5e8" dependencies = [ - "windows-core 0.59.0", - "windows-targets 0.53.0", + "windows-core 0.61.2", ] [[package]] @@ -6781,15 +6790,26 @@ dependencies = [ [[package]] name = "windows-core" -version = "0.59.0" +version = "0.61.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "810ce18ed2112484b0d4e15d022e5f598113e220c53e373fb31e67e21670c1ce" +checksum = "c0fdd3ddb90610c7638aa2b3a3ab2904fb9e5cdbecc643ddb3647212781c4ae3" dependencies = [ - "windows-implement 0.59.0", - "windows-interface 0.59.0", - "windows-result 0.3.0", - "windows-strings 0.3.0", - "windows-targets 0.53.0", + "windows-implement 0.60.1", + "windows-interface 0.59.2", + "windows-link 0.1.3", + "windows-result 0.3.4", + "windows-strings 0.4.2", +] + +[[package]] +name = "windows-future" +version = "0.2.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "fc6a41e98427b19fe4b73c550f060b59fa592d7d686537eebf9385621bfbad8e" +dependencies = [ + "windows-core 0.61.2", + "windows-link 0.1.3", + "windows-threading", ] [[package]] @@ -6805,9 +6825,9 @@ dependencies = [ [[package]] name = "windows-implement" -version = "0.59.0" +version = "0.60.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "83577b051e2f49a058c308f17f273b570a6a758386fc291b5f6a934dd84e48c1" +checksum = "edb307e42a74fb6de9bf3a02d9712678b22399c87e6fa869d6dfcd8c1b7754e0" dependencies = [ "proc-macro2", "quote", @@ -6839,9 +6859,9 @@ dependencies = [ [[package]] name = "windows-interface" -version = "0.59.0" +version = "0.59.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "cb26fd936d991781ea39e87c3a27285081e3c0da5ca0fcbc02d368cc6f52ff01" +checksum = "c0abd1ddbc6964ac14db11c7213d6532ef34bd9aa042c2e5935f59d7908b46a5" dependencies = [ "proc-macro2", "quote", @@ -6849,6 +6869,28 @@ dependencies = [ ] [[package]] +name = "windows-link" +version = "0.1.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5e6ad25900d524eaabdbbb96d20b4311e1e7ae1699af4fb28c17ae66c80d798a" + +[[package]] +name = "windows-link" +version = "0.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "45e46c0661abb7180e7b9c281db115305d49ca1709ab8242adf09666d2173c65" + +[[package]] +name = "windows-numerics" +version = "0.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9150af68066c4c5c07ddc0ce30421554771e528bde427614c61038bc2c92c2b1" +dependencies = [ + "windows-core 0.61.2", + "windows-link 0.1.3", +] + +[[package]] name = "windows-result" version = "0.2.0" source = "registry+https://github.com/rust-lang/crates.io-index" @@ -6859,11 +6901,11 @@ dependencies = [ [[package]] name = "windows-result" -version = "0.3.0" +version = "0.3.4" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d08106ce80268c4067c0571ca55a9b4e9516518eaa1a1fe9b37ca403ae1d1a34" +checksum = "56f42bd332cc6c8eac5af113fc0c1fd6a8fd2aa08a0119358686e5160d0586c6" dependencies = [ - "windows-targets 0.53.0", + "windows-link 0.1.3", ] [[package]] @@ -6889,11 +6931,11 @@ dependencies = [ [[package]] name = "windows-strings" -version = "0.3.0" +version = "0.4.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b888f919960b42ea4e11c2f408fadb55f78a9f236d5eef084103c8ce52893491" +checksum = "56e6c93f3a0c3b36176cb1327a4958a0353d5d166c2a35cb268ace15e91d3b57" dependencies = [ - "windows-targets 0.53.0", + "windows-link 0.1.3", ] [[package]] @@ -6946,6 +6988,15 @@ dependencies = [ ] [[package]] +name = "windows-sys" +version = "0.61.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6f109e41dd4a3c848907eb83d5a42ea98b3769495597450cf6d153507b166f0f" +dependencies = [ + "windows-link 0.2.0", +] + +[[package]] name = "windows-targets" version = "0.42.2" source = "registry+https://github.com/rust-lang/crates.io-index" @@ -6984,7 +7035,7 @@ dependencies = [ "windows_aarch64_gnullvm 0.52.6", "windows_aarch64_msvc 0.52.6", "windows_i686_gnu 0.52.6", - "windows_i686_gnullvm 0.52.6", + "windows_i686_gnullvm", "windows_i686_msvc 0.52.6", "windows_x86_64_gnu 0.52.6", "windows_x86_64_gnullvm 0.52.6", @@ -6992,19 +7043,12 @@ dependencies = [ ] [[package]] -name = "windows-targets" -version = "0.53.0" +name = "windows-threading" +version = "0.1.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b1e4c7e8ceaaf9cb7d7507c974735728ab453b67ef8f18febdd7c11fe59dca8b" +checksum = "b66463ad2e0ea3bbf808b7f1d371311c80e115c0b71d60efc142cafbcfb057a6" dependencies = [ - "windows_aarch64_gnullvm 0.53.0", - "windows_aarch64_msvc 0.53.0", - "windows_i686_gnu 0.53.0", - "windows_i686_gnullvm 0.53.0", - "windows_i686_msvc 0.53.0", - "windows_x86_64_gnu 0.53.0", - "windows_x86_64_gnullvm 0.53.0", - "windows_x86_64_msvc 0.53.0", + "windows-link 0.1.3", ] [[package]] @@ -7015,7 +7059,7 @@ dependencies = [ "talpid-types", "talpid-windows", "thiserror 2.0.9", - "windows 0.59.0", + "windows 0.61.3", ] [[package]] @@ -7037,12 +7081,6 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "32a4622180e7a0ec044bb555404c800bc9fd9ec262ec147edd5989ccd0c02cd3" [[package]] -name = "windows_aarch64_gnullvm" -version = "0.53.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "86b8d5f90ddd19cb4a147a5fa63ca848db3df085e25fee3cc10b39b6eebae764" - -[[package]] name = "windows_aarch64_msvc" version = "0.36.1" source = "registry+https://github.com/rust-lang/crates.io-index" @@ -7067,12 +7105,6 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "09ec2a7bb152e2252b53fa7803150007879548bc709c039df7627cabbd05d469" [[package]] -name = "windows_aarch64_msvc" -version = "0.53.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c7651a1f62a11b8cbd5e0d42526e55f2c99886c77e007179efff86c2b137e66c" - -[[package]] name = "windows_i686_gnu" version = "0.36.1" source = "registry+https://github.com/rust-lang/crates.io-index" @@ -7097,24 +7129,12 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "8e9b5ad5ab802e97eb8e295ac6720e509ee4c243f69d781394014ebfe8bbfa0b" [[package]] -name = "windows_i686_gnu" -version = "0.53.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c1dc67659d35f387f5f6c479dc4e28f1d4bb90ddd1a5d3da2e5d97b42d6272c3" - -[[package]] name = "windows_i686_gnullvm" version = "0.52.6" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "0eee52d38c090b3caa76c563b86c3a4bd71ef1a819287c19d586d7334ae8ed66" [[package]] -name = "windows_i686_gnullvm" -version = "0.53.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "9ce6ccbdedbf6d6354471319e781c0dfef054c81fbc7cf83f338a4296c0cae11" - -[[package]] name = "windows_i686_msvc" version = "0.36.1" source = "registry+https://github.com/rust-lang/crates.io-index" @@ -7139,12 +7159,6 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "240948bc05c5e7c6dabba28bf89d89ffce3e303022809e73deaefe4f6ec56c66" [[package]] -name = "windows_i686_msvc" -version = "0.53.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "581fee95406bb13382d2f65cd4a908ca7b1e4c2f1917f143ba16efe98a589b5d" - -[[package]] name = "windows_x86_64_gnu" version = "0.36.1" source = "registry+https://github.com/rust-lang/crates.io-index" @@ -7169,12 +7183,6 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "147a5c80aabfbf0c7d901cb5895d1de30ef2907eb21fbbab29ca94c5b08b1a78" [[package]] -name = "windows_x86_64_gnu" -version = "0.53.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "2e55b5ac9ea33f2fc1716d1742db15574fd6fc8dadc51caab1c16a3d3b4190ba" - -[[package]] name = "windows_x86_64_gnullvm" version = "0.42.2" source = "registry+https://github.com/rust-lang/crates.io-index" @@ -7193,12 +7201,6 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "24d5b23dc417412679681396f2b49f3de8c1473deb516bd34410872eff51ed0d" [[package]] -name = "windows_x86_64_gnullvm" -version = "0.53.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "0a6e035dd0599267ce1ee132e51c27dd29437f63325753051e71dd9e42406c57" - -[[package]] name = "windows_x86_64_msvc" version = "0.36.1" source = "registry+https://github.com/rust-lang/crates.io-index" @@ -7223,12 +7225,6 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "589f6da84c646204747d1270a2a5661ea66ed1cced2631d546fdfb155959f9ec" [[package]] -name = "windows_x86_64_msvc" -version = "0.53.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "271414315aff87387382ec3d271b52d7ae78726f5d44ac98b4f4030c91880486" - -[[package]] name = "winnow" version = "0.6.20" source = "registry+https://github.com/rust-lang/crates.io-index" @@ -7249,16 +7245,6 @@ dependencies = [ [[package]] name = "winreg" -version = "0.51.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "937f3df7948156640f46aacef17a70db0de5917bda9c92b0f751f3a955b588fc" -dependencies = [ - "cfg-if", - "windows-sys 0.48.0", -] - -[[package]] -name = "winreg" version = "0.55.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "cb5a765337c50e9ec252c2069be9bf91c7df47afb103b642ba3a53bf8101be97" @@ -7301,7 +7287,7 @@ dependencies = [ "maybenot-ffi", "talpid-types", "thiserror 2.0.9", - "windows-sys 0.52.0", + "windows-sys 0.61.1", "zeroize", ] diff --git a/Cargo.toml b/Cargo.toml index 39d3603f46..f71bab9b30 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -95,7 +95,9 @@ clap = { version = "4.4.18", features = ["cargo", "derive"] } once_cell = "1.16" serde = "1.0.204" serde_json = "1.0.122" -windows-sys = "0.52.0" +windows = "0.61.0" +windows-sys = "0.61.0" +winreg = "0.55" nix = "0.30.1" strum = { version = "0.27" } rand = "0.9" diff --git a/desktop/packages/windows-utils/Cargo.toml b/desktop/packages/windows-utils/Cargo.toml index f64ea99812..a52b025a58 100644 --- a/desktop/packages/windows-utils/Cargo.toml +++ b/desktop/packages/windows-utils/Cargo.toml @@ -17,7 +17,7 @@ path = "windows-utils-rs/lib.rs" [target.'cfg(target_os = "windows")'.dependencies] neon = "1" -windows = { version = "0.59.0", features = ["Win32", "Win32_UI", "Win32_UI_Shell", "Win32_System", "Win32_System_Com", "Win32_Storage_FileSystem"] } +windows = { workspace = true, features = ["Win32", "Win32_UI", "Win32_UI_Shell", "Win32_System", "Win32_System_Com", "Win32_Storage_FileSystem"] } thiserror = { workspace = true } talpid-types = { path = "../../../talpid-types" } diff --git a/installer-downloader/src/winapi_impl/delegate.rs b/installer-downloader/src/winapi_impl/delegate.rs index 1734642771..475fc58786 100644 --- a/installer-downloader/src/winapi_impl/delegate.rs +++ b/installer-downloader/src/winapi_impl/delegate.rs @@ -3,7 +3,7 @@ use installer_downloader::delegate::ErrorMessage; use native_windows_gui::{self as nwg, Event}; -use windows_sys::Win32::UI::WindowsAndMessaging::PostMessageW; +use windows_sys::Win32::{Foundation::HWND, UI::WindowsAndMessaging::PostMessageW}; use super::ui::{AppWindow, QUEUE_MESSAGE}; use crate::delegate::{AppDelegate, AppDelegateQueue}; @@ -247,6 +247,6 @@ impl AppDelegateQueue<AppWindow> for Queue { }; let context_ptr = Box::into_raw(Box::new(context)); // SAFETY: This is safe since `callback` is Send - unsafe { PostMessageW(hwnd as isize, QUEUE_MESSAGE, 0, context_ptr as isize) }; + unsafe { PostMessageW(hwnd as HWND, QUEUE_MESSAGE, 0, context_ptr as isize) }; } } diff --git a/installer-downloader/src/winapi_impl/ui.rs b/installer-downloader/src/winapi_impl/ui.rs index 6c44f36101..df9439ccb0 100644 --- a/installer-downloader/src/winapi_impl/ui.rs +++ b/installer-downloader/src/winapi_impl/ui.rs @@ -8,7 +8,8 @@ use native_windows_gui::{self as nwg, ControlHandle, ImageDecoder, WindowFlags}; use windows_sys::Win32::Foundation::COLORREF; use windows_sys::Win32::Graphics::Gdi::{ - COLOR_WINDOW, CreateFontIndirectW, LOGFONTW, SetBkColor, SetBkMode, SetTextColor, TRANSPARENT, + COLOR_WINDOW, CreateFontIndirectW, HDC, LOGFONTW, SetBkColor, SetBkMode, SetTextColor, + TRANSPARENT, }; use windows_sys::Win32::UI::WindowsAndMessaging::WM_CTLCOLORSTATIC; @@ -392,8 +393,8 @@ fn handle_banner_label_colors( if msg == WM_CTLCOLORSTATIC { // SAFETY: `w` is a valid device context for WM_CTLCOLORSTATIC unsafe { - SetTextColor(w as isize, rgb([255, 255, 255])); - SetBkColor(w as isize, rgb(BACKGROUND_COLOR)); + SetTextColor(w as HDC, rgb([255, 255, 255])); + SetBkColor(w as HDC, rgb(BACKGROUND_COLOR)); } } None @@ -411,8 +412,8 @@ fn handle_link_messages( if msg == WM_CTLCOLORSTATIC && Some(p) == link_hwnd { // SAFETY: `w` is a valid device context for WM_CTLCOLORSTATIC unsafe { - SetBkMode(w as isize, TRANSPARENT as _); - SetTextColor(w as isize, rgb(LINK_COLOR)); + SetBkMode(w as HDC, TRANSPARENT as _); + SetTextColor(w as HDC, rgb(LINK_COLOR)); } // Out of bounds background return Some(COLOR_WINDOW as isize); @@ -488,7 +489,7 @@ fn create_link_font() -> Result<&'static nwg::Font, nwg::NwgError> { // SAFETY: `logfont` is a valid font let raw_font = unsafe { CreateFontIndirectW(&logfont) }; - if raw_font == 0 { + if raw_font.is_null() { return Err(nwg::NwgError::Unknown); } diff --git a/mullvad-daemon/Cargo.toml b/mullvad-daemon/Cargo.toml index dfc1607467..5b1132fa10 100644 --- a/mullvad-daemon/Cargo.toml +++ b/mullvad-daemon/Cargo.toml @@ -104,9 +104,7 @@ mullvad-version = { path = "../mullvad-version" } [target.'cfg(windows)'.build-dependencies.windows-sys] workspace = true -features = [ - "Win32_System_SystemServices", -] +features = ["Win32_System_SystemServices"] [package.metadata.winres] ProductName = "Mullvad VPN" diff --git a/mullvad-daemon/src/migrations/mod.rs b/mullvad-daemon/src/migrations/mod.rs index 284928073e..ea7d4523ee 100644 --- a/mullvad-daemon/src/migrations/mod.rs +++ b/mullvad-daemon/src/migrations/mod.rs @@ -243,11 +243,12 @@ mod windows { use talpid_types::ErrorExt; use tokio::fs; use windows_sys::Win32::{ - Foundation::{ERROR_SUCCESS, LocalFree, PSID}, + Foundation::{ERROR_SUCCESS, LocalFree}, Security::{ Authorization::{GetNamedSecurityInfoW, SE_FILE_OBJECT, SE_OBJECT_TYPE}, - IsWellKnownSid, OWNER_SECURITY_INFORMATION, PSECURITY_DESCRIPTOR, SECURITY_DESCRIPTOR, - SID, WELL_KNOWN_SID_TYPE, WinBuiltinAdministratorsSid, WinLocalSystemSid, + IsWellKnownSid, OWNER_SECURITY_INFORMATION, PSECURITY_DESCRIPTOR, PSID, + SECURITY_DESCRIPTOR, SID, WELL_KNOWN_SID_TYPE, WinBuiltinAdministratorsSid, + WinLocalSystemSid, }, }; diff --git a/mullvad-paths/src/windows.rs b/mullvad-paths/src/windows.rs index a81cf3ae34..ae58891e3e 100644 --- a/mullvad-paths/src/windows.rs +++ b/mullvad-paths/src/windows.rs @@ -1,11 +1,15 @@ -#![allow(clippy::undocumented_unsafe_blocks)] // Remove me if you dare. - use crate::{Error, Result, UserPermissions}; use once_cell::sync::OnceCell; use std::{ ffi::OsStr, io, mem, - os::windows::prelude::OsStrExt, + os::windows::{ + io::{ + AsHandle, AsRawHandle, BorrowedHandle, FromRawHandle, HandleOrNull, IntoRawHandle, + OwnedHandle, + }, + prelude::OsStrExt, + }, path::{Path, PathBuf}, ptr, }; @@ -13,8 +17,8 @@ use widestring::{WideCStr, WideCString}; use windows_sys::{ Win32::{ Foundation::{ - CloseHandle, ERROR_INSUFFICIENT_BUFFER, ERROR_SUCCESS, GENERIC_ALL, GENERIC_EXECUTE, - GENERIC_READ, GENERIC_WRITE, HANDLE, INVALID_HANDLE_VALUE, LUID, LocalFree, S_OK, + ERROR_INSUFFICIENT_BUFFER, ERROR_SUCCESS, GENERIC_ALL, GENERIC_EXECUTE, GENERIC_READ, + GENERIC_WRITE, HANDLE, LUID, LocalFree, S_OK, }, Security::{ self, AdjustTokenPrivileges, @@ -74,18 +78,6 @@ pub fn create_dir(path: PathBuf, user_permissions: Option<UserPermissions>) -> R Ok(path) } -struct Handle(HANDLE); - -impl Drop for Handle { - fn drop(&mut self) { - if self.0 != 0 && self.0 != INVALID_HANDLE_VALUE { - unsafe { - CloseHandle(self.0); - } - } - } -} - fn get_wide_str<S: AsRef<OsStr>>(string: S) -> Vec<u16> { let wide_string: Vec<u16> = string.as_ref() .encode_wide() @@ -178,6 +170,8 @@ fn set_security_permissions(path: &Path, user_permissions: UserPermissions) -> R let mut admin_psid = [0u8; MAX_SID_SIZE as usize]; let mut admin_psid_len = u32::try_from(admin_psid.len()).unwrap(); + + // SAFETY: The pointer to the PSID is valid for writes of `admin_psid_len` bytes if unsafe { CreateWellKnownSid( WinBuiltinAdministratorsSid, @@ -210,6 +204,8 @@ fn set_security_permissions(path: &Path, user_permissions: UserPermissions) -> R let mut au_psid = [0u8; MAX_SID_SIZE as usize]; let mut au_psid_len = u32::try_from(au_psid.len()).unwrap(); + + // SAFETY: The pointer to the PSID is valid for writes of `au_psid_len` bytes if unsafe { CreateWellKnownSid( WinAuthenticatedUserSid, @@ -243,6 +239,8 @@ fn set_security_permissions(path: &Path, user_permissions: UserPermissions) -> R let ea_entries = [admin_ea, authenticated_users_ea]; let mut new_dacl = ptr::null_mut(); + // SAFETY: `ea_entries` is valid for reads of `ea_entries.len()` elements + // `new_dacl` is a valid pointer to an ACL pointer let result = unsafe { SetEntriesInAclW( u32::try_from(ea_entries.len()).unwrap(), @@ -261,6 +259,7 @@ fn set_security_permissions(path: &Path, user_permissions: UserPermissions) -> R } // new_dacl is now allocated and must be freed with FreeLocal + // SAFETY: All pointers are valid let result = unsafe { SetNamedSecurityInfoW( wide_path.as_ptr(), @@ -273,6 +272,7 @@ fn set_security_permissions(path: &Path, user_permissions: UserPermissions) -> R ) }; + // SAFETY: `new_dacl` is a valid pointer since `SetEntriesInAclW` succeeded unsafe { LocalFree(new_dacl.cast()) }; if result != ERROR_SUCCESS { @@ -295,8 +295,15 @@ pub fn get_system_service_appdata() -> io::Result<PathBuf> { .get_or_try_init(|| { let join_handle = std::thread::spawn(|| { impersonate_self(|| { - let user_token = get_system_user_token()?; - get_known_folder_path(&FOLDERID_LocalAppData, KF_FLAG_DEFAULT, user_token.0) + let user_token = OwnedHandle::try_from(get_system_user_token()?).ok(); + // SAFETY: `FOLDERID_LocalAppData` is a valid known folder ID + unsafe { + get_known_folder_path( + &FOLDERID_LocalAppData, + KF_FLAG_DEFAULT, + user_token.as_ref().map(|t| t.as_handle()), + ) + } }) .or_else(|error| { log::error!("Failed to get AppData path: {error}"); @@ -311,54 +318,60 @@ pub fn get_system_service_appdata() -> io::Result<PathBuf> { /// Get user token for the system service user. Requires elevated privileges to work. /// Useful for deducing the config path for the daemon on Windows when running as a user that /// isn't the system service. -/// If the current user is system, this function succeeds and returns a `NULL` handle; -fn get_system_user_token() -> io::Result<Handle> { +/// If the current user is system, this function succeeds and returns a NULL handle +fn get_system_user_token() -> io::Result<HandleOrNull> { let thread_token = get_current_thread_token()?; - if is_local_system_user_token(thread_token.0)? { - return Ok(Handle(0)); + if is_local_system_user_token(&thread_token)? { + // SAFETY: It is safe to pass a null handle + return Ok(unsafe { HandleOrNull::from_raw_handle(ptr::null_mut()) }); } let system_debug_priv = WideCString::from_str("SeDebugPrivilege").unwrap(); - adjust_token_privilege(thread_token.0, &system_debug_priv, true)?; + adjust_token_privilege(&thread_token, &system_debug_priv, true)?; let find_result = find_process(|process_handle| { let process_token = open_process_token( - process_handle, + &process_handle, GENERIC_READ | TOKEN_IMPERSONATE | TOKEN_DUPLICATE, ) .ok()?; - match is_local_system_user_token(process_token.0) { + match is_local_system_user_token(&process_token) { Ok(true) => Some(process_token), _ => None, } }); - if let Err(err) = adjust_token_privilege(thread_token.0, &system_debug_priv, false) { + if let Err(err) = adjust_token_privilege(&thread_token, &system_debug_priv, false) { log::error!("Failed to drop SeDebugPrivilege: {}", err); } - find_result + // SAFETY: The handle is valid + find_result.map(|h| unsafe { HandleOrNull::from_raw_handle(h.into_raw_handle()) }) } -fn open_process_token(process: HANDLE, access: u32) -> io::Result<Handle> { - let mut process_token = 0; - if unsafe { OpenProcessToken(process, access, &mut process_token) } == 0 { +fn open_process_token(process: &impl AsRawHandle, access: u32) -> io::Result<OwnedHandle> { + let mut process_token = ptr::null_mut(); + // SAFETY: `process` is a valid handle + if unsafe { OpenProcessToken(process.as_raw_handle(), access, &mut process_token) } == 0 { return Err(io::Error::last_os_error()); } - Ok(Handle(process_token)) + // SAFETY: `process_token` is a valid handle since `OpenProcessToken` succeeded + Ok(unsafe { OwnedHandle::from_raw_handle(process_token) }) } /// If all else fails, infer the AppData path from the system directory. fn infer_appdata_from_system_directory() -> io::Result<PathBuf> { - let mut sysdir = get_known_folder_path(&FOLDERID_System, KF_FLAG_DEFAULT, 0)?; + // SAFETY: `FOLDERID_System` is a valid known folder ID + let mut sysdir = unsafe { get_known_folder_path(&FOLDERID_System, KF_FLAG_DEFAULT, None) }?; sysdir.extend(["config", "systemprofile", "AppData", "Local"]); Ok(sysdir) } -fn get_current_thread_token() -> std::io::Result<Handle> { - let mut token_handle: HANDLE = 0; +fn get_current_thread_token() -> std::io::Result<OwnedHandle> { + let mut token_handle: HANDLE = ptr::null_mut(); + // SAFETY: `GetCurrentThread` always returns a valid handle if unsafe { OpenThreadToken( GetCurrentThread(), @@ -370,16 +383,19 @@ fn get_current_thread_token() -> std::io::Result<Handle> { { return Err(std::io::Error::last_os_error()); } - Ok(Handle(token_handle)) + // SAFETY: `token_handle` is a valid handle since `OpenThreadToken` succeeded + Ok(unsafe { OwnedHandle::from_raw_handle(token_handle) }) } fn impersonate_self<T>(func: impl FnOnce() -> io::Result<T>) -> io::Result<T> { + // SAFETY: Trivially safe if unsafe { ImpersonateSelf(SecurityImpersonation) } == 0 { return Err(std::io::Error::last_os_error()); } let result = func(); + // SAFETY: Trivially safe if unsafe { RevertToSelf() } == 0 { log::error!("RevertToSelf failed: {}", io::Error::last_os_error()); } @@ -388,13 +404,13 @@ fn impersonate_self<T>(func: impl FnOnce() -> io::Result<T>) -> io::Result<T> { } fn adjust_token_privilege( - token_handle: HANDLE, + token_handle: &impl AsRawHandle, privilege: &WideCStr, enable: bool, ) -> std::io::Result<()> { - // SAFETY: LUID is a C struct and can safely be zeroed. - let mut privilege_luid: LUID = unsafe { mem::zeroed() }; + let mut privilege_luid = LUID::default(); + // SAFETY: `privilege` is a valid null-terminated string, and `privilege_luid` points to a LUID if unsafe { LookupPrivilegeValueW(ptr::null(), privilege.as_ptr(), &mut privilege_luid) } == 0 { return Err(std::io::Error::last_os_error()); } @@ -406,9 +422,10 @@ fn adjust_token_privilege( Attributes: if enable { SE_PRIVILEGE_ENABLED } else { 0 }, }], }; + // SAFETY: All pointers are valid let result = unsafe { AdjustTokenPrivileges( - token_handle, + token_handle.as_raw_handle(), 0, &privileges, 0, @@ -426,15 +443,30 @@ fn adjust_token_privilege( Ok(()) } -fn get_known_folder_path( +/// Retrieve path to a known folder for a specific user token. +/// +/// # Safety +/// +/// `folder_id` must be a valid pointer to a known folder ID GUID. +unsafe fn get_known_folder_path( folder_id: *const GUID, flags: i32, - user_token: HANDLE, + user_token: Option<BorrowedHandle<'_>>, ) -> std::io::Result<PathBuf> { let mut folder_path: PWSTR = ptr::null_mut(); - let status = - unsafe { SHGetKnownFolderPath(folder_id, flags as u32, user_token, &mut folder_path) }; + // SAFETY: All arguments are valid + let status = unsafe { + SHGetKnownFolderPath( + folder_id, + flags as u32, + user_token + .map(|h| h.as_raw_handle()) + .unwrap_or(ptr::null_mut()), + &mut folder_path, + ) + }; let result = if status == S_OK { + // SAFETY: `folder_path` is valid and null-terminated since `SHGetKnownFolderPath` succeeded let path = unsafe { WideCStr::from_ptr_str(folder_path) }; Ok(PathBuf::from(path.to_os_string())) } else { @@ -444,18 +476,21 @@ fn get_known_folder_path( )) }; + // SAFETY: `folder_path` was allocated by `SHGetKnownFolderPath` and must be freed with `CoTaskMemFree unsafe { CoTaskMemFree(folder_path as *mut _) }; result } /// Enumerate over all processes until `handle_process` returns a result or until there are /// no more processes left. In the latter case, an error is returned. -fn find_process<T>(handle_process: impl Fn(HANDLE) -> Option<T>) -> io::Result<T> { +fn find_process<T>(handle_process: impl Fn(BorrowedHandle<'_>) -> Option<T>) -> io::Result<T> { let mut pid_buffer = vec![0u32; 2048]; let mut num_procs: u32 = u32::try_from(pid_buffer.len()).unwrap(); let bytes_available = num_procs * (mem::size_of::<u32>() as u32); let mut bytes_written = 0; + + // SAFETY: `pid_buffer` is valid for writes of `bytes_available` bytes if unsafe { EnumProcesses(pid_buffer.as_mut_ptr(), bytes_available, &mut bytes_written) } == 0 { return Err(io::Error::last_os_error()); } @@ -466,12 +501,14 @@ fn find_process<T>(handle_process: impl Fn(HANDLE) -> Option<T>) -> io::Result<T pid_buffer .into_iter() .find_map(|process| { - let process_handle = - Handle(unsafe { OpenProcess(PROCESS_QUERY_INFORMATION, 0, process) }); - if process_handle.0 == 0 { + // SAFETY: Trivially safe + let process_handle = unsafe { OpenProcess(PROCESS_QUERY_INFORMATION, 0, process) }; + if process_handle.is_null() { return None; } - handle_process(process_handle.0) + // SAFETY: `process_handle` is a valid handle since `OpenProcess` succeeded + let process_handle = unsafe { OwnedHandle::from_raw_handle(process_handle) }; + handle_process(process_handle.as_handle()) }) .ok_or(io::Error::new( io::ErrorKind::NotFound, @@ -479,15 +516,17 @@ fn find_process<T>(handle_process: impl Fn(HANDLE) -> Option<T>) -> io::Result<T )) } -fn is_local_system_user_token(token: HANDLE) -> io::Result<bool> { +fn is_local_system_user_token(token: &impl AsRawHandle) -> io::Result<bool> { let mut token_info = vec![0u8; 1024]; loop { let mut returned_info_len = 0; + // SAFETY: `token` is a valid handle, and `token_info` is valid for writes of + // `token_info.len()` bytes let info_result = unsafe { GetTokenInformation( - token, + token.as_raw_handle(), TokenUser, token_info.as_mut_ptr() as _, u32::try_from(token_info.len()).expect("len must fit in u32"), @@ -517,6 +556,7 @@ fn is_local_system_user_token(token: HANDLE) -> io::Result<bool> { let mut local_system_sid = [0u8; MAX_SID_SIZE as usize]; let mut local_system_size = u32::try_from(local_system_sid.len()).unwrap(); + // SAFETY: `local_system_sid` is valid for writes of `local_system_size` bytes if unsafe { CreateWellKnownSid( WinLocalSystemSid, @@ -531,5 +571,6 @@ fn is_local_system_user_token(token: HANDLE) -> io::Result<bool> { return Err(err); } + // SAFETY: Both arguments point to valid security identifiers Ok(unsafe { EqualSid(token_user.User.Sid, local_system_sid.as_ptr() as _) } != 0) } diff --git a/talpid-core/Cargo.toml b/talpid-core/Cargo.toml index 0c14671c6e..4f99bd9c5a 100644 --- a/talpid-core/Cargo.toml +++ b/talpid-core/Cargo.toml @@ -71,7 +71,7 @@ talpid-net = { path = "../talpid-net" } bitflags = "2.6" csv = "1.3.1" widestring = "1.0" -winreg = { version = "0.51", features = ["transactions"] } +winreg = { workspace = true, features = ["transactions"] } memoffset = "0.6" once_cell = { workspace = true } windows-service = "0.6.0" diff --git a/talpid-core/src/dns/windows/dnsapi.rs b/talpid-core/src/dns/windows/dnsapi.rs index 6a71dd6ce4..0465b233e6 100644 --- a/talpid-core/src/dns/windows/dnsapi.rs +++ b/talpid-core/src/dns/windows/dnsapi.rs @@ -6,7 +6,6 @@ use std::{ }, time::{Duration, Instant}, }; -use windows_sys::Win32::Foundation::BOOL; static FLUSH_TIMEOUT: Duration = Duration::from_secs(5); static DNSAPI_HANDLE: OnceLock<DnsApi> = OnceLock::new(); @@ -64,7 +63,7 @@ impl DnsApi { let begin = Instant::now(); // SAFETY: this function is trivially safe to call - let result = if unsafe { (DnsFlushResolverCache)() } != 0 { + let result = if unsafe { (DnsFlushResolverCache)() } { let elapsed = begin.elapsed(); if elapsed >= FLUSH_TIMEOUT { log::warn!( @@ -93,5 +92,5 @@ impl DnsApi { #[link(name = "dnsapi")] unsafe extern "system" { // Flushes the DNS resolver cache - pub fn DnsFlushResolverCache() -> BOOL; + pub fn DnsFlushResolverCache() -> bool; } diff --git a/talpid-core/src/dns/windows/iphlpapi.rs b/talpid-core/src/dns/windows/iphlpapi.rs index d862e15036..cfb61d2771 100644 --- a/talpid-core/src/dns/windows/iphlpapi.rs +++ b/talpid-core/src/dns/windows/iphlpapi.rs @@ -72,8 +72,14 @@ static IPHLPAPI_HANDLE: OnceCell<IphlpApi> = OnceCell::new(); impl IphlpApi { fn new() -> Result<Self, Error> { - let module = unsafe { LoadLibraryExW(w!("iphlpapi.dll"), 0, LOAD_LIBRARY_SEARCH_SYSTEM32) }; - if module == 0 { + let module = unsafe { + LoadLibraryExW( + w!("iphlpapi.dll"), + ptr::null_mut(), + LOAD_LIBRARY_SEARCH_SYSTEM32, + ) + }; + if module.is_null() { log::error!("Failed to load iphlpapi.dll"); return Err(Error::LoadDll(io::Error::last_os_error())); } diff --git a/talpid-core/src/dns/windows/netsh.rs b/talpid-core/src/dns/windows/netsh.rs index 2f8a4d77d1..624fe34c69 100644 --- a/talpid-core/src/dns/windows/netsh.rs +++ b/talpid-core/src/dns/windows/netsh.rs @@ -170,7 +170,7 @@ fn wait_for_child(subproc: &mut Child, timeout: Duration) -> io::Result<Option<E let dur_millis = u32::try_from(timeout.as_millis()).unwrap_or(INFINITE); let subproc_handle = subproc.as_raw_handle(); - match unsafe { WaitForSingleObject(subproc_handle as isize, dur_millis) } { + match unsafe { WaitForSingleObject(subproc_handle, dur_millis) } { WAIT_OBJECT_0 => subproc.try_wait(), WAIT_TIMEOUT => Ok(None), _error => Err(io::Error::last_os_error()), diff --git a/talpid-core/src/split_tunnel/windows/driver.rs b/talpid-core/src/split_tunnel/windows/driver.rs index 27f10d72dd..5cbd172266 100644 --- a/talpid-core/src/split_tunnel/windows/driver.rs +++ b/talpid-core/src/split_tunnel/windows/driver.rs @@ -28,7 +28,7 @@ use talpid_windows::{io::Overlapped, process::ProcessSnapshot, sync::Event}; use windows_sys::Win32::{ Foundation::{ ERROR_ACCESS_DENIED, ERROR_FILE_NOT_FOUND, ERROR_INVALID_PARAMETER, ERROR_IO_PENDING, - HANDLE, NTSTATUS, WAIT_ABANDONED, WAIT_ABANDONED_0, WAIT_FAILED, WAIT_OBJECT_0, + NTSTATUS, WAIT_ABANDONED, WAIT_ABANDONED_0, WAIT_FAILED, WAIT_OBJECT_0, }, Networking::WinSock::{IN_ADDR, IN6_ADDR}, Storage::FileSystem::FILE_FLAG_OVERLAPPED, @@ -849,7 +849,7 @@ pub unsafe fn device_io_control_buffer_async( let result = unsafe { DeviceIoControl( - device.as_raw_handle() as HANDLE, + device.as_raw_handle(), ioctl_code, input_ptr, u32::try_from(input_len).map_err(|_error| { @@ -888,13 +888,13 @@ pub fn get_overlapped_result( let event = overlapped.get_event().unwrap(); // SAFETY: This is a valid event object. - unsafe { wait_for_single_object(event.as_raw(), None) }?; + unsafe { wait_for_single_object(event, None) }?; // SAFETY: The handle and overlapped object are valid. let mut returned_bytes = 0u32; let result = unsafe { GetOverlappedResult( - device.as_raw_handle() as HANDLE, + device.as_raw_handle(), overlapped.as_mut_ptr(), &mut returned_bytes, 0, @@ -911,14 +911,17 @@ pub fn get_overlapped_result( /// # Safety /// /// * `object` must be a valid object that can be signaled, such as an event object. -pub unsafe fn wait_for_single_object(object: HANDLE, timeout: Option<Duration>) -> io::Result<()> { +pub unsafe fn wait_for_single_object( + object: &impl AsRawHandle, + timeout: Option<Duration>, +) -> io::Result<()> { let timeout = match timeout { Some(timeout) => u32::try_from(timeout.as_millis()).map_err(|_error| { io::Error::new(io::ErrorKind::InvalidInput, "the duration is too long") })?, None => INFINITE, }; - let result = unsafe { WaitForSingleObject(object, timeout) }; + let result = unsafe { WaitForSingleObject(object.as_raw_handle(), timeout) }; match result { WAIT_OBJECT_0 => Ok(()), WAIT_FAILED => Err(io::Error::last_os_error()), @@ -933,7 +936,10 @@ pub unsafe fn wait_for_single_object(object: HANDLE, timeout: Option<Duration>) /// # Safety /// /// * `objects` must be a slice of valid objects that can be signaled, such as event objects. -pub unsafe fn wait_for_multiple_objects(objects: &[HANDLE], wait_all: bool) -> io::Result<HANDLE> { +pub unsafe fn wait_for_multiple_objects( + objects: &[RawHandle], + wait_all: bool, +) -> io::Result<RawHandle> { unsafe { let objects_len = u32::try_from(objects.len()) .map_err(|_error| io::Error::new(io::ErrorKind::InvalidInput, "too many objects"))?; diff --git a/talpid-core/src/split_tunnel/windows/mod.rs b/talpid-core/src/split_tunnel/windows/mod.rs index 8b9db7843c..4c80c9e147 100644 --- a/talpid-core/src/split_tunnel/windows/mod.rs +++ b/talpid-core/src/split_tunnel/windows/mod.rs @@ -13,6 +13,7 @@ use std::{ ffi::{OsStr, OsString}, io, net::{IpAddr, Ipv4Addr, Ipv6Addr}, + os::windows::io::AsRawHandle, path::{Path, PathBuf}, sync::{ Arc, Mutex, MutexGuard, RwLock, Weak, @@ -245,9 +246,7 @@ impl SplitTunnel { overlapped: &mut Overlapped, data_buffer: &mut Vec<u8>, ) -> io::Result<EventResult> { - if unsafe { driver::wait_for_single_object(quit_event.as_raw(), Some(Duration::ZERO)) } - .is_ok() - { + if unsafe { driver::wait_for_single_object(quit_event, Some(Duration::ZERO)) }.is_ok() { return Ok(EventResult::Quit); } @@ -271,8 +270,8 @@ impl SplitTunnel { })?; let event_objects = [ - overlapped.get_event().unwrap().as_raw(), - quit_event.as_raw(), + overlapped.get_event().unwrap().as_raw_handle(), + quit_event.as_raw_handle(), ]; let signaled_object = @@ -285,7 +284,7 @@ impl SplitTunnel { }, )?; - if signaled_object == quit_event.as_raw() { + if signaled_object == quit_event.as_raw_handle() { // Quit event was signaled return Ok(EventResult::Quit); } diff --git a/talpid-core/src/split_tunnel/windows/path_monitor.rs b/talpid-core/src/split_tunnel/windows/path_monitor.rs index a4ea65ddde..b7fd5cc337 100644 --- a/talpid-core/src/split_tunnel/windows/path_monitor.rs +++ b/talpid-core/src/split_tunnel/windows/path_monitor.rs @@ -280,7 +280,7 @@ impl DirContext { ) }; - if handle == 0 { + if handle.is_null() { return Err(io::Error::last_os_error()); } @@ -356,9 +356,10 @@ struct CompletionPort { impl CompletionPort { // `concurrent_threads`: 0 ==> number of processors fn create(concurrent_threads: u32) -> io::Result<Self> { - let handle = - unsafe { CreateIoCompletionPort(INVALID_HANDLE_VALUE, 0, 0, concurrent_threads) }; - if handle == 0 { + let handle = unsafe { + CreateIoCompletionPort(INVALID_HANDLE_VALUE, ptr::null_mut(), 0, concurrent_threads) + }; + if handle.is_null() { return Err(io::Error::last_os_error()); } Ok(CompletionPort { handle }) diff --git a/talpid-core/src/split_tunnel/windows/windows.rs b/talpid-core/src/split_tunnel/windows/windows.rs index 1d6c03d39c..fd042802f4 100644 --- a/talpid-core/src/split_tunnel/windows/windows.rs +++ b/talpid-core/src/split_tunnel/windows/windows.rs @@ -139,8 +139,7 @@ pub fn open_process( pid: u32, ) -> Result<WinHandle, io::Error> { let handle = unsafe { OpenProcess(access as u32, if inherit_handle { 1 } else { 0 }, pid) }; - - if handle == 0 { + if handle.is_null() { return Err(io::Error::last_os_error()); } Ok(WinHandle(handle)) diff --git a/talpid-core/src/window.rs b/talpid-core/src/window.rs index 69373991f2..b0615c8ed7 100644 --- a/talpid-core/src/window.rs +++ b/talpid-core/src/window.rs @@ -52,8 +52,8 @@ pub fn create_hidden_window<F: (Fn(HWND, u32, WPARAM, LPARAM) -> LRESULT) + Send 0, 0, 0, - 0, - 0, + ptr::null_mut(), + ptr::null_mut(), GetModuleHandleW(ptr::null_mut()), ptr::null_mut(), ) @@ -77,7 +77,7 @@ pub fn create_hidden_window<F: (Fn(HWND, u32, WPARAM, LPARAM) -> LRESULT) + Send let mut msg = unsafe { std::mem::zeroed() }; loop { - let status = unsafe { GetMessageW(&mut msg, 0, 0, 0) }; + let status = unsafe { GetMessageW(&mut msg, ptr::null_mut(), 0, 0) }; if status < 0 { continue; @@ -86,7 +86,7 @@ pub fn create_hidden_window<F: (Fn(HWND, u32, WPARAM, LPARAM) -> LRESULT) + Send break; } - if msg.hwnd == 0 { + if msg.hwnd.is_null() { if msg.message == REQUEST_THREAD_SHUTDOWN { unsafe { DestroyWindow(dummy_window) }; } diff --git a/talpid-openvpn/Cargo.toml b/talpid-openvpn/Cargo.toml index 031d6f996f..656c7de32e 100644 --- a/talpid-openvpn/Cargo.toml +++ b/talpid-openvpn/Cargo.toml @@ -31,7 +31,7 @@ prost = { workspace = true } [target.'cfg(windows)'.dependencies] widestring = "1.0" -winreg = { version = "0.51", features = ["transactions"] } +winreg = { workspace = true, features = ["transactions"] } talpid-windows = { path = "../talpid-windows" } once_cell = { workspace = true } # Only needed because parity-tokio-ipc has forgotten to enable the winerror feature of winapi .. diff --git a/talpid-openvpn/src/wintun.rs b/talpid-openvpn/src/wintun.rs index 230675c54b..5e86fb3516 100644 --- a/talpid-openvpn/src/wintun.rs +++ b/talpid-openvpn/src/wintun.rs @@ -180,9 +180,14 @@ impl WintunDll { fn new(resource_dir: &Path) -> io::Result<Self> { let wintun_dll = U16CString::from_os_str_truncate(resource_dir.join("wintun.dll")); - let handle = - unsafe { LoadLibraryExW(wintun_dll.as_ptr(), 0, LOAD_WITH_ALTERED_SEARCH_PATH) }; - if handle == 0 { + let handle = unsafe { + LoadLibraryExW( + wintun_dll.as_ptr(), + ptr::null_mut(), + LOAD_WITH_ALTERED_SEARCH_PATH, + ) + }; + if handle.is_null() { return Err(io::Error::last_os_error()); } Self::new_inner(handle, Self::get_proc_address) @@ -373,7 +378,7 @@ mod tests { #[test] fn test_wintun_imports() { - WintunDll::new_inner(0, get_proc_fn).unwrap(); + WintunDll::new_inner(ptr::null_mut(), get_proc_fn).unwrap(); } #[test] diff --git a/talpid-platform-metadata/src/windows.rs b/talpid-platform-metadata/src/windows.rs index 6cb5599ceb..5841a54e45 100644 --- a/talpid-platform-metadata/src/windows.rs +++ b/talpid-platform-metadata/src/windows.rs @@ -66,7 +66,7 @@ impl WindowsVersion { // SAFETY: module_name is a valid UTF-16/WTF-16 null-terminated string. let ntdll = unsafe { GetModuleHandleW(module_name.as_ptr()) }; - if ntdll == 0 { + if ntdll.is_null() { return Err(io::Error::last_os_error()); } diff --git a/talpid-routing/src/windows/default_route_monitor.rs b/talpid-routing/src/windows/default_route_monitor.rs index 6064a789b3..e67f2d4a50 100644 --- a/talpid-routing/src/windows/default_route_monitor.rs +++ b/talpid-routing/src/windows/default_route_monitor.rs @@ -6,26 +6,23 @@ use crate::debounce::BurstGuard; use std::{ ffi::c_void, + os::windows::io::RawHandle, + ptr, sync::{Arc, Mutex}, time::Duration, }; use talpid_types::win32_err; -use windows_sys::Win32::{ - Foundation::{BOOLEAN, HANDLE}, - NetworkManagement::{ - IpHelper::{ - CancelMibChangeNotify2, ConvertInterfaceLuidToIndex, MIB_IPFORWARD_ROW2, - MIB_IPINTERFACE_ROW, MIB_NOTIFICATION_TYPE, MIB_UNICASTIPADDRESS_ROW, - NotifyIpInterfaceChange, NotifyRouteChange2, NotifyUnicastIpAddressChange, - }, - Ndis::NET_LUID_LH, +use windows_sys::Win32::NetworkManagement::{ + IpHelper::{ + CancelMibChangeNotify2, ConvertInterfaceLuidToIndex, MIB_IPFORWARD_ROW2, + MIB_IPINTERFACE_ROW, MIB_NOTIFICATION_TYPE, MIB_UNICASTIPADDRESS_ROW, + NotifyIpInterfaceChange, NotifyRouteChange2, NotifyUnicastIpAddressChange, }, + Ndis::NET_LUID_LH, }; use talpid_windows::net::AddressFamily; -const WIN_FALSE: BOOLEAN = 0; - struct DefaultRouteMonitorContext { callback: Box<dyn for<'a> Fn(EventType<'a>) + Send + 'static>, refresh_current_route: bool, @@ -128,7 +125,7 @@ impl Drop for DefaultRouteMonitor { } } -struct NotifyChangeHandle(HANDLE); +struct NotifyChangeHandle(RawHandle); /// SAFETY: NotifyChangeHandle is `Send` since it holds sole ownership of a pointer provided by C unsafe impl Send for NotifyChangeHandle {} @@ -240,7 +237,7 @@ impl DefaultRouteMonitor { // we cancel the callbacks. This will leak the weak pointer but the context state itself // will be correctly dropped when DefaultRouteManager is dropped. let context_ptr = context_and_burst; - let mut handle_ptr = 0; + let mut handle_ptr = ptr::null_mut(); // SAFETY: No clear safety specifications, context_ptr must be valid for as long as handle // has not been dropped. win32_err!(unsafe { @@ -248,14 +245,14 @@ impl DefaultRouteMonitor { family, Some(route_change_callback), context_ptr as *const _, - WIN_FALSE, + false, &mut handle_ptr, ) }) .map_err(Error::RegisterNotifyRouteCallback)?; let notify_route_change_handle = NotifyChangeHandle(handle_ptr); - let mut handle_ptr = 0; + let mut handle_ptr = ptr::null_mut(); // SAFETY: No clear safety specifications, context_ptr must be valid for as long as handle // has not been dropped. win32_err!(unsafe { @@ -263,14 +260,14 @@ impl DefaultRouteMonitor { family, Some(interface_change_callback), context_ptr as *const _, - WIN_FALSE, + false, &mut handle_ptr, ) }) .map_err(Error::RegisterNotifyIpInterfaceCallback)?; let notify_interface_change_handle = NotifyChangeHandle(handle_ptr); - let mut handle_ptr = 0; + let mut handle_ptr = ptr::null_mut(); // SAFETY: No clear safety specifications, context_ptr must be valid for as long as handle // has not been dropped. win32_err!(unsafe { @@ -278,7 +275,7 @@ impl DefaultRouteMonitor { family, Some(ip_address_change_callback), context_ptr as *const _, - WIN_FALSE, + false, &mut handle_ptr, ) }) diff --git a/talpid-routing/src/windows/get_best_default_route.rs b/talpid-routing/src/windows/get_best_default_route.rs index 9ce5987dd7..0fc0af02a1 100644 --- a/talpid-routing/src/windows/get_best_default_route.rs +++ b/talpid-routing/src/windows/get_best_default_route.rs @@ -163,12 +163,8 @@ fn annotate_route(route: &MIB_IPFORWARD_ROW2) -> Option<AnnotatedRoute<'_>> { ) .ok()?; - if iface.Connected == 0 { - None - } else { - Some(AnnotatedRoute { - route, - effective_metric: route.Metric + iface.Metric, - }) - } + iface.Connected.then(|| AnnotatedRoute { + route, + effective_metric: route.Metric + iface.Metric, + }) } diff --git a/talpid-tunnel-config-client/Cargo.toml b/talpid-tunnel-config-client/Cargo.toml index a7290c3613..04c7968050 100644 --- a/talpid-tunnel-config-client/Cargo.toml +++ b/talpid-tunnel-config-client/Cargo.toml @@ -32,10 +32,6 @@ zeroize = "1.5.7" [target.'cfg(unix)'.dependencies] nix = { workspace = true, features = ["socket"] } -[target.'cfg(windows)'.dependencies.windows-sys] -workspace = true -features = ["Win32_Networking_WinSock"] - [build-dependencies] tonic-build = { workspace = true, default-features = false, features = [ "transport", diff --git a/talpid-tunnel/src/windows.rs b/talpid-tunnel/src/windows.rs index 21315d45ff..28759bda10 100644 --- a/talpid-tunnel/src/windows.rs +++ b/talpid-tunnel/src/windows.rs @@ -30,12 +30,12 @@ pub fn initialize_interfaces( row.SitePrefixLength = 0; row.RouterDiscoveryBehavior = RouterDiscoveryDisabled; row.DadTransmits = 0; - row.ManagedAddressConfigurationSupported = 0; - row.OtherStatefulConfigurationSupported = 0; + row.ManagedAddressConfigurationSupported = false; + row.OtherStatefulConfigurationSupported = false; // Ensure lowest interface metric row.Metric = 1; - row.UseAutomaticMetric = 0; + row.UseAutomaticMetric = false; set_ip_interface_entry(&mut row)?; } diff --git a/talpid-windows/src/fs.rs b/talpid-windows/src/fs.rs index fd716dddb1..8523cee918 100644 --- a/talpid-windows/src/fs.rs +++ b/talpid-windows/src/fs.rs @@ -20,7 +20,7 @@ pub fn is_admin_owned<T: AsRawHandle>(handle: T) -> io::Result<bool> { // SAFETY: `handle` is a valid handle let result = unsafe { GetSecurityInfo( - handle.as_raw_handle() as isize, + handle.as_raw_handle(), SE_FILE_OBJECT, OWNER_SECURITY_INFORMATION, (&mut owner) as *mut *mut SID as *mut *mut c_void, diff --git a/talpid-windows/src/io.rs b/talpid-windows/src/io.rs index 8a8966f620..cfb6a23ec3 100644 --- a/talpid-windows/src/io.rs +++ b/talpid-windows/src/io.rs @@ -1,4 +1,4 @@ -use std::{io, mem}; +use std::{io, mem, os::windows::io::AsRawHandle, ptr}; use windows_sys::Win32::System::IO::OVERLAPPED; use crate::sync::Event; @@ -41,11 +41,11 @@ impl Overlapped { fn set_event(&mut self, event: Option<Event>) { match event { Some(event) => { - self.overlapped.hEvent = event.as_raw(); + self.overlapped.hEvent = event.as_raw_handle(); self.event = Some(event); } None => { - self.overlapped.hEvent = 0; + self.overlapped.hEvent = ptr::null_mut(); self.event = None; } } diff --git a/talpid-windows/src/net.rs b/talpid-windows/src/net.rs index 557af9e961..5aa92fb089 100644 --- a/talpid-windows/src/net.rs +++ b/talpid-windows/src/net.rs @@ -4,9 +4,10 @@ use socket2::SockAddr; use std::{ ffi::{OsStr, OsString}, fmt, io, - mem::{self, MaybeUninit}, + mem::MaybeUninit, net::{IpAddr, Ipv4Addr, Ipv6Addr, SocketAddr, SocketAddrV4, SocketAddrV6}, os::windows::ffi::{OsStrExt, OsStringExt}, + ptr, sync::Mutex, time::{Duration, Instant}, }; @@ -30,7 +31,6 @@ use windows_sys::{ AF_INET, AF_INET6, AF_UNSPEC, IN_ADDR, IN6_ADDR, IpDadStateDeprecated, IpDadStateDuplicate, IpDadStateInvalid, IpDadStatePreferred, IpDadStateTentative, NL_DAD_STATE, SOCKADDR_IN as sockaddr_in, SOCKADDR_IN6 as sockaddr_in6, SOCKADDR_INET, - SOCKADDR_STORAGE as sockaddr_storage, }, }, core::GUID, @@ -174,7 +174,7 @@ pub fn notify_ip_interface_change<'a, T: FnMut(&MIB_IPINTERFACE_ROW, i32) + Send ) -> io::Result<Box<IpNotifierHandle<'a>>> { let mut context = Box::new(IpNotifierHandle { callback: Mutex::new(Box::new(callback)), - handle: 0, + handle: ptr::null_mut(), }); win32_err!(unsafe { @@ -182,7 +182,7 @@ pub fn notify_ip_interface_change<'a, T: FnMut(&MIB_IPINTERFACE_ROW, i32) + Send af_family_from_family(family), Some(inner_callback), &mut *context as *mut _ as *mut _, - 0, + false, (&mut context.handle) as *mut _, ) })?; @@ -194,9 +194,11 @@ pub fn get_ip_interface_entry( family: AddressFamily, luid: &NET_LUID_LH, ) -> io::Result<MIB_IPINTERFACE_ROW> { - let mut row: MIB_IPINTERFACE_ROW = unsafe { mem::zeroed() }; - row.Family = family as u16; - row.InterfaceLuid = *luid; + let mut row = MIB_IPINTERFACE_ROW { + Family: family as u16, + InterfaceLuid: *luid, + ..Default::default() + }; win32_err!(unsafe { GetIpInterfaceEntry(&mut row) })?; Ok(row) @@ -324,7 +326,7 @@ pub fn get_ip_address_for_interface( /// Adds a unicast IP address for the given interface. pub fn add_ip_address_for_interface(luid: NET_LUID_LH, address: IpAddr) -> Result<()> { - let mut row = unsafe { mem::zeroed() }; + let mut row = MIB_UNICASTIPADDRESS_ROW::default(); unsafe { InitializeUnicastIpAddressEntry(&mut row) }; row.InterfaceLuid = luid; @@ -385,7 +387,7 @@ pub fn luid_from_alias<T: AsRef<OsStr>>(alias: T) -> io::Result<NET_LUID_LH> { .encode_wide() .chain(std::iter::once(0u16)) .collect(); - let mut luid: NET_LUID_LH = unsafe { std::mem::zeroed() }; + let mut luid = NET_LUID_LH::default(); win32_err!(unsafe { ConvertInterfaceAliasToLuid(alias_wide.as_ptr(), &mut luid) })?; Ok(luid) } @@ -427,7 +429,7 @@ pub fn ipaddr_from_in6addr(addr: IN6_ADDR) -> Ipv6Addr { /// Converts a `SocketAddr` to `SOCKADDR_INET` pub fn inet_sockaddr_from_socketaddr(addr: SocketAddr) -> SOCKADDR_INET { // SAFETY: SOCKADDR_INET is a union of C structs, these can be safely zeroed. - let mut sockaddr: SOCKADDR_INET = unsafe { mem::zeroed() }; + let mut sockaddr = SOCKADDR_INET::default(); match addr { // SAFETY: `*const sockaddr` may be treated as `*const sockaddr_in` since we know it's a v4 // address. @@ -447,13 +449,30 @@ pub fn inet_sockaddr_from_socketaddr(addr: SocketAddr) -> SOCKADDR_INET { pub fn try_socketaddr_from_inet_sockaddr(addr: SOCKADDR_INET) -> Result<SocketAddr> { // SAFETY: si_family is always valid let family = unsafe { addr.si_family }; - unsafe { - let mut storage: sockaddr_storage = mem::zeroed(); - *(&mut storage as *mut _ as *mut SOCKADDR_INET) = addr; - SockAddr::new(storage, mem::size_of_val(&addr) as i32) + match family { + AF_INET => { + // SAFETY: We know this is an IPv4 address based on the family + let ipv4_addr = unsafe { addr.Ipv4 }; + // SAFETY: The IPv4 address is initialized + let ip = Ipv4Addr::from(u32::from_be(unsafe { ipv4_addr.sin_addr.S_un.S_addr })); + let port = u16::from_be(ipv4_addr.sin_port); + Ok(SocketAddr::V4(SocketAddrV4::new(ip, port))) + } + AF_INET6 => { + // SAFETY: We know this is an IPv6 address based on the family + let ipv6_addr = unsafe { addr.Ipv6 }; + // SAFETY: The IPv6 address is initialized + let ip = Ipv6Addr::from(unsafe { ipv6_addr.sin6_addr.u.Byte }); + let port = u16::from_be(ipv6_addr.sin6_port); + let flowinfo = ipv6_addr.sin6_flowinfo; + // SAFETY: The scope ID is initialized + let scope_id = unsafe { ipv6_addr.Anonymous.sin6_scope_id }; + Ok(SocketAddr::V6(SocketAddrV6::new( + ip, port, flowinfo, scope_id, + ))) + } + _ => Err(Error::UnknownAddressFamily(family)), } - .as_socket() - .ok_or(Error::UnknownAddressFamily(family)) } /// Address family. These correspond to the `AF_*` constants. diff --git a/talpid-windows/src/process.rs b/talpid-windows/src/process.rs index b7c5d8b972..5dbd52f3c7 100644 --- a/talpid-windows/src/process.rs +++ b/talpid-windows/src/process.rs @@ -1,11 +1,12 @@ #![allow(clippy::undocumented_unsafe_blocks)] // Remove me if you dare use std::{ - ffi::{CStr, c_char}, + ffi::CStr, io, mem, + os::windows::io::{AsRawHandle, FromRawHandle, OwnedHandle, RawHandle}, }; use windows_sys::Win32::{ - Foundation::{CloseHandle, ERROR_NO_MORE_FILES, HANDLE, INVALID_HANDLE_VALUE}, + Foundation::{ERROR_NO_MORE_FILES, INVALID_HANDLE_VALUE}, System::Diagnostics::ToolHelp::{ CreateToolhelp32Snapshot, MODULEENTRY32, Module32First, Module32Next, PROCESSENTRY32W, Process32FirstW, Process32NextW, @@ -14,30 +15,31 @@ use windows_sys::Win32::{ /// A snapshot of process modules, threads, and heaps pub struct ProcessSnapshot { - handle: HANDLE, + handle: OwnedHandle, } impl ProcessSnapshot { /// Create a new process snapshot using `CreateToolhelp32Snapshot` pub fn new(flags: u32, process_id: u32) -> io::Result<ProcessSnapshot> { + // SAFETY: `CreateToolhelp32Snapshot` should handle invalid flags and process IDs let snap = unsafe { CreateToolhelp32Snapshot(flags, process_id) }; if snap == INVALID_HANDLE_VALUE { Err(io::Error::last_os_error()) } else { - Ok(ProcessSnapshot { handle: snap }) + Ok(ProcessSnapshot { + // SAFETY: `snap` is a valid handle since `CreateToolhelp32Snapshot` succeeded + handle: unsafe { OwnedHandle::from_raw_handle(snap) }, + }) } } - /// Return the raw handle - pub fn as_raw(&self) -> HANDLE { - self.handle - } - /// Return an iterator over the modules in the snapshot pub fn modules(&self) -> ProcessSnapshotModules<'_> { - let mut entry: MODULEENTRY32 = unsafe { mem::zeroed() }; - entry.dwSize = mem::size_of::<MODULEENTRY32>() as u32; + let entry = MODULEENTRY32 { + dwSize: mem::size_of::<MODULEENTRY32>() as u32, + ..Default::default() + }; ProcessSnapshotModules { snapshot: self, @@ -48,8 +50,10 @@ impl ProcessSnapshot { /// Return an iterator over the processes in the snapshot pub fn processes(&self) -> ProcessSnapshotEntries<'_> { - let mut entry: PROCESSENTRY32W = unsafe { mem::zeroed() }; - entry.dwSize = mem::size_of::<PROCESSENTRY32W>() as u32; + let entry = PROCESSENTRY32W { + dwSize: mem::size_of::<PROCESSENTRY32W>() as u32, + ..Default::default() + }; ProcessSnapshotEntries { snapshot: self, @@ -59,11 +63,9 @@ impl ProcessSnapshot { } } -impl Drop for ProcessSnapshot { - fn drop(&mut self) { - unsafe { - CloseHandle(self.handle); - } +impl AsRawHandle for ProcessSnapshot { + fn as_raw_handle(&self) -> RawHandle { + self.handle.as_raw_handle() } } @@ -89,7 +91,8 @@ impl Iterator for ProcessSnapshotModules<'_> { fn next(&mut self) -> Option<io::Result<ModuleEntry>> { if self.iter_started { - if unsafe { Module32Next(self.snapshot.as_raw(), &mut self.temp_entry) } == 0 { + // SAFETY: `self.snapshot` is a valid pointer, and `temp_entry` is a valid `MODULEENTRY32` + if unsafe { Module32Next(self.snapshot.as_raw_handle(), &mut self.temp_entry) } == 0 { let last_error = io::Error::last_os_error(); return if last_error.raw_os_error().unwrap() as u32 == ERROR_NO_MORE_FILES { @@ -99,14 +102,16 @@ impl Iterator for ProcessSnapshotModules<'_> { }; } } else { - if unsafe { Module32First(self.snapshot.as_raw(), &mut self.temp_entry) } == 0 { + // SAFETY: `self.snapshot` is a valid pointer, and `temp_entry` is a valid `MODULEENTRY32` + if unsafe { Module32First(self.snapshot.as_raw_handle(), &mut self.temp_entry) } == 0 { return Some(Err(io::Error::last_os_error())); } self.iter_started = true; } let cstr_ref = &self.temp_entry.szModule[0]; - let cstr = unsafe { CStr::from_ptr(cstr_ref as *const u8 as *const c_char) }; + // SAFETY: `szModule` is a null-terminated C string + let cstr = unsafe { CStr::from_ptr(cstr_ref) }; Some(Ok(ModuleEntry { name: cstr.to_string_lossy().into_owned(), base_address: self.temp_entry.modBaseAddr, @@ -135,7 +140,8 @@ impl Iterator for ProcessSnapshotEntries<'_> { fn next(&mut self) -> Option<io::Result<ProcessEntry>> { if self.iter_started { - if unsafe { Process32NextW(self.snapshot.as_raw(), &mut self.temp_entry) } == 0 { + // SAFETY: `self.snapshot` is a valid pointer, and `temp_entry` is a valid `PROCESSENTRY32W` + if unsafe { Process32NextW(self.snapshot.as_raw_handle(), &mut self.temp_entry) } == 0 { let last_error = io::Error::last_os_error(); return if last_error.raw_os_error().unwrap() as u32 == ERROR_NO_MORE_FILES { @@ -145,7 +151,9 @@ impl Iterator for ProcessSnapshotEntries<'_> { }; } } else { - if unsafe { Process32FirstW(self.snapshot.as_raw(), &mut self.temp_entry) } == 0 { + // SAFETY: `self.snapshot` is a valid pointer, and `temp_entry` is a valid `PROCESSENTRY32W` + if unsafe { Process32FirstW(self.snapshot.as_raw_handle(), &mut self.temp_entry) } == 0 + { return Some(Err(io::Error::last_os_error())); } self.iter_started = true; diff --git a/talpid-windows/src/sync.rs b/talpid-windows/src/sync.rs index ddd97facdb..567f556a73 100644 --- a/talpid-windows/src/sync.rs +++ b/talpid-windows/src/sync.rs @@ -1,13 +1,17 @@ #![allow(clippy::undocumented_unsafe_blocks)] // Remove me if you dare. -use std::{io, ptr}; +use std::{ + io, + os::windows::io::{AsRawHandle, FromRawHandle, OwnedHandle, RawHandle}, + ptr, +}; use windows_sys::Win32::{ - Foundation::{BOOL, CloseHandle, DUPLICATE_SAME_ACCESS, DuplicateHandle, HANDLE}, + Foundation::{DUPLICATE_SAME_ACCESS, DuplicateHandle}, System::Threading::{CreateEventW, GetCurrentProcess, SetEvent}, }; /// Windows event object -pub struct Event(HANDLE); +pub struct Event(OwnedHandle); unsafe impl Send for Event {} unsafe impl Sync for Event {} @@ -18,37 +22,34 @@ impl Event { let event = unsafe { CreateEventW( ptr::null_mut(), - bool_to_winbool(manual_reset), - bool_to_winbool(initial_state), + i32::from(manual_reset), + i32::from(initial_state), ptr::null(), ) }; - if event == 0 { + if event.is_null() { return Err(io::Error::last_os_error()); } - Ok(Self(event)) + // SAFETY: `event` is a valid handle since `CreateEventW` succeeded + Ok(Self(unsafe { OwnedHandle::from_raw_handle(event) })) } /// Signal the event object pub fn set(&self) -> io::Result<()> { - if unsafe { SetEvent(self.0) } == 0 { + // SAFETY: `self.0` is a valid handle + if unsafe { SetEvent(self.0.as_raw_handle()) } == 0 { return Err(io::Error::last_os_error()); } Ok(()) } - /// Return raw event object - pub fn as_raw(&self) -> HANDLE { - self.0 - } - /// Duplicate the event object with `DuplicateHandle()` pub fn duplicate(&self) -> io::Result<Event> { - let mut new_event = 0; + let mut new_event = ptr::null_mut(); let status = unsafe { DuplicateHandle( GetCurrentProcess(), - self.0, + self.0.as_raw_handle(), GetCurrentProcess(), &mut new_event, 0, @@ -59,19 +60,13 @@ impl Event { if status == 0 { return Err(io::Error::last_os_error()); } - Ok(Event(new_event)) - } -} - -impl Drop for Event { - fn drop(&mut self) { - unsafe { CloseHandle(self.0) }; + // SAFETY: `new_event` is a valid handle since `DuplicateHandle` succeeded + Ok(Event(unsafe { OwnedHandle::from_raw_handle(new_event) })) } } -const fn bool_to_winbool(val: bool) -> BOOL { - match val { - true => 1, - false => 0, +impl AsRawHandle for Event { + fn as_raw_handle(&self) -> RawHandle { + self.0.as_raw_handle() } } diff --git a/talpid-wireguard/Cargo.toml b/talpid-wireguard/Cargo.toml index 306c58e845..162bc7156e 100644 --- a/talpid-wireguard/Cargo.toml +++ b/talpid-wireguard/Cargo.toml @@ -66,10 +66,6 @@ talpid-dbus = { path = "../talpid-dbus" } bitflags = "1.2" talpid-windows = { path = "../talpid-windows" } widestring = "1.0" -maybenot = "2.0.0" -# TODO: rand 0.8 is a hard requirement of maybenot-ffi 2.0. May be upgraded to rand 0.9 -# when maybenot 2.2 is released. -rand08 = { package = "rand", version = "0.8.5" } rand_chacha = "0.3.1" # TODO: Figure out which features are needed and which are not diff --git a/talpid-wireguard/src/wireguard_nt/daita.rs b/talpid-wireguard/src/wireguard_nt/daita.rs deleted file mode 100644 index a2c5d7e6c1..0000000000 --- a/talpid-wireguard/src/wireguard_nt/daita.rs +++ /dev/null @@ -1,480 +0,0 @@ -use super::WIREGUARD_KEY_LENGTH; -use maybenot::{MachineId, Timer}; -use once_cell::sync::OnceCell; -use rand08::{ - SeedableRng, - rngs::{OsRng, adapter::ReseedingRng}, -}; -use std::{ - collections::HashMap, fs, io, os::windows::prelude::RawHandle, path::Path, sync::Arc, - time::Duration, -}; -use talpid_types::net::wireguard::PublicKey; -use tokio::task::JoinHandle; -use windows_sys::Win32::{ - Foundation::{BOOLEAN, ERROR_NO_MORE_ITEMS}, - System::Threading::{INFINITE, WaitForMultipleObjects, WaitForSingleObject}, -}; - -type Rng = ReseedingRng<rand_chacha::ChaCha12Core, OsRng>; -const RNG_RESEED_THRESHOLD: u64 = 1024 * 64; // 64 KiB - -#[derive(Debug, thiserror::Error)] -pub enum Error { - /// Failed to find maybenot machines - #[error("Failed to enumerate maybenot machines")] - EnumerateMachines(#[source] io::Error), - /// Failed to parse maybenot machine - #[error("Failed to parse maybenot machine \"{0}\"")] - InvalidMachine(String), - /// Failed to initialize quit event - #[error("Failed to initialize quit event")] - InitializeQuitEvent(#[source] io::Error), - /// Failed to initialize machinist handle - #[error("Failed to initialize machinist handle")] - InitializeHandle(#[source] io::Error), - /// Failed to initialize maybenot framework - #[error("Failed to initialize maybenot framework: {0}")] - InitializeMaybenot(String), -} - -// See DAITA_EVENT_TYPE: -// https://github.com/mullvad/wireguard-nt-priv/blob/mullvad-patches/driver/daita.h -#[repr(C)] -#[derive(Debug)] -#[allow(dead_code)] -pub enum EventType { - NonpaddingSent, - NonpaddingReceived, - PaddingSent, - PaddingReceived, -} - -// See DAITA_EVENT: -// https://github.com/mullvad/wireguard-nt-priv/blob/mullvad-patches/driver/daita.h -#[repr(C)] -#[derive(Debug)] -pub struct Event { - pub peer: [u8; WIREGUARD_KEY_LENGTH], - pub event_type: EventType, - pub xmit_bytes: u16, - pub user_context: usize, -} - -// See DAITA_ACTION_TYPE: -// https://github.com/mullvad/wireguard-nt-priv/blob/mullvad-patches/driver/daita.h -#[repr(C)] -pub enum ActionType { - InjectPadding, -} - -// See DAITA_PADDING_ACTION: -// https://github.com/mullvad/wireguard-nt-priv/blob/mullvad-patches/driver/daita.h -#[repr(C)] -#[derive(Debug, Clone, Copy)] -pub struct PaddingAction { - pub byte_count: u16, - pub replace: BOOLEAN, -} - -// See DAITA_ACTION: -// https://github.com/mullvad/wireguard-nt-priv/blob/mullvad-patches/driver/daita.h -#[repr(C)] -pub struct Action { - pub peer: [u8; WIREGUARD_KEY_LENGTH], - pub action_type: ActionType, - pub payload: ActionPayload, - pub user_context: usize, -} - -#[repr(C)] -pub union ActionPayload { - pub padding: PaddingAction, -} - -/// Maximum number of events that can be stored in the underlying buffer -const EVENTS_CAPACITY: usize = 1000; -/// Maximum number of actions that can be stored in the underlying buffer -const ACTIONS_CAPACITY: usize = 1000; - -pub mod bindings { - use super::*; - use windows_sys::Win32::Foundation::BOOL; - - pub type WireGuardDaitaActivateFn = unsafe extern "stdcall" fn( - adapter: RawHandle, - events_capacity: usize, - actions_capacity: usize, - ) -> BOOL; - pub type WireGuardDaitaEventDataAvailableEventFn = - unsafe extern "stdcall" fn(adapter: RawHandle) -> RawHandle; - pub type WireGuardDaitaReceiveEventsFn = - unsafe extern "stdcall" fn(adapter: RawHandle, events: *mut Event) -> usize; - pub type WireGuardDaitaSendActionFn = - unsafe extern "stdcall" fn(adapter: RawHandle, action: *const Action) -> BOOL; -} - -#[derive(Debug)] -pub struct Session { - adapter: Arc<super::WgNtAdapter>, -} - -impl Session { - /// Call `WireGuardDaitaActivate` for an existing WireGuard interface - pub(super) fn from_adapter(adapter: Arc<super::WgNtAdapter>) -> io::Result<Session> { - // SAFETY: `WgNtAdapter` has a valid adapter handle - unsafe { - adapter - .dll_handle - .daita_activate(adapter.handle, EVENTS_CAPACITY, ACTIONS_CAPACITY) - }?; - Ok(Self { adapter }) - } - - pub fn receive_events<'a>( - &self, - buffer: &'a mut [Event; EVENTS_CAPACITY], - ) -> io::Result<&'a [Event]> { - let num_events = unsafe { - // SAFETY: The adapter is valid, and the buffer is large enough to accommodate all - // events. - self.adapter - .dll_handle - .daita_receive_events(self.adapter.handle, buffer.as_mut_ptr())? - }; - Ok(unsafe { std::slice::from_raw_parts(buffer.as_ptr(), num_events) }) - } - - pub fn send_action(&self, action: &Action) -> io::Result<()> { - // SAFETY: The adapter is valid - unsafe { - self.adapter - .dll_handle - .daita_send_action(self.adapter.handle, action) - } - } - - pub fn event_data_available_event(&self) -> RawHandle { - // SAFETY: The adapter is valid - // This never fails when there's a DAITA session - unsafe { - self.adapter - .dll_handle - .daita_event_data_available_event(self.adapter.handle) - .unwrap() - } - } -} - -fn maybenot_event_from_event( - event: &Event, - machine_ids: &MachineMap, -) -> Option<maybenot::TriggerEvent> { - match event.event_type { - EventType::PaddingReceived => Some(maybenot::TriggerEvent::PaddingRecv), - EventType::NonpaddingSent => Some(maybenot::TriggerEvent::NormalSent), - EventType::NonpaddingReceived => Some(maybenot::TriggerEvent::NormalRecv), - EventType::PaddingSent => Some(maybenot::TriggerEvent::PaddingSent { - machine: machine_ids.get_machine_id(event.user_context)?.to_owned(), - }), - } -} - -/// Handle for a set of DAITA machines. -/// Note: `close` is NOT called implicitly when this is dropped. -pub struct MachinistHandle { - quit_event: talpid_windows::sync::Event, -} - -impl MachinistHandle { - fn new(quit_event: &talpid_windows::sync::Event) -> io::Result<MachinistHandle> { - Ok(MachinistHandle { - quit_event: quit_event.duplicate()?, - }) - } - - /// Signal quit event - pub fn close(&self) -> io::Result<()> { - self.quit_event.set() - } -} - -pub struct Machinist { - daita: Arc<Session>, - machine_ids: MachineMap, - machine_tasks: HashMap<usize, JoinHandle<()>>, - tokio_handle: tokio::runtime::Handle, - quit_event: talpid_windows::sync::Event, - peer: PublicKey, - mtu: u16, -} - -// TODO: This is silly. Let me use the raw ID of MachineId, please. -struct MachineMap { - id_to_num: HashMap<MachineId, usize>, - num_to_id: HashMap<usize, MachineId>, -} - -impl MachineMap { - fn new() -> Self { - Self { - id_to_num: HashMap::new(), - num_to_id: HashMap::new(), - } - } - - fn get_or_create_raw_id(&mut self, machine_id: MachineId) -> usize { - *self.id_to_num.entry(machine_id).or_insert_with(|| { - let raw_id = self.num_to_id.len(); - self.num_to_id.insert(raw_id, machine_id); - raw_id - }) - } - - fn get_machine_id(&self, raw_id: usize) -> Option<&MachineId> { - self.num_to_id.get(&raw_id) - } -} - -impl Machinist { - /// Spawn an actor that handles scheduling of Maybenot actions and forwards DAITA events to the - /// framework. - pub fn spawn( - resource_dir: &Path, - daita: Session, - peer: PublicKey, - mtu: u16, - ) -> std::result::Result<MachinistHandle, Error> { - const MAX_PADDING_BYTES: f64 = 0.0; - const MAX_BLOCKING_BYTES: f64 = 0.0; - - static MAYBENOT_MACHINES: OnceCell<Vec<maybenot::Machine>> = OnceCell::new(); - - let machines = - MAYBENOT_MACHINES.get_or_try_init(|| load_maybenot_machines(resource_dir))?; - - let quit_event = - talpid_windows::sync::Event::new(true, false).map_err(Error::InitializeQuitEvent)?; - let handle = MachinistHandle::new(&quit_event).map_err(Error::InitializeHandle)?; - - let framework = maybenot::Framework::new( - machines.clone(), - MAX_PADDING_BYTES, - MAX_BLOCKING_BYTES, - std::time::Instant::now(), - Rng::new( - rand_chacha::ChaCha12Core::from_entropy(), - RNG_RESEED_THRESHOLD, - OsRng, - ), - ) - .map_err(|error| Error::InitializeMaybenot(error.to_string()))?; - - let daita = Arc::new(daita); - let tokio_handle = tokio::runtime::Handle::current(); - - std::thread::spawn(move || { - Self { - daita, - machine_ids: MachineMap::new(), - machine_tasks: HashMap::new(), - tokio_handle, - quit_event, - peer, - mtu, - } - .event_loop(framework); - }); - - Ok(handle) - } - - fn event_loop(mut self, mut framework: maybenot::Framework<Vec<maybenot::Machine>, Rng>) { - use windows_sys::Win32::Foundation::WAIT_OBJECT_0; - - loop { - if unsafe { WaitForSingleObject(self.quit_event.as_raw(), 0) } == WAIT_OBJECT_0 { - break; - } - - let events = match self.wait_for_events() { - Ok(events) => { - if events.is_empty() { - break; - } - events - } - Err(error) => { - log::error!("Error while waiting for DAITA events: {error}"); - break; - } - }; - - for action in framework.trigger_events(&events, std::time::Instant::now()) { - self.handle_action(action); - } - } - - log::debug!("Stopped DAITA event loop"); - } - - fn handle_action(&mut self, action: &maybenot::action::TriggerAction) { - match *action { - maybenot::action::TriggerAction::Cancel { machine, timer } => { - debug_assert_ne!(timer, Timer::Internal, "machine timers not implemented"); - - let raw_id = self.machine_ids.get_or_create_raw_id(machine); - - // Drop all scheduled actions for a given machine - if let Some(task) = self.machine_tasks.get_mut(&raw_id) { - task.abort(); - } - } - maybenot::action::TriggerAction::SendPadding { - timeout, - machine, - replace, - .. - } => { - let peer = self.peer.clone(); - - let raw_id = self.machine_ids.get_or_create_raw_id(machine); - self.machine_tasks.entry(raw_id).and_modify(|f| f.abort()); - - let action = Action { - peer: *peer.as_bytes(), - action_type: ActionType::InjectPadding, - user_context: raw_id, - payload: ActionPayload { - padding: PaddingAction { - byte_count: self.mtu, - replace: if replace { 1 } else { 0 }, - }, - }, - }; - - if timeout == Duration::ZERO { - if let Err(error) = self.daita.send_action(&action) { - log::error!("Failed to send DAITA action: {error}"); - } - } else { - // Schedule action on the tokio runtime - let daita = Arc::downgrade(&self.daita); - let task = self.tokio_handle.spawn(async move { - tokio::time::sleep(timeout).await; - - let Some(daita) = daita.upgrade() else { return }; - - if let Err(error) = daita.send_action(&action) { - log::error!("Failed to send DAITA action: {error}"); - } - }); - self.machine_tasks.insert(raw_id, task); - } - } - maybenot::action::TriggerAction::BlockOutgoing { .. } => { - if cfg!(debug_assertions) { - unimplemented!("received BlockOutgoing action"); - } - } - maybenot::action::TriggerAction::UpdateTimer { .. } => { - if cfg!(debug_assertions) { - unimplemented!("received UpdateTimer action"); - } - } - } - } - - /// Take all events from the ring buffer while there are any left. - /// If there are no events available, wait for events to arrive. - /// Otherwise, break and return a non-zero number of events to be processed. - /// If the quit event was signaled, this returns an empty vector. - fn wait_for_events(&mut self) -> io::Result<Vec<maybenot::TriggerEvent>> { - use windows_sys::Win32::Foundation::WAIT_OBJECT_0; - - let wait_events = [ - self.quit_event.as_raw(), - self.daita.event_data_available_event() as isize, - ]; - - let mut event_buffer: [Event; EVENTS_CAPACITY] = unsafe { std::mem::zeroed() }; - - loop { - match self.daita.receive_events(&mut event_buffer) { - Ok(events) => { - let converted_events: Vec<_> = events - .iter() - .filter(|event| &event.peer == self.peer.as_bytes()) - .filter_map(|event| maybenot_event_from_event(event, &self.machine_ids)) - .collect(); - if !converted_events.is_empty() { - return Ok(converted_events); - } - // Try again if we only received events for irrelevant peers - } - Err(error) => { - if error.raw_os_error() == Some(ERROR_NO_MORE_ITEMS as i32) { - let wait_result = unsafe { - WaitForMultipleObjects( - u32::try_from(wait_events.len()).unwrap(), - wait_events.as_ptr(), - 0, - INFINITE, - ) - }; - - if wait_result == WAIT_OBJECT_0 { - // Quit event signaled - break Ok(vec![]); - } - if wait_result == WAIT_OBJECT_0 + 1 { - // Event object signaled -- try to receive more events - continue; - } - } - break Err(std::io::Error::last_os_error()); - } - } - } - } -} - -fn load_maybenot_machines(resource_dir: &Path) -> Result<Vec<maybenot::Machine>, Error> { - let path = resource_dir.join("maybenot_machines"); - log::debug!("Reading maybenot machines from {}", path.display()); - - let mut machines = vec![]; - let machines_str = fs::read_to_string(path).map_err(Error::EnumerateMachines)?; - for machine_str in machines_str.lines() { - let machine_str = machine_str.trim(); - if matches!(machine_str.chars().next(), None | Some('#')) { - continue; - } - log::debug!("Adding maybenot machine: {machine_str}"); - machines.push( - machine_str - .parse::<maybenot::Machine>() - .map_err(|_error| Error::InvalidMachine(machine_str.to_owned()))?, - ); - } - Ok(machines) -} - -#[cfg(test)] -mod test { - use super::load_maybenot_machines; - use std::path::PathBuf; - - /// Test whether `maybenot_machines` in dist-assets contains valid machines. - /// TODO: Remove when switching to dynamic machines. - #[test] - fn test_load_maybenot_machines() { - let dist_assets = std::env::var("CARGO_MANIFEST_DIR") - .map(PathBuf::from) - .expect("CARGO_MANIFEST_DIR env var not set") - .join("..") - .join("dist-assets"); - - load_maybenot_machines(&dist_assets).unwrap(); - } -} diff --git a/talpid-wireguard/src/wireguard_nt/mod.rs b/talpid-wireguard/src/wireguard_nt/mod.rs index 0f650d3866..a1472bb2ff 100644 --- a/talpid-wireguard/src/wireguard_nt/mod.rs +++ b/talpid-wireguard/src/wireguard_nt/mod.rs @@ -1,5 +1,7 @@ #![allow(clippy::undocumented_unsafe_blocks)] // Remove me if you dare. +use crate::TunnelError; + use super::{ Tunnel, config::Config, @@ -11,7 +13,7 @@ use futures::SinkExt; use ipnetwork::IpNetwork; use once_cell::sync::OnceCell; use std::{ - ffi::CStr, + ffi::{CStr, c_uchar}, fmt, future::Future, io, @@ -24,14 +26,13 @@ use std::{ sync::{Arc, LazyLock, Mutex}, time::{Duration, SystemTime, UNIX_EPOCH}, }; -#[cfg(daita)] -use std::{ffi::c_uchar, path::PathBuf}; +use talpid_tunnel_config_client::DaitaSettings; use talpid_types::{BoxedError, ErrorExt}; use talpid_windows::net; use widestring::{U16CStr, U16CString}; use windows_sys::{ Win32::{ - Foundation::{BOOL, ERROR_MORE_DATA, FreeLibrary, HMODULE}, + Foundation::{ERROR_MORE_DATA, FreeLibrary, HMODULE}, NetworkManagement::Ndis::NET_LUID_LH, Networking::WinSock::{ ADDRESS_FAMILY, AF_INET, AF_INET6, IN_ADDR, IN6_ADDR, SOCKADDR_INET, @@ -41,9 +42,6 @@ use windows_sys::{ core::GUID, }; -#[cfg(daita)] -mod daita; - static WG_NT_DLL: OnceCell<WgNtDll> = OnceCell::new(); static ADAPTER_TYPE: LazyLock<U16CString> = LazyLock::new(|| U16CString::from_str("Mullvad").unwrap()); @@ -69,14 +67,14 @@ type WireGuardSetConfigurationFn = unsafe extern "stdcall" fn( adapter: RawHandle, config: *const MaybeUninit<u8>, bytes: u32, -) -> BOOL; +) -> bool; type WireGuardGetConfigurationFn = unsafe extern "stdcall" fn( adapter: RawHandle, config: *const MaybeUninit<u8>, bytes: *mut u32, -) -> BOOL; +) -> bool; type WireGuardSetStateFn = - unsafe extern "stdcall" fn(adapter: RawHandle, state: WgAdapterState) -> BOOL; + unsafe extern "stdcall" fn(adapter: RawHandle, state: WgAdapterState) -> bool; #[repr(C)] #[allow(dead_code)] @@ -108,7 +106,7 @@ enum WireGuardAdapterLogState { } type WireGuardSetAdapterLoggingFn = - unsafe extern "stdcall" fn(adapter: RawHandle, state: WireGuardAdapterLogState) -> BOOL; + unsafe extern "stdcall" fn(adapter: RawHandle, state: WireGuardAdapterLogState) -> bool; pub type Result<T> = std::result::Result<T, Error>; @@ -165,27 +163,13 @@ pub enum Error { /// Failed to parse data returned by the driver #[error("Failed to parse data returned by wireguard-nt")] InvalidConfigData, - - /// DAITA machinist failed - #[cfg(daita)] - #[error("Failed to enable DAITA on tunnel device")] - EnableTunnelDaita(#[source] io::Error), - - /// DAITA machinist failed - #[cfg(daita)] - #[error("Failed to initialize DAITA machinist")] - InitializeMachinist(#[source] daita::Error), } pub struct WgNtTunnel { - #[cfg(daita)] - resource_dir: PathBuf, config: Arc<Mutex<Config>>, device: Option<Arc<WgNtAdapter>>, interface_name: String, setup_handle: tokio::task::JoinHandle<()>, - #[cfg(daita)] - daita_handle: Option<daita::MachinistHandle>, _logger_handle: LoggerHandle, } @@ -326,8 +310,6 @@ bitflags! { const REPLACE_ALLOWED_IPS = 0b00100000; const REMOVE = 0b01000000; const UPDATE = 0b10000000; - #[cfg(daita)] - const HAS_CONSTANT_PACKET_SIZE = 0b100000000; } } @@ -345,7 +327,6 @@ struct WgPeer { rx_bytes: u64, last_handshake: u64, allowed_ips_count: u32, - #[cfg(daita)] constant_packet_size: c_uchar, } @@ -484,54 +465,18 @@ impl WgNtTunnel { }); Ok(WgNtTunnel { - #[cfg(daita)] - resource_dir: resource_dir.to_owned(), config: Arc::new(Mutex::new(config.clone())), device, interface_name, setup_handle, - #[cfg(daita)] - daita_handle: None, _logger_handle: logger_handle, }) } fn stop_tunnel(&mut self) { self.setup_handle.abort(); - #[cfg(daita)] - if let Some(daita_handle) = self.daita_handle.take() { - let _ = daita_handle.close(); - } let _ = self.device.take(); } - - #[cfg(daita)] - fn spawn_machinist(&mut self) -> Result<()> { - if let Some(handle) = self.daita_handle.take() { - log::info!("Stopping previous DAITA machines"); - let _ = handle.close(); - } - - let Some(device) = self.device.clone() else { - log::debug!("Tunnel is stopped; not starting machines"); - return Ok(()); - }; - - let config = self.config.lock().unwrap(); - - log::info!("Initializing DAITA for wireguard device"); - let session = daita::Session::from_adapter(device).map_err(Error::EnableTunnelDaita)?; - self.daita_handle = Some( - daita::Machinist::spawn( - &self.resource_dir, - session, - config.entry_peer.public_key.clone(), - config.mtu, - ) - .map_err(Error::InitializeMachinist)?, - ); - Ok(()) - } } async fn setup_ip_listener(device: Arc<WgNtAdapter>, mtu: u32, has_ipv6: bool) -> Result<()> { @@ -689,14 +634,6 @@ struct WgNtDll { func_set_adapter_state: WireGuardSetStateFn, func_set_logger: WireGuardSetLoggerFn, func_set_adapter_logging: WireGuardSetAdapterLoggingFn, - #[cfg(daita)] - func_daita_activate: daita::bindings::WireGuardDaitaActivateFn, - #[cfg(daita)] - func_daita_event_data_available_event: daita::bindings::WireGuardDaitaEventDataAvailableEventFn, - #[cfg(daita)] - func_daita_receive_events: daita::bindings::WireGuardDaitaReceiveEventsFn, - #[cfg(daita)] - func_daita_send_action: daita::bindings::WireGuardDaitaSendActionFn, } unsafe impl Send for WgNtDll {} @@ -707,9 +644,14 @@ impl WgNtDll { let wg_nt_dll = U16CString::from_os_str_truncate(resource_dir.join("mullvad-wireguard.dll")); - let handle = - unsafe { LoadLibraryExW(wg_nt_dll.as_ptr(), 0, LOAD_WITH_ALTERED_SEARCH_PATH) }; - if handle == 0 { + let handle = unsafe { + LoadLibraryExW( + wg_nt_dll.as_ptr(), + ptr::null_mut(), + LOAD_WITH_ALTERED_SEARCH_PATH, + ) + }; + if handle.is_null() { return Err(io::Error::last_os_error()); } Self::new_inner(handle, Self::get_proc_address) @@ -745,23 +687,6 @@ impl WgNtDll { func_set_adapter_logging: unsafe { *((&get_proc_fn(handle, c"WireGuardSetAdapterLogging")?) as *const _ as *const _) }, - #[cfg(daita)] - func_daita_activate: unsafe { - *((&get_proc_fn(handle, c"WireGuardDaitaActivate")?) as *const _ as *const _) - }, - #[cfg(daita)] - func_daita_event_data_available_event: unsafe { - *((&get_proc_fn(handle, c"WireGuardDaitaEventDataAvailableEvent")?) as *const _ - as *const _) - }, - #[cfg(daita)] - func_daita_receive_events: unsafe { - *((&get_proc_fn(handle, c"WireGuardDaitaReceiveEvents")?) as *const _ as *const _) - }, - #[cfg(daita)] - func_daita_send_action: unsafe { - *((&get_proc_fn(handle, c"WireGuardDaitaSendAction")?) as *const _ as *const _) - }, }) } @@ -808,10 +733,10 @@ impl WgNtDll { config: *const MaybeUninit<u8>, config_size: usize, ) -> io::Result<()> { - let result = unsafe { + let succeeded = unsafe { (self.func_set_configuration)(adapter, config, u32::try_from(config_size).unwrap()) }; - if result == 0 { + if !succeeded { return Err(io::Error::last_os_error()); } Ok(()) @@ -821,10 +746,10 @@ impl WgNtDll { let mut config_size = 0; let mut config = vec![]; loop { - let result = unsafe { + let succeeded = unsafe { (self.func_get_configuration)(adapter, config.as_mut_ptr(), &mut config_size) }; - if result == 0 { + if !succeeded { let last_error = io::Error::last_os_error(); if last_error.raw_os_error() != Some(ERROR_MORE_DATA as i32) { break Err(last_error); @@ -841,8 +766,8 @@ impl WgNtDll { adapter: RawHandle, state: WgAdapterState, ) -> io::Result<()> { - let result = unsafe { (self.func_set_adapter_state)(adapter, state) }; - if result == 0 { + let succeeded = unsafe { (self.func_set_adapter_state)(adapter, state) }; + if !succeeded { return Err(io::Error::last_os_error()); } Ok(()) @@ -857,57 +782,7 @@ impl WgNtDll { adapter: RawHandle, state: WireGuardAdapterLogState, ) -> io::Result<()> { - if unsafe { (self.func_set_adapter_logging)(adapter, state) } == 0 { - return Err(io::Error::last_os_error()); - } - Ok(()) - } - - #[cfg(daita)] - pub unsafe fn daita_activate( - &self, - adapter: RawHandle, - events_capacity: usize, - actions_capacity: usize, - ) -> io::Result<()> { - if unsafe { (self.func_daita_activate)(adapter, events_capacity, actions_capacity) } == 0 { - return Err(io::Error::last_os_error()); - } - Ok(()) - } - - #[cfg(daita)] - pub unsafe fn daita_event_data_available_event( - &self, - adapter: RawHandle, - ) -> io::Result<RawHandle> { - let ready_event = unsafe { (self.func_daita_event_data_available_event)(adapter) }; - if ready_event.is_null() { - return Err(io::Error::last_os_error()); - } - Ok(ready_event) - } - - #[cfg(daita)] - pub unsafe fn daita_receive_events( - &self, - adapter: RawHandle, - events: *mut daita::Event, - ) -> io::Result<usize> { - let num_events = unsafe { (self.func_daita_receive_events)(adapter, events) }; - if num_events == 0 { - return Err(io::Error::last_os_error()); - } - Ok(num_events) - } - - #[cfg(daita)] - pub unsafe fn daita_send_action( - &self, - adapter: RawHandle, - action: *const daita::Action, - ) -> io::Result<()> { - if unsafe { (self.func_daita_send_action)(adapter, action) } == 0 { + if !unsafe { (self.func_set_adapter_logging)(adapter, state) } { return Err(io::Error::last_os_error()); } Ok(()) @@ -938,17 +813,10 @@ fn serialize_config(config: &Config) -> Result<Vec<MaybeUninit<u8>>> { buffer.extend(as_uninit_byte_slice(&header)); for peer in config.peers() { - #[cfg(not(daita))] let mut flags = WgPeerFlag::HAS_PUBLIC_KEY | WgPeerFlag::HAS_ENDPOINT; - #[cfg(daita)] - let mut flags = WgPeerFlag::HAS_PUBLIC_KEY - | WgPeerFlag::HAS_ENDPOINT - | WgPeerFlag::HAS_CONSTANT_PACKET_SIZE; if peer.psk.is_some() { flags |= WgPeerFlag::HAS_PRESHARED_KEY; } - #[cfg(daita)] - let constant_packet_size = if peer.constant_packet_size { 1 } else { 0 }; let wg_peer = WgPeer { flags, reserved: 0, @@ -964,8 +832,7 @@ fn serialize_config(config: &Config) -> Result<Vec<MaybeUninit<u8>>> { rx_bytes: 0, last_handshake: 0, allowed_ips_count: u32::try_from(peer.allowed_ips.len()).unwrap(), - #[cfg(daita)] - constant_packet_size, + constant_packet_size: 0, }; buffer.extend(as_uninit_byte_slice(&wg_peer)); @@ -1122,18 +989,8 @@ impl Tunnel for WgNtTunnel { }) } - #[cfg(daita)] - fn start_daita( - &mut self, - _: talpid_tunnel_config_client::DaitaSettings, - ) -> std::result::Result<(), crate::TunnelError> { - self.spawn_machinist().map_err(|error| { - log::error!( - "{}", - error.display_chain_with_msg("Failed to start DAITA for wg-nt tunnel") - ); - super::TunnelError::SetConfigError - }) + fn start_daita(&mut self, _settings: DaitaSettings) -> std::result::Result<(), TunnelError> { + unimplemented!("DAITA is not supported on wireguard-nt") } } @@ -1188,9 +1045,8 @@ mod tests { ipv6_gateway: None, mtu: 0, obfuscator_config: None, - #[cfg(daita)] - daita: false, quantum_resistant: false, + daita: false, }); static WG_STRUCT_CONFIG: LazyLock<Interface> = LazyLock::new(|| Interface { @@ -1202,9 +1058,7 @@ mod tests { peers_count: 1, }, p0: WgPeer { - flags: WgPeerFlag::HAS_PUBLIC_KEY - | WgPeerFlag::HAS_ENDPOINT - | WgPeerFlag::HAS_CONSTANT_PACKET_SIZE, + flags: WgPeerFlag::HAS_PUBLIC_KEY | WgPeerFlag::HAS_ENDPOINT, reserved: 0, public_key: *WG_PUBLIC_KEY.as_bytes(), preshared_key: [0; WIREGUARD_KEY_LENGTH], @@ -1235,7 +1089,7 @@ mod tests { #[test] fn test_dll_imports() { - WgNtDll::new_inner(0, get_proc_fn).unwrap(); + WgNtDll::new_inner(ptr::null_mut(), get_proc_fn).unwrap(); } #[test] diff --git a/test/Cargo.lock b/test/Cargo.lock index ccdedcc77f..8d0bdfb29b 100644 --- a/test/Cargo.lock +++ b/test/Cargo.lock @@ -682,7 +682,7 @@ dependencies = [ "libc", "option-ext", "redox_users", - "windows-sys 0.59.0", + "windows-sys 0.61.1", ] [[package]] @@ -1641,7 +1641,7 @@ dependencies = [ "socket2 0.5.8", "widestring", "windows-sys 0.48.0", - "winreg", + "winreg 0.50.0", ] [[package]] @@ -2119,7 +2119,7 @@ dependencies = [ "once_cell", "thiserror 2.0.3", "widestring", - "windows-sys 0.52.0", + "windows-sys 0.61.1", ] [[package]] @@ -3558,7 +3558,7 @@ version = "0.0.0" dependencies = [ "rs-release", "talpid-windows", - "windows-sys 0.52.0", + "windows-sys 0.61.1", ] [[package]] @@ -3591,7 +3591,7 @@ dependencies = [ "socket2 0.5.8", "talpid-types", "thiserror 2.0.3", - "windows-sys 0.52.0", + "windows-sys 0.61.1", ] [[package]] @@ -3752,7 +3752,7 @@ dependencies = [ "tokio-util", "windows-service", "windows-sys 0.45.0", - "winreg", + "winreg 0.55.0", ] [[package]] @@ -4508,6 +4508,12 @@ dependencies = [ ] [[package]] +name = "windows-link" +version = "0.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "45e46c0661abb7180e7b9c281db115305d49ca1709ab8242adf09666d2173c65" + +[[package]] name = "windows-service" version = "0.6.0" source = "registry+https://github.com/rust-lang/crates.io-index" @@ -4568,6 +4574,15 @@ dependencies = [ ] [[package]] +name = "windows-sys" +version = "0.61.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6f109e41dd4a3c848907eb83d5a42ea98b3769495597450cf6d153507b166f0f" +dependencies = [ + "windows-link", +] + +[[package]] name = "windows-targets" version = "0.42.2" source = "registry+https://github.com/rust-lang/crates.io-index" @@ -4795,6 +4810,16 @@ dependencies = [ ] [[package]] +name = "winreg" +version = "0.55.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "cb5a765337c50e9ec252c2069be9bf91c7df47afb103b642ba3a53bf8101be97" +dependencies = [ + "cfg-if", + "windows-sys 0.59.0", +] + +[[package]] name = "wit-bindgen-rt" version = "0.39.0" source = "registry+https://github.com/rust-lang/crates.io-index" diff --git a/test/test-runner/Cargo.toml b/test/test-runner/Cargo.toml index fa277332d9..3e1c8df913 100644 --- a/test/test-runner/Cargo.toml +++ b/test/test-runner/Cargo.toml @@ -40,7 +40,7 @@ socket2 = { workspace = true, features = ["all"] } talpid-windows = { path = "../../talpid-windows" } windows-service = "0.6" -winreg = "0.50" +winreg = "0.55" [target.'cfg(windows)'.dependencies.windows-sys] version = "0.45.0" diff --git a/wireguard-go-rs/Cargo.toml b/wireguard-go-rs/Cargo.toml index 415365e772..40d70623a5 100644 --- a/wireguard-go-rs/Cargo.toml +++ b/wireguard-go-rs/Cargo.toml @@ -25,7 +25,7 @@ talpid-types.path = "../talpid-types" maybenot-ffi = "2.0.1" [target.'cfg(target_os = "windows")'.dependencies] -windows-sys = { version = "0.52.0", features = [ +windows-sys = { workspace = true, features = [ "Win32_Networking", "Win32_NetworkManagement", "Win32_NetworkManagement_Ndis", |
