diff options
| author | David Lönnhager <david.l@mullvad.net> | 2022-08-25 14:49:55 +0200 |
|---|---|---|
| committer | David Lönnhager <david.l@mullvad.net> | 2022-08-25 14:49:55 +0200 |
| commit | 32e61accd7e01ffe08cdcf2fc743aa36346d3b43 (patch) | |
| tree | eb5611a64a250b4ac18ae5f85a959e4b755ddaf2 | |
| parent | b1f2398372ea2d929b320620120be3388a90bcef (diff) | |
| parent | dc9e97f758e0dddabaad482d1388b24fb9f431fe (diff) | |
| download | mullvadvpn-32e61accd7e01ffe08cdcf2fc743aa36346d3b43.tar.xz mullvadvpn-32e61accd7e01ffe08cdcf2fc743aa36346d3b43.zip | |
Merge branch 'winapi-to-windows-sys'
| -rw-r--r-- | talpid-core/Cargo.toml | 17 | ||||
| -rw-r--r-- | talpid-core/src/dns/windows/mod.rs | 2 | ||||
| -rw-r--r-- | talpid-core/src/logging/windows.rs | 16 | ||||
| -rw-r--r-- | talpid-core/src/routing/windows.rs | 4 | ||||
| -rw-r--r-- | talpid-core/src/split_tunnel/windows/driver.rs | 50 | ||||
| -rw-r--r-- | talpid-core/src/split_tunnel/windows/mod.rs | 23 | ||||
| -rw-r--r-- | talpid-core/src/split_tunnel/windows/path_monitor.rs | 67 | ||||
| -rw-r--r-- | talpid-core/src/split_tunnel/windows/volume_monitor.rs | 21 | ||||
| -rw-r--r-- | talpid-core/src/split_tunnel/windows/windows.rs | 112 | ||||
| -rw-r--r-- | talpid-core/src/tunnel/openvpn/mod.rs | 20 | ||||
| -rw-r--r-- | talpid-core/src/tunnel/openvpn/wintun.rs | 94 | ||||
| -rw-r--r-- | talpid-core/src/tunnel/windows.rs | 13 | ||||
| -rw-r--r-- | talpid-core/src/tunnel/wireguard/wireguard_go.rs | 10 | ||||
| -rw-r--r-- | talpid-core/src/tunnel/wireguard/wireguard_nt.rs | 114 | ||||
| -rw-r--r-- | talpid-core/src/windows/mod.rs | 143 | ||||
| -rw-r--r-- | talpid-core/src/windows/window.rs | 51 |
16 files changed, 365 insertions, 392 deletions
diff --git a/talpid-core/Cargo.toml b/talpid-core/Cargo.toml index b11407ae5c..6216a23bb2 100644 --- a/talpid-core/Cargo.toml +++ b/talpid-core/Cargo.toml @@ -81,7 +81,7 @@ subslice = "0.2" [target.'cfg(windows)'.dependencies] widestring = "0.5" winreg = { version = "0.7", features = ["transactions"] } -winapi = { version = "0.3.6", features = ["combaseapi", "handleapi", "ifdef", "ioapiset", "knownfolders", "libloaderapi", "netioapi", "psapi", "shlobj", "stringapiset", "synchapi", "tlhelp32", "winbase", "winioctl", "winuser", "dbt"] } +winapi = { version = "0.3.6", features = ["ws2def"] } talpid-platform-metadata = { path = "../talpid-platform-metadata" } memoffset = "0.6" @@ -89,7 +89,22 @@ memoffset = "0.6" version = "0.36.1" features = [ "Win32_Foundation", + "Win32_Globalization", + "Win32_System_Com", + "Win32_System_Diagnostics_ToolHelp", + "Win32_System_Ioctl", + "Win32_System_IO", "Win32_System_LibraryLoader", + "Win32_System_ProcessStatus", + "Win32_System_Registry", + "Win32_System_SystemServices", + "Win32_System_Threading", + "Win32_System_WindowsProgramming", + "Win32_Networking_WinSock", + "Win32_NetworkManagement_IpHelper", + "Win32_NetworkManagement_Ndis", + "Win32_UI_Shell", + "Win32_UI_WindowsAndMessaging", ] [build-dependencies] diff --git a/talpid-core/src/dns/windows/mod.rs b/talpid-core/src/dns/windows/mod.rs index 85d964aca1..d173290be3 100644 --- a/talpid-core/src/dns/windows/mod.rs +++ b/talpid-core/src/dns/windows/mod.rs @@ -1,7 +1,7 @@ use crate::windows::{guid_from_luid, luid_from_alias, string_from_guid}; use std::{io, net::IpAddr}; use talpid_types::ErrorExt; -use winapi::shared::guiddef::GUID; +use windows_sys::core::GUID; use winreg::{ enums::{HKEY_LOCAL_MACHINE, KEY_SET_VALUE}, transaction::Transaction, diff --git a/talpid-core/src/logging/windows.rs b/talpid-core/src/logging/windows.rs index 9e382d1bf5..2558660a5b 100644 --- a/talpid-core/src/logging/windows.rs +++ b/talpid-core/src/logging/windows.rs @@ -1,6 +1,6 @@ use libc::{c_char, c_void}; use std::{ffi::CStr, io, ptr}; -use winapi::um::{stringapiset::MultiByteToWideChar, winnls::CP_ACP}; +use windows_sys::Win32::Globalization::{MultiByteToWideChar, CP_ACP}; /// Logging callback type. pub type LogSink = extern "system" fn(level: log::Level, msg: *const c_char, context: *mut c_void); @@ -40,8 +40,16 @@ fn multibyte_to_wide(mb_string: &CStr, codepage: u32) -> Result<Vec<u16>, io::Er return Ok(vec![]); } - let wc_size = - unsafe { MultiByteToWideChar(codepage, 0, mb_string.as_ptr(), -1, ptr::null_mut(), 0) }; + let wc_size = unsafe { + MultiByteToWideChar( + codepage, + 0, + mb_string.as_ptr() as *const u8, + -1, + ptr::null_mut(), + 0, + ) + }; if wc_size == 0 { return Err(io::Error::last_os_error()); @@ -53,7 +61,7 @@ fn multibyte_to_wide(mb_string: &CStr, codepage: u32) -> Result<Vec<u16>, io::Er MultiByteToWideChar( codepage, 0, - mb_string.as_ptr(), + mb_string.as_ptr() as *const u8, -1, wc_buffer.as_mut_ptr(), wc_size, diff --git a/talpid-core/src/routing/windows.rs b/talpid-core/src/routing/windows.rs index 7d9bcba738..b391184f1a 100644 --- a/talpid-core/src/routing/windows.rs +++ b/talpid-core/src/routing/windows.rs @@ -8,6 +8,7 @@ use futures::{ StreamExt, }; use std::{collections::HashSet, net::IpAddr}; +use windows_sys::Win32::NetworkManagement::IpHelper::NET_LUID_LH; use winnet::WinNetAddrFamily; /// Windows routing errors. @@ -187,14 +188,13 @@ impl RouteManager { fn get_mtu_for_route(addr_family: WinNetAddrFamily) -> Result<Option<u16>> { use crate::windows::AddressFamily; - use winapi::shared::ifdef::NET_LUID; match winnet::get_best_default_route(addr_family) { Ok(Some(route)) => { let addr_family = match addr_family { WinNetAddrFamily::IPV4 => AddressFamily::Ipv4, WinNetAddrFamily::IPV6 => AddressFamily::Ipv6, }; - let luid = NET_LUID { + let luid = NET_LUID_LH { Value: route.interface_luid, }; let interface_row = crate::windows::get_ip_interface_entry(addr_family, &luid) diff --git a/talpid-core/src/split_tunnel/windows/driver.rs b/talpid-core/src/split_tunnel/windows/driver.rs index 2b612ff085..f73e170cce 100644 --- a/talpid-core/src/split_tunnel/windows/driver.rs +++ b/talpid-core/src/split_tunnel/windows/driver.rs @@ -23,26 +23,22 @@ use std::{ time::Duration, }; use talpid_types::ErrorExt; -use winapi::{ - shared::{ - in6addr::IN6_ADDR, - inaddr::IN_ADDR, - minwindef::{FALSE, TRUE}, - ntdef::NTSTATUS, - winerror::{ - ERROR_ACCESS_DENIED, ERROR_FILE_NOT_FOUND, ERROR_INVALID_PARAMETER, ERROR_IO_PENDING, - }, +use windows_sys::Win32::{ + Foundation::{ + ERROR_ACCESS_DENIED, ERROR_FILE_NOT_FOUND, ERROR_INVALID_PARAMETER, ERROR_IO_PENDING, + HANDLE, NTSTATUS, WAIT_FAILED, }, - um::{ - ioapiset::{DeviceIoControl, GetOverlappedResult}, - minwinbase::OVERLAPPED, - synchapi::{WaitForMultipleObjects, WaitForSingleObject}, - tlhelp32::TH32CS_SNAPPROCESS, - winbase::{ - FILE_FLAG_OVERLAPPED, INFINITE, WAIT_ABANDONED, WAIT_ABANDONED_0, WAIT_FAILED, + Networking::WinSock::{IN6_ADDR, IN_ADDR}, + Storage::FileSystem::FILE_FLAG_OVERLAPPED, + System::{ + Diagnostics::ToolHelp::TH32CS_SNAPPROCESS, + Ioctl::{FILE_ANY_ACCESS, METHOD_BUFFERED, METHOD_NEITHER}, + Threading::{ + WaitForMultipleObjects, WaitForSingleObject, WAIT_ABANDONED, WAIT_ABANDONED_0, WAIT_OBJECT_0, }, - winioctl::{FILE_ANY_ACCESS, METHOD_BUFFERED, METHOD_NEITHER}, + WindowsProgramming::INFINITE, + IO::{DeviceIoControl, GetOverlappedResult, OVERLAPPED}, }, }; @@ -842,7 +838,7 @@ pub unsafe fn device_io_control_buffer_async( let input_len = input.map(|input| input.len()).unwrap_or(0); let result = DeviceIoControl( - device.as_raw_handle(), + device.as_raw_handle() as HANDLE, ioctl_code, input_ptr, u32::try_from(input_len).map_err(|_error| { @@ -883,16 +879,16 @@ pub fn get_overlapped_result( let event = overlapped.get_event().unwrap(); // SAFETY: This is a valid event object. - unsafe { wait_for_single_object(event.as_raw_handle(), None) }?; + unsafe { wait_for_single_object(event.as_handle(), None) }?; // SAFETY: The handle and overlapped object are valid. let mut returned_bytes = 0u32; let result = unsafe { GetOverlappedResult( - device.as_raw_handle(), + device.as_raw_handle() as HANDLE, overlapped.as_mut_ptr(), &mut returned_bytes, - FALSE, + 0, ) }; if result == 0 { @@ -906,10 +902,7 @@ pub fn get_overlapped_result( /// # Safety /// /// * `object` must be a valid object that can be signaled, such as an event object. -pub unsafe fn wait_for_single_object( - object: RawHandle, - timeout: Option<Duration>, -) -> io::Result<()> { +pub unsafe fn wait_for_single_object(object: HANDLE, timeout: Option<Duration>) -> io::Result<()> { let timeout = match timeout { Some(timeout) => u32::try_from(timeout.as_millis()).map_err(|_error| { io::Error::new(io::ErrorKind::InvalidInput, "the duration is too long") @@ -931,16 +924,13 @@ pub unsafe fn wait_for_single_object( /// # Safety /// /// * `objects` must be a slice of valid objects that can be signaled, such as event objects. -pub unsafe fn wait_for_multiple_objects( - objects: &[RawHandle], - wait_all: bool, -) -> io::Result<RawHandle> { +pub unsafe fn wait_for_multiple_objects(objects: &[HANDLE], wait_all: bool) -> io::Result<HANDLE> { let objects_len = u32::try_from(objects.len()) .map_err(|_error| io::Error::new(io::ErrorKind::InvalidInput, "too many objects"))?; let result = WaitForMultipleObjects( objects_len, objects.as_ptr(), - if wait_all { TRUE } else { FALSE }, + if wait_all { 1 } else { 0 }, INFINITE, ); let signaled_index = if result >= WAIT_OBJECT_0 && result < WAIT_OBJECT_0 + objects_len { diff --git a/talpid-core/src/split_tunnel/windows/mod.rs b/talpid-core/src/split_tunnel/windows/mod.rs index 24679aa934..47c7d4ec8f 100644 --- a/talpid-core/src/split_tunnel/windows/mod.rs +++ b/talpid-core/src/split_tunnel/windows/mod.rs @@ -20,7 +20,6 @@ use std::{ ffi::{OsStr, OsString}, io, net::{IpAddr, Ipv4Addr, Ipv6Addr}, - os::windows::io::AsRawHandle, path::{Path, PathBuf}, sync::{ atomic::{AtomicBool, Ordering}, @@ -29,7 +28,9 @@ use std::{ time::Duration, }; use talpid_types::{tunnel::ErrorStateCause, ErrorExt}; -use winapi::shared::{ifdef::NET_LUID, winerror::ERROR_OPERATION_ABORTED}; +use windows_sys::Win32::{ + Foundation::ERROR_OPERATION_ABORTED, NetworkManagement::IpHelper::NET_LUID_LH, +}; const DRIVER_EVENT_BUFFER_SIZE: usize = 2048; const RESERVED_IP_V4: Ipv4Addr = Ipv4Addr::new(192, 0, 2, 123); @@ -254,10 +255,8 @@ impl SplitTunnel { overlapped: &mut windows::Overlapped, data_buffer: &mut Vec<u8>, ) -> io::Result<EventResult> { - if unsafe { - driver::wait_for_single_object(quit_event.as_raw_handle(), Some(Duration::ZERO)) - } - .is_ok() + if unsafe { driver::wait_for_single_object(quit_event.as_handle(), Some(Duration::ZERO)) } + .is_ok() { return Ok(EventResult::Quit); } @@ -283,8 +282,8 @@ impl SplitTunnel { })?; let event_objects = [ - overlapped.get_event().unwrap().as_raw_handle(), - quit_event.as_raw_handle(), + overlapped.get_event().unwrap().as_handle(), + quit_event.as_handle(), ]; let signaled_object = @@ -298,7 +297,7 @@ impl SplitTunnel { }, )?; - if signaled_object == quit_event.as_raw_handle() { + if signaled_object == quit_event.as_handle() { // Quit event was signaled return Ok(EventResult::Quit); } @@ -765,7 +764,7 @@ impl SplitTunnelDefaultRouteChangeHandlerContext { .map(|route| { get_ip_address_for_interface( AddressFamily::Ipv4, - NET_LUID { + NET_LUID_LH { Value: route.interface_luid, }, ) @@ -786,7 +785,7 @@ impl SplitTunnelDefaultRouteChangeHandlerContext { .map(|route| { get_ip_address_for_interface( AddressFamily::Ipv6, - NET_LUID { + NET_LUID_LH { Value: route.interface_luid, }, ) @@ -835,7 +834,7 @@ unsafe extern "system" fn split_tunnel_default_route_change_handler( DefaultRouteChanged | DefaultRouteUpdatedDetails => { match get_ip_address_for_interface( translated_family, - NET_LUID { + NET_LUID_LH { Value: default_route.interface_luid, }, ) { diff --git a/talpid-core/src/split_tunnel/windows/path_monitor.rs b/talpid-core/src/split_tunnel/windows/path_monitor.rs index 71e909a09b..22a3747d25 100644 --- a/talpid-core/src/split_tunnel/windows/path_monitor.rs +++ b/talpid-core/src/split_tunnel/windows/path_monitor.rs @@ -13,34 +13,25 @@ use std::{ sync::{mpsc as sync_mpsc, Arc}, time::{Duration, Instant}, }; -use winapi::{ - self, - shared::{ - minwindef::TRUE, - winerror::{ - ERROR_NOT_FOUND, ERROR_OPERATION_ABORTED, ERROR_PATH_NOT_FOUND, - ERROR_UNRECOGNIZED_VOLUME, - }, +use windows_sys::Win32::{ + Foundation::{ + CloseHandle, ERROR_NOT_FOUND, ERROR_OPERATION_ABORTED, ERROR_PATH_NOT_FOUND, + ERROR_UNRECOGNIZED_VOLUME, HANDLE, INVALID_HANDLE_VALUE, + }, + Globalization::CompareStringOrdinal, + Storage::FileSystem::{ + GetFileAttributesW, GetFullPathNameW, ReadDirectoryChangesW, FILE_ATTRIBUTE_REPARSE_POINT, + FILE_FLAG_BACKUP_SEMANTICS, FILE_FLAG_OPEN_REPARSE_POINT, FILE_FLAG_OVERLAPPED, + FILE_NOTIFY_CHANGE_ATTRIBUTES, FILE_NOTIFY_CHANGE_DIR_NAME, FILE_NOTIFY_CHANGE_FILE_NAME, + FILE_NOTIFY_INFORMATION, }, - um::{ - fileapi::{GetFileAttributesW, GetFullPathNameW}, - handleapi::{CloseHandle, INVALID_HANDLE_VALUE}, - ioapiset::{ + System::{ + Ioctl::FSCTL_GET_REPARSE_POINT, + SystemServices::{IO_REPARSE_TAG_MOUNT_POINT, IO_REPARSE_TAG_SYMLINK}, + WindowsProgramming::INFINITE, + IO::{ CancelIoEx, CreateIoCompletionPort, DeviceIoControl, GetQueuedCompletionStatus, - PostQueuedCompletionStatus, - }, - minwinbase::OVERLAPPED, - stringapiset::CompareStringOrdinal, - winbase::{ - ReadDirectoryChangesW, FILE_FLAG_BACKUP_SEMANTICS, FILE_FLAG_OPEN_REPARSE_POINT, - FILE_FLAG_OVERLAPPED, INFINITE, - }, - winioctl::FSCTL_GET_REPARSE_POINT, - winnt::{ - FILE_ATTRIBUTE_REPARSE_POINT, FILE_NOTIFY_CHANGE_ATTRIBUTES, - FILE_NOTIFY_CHANGE_DIR_NAME, FILE_NOTIFY_CHANGE_FILE_NAME, FILE_NOTIFY_INFORMATION, - HANDLE, IO_REPARSE_TAG_MOUNT_POINT, IO_REPARSE_TAG_SYMLINK, - MAXIMUM_REPARSE_DATA_BUFFER_SIZE, + PostQueuedCompletionStatus, OVERLAPPED, }, }, }; @@ -52,6 +43,7 @@ const PATH_MONITOR_COMPLETION_KEY_IGNORE: usize = usize::MAX; const CSTR_EQUAL: i32 = 2; const ANYSIZE_ARRAY: usize = 1; +const MAXIMUM_REPARSE_DATA_BUFFER_SIZE: u32 = 16384; const SYMLINK_FLAG_RELATIVE: u32 = 0x00000001; // See https://docs.microsoft.com/en-us/openspecs/windows_protocols/ms-fscc/c3a420cb-8a72-4adf-87e8-eee95379d78f. @@ -155,7 +147,7 @@ fn resolve_link<T: AsRef<Path> + Copy>(path: T) -> io::Result<Option<PathBuf>> { if unsafe { DeviceIoControl( - file.as_raw_handle() as *mut _, + file.as_raw_handle() as HANDLE, FSCTL_GET_REPARSE_POINT, ptr::null_mut(), 0u32, @@ -265,15 +257,15 @@ impl DirContext { let handle = unsafe { CreateIoCompletionPort( - dir_handle.as_raw_handle() as *mut _, - io_completion_port.as_raw_handle() as *mut _, + dir_handle.as_raw_handle() as HANDLE, + io_completion_port.as_raw_handle() as HANDLE, completion_key, // num of threads is ignored here 0, ) }; - if handle == ptr::null_mut() { + if handle == 0 { return Err(io::Error::last_os_error()); } @@ -290,10 +282,10 @@ impl DirContext { let mut _bytes_returned = 0; if unsafe { ReadDirectoryChangesW( - self.dir_handle.as_raw_handle() as *mut _, + self.dir_handle.as_raw_handle() as HANDLE, self.buffer.as_mut_ptr() as *mut _, self.buffer.len() as u32, - TRUE, + 1, FILE_NOTIFY_CHANGE_FILE_NAME | FILE_NOTIFY_CHANGE_DIR_NAME | FILE_NOTIFY_CHANGE_ATTRIBUTES, @@ -314,7 +306,7 @@ impl DirContext { /// Try to cancel a request. On success, return whether a request was cancelled. fn cancel_io(&mut self) -> io::Result<bool> { - if unsafe { CancelIoEx(self.dir_handle.as_raw_handle(), ptr::null_mut()) } == 0 { + if unsafe { CancelIoEx(self.dir_handle.as_raw_handle() as HANDLE, ptr::null_mut()) } == 0 { match io::Error::last_os_error() { _error if _error.raw_os_error() == Some(ERROR_NOT_FOUND as i32) => Ok(false), error => Err(error), @@ -349,10 +341,9 @@ struct CompletionPort { impl CompletionPort { // `concurrent_threads`: 0 ==> number of processors fn create(concurrent_threads: u32) -> io::Result<Self> { - let handle = unsafe { - CreateIoCompletionPort(INVALID_HANDLE_VALUE, ptr::null_mut(), 0, concurrent_threads) - }; - if handle == ptr::null_mut() { + let handle = + unsafe { CreateIoCompletionPort(INVALID_HANDLE_VALUE, 0, 0, concurrent_threads) }; + if handle == 0 { return Err(io::Error::last_os_error()); } Ok(CompletionPort { handle }) @@ -731,7 +722,7 @@ impl PathMonitor { file_name.len() as i32, file_name.as_ptr(), file_name.len() as i32, - TRUE, + 1, ) }; match cmp_status { diff --git a/talpid-core/src/split_tunnel/windows/volume_monitor.rs b/talpid-core/src/split_tunnel/windows/volume_monitor.rs index 1993cc809c..0758463399 100644 --- a/talpid-core/src/split_tunnel/windows/volume_monitor.rs +++ b/talpid-core/src/split_tunnel/windows/volume_monitor.rs @@ -10,16 +10,13 @@ use std::{ sync::{mpsc as sync_mpsc, Arc, Mutex, MutexGuard}, }; use talpid_types::ErrorExt; -use winapi::{ - shared::minwindef::TRUE, - um::{ - dbt::{ - DBTF_NET, DBT_DEVICEARRIVAL, DBT_DEVICEREMOVECOMPLETE, DBT_DEVTYP_VOLUME, - DEV_BROADCAST_HDR, DEV_BROADCAST_VOLUME, WM_DEVICECHANGE, - }, - fileapi::GetLogicalDrives, - winuser::DefWindowProcW, +use windows_sys::Win32::{ + Storage::FileSystem::GetLogicalDrives, + System::SystemServices::{ + DBTF_NET, DBT_DEVICEARRIVAL, DBT_DEVICEREMOVECOMPLETE, DBT_DEVTYP_VOLUME, + DEV_BROADCAST_HDR, DEV_BROADCAST_VOLUME, }, + UI::WindowsAndMessaging::{DefWindowProcW, WM_DEVICECHANGE}, }; pub(super) struct VolumeMonitor(()); @@ -125,7 +122,7 @@ fn start_internal_monitor( let volumes = unsafe { parse_device_volume_broadcast(&*(l_param as *const _)) }; let prev_state = *known_state_guard; - let is_arrival = w_param == DBT_DEVICEARRIVAL; + let is_arrival = w_param == DBT_DEVICEARRIVAL as usize; if is_arrival { *known_state_guard |= volumes; } else { @@ -144,7 +141,7 @@ fn start_internal_monitor( } // Always grant the request - TRUE as isize + 1 }) } @@ -184,7 +181,7 @@ fn matches_volume(volumes: u32, paths_guard: &MutexGuard<'_, Vec<OsString>>) -> fn is_device_arrival_or_removal(message: u32, w_param: usize) -> bool { message == WM_DEVICECHANGE - && (w_param == DBT_DEVICEARRIVAL || w_param == DBT_DEVICEREMOVECOMPLETE) + && (w_param == DBT_DEVICEARRIVAL as usize || w_param == DBT_DEVICEREMOVECOMPLETE as usize) } /// Return volumes affected by the device arrival or removal message as a mask. diff --git a/talpid-core/src/split_tunnel/windows/windows.rs b/talpid-core/src/split_tunnel/windows/windows.rs index 1cfcce15a5..aad3d2e63c 100644 --- a/talpid-core/src/split_tunnel/windows/windows.rs +++ b/talpid-core/src/split_tunnel/windows/windows.rs @@ -6,38 +6,36 @@ use std::{ fs, io, iter, mem, os::windows::{ ffi::{OsStrExt, OsStringExt}, - io::{AsRawHandle, RawHandle}, + prelude::AsRawHandle, }, path::{Component, Path}, ptr, }; -use winapi::{ - shared::{ - minwindef::{BOOL, DWORD, FALSE, FILETIME, TRUE}, - ntdef::ULARGE_INTEGER, - winerror::{ERROR_INSUFFICIENT_BUFFER, ERROR_NO_MORE_FILES}, +use windows_sys::Win32::{ + Foundation::{ + CloseHandle, BOOL, ERROR_INSUFFICIENT_BUFFER, ERROR_NO_MORE_FILES, FILETIME, HANDLE, + INVALID_HANDLE_VALUE, }, - um::{ - fileapi::{GetFinalPathNameByHandleW, QueryDosDeviceW}, - handleapi::{CloseHandle, INVALID_HANDLE_VALUE}, - minwinbase::OVERLAPPED, - processthreadsapi::{GetProcessTimes, OpenProcess}, - psapi::K32GetProcessImageFileNameW, - synchapi::{CreateEventW, SetEvent}, - tlhelp32::{CreateToolhelp32Snapshot, Process32FirstW, Process32NextW, PROCESSENTRY32W}, - winnt::{HANDLE, PROCESS_QUERY_LIMITED_INFORMATION}, + Storage::FileSystem::{GetFinalPathNameByHandleW, QueryDosDeviceW}, + System::{ + Diagnostics::ToolHelp::{ + CreateToolhelp32Snapshot, Process32FirstW, Process32NextW, PROCESSENTRY32W, + }, + ProcessStatus::K32GetProcessImageFileNameW, + Threading::{ + CreateEventW, GetProcessTimes, OpenProcess, SetEvent, PROCESS_QUERY_LIMITED_INFORMATION, + }, + WindowsProgramming::VOLUME_NAME_NT, + IO::OVERLAPPED, }, }; -/// Return path with the volume device path. -const VOLUME_NAME_NT: u32 = 0x02; - pub struct ProcessSnapshot { handle: HANDLE, } impl ProcessSnapshot { - pub fn new(flags: DWORD, process_id: DWORD) -> io::Result<ProcessSnapshot> { + pub fn new(flags: u32, process_id: u32) -> io::Result<ProcessSnapshot> { let snap = unsafe { CreateToolhelp32Snapshot(flags, process_id) }; if snap == INVALID_HANDLE_VALUE { @@ -87,7 +85,7 @@ impl Iterator for ProcessSnapshotEntries<'_> { fn next(&mut self) -> Option<io::Result<ProcessEntry>> { if self.iter_started { - if unsafe { Process32NextW(self.snapshot.handle(), &mut self.temp_entry) } == FALSE { + if unsafe { Process32NextW(self.snapshot.handle(), &mut self.temp_entry) } == 0 { let last_error = io::Error::last_os_error(); return if last_error.raw_os_error().unwrap() as u32 == ERROR_NO_MORE_FILES { @@ -97,7 +95,7 @@ impl Iterator for ProcessSnapshotEntries<'_> { }; } } else { - if unsafe { Process32FirstW(self.snapshot.handle(), &mut self.temp_entry) } == FALSE { + if unsafe { Process32FirstW(self.snapshot.handle(), &mut self.temp_entry) } == 0 { return Some(Err(io::Error::last_os_error())); } self.iter_started = true; @@ -115,7 +113,7 @@ pub fn get_device_path<T: AsRef<Path>>(path: T) -> Result<OsString, io::Error> { // Preferentially, use GetFinalPathNameByHandleW. If the file does not exist // or cannot be opened, infer the path from the label only. if let Ok(file) = fs::OpenOptions::new().read(true).open(path.as_ref()) { - return get_final_path_name_by_handle(file.as_raw_handle()); + return unsafe { get_final_path_name_by_handle(file.as_raw_handle() as HANDLE) }; } let mut components = path.as_ref().components(); @@ -142,21 +140,18 @@ pub fn get_device_path<T: AsRef<Path>>(path: T) -> Result<OsString, io::Error> { Ok(new_path) } -pub fn get_final_path_name_by_handle(raw_handle: RawHandle) -> Result<OsString, io::Error> { - let buffer_size = unsafe { - GetFinalPathNameByHandleW(raw_handle as *mut _, ptr::null_mut(), 0u32, VOLUME_NAME_NT) - } as usize; +pub unsafe fn get_final_path_name_by_handle(raw_handle: HANDLE) -> Result<OsString, io::Error> { + let buffer_size = + GetFinalPathNameByHandleW(raw_handle, ptr::null_mut(), 0u32, VOLUME_NAME_NT) as usize; let mut buffer = Vec::new(); buffer.resize(buffer_size, 0); - let status = unsafe { - GetFinalPathNameByHandleW( - raw_handle as *mut _, - buffer.as_mut_ptr(), - buffer_size as u32, - VOLUME_NAME_NT, - ) - } as usize; + let status = GetFinalPathNameByHandleW( + raw_handle, + buffer.as_mut_ptr(), + buffer_size as u32, + VOLUME_NAME_NT, + ) as usize; if status == 0 { return Err(io::Error::last_os_error()); @@ -206,10 +201,10 @@ fn query_dos_device<T: AsRef<OsStr>>(device_name: T) -> io::Result<OsString> { } /// Object that frees its handle when dropped. -pub struct WinHandle(RawHandle); +pub struct WinHandle(HANDLE); impl WinHandle { - pub fn get_raw(&self) -> RawHandle { + pub fn get_raw(&self) -> HANDLE { self.0 } } @@ -232,22 +227,16 @@ pub fn open_process( inherit_handle: bool, pid: u32, ) -> Result<WinHandle, io::Error> { - let handle = unsafe { - OpenProcess( - access as u32, - if inherit_handle { TRUE } else { FALSE }, - pid, - ) - }; + let handle = unsafe { OpenProcess(access as u32, if inherit_handle { 1 } else { 0 }, pid) }; - if handle == ptr::null_mut() { + if handle == 0 { return Err(io::Error::last_os_error()); } Ok(WinHandle(handle)) } /// Returns the age of a running process. -pub fn get_process_creation_time(handle: RawHandle) -> Result<u64, io::Error> { +pub fn get_process_creation_time(handle: HANDLE) -> Result<u64, io::Error> { // TODO: FileTimeToSystemTime -> chrono::NaiveDateTime let mut creation_time: FILETIME = unsafe { mem::zeroed() }; let mut dummy: FILETIME = unsafe { mem::zeroed() }; @@ -264,17 +253,13 @@ pub fn get_process_creation_time(handle: RawHandle) -> Result<u64, io::Error> { return Err(io::Error::last_os_error()); } - let mut uli_time: ULARGE_INTEGER = unsafe { mem::zeroed() }; - unsafe { - uli_time.s_mut().LowPart = creation_time.dwLowDateTime; - uli_time.s_mut().HighPart = creation_time.dwHighDateTime; - } - - Ok(*unsafe { uli_time.QuadPart() }) + let time = + ((creation_time.dwHighDateTime as u64) << u32::BITS) | (creation_time.dwLowDateTime as u64); + Ok(time) } /// Returns the device path for a running process. -pub fn get_process_device_path(handle: RawHandle) -> Result<OsString, io::Error> { +pub fn get_process_device_path(handle: HANDLE) -> Result<OsString, io::Error> { let mut initial_capacity = 512; loop { let result = get_process_device_path_inner(handle, initial_capacity); @@ -293,7 +278,7 @@ pub fn get_process_device_path(handle: RawHandle) -> Result<OsString, io::Error> } fn get_process_device_path_inner( - handle: RawHandle, + handle: HANDLE, buffer_capacity: usize, ) -> Result<OsString, io::Error> { let mut buffer = Vec::<u16>::new(); @@ -350,12 +335,11 @@ impl Overlapped { fn set_event(&mut self, event: Option<Event>) { match event { Some(event) => { - let raw_event = event.0; - self.overlapped.hEvent = raw_event; + self.overlapped.hEvent = event.0; self.event = Some(event); } None => { - self.overlapped.hEvent = ptr::null_mut(); + self.overlapped.hEvent = 0; self.event = None; } } @@ -363,7 +347,7 @@ impl Overlapped { } /// Abstraction over a Windows event object. -pub struct Event(RawHandle); +pub struct Event(HANDLE); unsafe impl Send for Event {} unsafe impl Sync for Event {} @@ -378,22 +362,20 @@ impl Event { ptr::null(), ) }; - if event == ptr::null_mut() { + if event == 0 { return Err(io::Error::last_os_error()); } Ok(Self(event)) } pub fn set(&self) -> io::Result<()> { - if unsafe { SetEvent(self.0) } == FALSE { + if unsafe { SetEvent(self.0) } == 0 { return Err(io::Error::last_os_error()); } Ok(()) } -} -impl AsRawHandle for Event { - fn as_raw_handle(&self) -> RawHandle { + pub fn as_handle(&self) -> HANDLE { self.0 } } @@ -406,7 +388,7 @@ impl Drop for Event { const fn bool_to_winbool(val: bool) -> BOOL { match val { - true => TRUE, - false => FALSE, + true => 1, + false => 0, } } diff --git a/talpid-core/src/tunnel/openvpn/mod.rs b/talpid-core/src/tunnel/openvpn/mod.rs index 15d9013610..95bcac3295 100644 --- a/talpid-core/src/tunnel/openvpn/mod.rs +++ b/talpid-core/src/tunnel/openvpn/mod.rs @@ -35,7 +35,7 @@ use which; #[cfg(windows)] use widestring::U16CString; #[cfg(windows)] -use winapi::shared::{guiddef::GUID, ifdef::NET_LUID}; +use windows_sys::{core::GUID, Win32::NetworkManagement::IpHelper::NET_LUID_LH}; #[cfg(windows)] mod wintun; @@ -48,10 +48,10 @@ lazy_static! { #[cfg(windows)] const ADAPTER_GUID: GUID = GUID { - Data1: 0xAFE43773, - Data2: 0xE1F8, - Data3: 0x4EBB, - Data4: [0x85, 0x36, 0x57, 0x6A, 0xB8, 0x6A, 0xFE, 0x9A], + data1: 0xAFE43773, + data2: 0xE1F8, + data3: 0x4EBB, + data4: [0x85, 0x36, 0x57, 0x6A, 0xB8, 0x6A, 0xFE, 0x9A], }; /// Results from fallible operations on the OpenVPN tunnel. @@ -189,7 +189,7 @@ pub struct OpenVpnMonitor<C: OpenVpnBuilder = OpenVpnCommand> { #[cfg(windows)] #[async_trait::async_trait] trait WintunContext: Send + Sync { - fn luid(&self) -> NET_LUID; + fn luid(&self) -> NET_LUID_LH; fn ipv6(&self) -> bool; async fn wait_for_interfaces(&self) -> io::Result<()>; fn prepare_interface(&self) {} @@ -201,7 +201,7 @@ impl std::fmt::Debug for dyn WintunContext { write!( f, "WintunContext {{ luid: {}, ipv6: {} }}", - self.luid().Value, + unsafe { self.luid().Value }, self.ipv6() ) } @@ -218,7 +218,7 @@ struct WintunContextImpl { #[cfg(windows)] #[async_trait::async_trait] impl WintunContext for WintunContextImpl { - fn luid(&self) -> NET_LUID { + fn luid(&self) -> NET_LUID_LH { self.adapter.luid() } @@ -1132,8 +1132,8 @@ mod tests { #[cfg(windows)] #[async_trait::async_trait] impl WintunContext for TestWintunContext { - fn luid(&self) -> NET_LUID { - NET_LUID { Value: 0u64 } + fn luid(&self) -> NET_LUID_LH { + NET_LUID_LH { Value: 0u64 } } fn ipv6(&self) -> bool { false diff --git a/talpid-core/src/tunnel/openvpn/wintun.rs b/talpid-core/src/tunnel/openvpn/wintun.rs index 9d9996aefc..5834587e59 100644 --- a/talpid-core/src/tunnel/openvpn/wintun.rs +++ b/talpid-core/src/tunnel/openvpn/wintun.rs @@ -10,19 +10,17 @@ use std::{ }; use talpid_types::ErrorExt; use widestring::{U16CStr, U16CString}; -use winapi::{ - shared::{ - guiddef::GUID, - ifdef::NET_LUID, - minwindef::{FARPROC, HINSTANCE, HMODULE}, - netioapi::ConvertInterfaceLuidToGuid, - winerror::NO_ERROR, - }, - um::{ - libloaderapi::{ - FreeLibrary, GetProcAddress, LoadLibraryExW, LOAD_WITH_ALTERED_SEARCH_PATH, +use windows_sys::{ + core::GUID, + Win32::{ + Foundation::{HINSTANCE, NO_ERROR}, + NetworkManagement::IpHelper::{ConvertInterfaceLuidToGuid, NET_LUID_LH}, + System::{ + LibraryLoader::{ + FreeLibrary, GetProcAddress, LoadLibraryExW, LOAD_WITH_ALTERED_SEARCH_PATH, + }, + Registry::REG_SAM_FLAGS, }, - winreg::REGSAM, }, }; use winreg::{ @@ -43,7 +41,8 @@ type WintunCreateAdapterFn = unsafe extern "stdcall" fn( type WintunCloseAdapterFn = unsafe extern "stdcall" fn(adapter: RawHandle); -type WintunGetAdapterLuidFn = unsafe extern "stdcall" fn(adapter: RawHandle, luid: *mut NET_LUID); +type WintunGetAdapterLuidFn = + unsafe extern "stdcall" fn(adapter: RawHandle, luid: *mut NET_LUID_LH); type WintunLoggerCbFn = extern "stdcall" fn(WintunLoggerLevel, u64, *const u16); @@ -116,14 +115,14 @@ impl WintunAdapter { self.name.to_owned() } - pub fn luid(&self) -> NET_LUID { + pub fn luid(&self) -> NET_LUID_LH { unsafe { self.dll_handle.get_adapter_luid(self.handle) } } pub fn guid(&self) -> io::Result<GUID> { let mut guid = mem::MaybeUninit::zeroed(); let result = unsafe { ConvertInterfaceLuidToGuid(&self.luid(), guid.as_mut_ptr()) }; - if result != NO_ERROR { + if result != NO_ERROR as i32 { return Err(io::Error::from_raw_os_error(result as i32)); } Ok(unsafe { guid.assume_init() }) @@ -196,22 +195,20 @@ impl WintunDll { fn new(resource_dir: &Path) -> io::Result<Self> { let wintun_dll = U16CString::from_os_str_truncate(resource_dir.join("wintun.dll")); - let handle = unsafe { - LoadLibraryExW( - wintun_dll.as_ptr(), - ptr::null_mut(), - LOAD_WITH_ALTERED_SEARCH_PATH, - ) - }; - if handle == ptr::null_mut() { + let handle = + unsafe { LoadLibraryExW(wintun_dll.as_ptr(), 0, LOAD_WITH_ALTERED_SEARCH_PATH) }; + if handle == 0 { return Err(io::Error::last_os_error()); } Self::new_inner(handle, Self::get_proc_address) } fn new_inner( - handle: HMODULE, - get_proc_fn: unsafe fn(HMODULE, &CStr) -> io::Result<FARPROC>, + handle: HINSTANCE, + get_proc_fn: unsafe fn( + HINSTANCE, + &CStr, + ) -> io::Result<unsafe extern "system" fn() -> isize>, ) -> io::Result<Self> { Ok(WintunDll { handle, @@ -242,12 +239,12 @@ impl WintunDll { }) } - unsafe fn get_proc_address(handle: HMODULE, name: &CStr) -> io::Result<FARPROC> { - let handle = GetProcAddress(handle, name.as_ptr()); - if handle == ptr::null_mut() { - return Err(io::Error::last_os_error()); - } - Ok(handle) + unsafe fn get_proc_address( + handle: HINSTANCE, + name: &CStr, + ) -> io::Result<unsafe extern "system" fn() -> isize> { + let handle = GetProcAddress(handle, name.as_ptr() as *const u8); + handle.ok_or(io::Error::last_os_error()) } pub fn create_adapter( @@ -271,8 +268,8 @@ impl WintunDll { (self.func_close)(adapter); } - pub unsafe fn get_adapter_luid(&self, adapter: RawHandle) -> NET_LUID { - let mut luid = mem::MaybeUninit::<NET_LUID>::zeroed(); + pub unsafe fn get_adapter_luid(&self, adapter: RawHandle) -> NET_LUID_LH { + let mut luid = mem::MaybeUninit::<NET_LUID_LH>::zeroed(); (self.func_get_adapter_luid)(adapter, luid.as_mut_ptr()); luid.assume_init() } @@ -331,7 +328,7 @@ impl Drop for WintunLoggerHandle { } /// Returns the registry key for a network device identified by its GUID. -fn find_adapter_registry_key(find_guid: &str, permissions: REGSAM) -> io::Result<RegKey> { +fn find_adapter_registry_key(find_guid: &str, permissions: REG_SAM_FLAGS) -> io::Result<RegKey> { let net_devs = RegKey::predef(HKEY_LOCAL_MACHINE).open_subkey_with_flags( r"SYSTEM\CurrentControlSet\Control\Class\{4d36e972-e325-11ce-bfc1-08002be10318}", permissions, @@ -362,13 +359,16 @@ fn find_adapter_registry_key(find_guid: &str, permissions: REGSAM) -> io::Result mod tests { use super::*; - fn get_proc_fn(_handle: HMODULE, _symbol: &CStr) -> io::Result<FARPROC> { - Ok(std::ptr::null_mut()) + fn get_proc_fn( + _handle: HINSTANCE, + _symbol: &CStr, + ) -> io::Result<unsafe extern "system" fn() -> isize> { + Ok(null_fn) } #[test] fn test_wintun_imports() { - WintunDll::new_inner(ptr::null_mut(), get_proc_fn).unwrap(); + WintunDll::new_inner(0, get_proc_fn).unwrap(); } #[test] @@ -377,19 +377,19 @@ mod tests { ( "{AFE43773-E1F8-4EBB-8536-576AB86AFE9A}", GUID { - Data1: 0xAFE43773, - Data2: 0xE1F8, - Data3: 0x4EBB, - Data4: [0x85, 0x36, 0x57, 0x6A, 0xB8, 0x6A, 0xFE, 0x9A], + data1: 0xAFE43773, + data2: 0xE1F8, + data3: 0x4EBB, + data4: [0x85, 0x36, 0x57, 0x6A, 0xB8, 0x6A, 0xFE, 0x9A], }, ), ( "{00000000-0000-0000-0000-000000000000}", GUID { - Data1: 0, - Data2: 0, - Data3: 0, - Data4: [0; 8], + data1: 0, + data2: 0, + data3: 0, + data4: [0; 8], }, ), ]; @@ -401,4 +401,8 @@ mod tests { ); } } + + unsafe extern "system" fn null_fn() -> isize { + unreachable!("unexpected call of function") + } } diff --git a/talpid-core/src/tunnel/windows.rs b/talpid-core/src/tunnel/windows.rs index cf4871904f..14206a0d83 100644 --- a/talpid-core/src/tunnel/windows.rs +++ b/talpid-core/src/tunnel/windows.rs @@ -1,12 +1,13 @@ use crate::windows::{get_ip_interface_entry, set_ip_interface_entry, AddressFamily}; use std::io; -use winapi::shared::{ - ifdef::NET_LUID, nldef::RouterDiscoveryDisabled, ntdef::FALSE, winerror::ERROR_NOT_FOUND, +use windows_sys::Win32::{ + Foundation::ERROR_NOT_FOUND, NetworkManagement::IpHelper::NET_LUID_LH, + Networking::WinSock::RouterDiscoveryDisabled, }; /// Sets MTU, metric, and disables unnecessary features for the IP interfaces /// on the specified network interface (identified by `luid`). -pub fn initialize_interfaces(luid: NET_LUID, mtu: Option<u32>) -> io::Result<()> { +pub fn initialize_interfaces(luid: NET_LUID_LH, mtu: Option<u32>) -> io::Result<()> { for family in &[AddressFamily::Ipv4, AddressFamily::Ipv6] { let mut row = match get_ip_interface_entry(*family, &luid) { Ok(row) => row, @@ -22,12 +23,12 @@ pub fn initialize_interfaces(luid: NET_LUID, mtu: Option<u32>) -> io::Result<()> row.SitePrefixLength = 0; row.RouterDiscoveryBehavior = RouterDiscoveryDisabled; row.DadTransmits = 0; - row.ManagedAddressConfigurationSupported = FALSE; - row.OtherStatefulConfigurationSupported = FALSE; + row.ManagedAddressConfigurationSupported = 0; + row.OtherStatefulConfigurationSupported = 0; // Ensure lowest interface metric row.Metric = 1; - row.UseAutomaticMetric = FALSE; + row.UseAutomaticMetric = 0; set_ip_interface_entry(&mut row)?; } diff --git a/talpid-core/src/tunnel/wireguard/wireguard_go.rs b/talpid-core/src/tunnel/wireguard/wireguard_go.rs index 5673406982..d78e422697 100644 --- a/talpid-core/src/tunnel/wireguard/wireguard_go.rs +++ b/talpid-core/src/tunnel/wireguard/wireguard_go.rs @@ -168,8 +168,8 @@ impl WgGoTunnel { let has_ipv6 = config.tunnel.addresses.iter().any(|addr| addr.is_ipv6()); let setup_handle = tokio::spawn(async move { - use winapi::shared::ifdef::NET_LUID; - let luid = NET_LUID { + use windows_sys::Win32::NetworkManagement::IpHelper::NET_LUID_LH; + let luid = NET_LUID_LH { Value: interface_luid, }; log::debug!("Waiting for tunnel IP interfaces to arrive"); @@ -214,13 +214,15 @@ impl WgGoTunnel { default_route: winnet::WinNetDefaultRoute, _ctx: *mut libc::c_void, ) { - use winapi::shared::{ifdef::NET_LUID, netioapi::ConvertInterfaceLuidToIndex}; + use windows_sys::Win32::NetworkManagement::IpHelper::{ + ConvertInterfaceLuidToIndex, NET_LUID_LH, + }; use winnet::WinNetDefaultRouteChangeEventType::*; let iface_idx: u32 = match event_type { DefaultRouteChanged => { let mut iface_idx = 0u32; - let iface_luid = NET_LUID { + let iface_luid = NET_LUID_LH { Value: default_route.interface_luid, }; let status = diff --git a/talpid-core/src/tunnel/wireguard/wireguard_nt.rs b/talpid-core/src/tunnel/wireguard/wireguard_nt.rs index 113f6534c0..b0ffab20fa 100644 --- a/talpid-core/src/tunnel/wireguard/wireguard_nt.rs +++ b/talpid-core/src/tunnel/wireguard/wireguard_nt.rs @@ -24,19 +24,17 @@ use std::{ }; use talpid_types::{BoxedError, ErrorExt}; use widestring::{U16CStr, U16CString}; -use winapi::{ - shared::{ - guiddef::GUID, - ifdef::NET_LUID, - in6addr::IN6_ADDR, - inaddr::IN_ADDR, - minwindef::{BOOL, FARPROC, HINSTANCE, HMODULE}, - winerror::ERROR_MORE_DATA, - ws2def::{ADDRESS_FAMILY, AF_INET, AF_INET6}, - ws2ipdef::SOCKADDR_INET, - }, - um::libloaderapi::{ - FreeLibrary, GetProcAddress, LoadLibraryExW, LOAD_WITH_ALTERED_SEARCH_PATH, +use windows_sys::{ + core::GUID, + Win32::{ + Foundation::{BOOL, ERROR_MORE_DATA, HINSTANCE}, + NetworkManagement::IpHelper::NET_LUID_LH, + Networking::WinSock::{ + ADDRESS_FAMILY, AF_INET, AF_INET6, IN6_ADDR, IN_ADDR, SOCKADDR_INET, + }, + System::LibraryLoader::{ + FreeLibrary, GetProcAddress, LoadLibraryExW, LOAD_WITH_ALTERED_SEARCH_PATH, + }, }, }; @@ -47,10 +45,10 @@ lazy_static! { } const ADAPTER_GUID: GUID = GUID { - Data1: 0x514a3988, - Data2: 0x9716, - Data3: 0x43d5, - Data4: [0x8b, 0x05, 0x31, 0xda, 0x25, 0xa0, 0x44, 0xa9], + data1: 0x514a3988, + data2: 0x9716, + data3: 0x43d5, + data4: [0x8b, 0x05, 0x31, 0xda, 0x25, 0xa0, 0x44, 0xa9], }; type WireGuardCreateAdapterFn = unsafe extern "stdcall" fn( @@ -60,7 +58,7 @@ type WireGuardCreateAdapterFn = unsafe extern "stdcall" fn( ) -> RawHandle; type WireGuardCloseAdapterFn = unsafe extern "stdcall" fn(adapter: RawHandle); type WireGuardGetAdapterLuidFn = - unsafe extern "stdcall" fn(adapter: RawHandle, luid: *mut NET_LUID); + unsafe extern "stdcall" fn(adapter: RawHandle, luid: *mut NET_LUID_LH); type WireGuardSetConfigurationFn = unsafe extern "stdcall" fn( adapter: RawHandle, config: *const MaybeUninit<u8>, @@ -147,7 +145,7 @@ pub enum Error { /// Unknown address family #[error(display = "Unknown address family: {}", _0)] - UnknownAddressFamily(i32), + UnknownAddressFamily(u32), /// Failure to set up logging #[error(display = "Failed to set up logging")] @@ -213,7 +211,7 @@ impl From<Ipv4Addr> for WgIpAddr { #[repr(C, align(8))] struct WgAllowedIp { address: WgIpAddr, - address_family: ADDRESS_FAMILY, + address_family: u16, cidr: u8, } @@ -222,19 +220,19 @@ impl WgAllowedIp { Self::validate(&address, address_family, cidr)?; Ok(Self { address, - address_family, + address_family: address_family as u16, cidr, }) } fn validate(address: &WgIpAddr, address_family: ADDRESS_FAMILY, cidr: u8) -> Result<()> { - match address_family as i32 { + match address_family { AF_INET => { if cidr > 32 { return Err(Error::InvalidAllowedIpCidr); } let host_mask = u32::MAX.checked_shr(u32::from(cidr)).unwrap_or(0); - if host_mask & (unsafe { *(address.v4.S_un.S_addr()) }.to_be()) != 0 { + if host_mask & unsafe { address.v4.S_un.S_addr }.to_be() != 0 { return Err(Error::InvalidAllowedIpBits); } } @@ -243,7 +241,7 @@ impl WgAllowedIp { return Err(Error::InvalidAllowedIpCidr); } let mut host_mask = u128::MAX.checked_shr(u32::from(cidr)).unwrap_or(0); - let bytes = unsafe { address.v6.u.Byte() }; + let bytes = unsafe { address.v6.u.Byte }; for byte in bytes.iter().rev() { if byte & ((host_mask & 0xff) as u8) != 0 { return Err(Error::InvalidAllowedIpBits); @@ -262,7 +260,7 @@ impl PartialEq for WgAllowedIp { if self.cidr != other.cidr { return false; } - match self.address_family as i32 { + match self.address_family as u32 { AF_INET => { windows::ipaddr_from_inaddr(unsafe { self.address.v4 }) == windows::ipaddr_from_inaddr(unsafe { other.address.v4 }) @@ -283,7 +281,7 @@ impl Eq for WgAllowedIp {} impl fmt::Debug for WgAllowedIp { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { let mut s = f.debug_struct("WgAllowedIp"); - match self.address_family as i32 { + match self.address_family as u32 { AF_INET => s.field( "address", &windows::ipaddr_from_inaddr(unsafe { self.address.v4 }), @@ -470,9 +468,12 @@ async fn setup_ip_listener( has_ipv6: bool, ) -> Result<()> { let luid = { device.lock().unwrap().as_ref().unwrap().luid() }; + let luid = NET_LUID_LH { + Value: unsafe { luid.Value }, + }; log::debug!("Waiting for tunnel IP interfaces to arrive"); - windows::wait_for_interfaces(luid.clone(), true, has_ipv6) + windows::wait_for_interfaces(luid, true, has_ipv6) .await .map_err(Error::IpInterfacesError)?; log::debug!("Waiting for tunnel IP interfaces: Done"); @@ -573,7 +574,7 @@ impl WgNtAdapter { }) } - fn luid(&self) -> NET_LUID { + fn luid(&self) -> NET_LUID_LH { unsafe { self.dll_handle.get_adapter_luid(self.handle) } } @@ -632,22 +633,20 @@ impl WgNtDll { let wg_nt_dll = U16CString::from_os_str_truncate(resource_dir.join("mullvad-wireguard.dll")); - let handle = unsafe { - LoadLibraryExW( - wg_nt_dll.as_ptr(), - ptr::null_mut(), - LOAD_WITH_ALTERED_SEARCH_PATH, - ) - }; - if handle == ptr::null_mut() { + let handle = + unsafe { LoadLibraryExW(wg_nt_dll.as_ptr(), 0, LOAD_WITH_ALTERED_SEARCH_PATH) }; + if handle == 0 { return Err(io::Error::last_os_error()); } Self::new_inner(handle, Self::get_proc_address) } fn new_inner( - handle: HMODULE, - get_proc_fn: unsafe fn(HMODULE, &CStr) -> io::Result<FARPROC>, + handle: HINSTANCE, + get_proc_fn: unsafe fn( + HINSTANCE, + &CStr, + ) -> io::Result<unsafe extern "system" fn() -> isize>, ) -> io::Result<Self> { Ok(WgNtDll { handle, @@ -702,12 +701,12 @@ impl WgNtDll { }) } - unsafe fn get_proc_address(handle: HMODULE, name: &CStr) -> io::Result<FARPROC> { - let handle = GetProcAddress(handle, name.as_ptr()); - if handle == ptr::null_mut() { - return Err(io::Error::last_os_error()); - } - Ok(handle) + unsafe fn get_proc_address( + handle: HINSTANCE, + name: &CStr, + ) -> io::Result<unsafe extern "system" fn() -> isize> { + let handle = GetProcAddress(handle, name.as_ptr() as *const u8); + handle.ok_or(io::Error::last_os_error()) } pub fn create_adapter( @@ -731,8 +730,8 @@ impl WgNtDll { (self.func_close)(adapter); } - pub unsafe fn get_adapter_luid(&self, adapter: RawHandle) -> NET_LUID { - let mut luid = mem::MaybeUninit::<NET_LUID>::zeroed(); + pub unsafe fn get_adapter_luid(&self, adapter: RawHandle) -> NET_LUID_LH { + let mut luid = mem::MaybeUninit::<NET_LUID_LH>::zeroed(); (self.func_get_adapter_luid)(adapter, luid.as_mut_ptr()); luid.assume_init() } @@ -854,8 +853,8 @@ fn serialize_config(config: &Config) -> Result<Vec<MaybeUninit<u8>>> { for allowed_ip in &peer.allowed_ips { let address_family = match allowed_ip { - IpNetwork::V4(_) => AF_INET as u16, - IpNetwork::V6(_) => AF_INET6 as u16, + IpNetwork::V4(_) => AF_INET, + IpNetwork::V6(_) => AF_INET6, }; let address = match allowed_ip { IpNetwork::V4(v4_network) => WgIpAddr::from(v4_network.ip()), @@ -908,7 +907,7 @@ unsafe fn deserialize_config( let allowed_ip: WgAllowedIp = *(allowed_ip_data.as_ptr() as *const WgAllowedIp); if let Err(error) = WgAllowedIp::validate( &allowed_ip.address, - allowed_ip.address_family, + u32::from(allowed_ip.address_family), allowed_ip.cidr, ) { log::error!( @@ -1053,13 +1052,16 @@ mod tests { }; } - fn get_proc_fn(_handle: HMODULE, _symbol: &CStr) -> io::Result<FARPROC> { - Ok(std::ptr::null_mut()) + fn get_proc_fn( + _handle: HINSTANCE, + _symbol: &CStr, + ) -> io::Result<unsafe extern "system" fn() -> isize> { + Ok(null_fn) } #[test] fn test_dll_imports() { - WgNtDll::new_inner(ptr::null_mut(), get_proc_fn).unwrap(); + WgNtDll::new_inner(0, get_proc_fn).unwrap(); } #[test] @@ -1085,7 +1087,7 @@ mod tests { #[test] fn test_wg_allowed_ip_v4() { // Valid: /32 prefix - let address_family = AF_INET as u16; + let address_family = AF_INET; let address = WgIpAddr::from("127.0.0.1".parse::<Ipv4Addr>().unwrap()); let cidr = 32; WgAllowedIp::new(address, address_family, cidr).unwrap(); @@ -1113,7 +1115,7 @@ mod tests { #[test] fn test_wg_allowed_ip_v6() { // Valid: /128 prefix - let address_family = AF_INET6 as u16; + let address_family = AF_INET6; let address = WgIpAddr::from("::1".parse::<Ipv6Addr>().unwrap()); let cidr = 128; WgAllowedIp::new(address, address_family, cidr).unwrap(); @@ -1139,4 +1141,8 @@ mod tests { let cidr = 129; assert!(WgAllowedIp::new(address, address_family, cidr).is_err()); } + + unsafe extern "system" fn null_fn() -> isize { + unreachable!("unexpected call of function") + } } diff --git a/talpid-core/src/windows/mod.rs b/talpid-core/src/windows/mod.rs index 6b97ee389e..13efa824f0 100644 --- a/talpid-core/src/windows/mod.rs +++ b/talpid-core/src/windows/mod.rs @@ -1,48 +1,40 @@ +use libc::c_void; use socket2::SockAddr; use std::{ ffi::{OsStr, OsString}, fmt, io, mem::{self, MaybeUninit}, net::{IpAddr, Ipv4Addr, Ipv6Addr, SocketAddr, SocketAddrV4, SocketAddrV6}, - os::windows::{ - ffi::{OsStrExt, OsStringExt}, - io::RawHandle, - }, + os::windows::ffi::{OsStrExt, OsStringExt}, path::PathBuf, ptr, sync::Mutex, time::{Duration, Instant}, }; use widestring::WideCStr; -use winapi::{ - shared::{ - guiddef::GUID, - ifdef::NET_LUID, - in6addr::IN6_ADDR, - inaddr::IN_ADDR, - netioapi::{ - CancelMibChangeNotify2, ConvertInterfaceAliasToLuid, ConvertInterfaceLuidToAlias, - ConvertInterfaceLuidToGuid, CreateUnicastIpAddressEntry, FreeMibTable, - GetIpInterfaceEntry, GetUnicastIpAddressEntry, GetUnicastIpAddressTable, - InitializeUnicastIpAddressEntry, MibAddInstance, NotifyIpInterfaceChange, - SetIpInterfaceEntry, MIB_IPINTERFACE_ROW, MIB_UNICASTIPADDRESS_ROW, - MIB_UNICASTIPADDRESS_TABLE, +use winapi::shared::ws2def::SOCKADDR_STORAGE as sockaddr_storage; +use windows_sys::{ + core::{GUID, PWSTR}, + Win32::{ + Foundation::{ERROR_NOT_FOUND, HANDLE, NO_ERROR, S_OK}, + NetworkManagement::{ + IpHelper::{ + CancelMibChangeNotify2, ConvertInterfaceAliasToLuid, ConvertInterfaceLuidToAlias, + ConvertInterfaceLuidToGuid, CreateUnicastIpAddressEntry, FreeMibTable, + GetIpInterfaceEntry, GetUnicastIpAddressEntry, GetUnicastIpAddressTable, + InitializeUnicastIpAddressEntry, MibAddInstance, NotifyIpInterfaceChange, + SetIpInterfaceEntry, MIB_IPINTERFACE_ROW, MIB_UNICASTIPADDRESS_ROW, + MIB_UNICASTIPADDRESS_TABLE, NET_LUID_LH, + }, + Ndis::NDIS_IF_MAX_STRING_SIZE, }, - nldef::{IpDadStatePreferred, IpDadStateTentative, NL_DAD_STATE}, - ntddndis::NDIS_IF_MAX_STRING_SIZE, - ntdef::FALSE, - winerror::{ERROR_NOT_FOUND, NO_ERROR, S_OK}, - ws2def::{ - AF_INET, AF_INET6, AF_UNSPEC, SOCKADDR_IN as sockaddr_in, - SOCKADDR_STORAGE as sockaddr_storage, + Networking::WinSock::{ + IpDadStateDeprecated, IpDadStateDuplicate, IpDadStateInvalid, IpDadStatePreferred, + IpDadStateTentative, AF_INET, AF_INET6, AF_UNSPEC, IN6_ADDR, IN_ADDR, NL_DAD_STATE, + SOCKADDR_IN as sockaddr_in, SOCKADDR_IN6 as sockaddr_in6, SOCKADDR_INET, }, - ws2ipdef::{SOCKADDR_IN6_LH as sockaddr_in6, SOCKADDR_INET}, - }, - um::{ - combaseapi::{CoTaskMemFree, StringFromGUID2}, - knownfolders::FOLDERID_System, - shlobj::SHGetKnownFolderPath, - winnt::PWSTR, + System::Com::{CoTaskMemFree, StringFromGUID2}, + UI::Shell::{FOLDERID_System, SHGetKnownFolderPath}, }, }; @@ -95,7 +87,7 @@ pub enum Error { /// Unknown address family #[error(display = "Unknown address family: {}", _0)] - UnknownAddressFamily(i32), + UnknownAddressFamily(u32), } /// Address family. These correspond to the `AF_*` constants. @@ -119,7 +111,7 @@ impl fmt::Display for AddressFamily { impl AddressFamily { /// Convert an [`AddressFamily`] to one of the `AF_*` constants. pub fn try_from_af_family(family: u16) -> Result<AddressFamily> { - match family as i32 { + match u32::from(family) { AF_INET => Ok(AddressFamily::Ipv4), AF_INET6 => Ok(AddressFamily::Ipv6), family => Err(Error::UnknownAddressFamily(family)), @@ -130,22 +122,22 @@ impl AddressFamily { /// Context for [`notify_ip_interface_change`]. When it is dropped, /// the callback is unregistered. pub struct IpNotifierHandle<'a> { - callback: Mutex<Box<dyn FnMut(&MIB_IPINTERFACE_ROW, u32) + Send + 'a>>, - handle: RawHandle, + callback: Mutex<Box<dyn FnMut(&MIB_IPINTERFACE_ROW, i32) + Send + 'a>>, + handle: HANDLE, } unsafe impl Send for IpNotifierHandle<'_> {} impl<'a> Drop for IpNotifierHandle<'a> { fn drop(&mut self) { - unsafe { CancelMibChangeNotify2(self.handle as *mut _) }; + unsafe { CancelMibChangeNotify2(self.handle) }; } } unsafe extern "system" fn inner_callback( - context: *mut winapi::ctypes::c_void, - row: *mut MIB_IPINTERFACE_ROW, - notify_type: u32, + context: *const c_void, + row: *const MIB_IPINTERFACE_ROW, + notify_type: i32, ) { let context = &mut *(context as *mut IpNotifierHandle<'_>); context @@ -156,13 +148,13 @@ unsafe extern "system" fn inner_callback( /// Registers a callback function that is invoked when an interface is added, removed, /// or changed. -pub fn notify_ip_interface_change<'a, T: FnMut(&MIB_IPINTERFACE_ROW, u32) + Send + 'a>( +pub fn notify_ip_interface_change<'a, T: FnMut(&MIB_IPINTERFACE_ROW, i32) + Send + 'a>( callback: T, family: Option<AddressFamily>, ) -> io::Result<Box<IpNotifierHandle<'a>>> { let mut context = Box::new(IpNotifierHandle { callback: Mutex::new(Box::new(callback)), - handle: std::ptr::null_mut(), + handle: 0, }); let status = unsafe { @@ -170,12 +162,12 @@ pub fn notify_ip_interface_change<'a, T: FnMut(&MIB_IPINTERFACE_ROW, u32) + Send af_family_from_family(family), Some(inner_callback), &mut *context as *mut _ as *mut _, - FALSE, + 0, (&mut context.handle) as *mut _, ) }; - if status == NO_ERROR { + if status == NO_ERROR as i32 { Ok(context) } else { Err(io::Error::from_raw_os_error(status as i32)) @@ -185,14 +177,14 @@ pub fn notify_ip_interface_change<'a, T: FnMut(&MIB_IPINTERFACE_ROW, u32) + Send /// Returns information about a network IP interface. pub fn get_ip_interface_entry( family: AddressFamily, - luid: &NET_LUID, + luid: &NET_LUID_LH, ) -> io::Result<MIB_IPINTERFACE_ROW> { let mut row: MIB_IPINTERFACE_ROW = unsafe { mem::zeroed() }; row.Family = family as u16; row.InterfaceLuid = *luid; let result = unsafe { GetIpInterfaceEntry(&mut row) }; - if result == NO_ERROR { + if result == NO_ERROR as i32 { Ok(row) } else { Err(io::Error::from_raw_os_error(result as i32)) @@ -202,14 +194,14 @@ pub fn get_ip_interface_entry( /// Set the properties of an IP interface. pub fn set_ip_interface_entry(row: &mut MIB_IPINTERFACE_ROW) -> io::Result<()> { let result = unsafe { SetIpInterfaceEntry(row as *mut _) }; - if result == NO_ERROR { + if result == NO_ERROR as i32 { Ok(()) } else { Err(io::Error::from_raw_os_error(result as i32)) } } -fn ip_interface_entry_exists(family: AddressFamily, luid: &NET_LUID) -> io::Result<bool> { +fn ip_interface_entry_exists(family: AddressFamily, luid: &NET_LUID_LH) -> io::Result<bool> { match get_ip_interface_entry(family, luid) { Ok(_) => Ok(true), Err(error) if error.raw_os_error() == Some(ERROR_NOT_FOUND as i32) => Ok(false), @@ -218,7 +210,7 @@ fn ip_interface_entry_exists(family: AddressFamily, luid: &NET_LUID) -> io::Resu } /// Waits until the specified IP interfaces have attached to a given network interface. -pub async fn wait_for_interfaces(luid: NET_LUID, ipv4: bool, ipv6: bool) -> io::Result<()> { +pub async fn wait_for_interfaces(luid: NET_LUID_LH, ipv4: bool, ipv6: bool) -> io::Result<()> { let (tx, rx) = futures::channel::oneshot::channel(); let mut found_ipv4 = if ipv4 { false } else { true }; @@ -234,10 +226,10 @@ pub async fn wait_for_interfaces(luid: NET_LUID, ipv4: bool, ipv6: bool) -> io:: if notification_type != MibAddInstance { return; } - if row.InterfaceLuid.Value != luid.Value { + if unsafe { row.InterfaceLuid.Value != luid.Value } { return; } - match row.Family as i32 { + match row.Family as u32 { AF_INET => found_ipv4 = true, AF_INET6 => found_ipv6 = true, _ => (), @@ -263,7 +255,6 @@ pub async fn wait_for_interfaces(luid: NET_LUID, ipv4: bool, ipv6: bool) -> io:: } /// Handles cases where there DAD state is neither tentative nor preferred. -#[cfg(windows)] #[derive(err_derive::Error, Debug)] pub enum DadStateError { /// Invalid DAD state. @@ -280,14 +271,12 @@ pub enum DadStateError { /// Unknown DAD state constant. #[error(display = "Unknown DAD state: {}", _0)] - Unknown(u32), + Unknown(i32), } -#[cfg(windows)] #[allow(non_upper_case_globals)] impl From<NL_DAD_STATE> for DadStateError { fn from(state: NL_DAD_STATE) -> DadStateError { - use winapi::shared::nldef::*; match state { IpDadStateInvalid => DadStateError::Invalid, IpDadStateDuplicate => DadStateError::Duplicate, @@ -298,12 +287,12 @@ impl From<NL_DAD_STATE> for DadStateError { } /// Wait for addresses to be usable on an network adapter. -pub async fn wait_for_addresses(luid: NET_LUID) -> Result<()> { +pub async fn wait_for_addresses(luid: NET_LUID_LH) -> Result<()> { // Obtain unicast IP addresses let mut unicast_rows: Vec<MIB_UNICASTIPADDRESS_ROW> = get_unicast_table(None) .map_err(Error::ObtainUnicastAddress)? .into_iter() - .filter(|row| row.InterfaceLuid.Value == luid.Value) + .filter(|row| unsafe { row.InterfaceLuid.Value == luid.Value }) .collect(); if unicast_rows.is_empty() { return Err(Error::NoUnicastAddress); @@ -320,7 +309,7 @@ pub async fn wait_for_addresses(luid: NET_LUID) -> Result<()> { for row in &mut unicast_rows { let status = unsafe { GetUnicastIpAddressEntry(row) }; - if status != NO_ERROR { + if status != NO_ERROR as i32 { return Err(Error::ObtainUnicastAddress(io::Error::from_raw_os_error( status as i32, ))); @@ -351,12 +340,12 @@ pub async fn wait_for_addresses(luid: NET_LUID) -> Result<()> { /// Returns the first unicast IP address for the given interface. pub fn get_ip_address_for_interface( family: AddressFamily, - luid: NET_LUID, + luid: NET_LUID_LH, ) -> Result<Option<IpAddr>> { match get_unicast_table(Some(family)) .map_err(Error::ObtainUnicastAddress)? .into_iter() - .find(|row| row.InterfaceLuid.Value == luid.Value) + .find(|row| unsafe { row.InterfaceLuid.Value == luid.Value }) { Some(row) => Ok(Some(try_socketaddr_from_inet_sockaddr(row.Address)?.ip())), None => Ok(None), @@ -364,7 +353,7 @@ pub fn get_ip_address_for_interface( } /// Adds a unicast IP address for the given interface. -pub fn add_ip_address_for_interface(luid: NET_LUID, address: IpAddr) -> Result<()> { +pub fn add_ip_address_for_interface(luid: NET_LUID_LH, address: IpAddr) -> Result<()> { let mut row = unsafe { mem::zeroed() }; unsafe { InitializeUnicastIpAddressEntry(&mut row) }; @@ -374,7 +363,7 @@ pub fn add_ip_address_for_interface(luid: NET_LUID, address: IpAddr) -> Result<( row.OnLinkPrefixLength = 255; let status = unsafe { CreateUnicastIpAddressEntry(&row) }; - if status != NO_ERROR { + if status != NO_ERROR as i32 { return Err(Error::CreateUnicastEntry(io::Error::from_raw_os_error( status as i32, ))); @@ -392,7 +381,7 @@ pub fn get_unicast_table( let status = unsafe { GetUnicastIpAddressTable(af_family_from_family(family), &mut unicast_table) }; - if status != NO_ERROR { + if status != NO_ERROR as i32 { return Err(io::Error::from_raw_os_error(status as i32)); } let first_row = unsafe { &(*unicast_table).Table[0] } as *const MIB_UNICASTIPADDRESS_ROW; @@ -416,36 +405,36 @@ pub fn string_from_guid(guid: &GUID) -> String { } /// Returns the GUID of a network interface given its LUID. -pub fn guid_from_luid(luid: &NET_LUID) -> io::Result<GUID> { +pub fn guid_from_luid(luid: &NET_LUID_LH) -> io::Result<GUID> { let mut guid = MaybeUninit::zeroed(); let status = unsafe { ConvertInterfaceLuidToGuid(luid, guid.as_mut_ptr()) }; - if status != NO_ERROR { + if status != NO_ERROR as i32 { return Err(io::Error::from_raw_os_error(status as i32)); } Ok(unsafe { guid.assume_init() }) } /// Returns the LUID of an interface given its alias. -pub fn luid_from_alias<T: AsRef<OsStr>>(alias: T) -> io::Result<NET_LUID> { +pub fn luid_from_alias<T: AsRef<OsStr>>(alias: T) -> io::Result<NET_LUID_LH> { let alias_wide: Vec<u16> = alias .as_ref() .encode_wide() .chain(std::iter::once(0u16)) .collect(); - let mut luid: NET_LUID = unsafe { std::mem::zeroed() }; + let mut luid: NET_LUID_LH = unsafe { std::mem::zeroed() }; let status = unsafe { ConvertInterfaceAliasToLuid(alias_wide.as_ptr(), &mut luid) }; - if status != NO_ERROR { + if status != NO_ERROR as i32 { return Err(io::Error::from_raw_os_error(status as i32)); } Ok(luid) } /// Returns the alias of an interface given its LUID. -pub fn alias_from_luid(luid: &NET_LUID) -> io::Result<OsString> { - let mut buffer = [0u16; NDIS_IF_MAX_STRING_SIZE + 1]; +pub fn alias_from_luid(luid: &NET_LUID_LH) -> io::Result<OsString> { + let mut buffer = [0u16; NDIS_IF_MAX_STRING_SIZE as usize + 1]; let status = unsafe { ConvertInterfaceLuidToAlias(luid, &mut buffer[0] as *mut _, buffer.len()) }; - if status != NO_ERROR { + if status != NO_ERROR as i32 { return Err(io::Error::from_raw_os_error(status as i32)); } let nul = buffer.iter().position(|&c| c == 0u16).unwrap(); @@ -472,12 +461,12 @@ pub fn in6addr_from_ipaddr(addr: Ipv6Addr) -> IN6_ADDR { /// Converts an `IN_ADDR` to `Ipv4Addr` pub fn ipaddr_from_inaddr(addr: IN_ADDR) -> Ipv4Addr { - Ipv4Addr::from(unsafe { *(addr.S_un.S_addr()) }.to_ne_bytes()) + Ipv4Addr::from(unsafe { addr.S_un.S_addr }.to_ne_bytes()) } /// Converts an `IN6_ADDR` to `Ipv6Addr` pub fn ipaddr_from_in6addr(addr: IN6_ADDR) -> Ipv6Addr { - Ipv6Addr::from(*unsafe { addr.u.Byte() }) + Ipv6Addr::from(unsafe { addr.u.Byte }) } /// Converts a `SocketAddr` to `SOCKADDR_INET` @@ -487,12 +476,12 @@ pub fn inet_sockaddr_from_socketaddr(addr: SocketAddr) -> SOCKADDR_INET { // SAFETY: `*const sockaddr` may be treated as `*const sockaddr_in` since we know it's a v4 // address. SocketAddr::V4(_) => unsafe { - *sockaddr.Ipv4_mut() = *(SockAddr::from(addr).as_ptr() as *const _) + sockaddr.Ipv4 = *(SockAddr::from(addr).as_ptr() as *const _) }, // SAFETY: `*const sockaddr` may be treated as `*const sockaddr_in6` since we know it's a v6 // address. SocketAddr::V6(_) => unsafe { - *sockaddr.Ipv6_mut() = *(SockAddr::from(addr).as_ptr() as *const _) + sockaddr.Ipv6 = *(SockAddr::from(addr).as_ptr() as *const _) }, } sockaddr @@ -500,10 +489,11 @@ pub fn inet_sockaddr_from_socketaddr(addr: SocketAddr) -> SOCKADDR_INET { /// Converts a `SOCKADDR_INET` to `SocketAddr`. Returns an error if the address family is invalid. pub fn try_socketaddr_from_inet_sockaddr(addr: SOCKADDR_INET) -> Result<SocketAddr> { - let family = unsafe { *addr.si_family() } as i32; + let family = unsafe { addr.si_family } as u32; unsafe { let mut storage: sockaddr_storage = mem::zeroed(); *(&mut storage as *mut _ as *mut SOCKADDR_INET) = addr; + // TODO: Switch to windows-sys struct once socket2 is updated SockAddr::new(storage, mem::size_of_val(&addr) as i32) } .as_socket() @@ -513,8 +503,7 @@ pub fn try_socketaddr_from_inet_sockaddr(addr: SOCKADDR_INET) -> Result<SocketAd /// Returns the system directory, i.e. `%windir%\system32`. pub fn get_system_dir() -> io::Result<PathBuf> { let mut folder_path: PWSTR = ptr::null_mut(); - let status = - unsafe { SHGetKnownFolderPath(&FOLDERID_System, 0, ptr::null_mut(), &mut folder_path) }; + let status = unsafe { SHGetKnownFolderPath(&FOLDERID_System, 0, 0, &mut folder_path) }; let result = if status == S_OK { let path = unsafe { WideCStr::from_ptr_str(folder_path) }; Ok(path.to_ustring().to_os_string().into()) diff --git a/talpid-core/src/windows/window.rs b/talpid-core/src/windows/window.rs index 006e5b8e16..badb50de1c 100644 --- a/talpid-core/src/windows/window.rs +++ b/talpid-core/src/windows/window.rs @@ -2,26 +2,19 @@ use std::{os::windows::io::AsRawHandle, ptr, sync::Arc, thread}; use tokio::sync::broadcast; -use winapi::{ - shared::{ - basetsd::LONG_PTR, - minwindef::{LPARAM, LRESULT, UINT, WPARAM}, - windef::HWND, - }, - um::{ - libloaderapi::GetModuleHandleW, - processthreadsapi::GetThreadId, - winuser::{ - CreateWindowExW, DefWindowProcW, DestroyWindow, DispatchMessageW, GetMessageW, - GetWindowLongPtrW, PostQuitMessage, PostThreadMessageW, SetWindowLongPtrW, - TranslateMessage, GWLP_USERDATA, GWLP_WNDPROC, PBT_APMRESUMEAUTOMATIC, - PBT_APMRESUMESUSPEND, PBT_APMSUSPEND, WM_DESTROY, WM_POWERBROADCAST, WM_USER, - }, +use windows_sys::Win32::{ + Foundation::{HANDLE, HWND, LPARAM, LRESULT, WPARAM}, + System::{LibraryLoader::GetModuleHandleW, Threading::GetThreadId}, + UI::WindowsAndMessaging::{ + CreateWindowExW, DefWindowProcW, DestroyWindow, DispatchMessageW, GetMessageW, + GetWindowLongPtrW, PostQuitMessage, PostThreadMessageW, SetWindowLongPtrW, + TranslateMessage, GWLP_USERDATA, GWLP_WNDPROC, PBT_APMRESUMEAUTOMATIC, + PBT_APMRESUMESUSPEND, PBT_APMSUSPEND, WM_DESTROY, WM_POWERBROADCAST, WM_USER, }, }; const CLASS_NAME: &[u8] = b"S\0T\0A\0T\0I\0C\0\0\0"; -const REQUEST_THREAD_SHUTDOWN: UINT = WM_USER + 1; +const REQUEST_THREAD_SHUTDOWN: u32 = WM_USER + 1; /// Handle for closing an associated window. /// The window is not destroyed when this is dropped. @@ -33,7 +26,7 @@ impl WindowCloseHandle { /// Close the window and wait for the thread. pub fn close(&mut self) { if let Some(thread) = self.thread.take() { - let thread_id = unsafe { GetThreadId(thread.as_raw_handle()) }; + let thread_id = unsafe { GetThreadId(thread.as_raw_handle() as HANDLE) }; unsafe { PostThreadMessageW(thread_id, REQUEST_THREAD_SHUTDOWN, 0, 0) }; let _ = thread.join(); } @@ -41,7 +34,7 @@ impl WindowCloseHandle { } /// Creates a dummy window whose messages are handled by `wnd_proc`. -pub fn create_hidden_window<F: (Fn(HWND, UINT, WPARAM, LPARAM) -> LRESULT) + Send + 'static>( +pub fn create_hidden_window<F: (Fn(HWND, u32, WPARAM, LPARAM) -> LRESULT) + Send + 'static>( wnd_proc: F, ) -> WindowCloseHandle { let join_handle = thread::spawn(move || { @@ -55,8 +48,8 @@ pub fn create_hidden_window<F: (Fn(HWND, UINT, WPARAM, LPARAM) -> LRESULT) + Sen 0, 0, 0, - ptr::null_mut(), - ptr::null_mut(), + 0, + 0, GetModuleHandleW(ptr::null_mut()), ptr::null_mut(), ) @@ -67,18 +60,14 @@ pub fn create_hidden_window<F: (Fn(HWND, UINT, WPARAM, LPARAM) -> LRESULT) + Sen let raw_callback = Box::into_raw(Box::new(wnd_proc)); unsafe { - SetWindowLongPtrW(dummy_window, GWLP_USERDATA, raw_callback as LONG_PTR); - SetWindowLongPtrW( - dummy_window, - GWLP_WNDPROC, - window_procedure::<F> as LONG_PTR, - ); + SetWindowLongPtrW(dummy_window, GWLP_USERDATA, raw_callback as isize); + SetWindowLongPtrW(dummy_window, GWLP_WNDPROC, window_procedure::<F> as isize); } let mut msg = unsafe { std::mem::zeroed() }; loop { - let status = unsafe { GetMessageW(&mut msg, ptr::null_mut(), 0, 0) }; + let status = unsafe { GetMessageW(&mut msg, 0, 0, 0) }; if status < 0 { continue; @@ -87,7 +76,7 @@ pub fn create_hidden_window<F: (Fn(HWND, UINT, WPARAM, LPARAM) -> LRESULT) + Sen break; } - if msg.hwnd.is_null() { + if msg.hwnd == 0 { if msg.message == REQUEST_THREAD_SHUTDOWN { unsafe { DestroyWindow(dummy_window) }; } @@ -110,12 +99,12 @@ pub fn create_hidden_window<F: (Fn(HWND, UINT, WPARAM, LPARAM) -> LRESULT) + Sen unsafe extern "system" fn window_procedure<F>( window: HWND, - message: UINT, + message: u32, wparam: WPARAM, lparam: LPARAM, ) -> LRESULT where - F: Fn(HWND, UINT, WPARAM, LPARAM) -> LRESULT, + F: Fn(HWND, u32, WPARAM, LPARAM) -> LRESULT, { if message == WM_DESTROY { PostQuitMessage(0); @@ -146,7 +135,7 @@ pub enum PowerManagementEvent { impl PowerManagementEvent { fn try_from_winevent(wparam: usize) -> Option<Self> { use PowerManagementEvent::*; - match wparam { + match wparam as u32 { PBT_APMRESUMEAUTOMATIC => Some(ResumeAutomatic), PBT_APMRESUMESUSPEND => Some(ResumeSuspend), PBT_APMSUSPEND => Some(Suspend), |
