summaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorJoakim Hulthe <joakim.hulthe@mullvad.net>2025-02-11 15:08:23 +0100
committerJoakim Hulthe <joakim.hulthe@mullvad.net>2025-02-25 13:37:32 +0100
commit4ec109b4d34ec325f3d53b3f5d18e5b7498242f5 (patch)
tree3cb0391699bee43b1731547b1b8642588bb089a8
parent3ac5e020750d2173c899fe8a16b0d5b1db843c9f (diff)
downloadmullvadvpn-4ec109b4d34ec325f3d53b3f5d18e5b7498242f5.tar.xz
mullvadvpn-4ec109b4d34ec325f3d53b3f5d18e5b7498242f5.zip
Add safety comments to talpid_platform_metadata::windows
-rw-r--r--talpid-platform-metadata/src/windows.rs49
1 files changed, 34 insertions, 15 deletions
diff --git a/talpid-platform-metadata/src/windows.rs b/talpid-platform-metadata/src/windows.rs
index 1df2cb0f12..99e2f18b3c 100644
--- a/talpid-platform-metadata/src/windows.rs
+++ b/talpid-platform-metadata/src/windows.rs
@@ -4,10 +4,13 @@ use std::{
mem::{self, MaybeUninit},
os::windows::ffi::OsStrExt,
};
-use windows_sys::Win32::System::{
- LibraryLoader::{GetModuleHandleW, GetProcAddress},
- SystemInformation::OSVERSIONINFOEXW,
- SystemServices::VER_NT_WORKSTATION,
+use windows_sys::Win32::{
+ Foundation::{NTSTATUS, STATUS_SUCCESS},
+ System::{
+ LibraryLoader::{GetModuleHandleW, GetProcAddress},
+ SystemInformation::OSVERSIONINFOEXW,
+ SystemServices::VER_NT_WORKSTATION,
+ },
};
#[allow(non_camel_case_types)]
@@ -49,27 +52,43 @@ impl WindowsVersion {
.chain(iter::once(0u16))
.collect();
+ // SAFETY: module_name is a valid UTF-16/WTF-16 null-terminated string.
let ntdll = unsafe { GetModuleHandleW(module_name.as_ptr()) };
if ntdll == 0 {
return Err(io::Error::last_os_error());
}
+ // SAFETY: ntdll is a valid pointer, RtlGetVersion is a valid null-terminated ANSI string.
let function_address = unsafe { GetProcAddress(ntdll, b"RtlGetVersion\0" as *const u8) }
.ok_or_else(io::Error::last_os_error)?;
- let rtl_get_version: extern "stdcall" fn(*mut RTL_OSVERSIONINFOEXW) =
- unsafe { *(&function_address as *const _ as *const _) };
+ // SAFETY: We're correcting this function pointer to the ACTUAL type of RtlGetVersion.
+ // https://learn.microsoft.com/en-us/windows/win32/devnotes/rtlgetversion
+ let rtl_get_version = unsafe {
+ mem::transmute::<
+ unsafe extern "system" fn() -> isize,
+ unsafe extern "stdcall" fn(*mut RTL_OSVERSIONINFOEXW) -> NTSTATUS,
+ >(function_address)
+ };
- let mut version_info: MaybeUninit<RTL_OSVERSIONINFOEXW> = mem::MaybeUninit::zeroed();
- unsafe {
- (*version_info.as_mut_ptr()).dwOSVersionInfoSize =
- mem::size_of_val(&version_info) as u32;
- rtl_get_version(version_info.as_mut_ptr());
+ let mut version_info: RTL_OSVERSIONINFOEXW =
+ // SAFETY: RTL_OSVERSIONINFOEXW is a C struct and can safely be zeroed.
+ unsafe { MaybeUninit::zeroed().assume_init() };
- Ok(WindowsVersion {
- inner: version_info.assume_init(),
- })
- }
+ version_info.dwOSVersionInfoSize = mem::size_of_val(&version_info) as u32;
+
+ // SAFETY:
+ // - &mut version_info is a valid pointer.
+ // - rtl_get_version was provided by GetProcAddress and should be valid.
+ let status = unsafe { rtl_get_version(&mut version_info) };
+ debug_assert_eq!(
+ status, STATUS_SUCCESS,
+ "RtlGetVersion always returns success"
+ );
+
+ Ok(WindowsVersion {
+ inner: version_info,
+ })
}
pub fn windows_version_string(&self) -> String {