summaryrefslogtreecommitdiffhomepage
path: root/talpid-wireguard/src/wireguard_kernel/parsers.rs
blob: a3a9eeffbb5886b8ce9e0ba1d0365de2a24c6509 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
use byteorder::{ByteOrder, NativeEndian};
use nix::sys::{
    socket::{SockaddrIn, SockaddrIn6},
    time::TimeSpec,
};
use std::{
    ffi::{CStr, CString},
    mem::{self, transmute},
    net::{IpAddr, SocketAddr},
};

use netlink_packet_core::DecodeError;

pub fn parse_ip_addr(bytes: &[u8]) -> Result<IpAddr, DecodeError> {
    if bytes.len() == 4 {
        let mut ipv4_bytes = [0u8; 4];
        ipv4_bytes.copy_from_slice(bytes);
        Ok(IpAddr::from(ipv4_bytes))
    } else if bytes.len() == 16 {
        let mut ipv6_bytes = [0u8; 16];
        ipv6_bytes.copy_from_slice(bytes);
        Ok(IpAddr::from(ipv6_bytes))
    } else {
        log::error!("Expected either 4 or 16 bytes, got {} bytes", bytes.len());
        Err(format!("Invalid bytes for IP address: {bytes:?}").into())
    }
}

pub fn parse_wg_key(buffer: &[u8]) -> Result<[u8; 32], DecodeError> {
    match buffer.len() {
        32 => {
            let mut key = [0u8; 32];
            key.clone_from_slice(buffer);
            Ok(key)
        }
        anything_else => Err(format!("Unexpected length of key: {anything_else}").into()),
    }
}

pub fn parse_inet_sockaddr(buffer: &[u8]) -> Result<SocketAddr, DecodeError> {
    let wrong_len = || {
        format!(
            "Unexpected length for sockaddr_in: {}, expected {} or {}",
            buffer.len(),
            mem::size_of::<libc::sockaddr_in6>(),
            mem::size_of::<libc::sockaddr_in>()
        )
    };

    const AF_INET: u16 = libc::AF_INET as u16;
    const AF_INET6: u16 = libc::AF_INET6 as u16;

    if buffer.len() < size_of::<u16>() {
        return Err(wrong_len().into());
    }

    match NativeEndian::read_u16(buffer) {
        AF_INET => {
            let buffer: &[u8; size_of::<libc::sockaddr_in>()] =
                buffer.try_into().map_err(|_| wrong_len())?;

            // SAFETY: sockaddr_in has a defined repr(C) layout and is valid for all bit patterns
            let sockaddr: libc::sockaddr_in = unsafe { transmute(*buffer) };
            let sockaddr = SockaddrIn::from(sockaddr);

            Ok(SocketAddr::from(sockaddr))
        }
        AF_INET6 => {
            let buffer: &[u8; size_of::<libc::sockaddr_in6>()] =
                buffer.try_into().map_err(|_| wrong_len())?;

            // SAFETY: sockaddr_in6 has a defined repr(C) layout and is valid for all bit patterns
            let sockaddr: libc::sockaddr_in6 = unsafe { transmute(*buffer) };
            let sockaddr = SockaddrIn6::from(sockaddr);

            Ok(SocketAddr::from(sockaddr))
        }
        unexpected_addr_family => {
            Err(format!("Unexpected address family: {unexpected_addr_family}").into())
        }
    }
}

pub fn parse_timespec(buffer: &[u8]) -> Result<TimeSpec, DecodeError> {
    if buffer.len() != mem::size_of::<libc::timespec>() {
        return Err(format!("Unexpected size for timespec: {}", buffer.len()).into());
    }

    Ok(TimeSpec::from(libc::timespec {
        tv_sec: NativeEndian::read_i64(buffer),
        // TODO: become compatible with 32-bit systems maybe?
        tv_nsec: NativeEndian::read_i64(buffer),
    }))
}

pub fn parse_cstring(buffer: &[u8]) -> Result<CString, DecodeError> {
    Ok(CStr::from_bytes_with_nul(buffer)
        .map_err(|err| format!("{err}"))?
        .into())
}

pub fn parse_genlmsghdr(buffer: &[u8]) -> Result<(u8, u8), DecodeError> {
    const GENLMSGHDR_SIZE: usize = mem::size_of::<libc::genlmsghdr>();
    if buffer.len() < GENLMSGHDR_SIZE {
        return Err(format!(
            "Expected at least {}, got {}",
            GENLMSGHDR_SIZE,
            buffer.len()
        )
        .into());
    }

    Ok((buffer[0], buffer[1]))
}