diff options
| author | Joakim Hulthe <joakim.hulthe@mullvad.net> | 2025-02-11 15:08:23 +0100 |
|---|---|---|
| committer | Joakim Hulthe <joakim.hulthe@mullvad.net> | 2025-02-25 13:37:32 +0100 |
| commit | 4ec109b4d34ec325f3d53b3f5d18e5b7498242f5 (patch) | |
| tree | 3cb0391699bee43b1731547b1b8642588bb089a8 | |
| parent | 3ac5e020750d2173c899fe8a16b0d5b1db843c9f (diff) | |
| download | mullvadvpn-4ec109b4d34ec325f3d53b3f5d18e5b7498242f5.tar.xz mullvadvpn-4ec109b4d34ec325f3d53b3f5d18e5b7498242f5.zip | |
Add safety comments to talpid_platform_metadata::windows
| -rw-r--r-- | talpid-platform-metadata/src/windows.rs | 49 |
1 files changed, 34 insertions, 15 deletions
diff --git a/talpid-platform-metadata/src/windows.rs b/talpid-platform-metadata/src/windows.rs index 1df2cb0f12..99e2f18b3c 100644 --- a/talpid-platform-metadata/src/windows.rs +++ b/talpid-platform-metadata/src/windows.rs @@ -4,10 +4,13 @@ use std::{ mem::{self, MaybeUninit}, os::windows::ffi::OsStrExt, }; -use windows_sys::Win32::System::{ - LibraryLoader::{GetModuleHandleW, GetProcAddress}, - SystemInformation::OSVERSIONINFOEXW, - SystemServices::VER_NT_WORKSTATION, +use windows_sys::Win32::{ + Foundation::{NTSTATUS, STATUS_SUCCESS}, + System::{ + LibraryLoader::{GetModuleHandleW, GetProcAddress}, + SystemInformation::OSVERSIONINFOEXW, + SystemServices::VER_NT_WORKSTATION, + }, }; #[allow(non_camel_case_types)] @@ -49,27 +52,43 @@ impl WindowsVersion { .chain(iter::once(0u16)) .collect(); + // SAFETY: module_name is a valid UTF-16/WTF-16 null-terminated string. let ntdll = unsafe { GetModuleHandleW(module_name.as_ptr()) }; if ntdll == 0 { return Err(io::Error::last_os_error()); } + // SAFETY: ntdll is a valid pointer, RtlGetVersion is a valid null-terminated ANSI string. let function_address = unsafe { GetProcAddress(ntdll, b"RtlGetVersion\0" as *const u8) } .ok_or_else(io::Error::last_os_error)?; - let rtl_get_version: extern "stdcall" fn(*mut RTL_OSVERSIONINFOEXW) = - unsafe { *(&function_address as *const _ as *const _) }; + // SAFETY: We're correcting this function pointer to the ACTUAL type of RtlGetVersion. + // https://learn.microsoft.com/en-us/windows/win32/devnotes/rtlgetversion + let rtl_get_version = unsafe { + mem::transmute::< + unsafe extern "system" fn() -> isize, + unsafe extern "stdcall" fn(*mut RTL_OSVERSIONINFOEXW) -> NTSTATUS, + >(function_address) + }; - let mut version_info: MaybeUninit<RTL_OSVERSIONINFOEXW> = mem::MaybeUninit::zeroed(); - unsafe { - (*version_info.as_mut_ptr()).dwOSVersionInfoSize = - mem::size_of_val(&version_info) as u32; - rtl_get_version(version_info.as_mut_ptr()); + let mut version_info: RTL_OSVERSIONINFOEXW = + // SAFETY: RTL_OSVERSIONINFOEXW is a C struct and can safely be zeroed. + unsafe { MaybeUninit::zeroed().assume_init() }; - Ok(WindowsVersion { - inner: version_info.assume_init(), - }) - } + version_info.dwOSVersionInfoSize = mem::size_of_val(&version_info) as u32; + + // SAFETY: + // - &mut version_info is a valid pointer. + // - rtl_get_version was provided by GetProcAddress and should be valid. + let status = unsafe { rtl_get_version(&mut version_info) }; + debug_assert_eq!( + status, STATUS_SUCCESS, + "RtlGetVersion always returns success" + ); + + Ok(WindowsVersion { + inner: version_info, + }) } pub fn windows_version_string(&self) -> String { |
