diff options
| -rw-r--r-- | talpid-wireguard/src/wireguard_kernel/parsers.rs | 43 |
1 files changed, 27 insertions, 16 deletions
diff --git a/talpid-wireguard/src/wireguard_kernel/parsers.rs b/talpid-wireguard/src/wireguard_kernel/parsers.rs index 86ecd025b5..a8c16d7d68 100644 --- a/talpid-wireguard/src/wireguard_kernel/parsers.rs +++ b/talpid-wireguard/src/wireguard_kernel/parsers.rs @@ -2,7 +2,7 @@ use byteorder::{ByteOrder, NativeEndian}; use nix::sys::{socket::InetAddr, time::TimeSpec}; use std::{ ffi::{CStr, CString}, - mem, + mem::{self, transmute}, net::IpAddr, }; @@ -36,30 +36,41 @@ pub fn parse_wg_key(buffer: &[u8]) -> Result<[u8; 32], DecodeError> { } pub fn parse_inet_sockaddr(buffer: &[u8]) -> Result<InetAddr, DecodeError> { - if buffer.len() != mem::size_of::<libc::sockaddr_in6>() - && buffer.len() != mem::size_of::<libc::sockaddr_in>() - { - return Err(format!( + let wrong_len = || { + format!( "Unexpected length for sockaddr_in: {}, expected {} or {}", buffer.len(), mem::size_of::<libc::sockaddr_in6>(), mem::size_of::<libc::sockaddr_in>() ) - .into()); - } - let ptr = buffer.as_ptr(); + }; + const AF_INET: u16 = libc::AF_INET as u16; const AF_INET6: u16 = libc::AF_INET6 as u16; + if buffer.len() < size_of::<u16>() { + return Err(wrong_len().into()); + } + match NativeEndian::read_u16(buffer) { - AF_INET => unsafe { - let sockaddr: *const libc::sockaddr_in = ptr as *const _; - Ok(InetAddr::V4(*sockaddr)) - }, - AF_INET6 => unsafe { - let sockaddr: *const libc::sockaddr_in6 = ptr as *const _; - Ok(InetAddr::V6(*sockaddr)) - }, + AF_INET => { + let buffer: &[u8; size_of::<libc::sockaddr_in>()] = + buffer.try_into().map_err(|_| wrong_len())?; + + // SAFETY: sockaddr_in has a defined repr(C) layout and is valid for all bit patterns + let sockaddr: libc::sockaddr_in = unsafe { transmute(*buffer) }; + + Ok(InetAddr::V4(sockaddr)) + } + AF_INET6 => { + let buffer: &[u8; size_of::<libc::sockaddr_in6>()] = + buffer.try_into().map_err(|_| wrong_len())?; + + // SAFETY: sockaddr_in6 has a defined repr(C) layout and is valid for all bit patterns + let sockaddr: libc::sockaddr_in6 = unsafe { transmute(*buffer) }; + + Ok(InetAddr::V6(sockaddr)) + } unexpected_addr_family => { Err(format!("Unexpected address family: {unexpected_addr_family}").into()) } |
