diff options
| author | Emīls Piņķis <emils@mullvad.net> | 2021-10-17 19:31:47 +0100 |
|---|---|---|
| committer | Emīls Piņķis <emils@mullvad.net> | 2021-10-17 19:31:47 +0100 |
| commit | ba817cf569cd3505bdae19d6683cfc158c628e96 (patch) | |
| tree | bcd153adbe156e6e3e841d468b34d0c2ff99025c | |
| parent | f1dc793226358251ff64a3e49fab050dab00dfa2 (diff) | |
| parent | 1f7b488c21047685032c87ec4ad696734f7dd997 (diff) | |
| download | mullvadvpn-ba817cf569cd3505bdae19d6683cfc158c628e96.tar.xz mullvadvpn-ba817cf569cd3505bdae19d6683cfc158c628e96.zip | |
Merge branch 'unify-struct-writing'
| -rw-r--r-- | talpid-core/src/split_tunnel/windows/driver.rs | 45 | ||||
| -rw-r--r-- | talpid-core/src/tunnel/wireguard/wireguard_nt.rs | 39 | ||||
| -rw-r--r-- | talpid-core/src/windows.rs | 6 |
3 files changed, 50 insertions, 40 deletions
diff --git a/talpid-core/src/split_tunnel/windows/driver.rs b/talpid-core/src/split_tunnel/windows/driver.rs index 5a9f27b7e3..7422972dba 100644 --- a/talpid-core/src/split_tunnel/windows/driver.rs +++ b/talpid-core/src/split_tunnel/windows/driver.rs @@ -2,6 +2,7 @@ use super::windows::{ get_device_path, get_process_creation_time, get_process_device_path, open_process, ProcessAccess, ProcessSnapshot, }; +use crate::windows::as_uninit_byte_slice; use memoffset::offset_of; use std::{ cell::RefCell, @@ -9,7 +10,7 @@ use std::{ ffi::{OsStr, OsString}, fs::{self, OpenOptions}, io, - mem::{self, size_of}, + mem::{self, size_of, MaybeUninit}, net::{Ipv4Addr, Ipv6Addr}, os::windows::{ ffi::{OsStrExt, OsStringExt}, @@ -233,9 +234,7 @@ impl DeviceHandle { } } - let buffer = &addresses as *const _ as *const u8; - let buffer = - unsafe { std::slice::from_raw_parts(buffer, size_of::<SplitTunnelAddresses>()) }; + let buffer = as_uninit_byte_slice(&addresses); device_io_control( self.handle.as_raw_handle(), @@ -315,6 +314,7 @@ impl AsRawHandle for DeviceHandle { } } +#[derive(Clone, Copy)] #[repr(C)] struct SplitTunnelAddresses { tunnel_ipv4: IN_ADDR, @@ -324,6 +324,7 @@ struct SplitTunnelAddresses { } #[repr(C)] +#[derive(Clone, Copy)] struct ConfigurationHeader { // Number of entries immediately following the header. num_entries: usize, @@ -332,6 +333,7 @@ struct ConfigurationHeader { } #[repr(C)] +#[derive(Clone, Copy)] struct ConfigurationEntry { // Offset into buffer region that follows all entries. // The image name uses the physical path. @@ -342,7 +344,7 @@ struct ConfigurationEntry { /// Create a buffer containing a `ConfigurationHeader` and number of `ConfigurationEntry`s, /// followed by the same number of paths to those entries. -fn make_process_config<T: AsRef<OsStr>>(apps: &[T]) -> Vec<u8> { +fn make_process_config<T: AsRef<OsStr>>(apps: &[T]) -> Vec<MaybeUninit<u8>> { let apps: Vec<Vec<u16>> = apps .iter() .map(|app| app.as_ref().encode_wide().collect()) @@ -354,8 +356,8 @@ fn make_process_config<T: AsRef<OsStr>>(apps: &[T]) -> Vec<u8> { + size_of::<ConfigurationEntry>() * apps.len() + total_string_size; - let mut buffer = Vec::<u8>::new(); - buffer.resize(total_buffer_size, 0); + let mut buffer = Vec::<MaybeUninit<u8>>::new(); + buffer.resize(total_buffer_size, MaybeUninit::new(0)); let (header, tail) = buffer.split_at_mut(size_of::<ConfigurationHeader>()); @@ -364,7 +366,7 @@ fn make_process_config<T: AsRef<OsStr>>(apps: &[T]) -> Vec<u8> { num_entries: apps.len(), total_length: total_buffer_size, }; - header.copy_from_slice(unsafe { as_u8_slice(&header_struct) }); + header.copy_from_slice(as_uninit_byte_slice(&header_struct)); // Serialize configuration entries and strings let (entries, string_data) = tail.split_at_mut(apps.len() * size_of::<ConfigurationEntry>()); @@ -379,8 +381,9 @@ fn make_process_config<T: AsRef<OsStr>>(apps: &[T]) -> Vec<u8> { name_length: app_bytelen as u16, }; let entry_offset = size_of::<ConfigurationEntry>() * i; + entries[entry_offset..entry_offset + size_of::<ConfigurationEntry>()] - .copy_from_slice(unsafe { as_u8_slice(&entry) }); + .copy_from_slice(as_uninit_byte_slice(&entry)); string_offset += app_bytelen; } @@ -462,6 +465,7 @@ fn build_process_tree() -> io::Result<Vec<ProcessInfo>> { .collect()) } +#[derive(Clone, Copy)] #[repr(C)] struct ProcessRegistryHeader { // Number of entries immediately following the header. @@ -470,6 +474,7 @@ struct ProcessRegistryHeader { total_length: usize, } +#[derive(Clone, Copy)] #[repr(C)] struct ProcessRegistryEntry { pid: RawHandle, @@ -480,7 +485,7 @@ struct ProcessRegistryEntry { image_name_size: u16, } -fn serialize_process_tree(processes: Vec<ProcessInfo>) -> Result<Vec<u8>, io::Error> { +fn serialize_process_tree(processes: Vec<ProcessInfo>) -> Result<Vec<MaybeUninit<u8>>, io::Error> { // Construct a buffer: // ProcessRegistryHeader // ProcessRegistryEntry.. @@ -494,15 +499,15 @@ fn serialize_process_tree(processes: Vec<ProcessInfo>) -> Result<Vec<u8>, io::Er + size_of::<ProcessRegistryEntry>() * processes.len() + total_string_size; - let mut buffer = Vec::<u8>::new(); - buffer.resize(total_buffer_size, 0); + let mut buffer = Vec::new(); + buffer.resize(total_buffer_size, MaybeUninit::new(0u8)); let (header, tail) = buffer.split_at_mut(size_of::<ProcessRegistryHeader>()); let header_struct = ProcessRegistryHeader { num_entries: processes.len(), total_length: total_buffer_size, }; - header.copy_from_slice(unsafe { as_u8_slice(&header_struct) }); + header.copy_from_slice(as_uninit_byte_slice(&header_struct)); let (entries, string_data) = tail.split_at_mut(size_of::<ProcessRegistryEntry>() * processes.len()); @@ -528,7 +533,7 @@ fn serialize_process_tree(processes: Vec<ProcessInfo>) -> Result<Vec<u8>, io::Er let entry_offset = size_of::<ProcessRegistryEntry>() * i; entries[entry_offset..entry_offset + size_of::<ProcessRegistryEntry>()] - .copy_from_slice(unsafe { as_u8_slice(&out_entry) }); + .copy_from_slice(as_uninit_byte_slice(&out_entry)); } Ok(buffer) @@ -702,7 +707,7 @@ pub fn parse_event_buffer(buffer: &Vec<u8>) -> Option<(EventId, EventBody)> { pub fn device_io_control( device: RawHandle, ioctl_code: u32, - input: Option<&[u8]>, + input: Option<&[MaybeUninit<u8>]>, output_size: u32, timeout: Option<Duration>, ) -> Result<Option<Vec<u8>>, io::Error> { @@ -749,7 +754,7 @@ pub fn device_io_control( pub fn device_io_control_buffer( device: RawHandle, ioctl_code: u32, - input: Option<&[u8]>, + input: Option<&[MaybeUninit<u8>]>, mut output: Option<&mut Vec<u8>>, overlapped: &OVERLAPPED, timeout: Option<Duration>, @@ -886,16 +891,12 @@ pub unsafe fn deserialize_buffer<T: Sized>(buffer: &Vec<u8>) -> T { instance } -fn write_string_to_buffer(buffer: &mut [u8], byte_offset: usize, string: &[u16]) { +fn write_string_to_buffer(buffer: &mut [MaybeUninit<u8>], byte_offset: usize, string: &[u16]) { for (i, byte) in string .iter() .flat_map(|word| std::array::IntoIter::new(word.to_ne_bytes())) .enumerate() { - buffer[byte_offset + i] = byte; + buffer[byte_offset + i] = MaybeUninit::new(byte); } } - -unsafe fn as_u8_slice<T: Sized>(object: &T) -> &[u8] { - std::slice::from_raw_parts(object as *const _ as *const _, size_of::<T>()) -} diff --git a/talpid-core/src/tunnel/wireguard/wireguard_nt.rs b/talpid-core/src/tunnel/wireguard/wireguard_nt.rs index 2bd14f644d..b094416ae7 100644 --- a/talpid-core/src/tunnel/wireguard/wireguard_nt.rs +++ b/talpid-core/src/tunnel/wireguard/wireguard_nt.rs @@ -11,6 +11,7 @@ use lazy_static::lazy_static; use std::{ ffi::CStr, fmt, io, iter, mem, + mem::MaybeUninit, os::windows::{ffi::OsStrExt, io::RawHandle}, path::Path, ptr, @@ -68,10 +69,16 @@ type WireGuardGetAdapterLuidFn = unsafe extern "stdcall" fn(adapter: RawHandle, luid: *mut NET_LUID); type WireGuardGetAdapterNameFn = unsafe extern "stdcall" fn(adapter: RawHandle, name: *mut u16) -> BOOL; -type WireGuardSetConfigurationFn = - unsafe extern "stdcall" fn(adapter: RawHandle, config: *const u8, bytes: u32) -> BOOL; -type WireGuardGetConfigurationFn = - unsafe extern "stdcall" fn(adapter: RawHandle, config: *const u8, bytes: *mut u32) -> BOOL; +type WireGuardSetConfigurationFn = unsafe extern "stdcall" fn( + adapter: RawHandle, + config: *const MaybeUninit<u8>, + bytes: u32, +) -> BOOL; +type WireGuardGetConfigurationFn = unsafe extern "stdcall" fn( + adapter: RawHandle, + config: *const MaybeUninit<u8>, + bytes: *mut u32, +) -> BOOL; type WireGuardSetStateFn = unsafe extern "stdcall" fn(adapter: RawHandle, state: WgAdapterState) -> BOOL; @@ -795,7 +802,7 @@ impl WgNtDll { pub unsafe fn set_config( &self, adapter: RawHandle, - config: *const u8, + config: *const MaybeUninit<u8>, config_size: usize, ) -> io::Result<()> { let result = (self.func_set_configuration)(adapter, config, config_size as u32); @@ -805,7 +812,7 @@ impl WgNtDll { Ok(()) } - pub unsafe fn get_config(&self, adapter: RawHandle) -> io::Result<Vec<u8>> { + pub unsafe fn get_config(&self, adapter: RawHandle) -> io::Result<Vec<MaybeUninit<u8>>> { let mut config_size = 0; let mut config = vec![]; loop { @@ -816,7 +823,7 @@ impl WgNtDll { if last_error.raw_os_error() != Some(ERROR_MORE_DATA as i32) { break Err(last_error); } - config.resize(config_size as usize, 0); + config.resize(config_size as usize, MaybeUninit::new(0u8)); } else { break Ok(config); } @@ -869,7 +876,7 @@ fn load_wg_nt_dll(resource_dir: &Path) -> Result<Arc<WgNtDll>> { } } -fn serialize_config(config: &Config) -> Result<Vec<u8>> { +fn serialize_config(config: &Config) -> Result<Vec<MaybeUninit<u8>>> { let mut buffer = vec![]; let header = WgInterface { @@ -880,7 +887,7 @@ fn serialize_config(config: &Config) -> Result<Vec<u8>> { peers_count: config.peers.len() as u32, }; - buffer.extend_from_slice(unsafe { as_u8_slice(&header) }); + buffer.extend(windows::as_uninit_byte_slice(&header)); for peer in &config.peers { let wg_peer = WgPeer { @@ -896,7 +903,7 @@ fn serialize_config(config: &Config) -> Result<Vec<u8>> { allowed_ips_count: peer.allowed_ips.len() as u32, }; - buffer.extend_from_slice(unsafe { as_u8_slice(&wg_peer) }); + buffer.extend(windows::as_uninit_byte_slice(&wg_peer)); for allowed_ip in &peer.allowed_ips { let address_family = match allowed_ip { @@ -915,7 +922,7 @@ fn serialize_config(config: &Config) -> Result<Vec<u8>> { let wg_allowed_ip = WgAllowedIp::new(address, address_family, allowed_ip.prefix() as u8)?; - buffer.extend_from_slice(unsafe { as_u8_slice(&wg_allowed_ip) }); + buffer.extend(windows::as_uninit_byte_slice(&wg_allowed_ip)); } } @@ -923,7 +930,7 @@ fn serialize_config(config: &Config) -> Result<Vec<u8>> { } unsafe fn deserialize_config( - config: &[u8], + config: &[MaybeUninit<u8>], ) -> Result<(WgInterface, Vec<(WgPeer, Vec<WgAllowedIp>)>)> { if config.len() < mem::size_of::<WgInterface>() { return Err(Error::InvalidConfigData); @@ -1043,10 +1050,6 @@ impl Tunnel for WgNtTunnel { } } -unsafe fn as_u8_slice<T: Sized>(object: &T) -> &[u8] { - std::slice::from_raw_parts(object as *const _ as *const _, mem::size_of::<T>()) -} - #[cfg(test)] mod tests { use super::*; @@ -1133,8 +1136,8 @@ mod tests { #[test] fn test_config_deserialization() { - let (iface, peers) = - unsafe { deserialize_config(as_u8_slice(&*WG_STRUCT_CONFIG)) }.unwrap(); + let config_buffer = windows::as_uninit_byte_slice(&*WG_STRUCT_CONFIG); + let (iface, peers) = unsafe { deserialize_config(config_buffer) }.unwrap(); assert_eq!(iface, WG_STRUCT_CONFIG.interface); assert_eq!(peers.len(), 1); let (peer, allowed_ips) = &peers[0]; diff --git a/talpid-core/src/windows.rs b/talpid-core/src/windows.rs index 03cd8a9c74..6236f32c4a 100644 --- a/talpid-core/src/windows.rs +++ b/talpid-core/src/windows.rs @@ -451,6 +451,12 @@ pub fn try_socketaddr_from_inet_sockaddr(addr: SOCKADDR_INET) -> Result<SocketAd } } +/// Casts a struct to a slice of possibly uninitialized bytes. +#[cfg(target_os = "windows")] +pub fn as_uninit_byte_slice<T: Copy + Sized>(value: &T) -> &[mem::MaybeUninit<u8>] { + unsafe { std::slice::from_raw_parts(value as *const _ as *const _, mem::size_of::<T>()) } +} + #[cfg(test)] mod tests { use super::*; |
