diff options
| author | Emīls <emils@mullvad.net> | 2020-08-26 11:38:27 +0100 |
|---|---|---|
| committer | Emīls <emils@mullvad.net> | 2020-09-02 10:58:47 +0100 |
| commit | 68ef511986648dfac7893dafb00aca89a077e743 (patch) | |
| tree | 0e61905e1d8657e088ebdc15679d5944008e6e59 | |
| parent | f39d70ea03376f4a67955303a6fbc7fd559d98b7 (diff) | |
| download | mullvadvpn-68ef511986648dfac7893dafb00aca89a077e743.tar.xz mullvadvpn-68ef511986648dfac7893dafb00aca89a077e743.zip | |
Add WireGuard kernel implementation
6 files changed, 1650 insertions, 6 deletions
diff --git a/talpid-core/src/dns/linux/network_manager.rs b/talpid-core/src/dns/linux/network_manager.rs index 5a17c3c59e..a8c4efccd3 100644 --- a/talpid-core/src/dns/linux/network_manager.rs +++ b/talpid-core/src/dns/linux/network_manager.rs @@ -194,6 +194,12 @@ impl NetworkManager { .get(NM_DEVICE, "Ip6Config") .map_err(Error::Dbus)?; + let device_addresses6: Vec<(Vec<u8>, u32, Vec<u8>)> = self + .dbus_connection + .with_path(NM_BUS, &device_ip6_config, RPC_TIMEOUT_MS) + .get(NM_IP6_CONFIG, "Addresses") + .map_err(Error::Dbus)?; + let device_routes6: Vec<(Vec<u8>, u32, Vec<u8>, u32)> = self .dbus_connection .with_path(NM_BUS, &device_ip6_config, RPC_TIMEOUT_MS) @@ -209,6 +215,7 @@ impl NetworkManager { ipv6_settings.insert("route-metric", Variant(Box::new(0u32))); ipv6_settings.insert("routes", Variant(Box::new(device_routes6))); ipv6_settings.insert("route-data", Variant(Box::new(device_route6_data))); + ipv6_settings.insert("addresses", Variant(Box::new(device_addresses6))); } let mut settings_backup = diff --git a/talpid-core/src/tunnel/wireguard/mod.rs b/talpid-core/src/tunnel/wireguard/mod.rs index ef9b2cabf8..987274acb5 100644 --- a/talpid-core/src/tunnel/wireguard/mod.rs +++ b/talpid-core/src/tunnel/wireguard/mod.rs @@ -16,6 +16,8 @@ mod connectivity_check; mod logging; mod stats; mod wireguard_go; +#[cfg(target_os = "linux")] +mod wireguard_kernel; use self::wireguard_go::WgGoTunnel; @@ -62,12 +64,7 @@ impl WireguardMonitor { tun_provider: &mut TunProvider, route_manager: &mut routing::RouteManager, ) -> Result<WireguardMonitor> { - let tunnel = Box::new(WgGoTunnel::start_tunnel( - &config, - log_path, - tun_provider, - Self::get_tunnel_routes(config), - )?); + let tunnel = Self::open_tunnel(&config, log_path, tun_provider, route_manager)?; let iface_name = tunnel.get_interface_name().to_string(); route_manager .add_routes(Self::get_routes(&iface_name, &config)) @@ -125,6 +122,34 @@ impl WireguardMonitor { Ok(monitor) } + #[cfg_attr(not(target_os = "linux"), allow(unused_variables))] + fn open_tunnel( + config: &Config, + log_path: Option<&Path>, + tun_provider: &mut TunProvider, + route_manager: &mut routing::RouteManager, + ) -> Result<Box<dyn Tunnel>> { + #[cfg(target_os = "linux")] + match wireguard_kernel::KernelTunnel::new(route_manager.runtime_handle(), config) { + Ok(tunnel) => { + return Ok(Box::new(tunnel)); + } + Err(err) => { + log::error!( + "Failed to setup kernel WireGuard device, falling back to userspace: {}", + err + ); + } + }; + + Ok(Box::new(WgGoTunnel::start_tunnel( + &config, + log_path, + tun_provider, + Self::get_tunnel_routes(config), + )?)) + } + /// Returns a close handle for the tunnel pub fn close_handle(&self) -> CloseHandle { CloseHandle { diff --git a/talpid-core/src/tunnel/wireguard/wireguard_kernel/mod.rs b/talpid-core/src/tunnel/wireguard/wireguard_kernel/mod.rs new file mode 100644 index 0000000000..b1fa3a0f65 --- /dev/null +++ b/talpid-core/src/tunnel/wireguard/wireguard_kernel/mod.rs @@ -0,0 +1,479 @@ +use super::{stats::Stats, Config, Tunnel, TunnelError}; +use futures::future::{abortable, AbortHandle}; +use netlink_packet_core::{constants::*, NetlinkDeserializable}; +use netlink_packet_route::{ + rtnl::{ + address::nlas::Nla as AddressNla, + link::nlas::{Info, InfoKind, Nla as LinkNla}, + AddressMessage, LinkMessage, RtnlMessage, RT_SCOPE_UNIVERSE, + }, + NetlinkMessage, NetlinkPayload, +}; +use netlink_packet_utils::DecodeError; +use netlink_proto::{ + sys::{Protocol, SocketAddr}, + ConnectionHandle, Error as NetlinkError, +}; +use std::{ffi::CString, net::IpAddr}; +use tokio::stream::StreamExt; + +mod parsers; + +mod wg_message; +use wg_message::{DeviceMessage, DeviceNla, PeerNla}; +mod nl_message; +use nl_message::{ControlNla, NetlinkControlMessage}; + + +#[derive(err_derive::Error, Debug)] +#[error(no_from)] +pub enum Error { + #[error(display = "Failed to decode netlink message")] + DecodeError(#[error(source)] DecodeError), + + #[error(display = "Failed to execute netlink control request")] + NetlinkControlMessageError(#[error(source)] nl_message::Error), + + #[error(display = "Failed to open netlink socket")] + NetlinkSocketError(#[error(source)] std::io::Error), + + #[error(display = "Failed to send netlink control request")] + NetlinkRequestError(#[error(source)] netlink_proto::Error<NetlinkControlMessage>), + + #[error(display = "WireGuard netlink interface unavailable. Is the kernel module loaded?")] + WireguardNetlinkInterfaceUnavailable, + + #[error(display = "Unknown WireGuard command _0")] + UnnkownWireguardCommmand(u8), + + #[error(display = "Received no response")] + NoResponse, + + #[error(display = "Received truncated message")] + Truncated, + + #[error(display = "WireGuard device does not exist")] + NoDevice, + + #[error(display = "Failed to get config: _0")] + WgGetConfError(netlink_packet_core::error::ErrorMessage), + + #[error(display = "Failed to apply config: _0")] + WgSetConfError(netlink_packet_core::error::ErrorMessage), + + #[error(display = "Interface name too long")] + InterfaceNameError, + + #[error(display = "Send request error")] + SendRequestError(#[error(source)] NetlinkError<DeviceMessage>), + + #[error(display = "Create device error")] + NetlinkCreateDeviceError(#[error(source)] rtnetlink::Error), + + #[error(display = "Add IP to device error")] + NetlinkSetIpError(rtnetlink::Error), + + #[error(display = "Failed to delete device")] + DeleteDeviceError(#[error(source)] rtnetlink::Error), +} + +pub struct KernelTunnel { + interface_index: u32, + netlink_connections: Handle, + tokio_handle: tokio::runtime::Handle, +} + +const MULLVAD_INTERFACE_NAME: &str = "wg-mullvad"; + +impl KernelTunnel { + pub fn new(tokio_handle: tokio::runtime::Handle, config: &Config) -> Result<Self, Error> { + tokio_handle.clone().block_on(async { + let mut netlink_connections = Handle::connect().await?; + let interface_index = netlink_connections + .create_device(MULLVAD_INTERFACE_NAME.to_string(), config.mtu as u32) + .await?; + + let mut tunnel = Self { + interface_index, + netlink_connections, + tokio_handle, + }; + + if let Err(err) = tunnel.setup(config).await { + if let Err(teardown_err) = tunnel + .netlink_connections + .delete_device(interface_index) + .await + { + log::error!( + "Failed to tear down WireGuard interface after failing to apply config: {}", + teardown_err + ); + } + return Err(err); + } + + + Ok(tunnel) + }) + } + + async fn setup(&mut self, config: &Config) -> Result<(), Error> { + self.netlink_connections + .wg_handle + .set_config(self.interface_index, config) + .await?; + + for tunnel_ip in config.tunnel.addresses.iter() { + self.netlink_connections + .set_ip_address(self.interface_index, *tunnel_ip) + .await?; + } + + Ok(()) + } +} + +impl Tunnel for KernelTunnel { + fn get_interface_name(&self) -> String { + let mut wg = self.netlink_connections.wg_handle.clone(); + let result = self.tokio_handle.block_on(async move { + let device = wg.get_by_index(self.interface_index).await?; + for nla in device.nlas { + if let DeviceNla::IfName(name) = nla { + return Ok(name); + } + } + return Err(Error::Truncated); + }); + + match result { + Ok(name) => name.to_string_lossy().to_string(), + Err(err) => { + log::error!("Failed to deduce interface name at runtime, will attempt to use the default name. {}", err); + MULLVAD_INTERFACE_NAME.to_string() + } + } + } + + fn stop(self: Box<Self>) -> std::result::Result<(), TunnelError> { + let Self { + mut netlink_connections, + interface_index, + tokio_handle, + } = *self; + tokio_handle.block_on(async move { + if let Err(err) = netlink_connections.delete_device(interface_index).await { + log::error!("Failed to remove WireGuard device - {}", err); + Err(TunnelError::FatalStartWireguardError) + } else { + Ok(()) + } + }) + } + + fn get_tunnel_stats(&self) -> std::result::Result<Stats, TunnelError> { + let mut wg = self.netlink_connections.wg_handle.clone(); + let interface_index = self.interface_index; + let result = self.tokio_handle.block_on(async move { + let device = wg.get_by_index(interface_index).await.map_err(|err| { + log::error!("Failed to fetch WireGuard device config: {}", err); + TunnelError::GetConfigError + })?; + + // iterate over device attributes + let mut tx_bytes = 0; + let mut rx_bytes = 0; + for nla in device.nlas { + if let DeviceNla::Peers(peers) = nla { + // iterate over all peer attributes + let peer_iter = peers.iter().map(|peer| peer.0.as_slice()).flatten(); + + for peer_nla in peer_iter { + match peer_nla { + PeerNla::TxBytes(bytes) => tx_bytes += *bytes, + PeerNla::RxBytes(bytes) => rx_bytes += *bytes, + _ => continue, + }; + } + } + } + + Ok(Stats { tx_bytes, rx_bytes }) + }); + + result + } +} + + +#[derive(Debug)] +pub struct Handle { + wg_handle: WireguardConnection, + route_handle: rtnetlink::Handle, + wg_abort_handle: AbortHandle, + route_abort_handle: AbortHandle, + message_type: u16, +} + + +impl Handle { + pub async fn connect() -> Result<Self, Error> { + let message_type = Self::get_wireguard_message_type().await?; + let (conn, wireguard_connection, _messages) = + netlink_proto::new_connection(Protocol::Generic).map_err(Error::NetlinkSocketError)?; + let wg_handle = WireguardConnection { + message_type, + connection: wireguard_connection, + }; + let (abortable_connection, wg_abort_handle) = abortable(conn); + tokio::spawn(abortable_connection); + let (conn, route_handle, _messages) = + rtnetlink::new_connection().map_err(Error::NetlinkSocketError)?; + let (abortable_connection, route_abort_handle) = abortable(conn); + tokio::spawn(abortable_connection); + + + Ok(Self { + wg_handle, + route_handle, + message_type, + wg_abort_handle, + route_abort_handle, + }) + } + + async fn get_wireguard_message_type() -> Result<u16, Error> { + let (conn, mut handle, _messages) = + netlink_proto::new_connection(Protocol::Generic).map_err(Error::NetlinkSocketError)?; + let (conn, abort_handle) = abortable(conn); + tokio::spawn(conn); + + let result = async move { + let mut message: NetlinkMessage<NetlinkControlMessage> = + NetlinkControlMessage::get_netlink_family_id(CString::new("wireguard").unwrap()) + .map_err(Error::NetlinkControlMessageError)? + .into(); + + message.header.flags = NLM_F_REQUEST | NLM_F_ACK; + + let mut req = handle + .request(message, SocketAddr::new(0, 0)) + .map_err(Error::NetlinkRequestError)?; + let response = req.next().await; + if let Some(response) = response { + if let NetlinkPayload::InnerMessage(msg) = response.payload { + for nla in msg.nlas.into_iter() { + if let ControlNla::FamilyId(id) = nla { + return Ok(id); + } + } + } + } + Err(Error::WireguardNetlinkInterfaceUnavailable) + } + .await; + + abort_handle.abort(); + result + } + + // create a wireguard device with the given name. + pub async fn create_device(&mut self, name: String, mtu: u32) -> Result<u32, Error> { + let mut message = LinkMessage::default(); + + // set link to be up + message.header.flags = netlink_packet_route::IFF_UP; + // message.header.change_mask = netlink_packet_route::IFF_UP; + // set link name + message.nlas.push(LinkNla::IfName(name.clone())); + // set link MTU + message.nlas.push(LinkNla::Mtu(mtu)); + // set link type + message + .nlas + .push(LinkNla::Info(vec![Info::Kind(InfoKind::Other( + "wireguard".to_string(), + ))])); + + let mut add_request = NetlinkMessage::from(RtnlMessage::NewLink(message)); + add_request.header.flags = + NLM_F_REQUEST | NLM_F_ACK | NLM_F_REPLACE | NLM_F_CREATE | NLM_F_MATCH; + let mut response = self + .route_handle + .request(add_request) + .map_err(Error::NetlinkCreateDeviceError)?; + while let Some(response_message) = response.next().await { + if let NetlinkPayload::Error(err) = response_message.payload { + // if the device exists, verify that it's a wireguard device + if -err.code != libc::EEXIST { + return Err(Error::NetlinkCreateDeviceError( + rtnetlink::Error::NetlinkError(err), + )); + } + } + } + + // fetch interface index of new device + let new_device = self.wg_handle.get_by_name(name).await?; + for nla in new_device.nlas { + if let DeviceNla::IfIndex(index) = nla { + return Ok(index); + } + } + + + Err(Error::NoDevice) + } + + pub async fn set_ip_address(&mut self, index: u32, addr: IpAddr) -> Result<(), Error> { + let address_message = add_ip_addr_message(index, addr); + let mut request = NetlinkMessage::from(RtnlMessage::NewAddress(address_message)); + request.header.flags = NLM_F_REQUEST | NLM_F_ACK | NLM_F_CREATE | NLM_F_REPLACE; + + + let mut response = self + .route_handle + .request(request) + .map_err(Error::NetlinkSetIpError)?; + while let Some(response_message) = response.next().await { + consume_netlink_error(response_message, Error::NetlinkSetIpError)?; + } + + Ok(()) + } + + pub async fn delete_device(&mut self, index: u32) -> Result<(), Error> { + let mut link_message = LinkMessage::default(); + link_message.header.index = index; + + let mut request = NetlinkMessage::from(RtnlMessage::DelLink(link_message)); + request.header.flags = NLM_F_REQUEST | NLM_F_ACK | NLM_F_EXCL | NLM_F_CREATE; + + let mut response = self + .route_handle + .request(request) + .map_err(Error::DeleteDeviceError)?; + while let Some(message) = response.next().await { + consume_netlink_error(message, Error::DeleteDeviceError)?; + } + + Ok(()) + } +} + +impl Drop for Handle { + fn drop(&mut self) { + self.wg_abort_handle.abort(); + self.route_abort_handle.abort(); + } +} + +#[derive(Debug, Clone)] +struct WireguardConnection { + connection: ConnectionHandle<DeviceMessage>, + message_type: u16, +} + +impl WireguardConnection { + pub async fn get_by_name(&mut self, name: String) -> Result<DeviceMessage, Error> { + self.fetch_device(DeviceMessage::get_by_name(self.message_type, name)?) + .await + } + + pub async fn get_by_index(&mut self, index: u32) -> Result<DeviceMessage, Error> { + self.fetch_device(DeviceMessage::get_by_index(self.message_type, index)) + .await + } + + pub async fn fetch_device( + &mut self, + device_message: DeviceMessage, + ) -> Result<DeviceMessage, Error> { + let mut netlink_message = NetlinkMessage::from(device_message); + netlink_message.header.flags = NLM_F_REQUEST | NLM_F_ACK | NLM_F_DUMP; + + let mut response = self + .connection + .request(netlink_message, SocketAddr::new(0, 0)) + .map_err(Error::SendRequestError)?; + match response.next().await { + Some(received_message) => match received_message.payload { + NetlinkPayload::InnerMessage(inner) => Ok(inner), + NetlinkPayload::Error(err) => { + if err.code == -libc::ENODEV { + Err(Error::NoDevice) + } else { + Err(Error::WgGetConfError(err)) + } + } + anything_else => { + log::error!("Received unexpected response - {:?}", anything_else); + Err(Error::NoResponse) + } + }, + None => Err(Error::NoResponse), + } + } + + pub async fn set_config(&mut self, interface_index: u32, config: &Config) -> Result<(), Error> { + let message = DeviceMessage::reset_config(self.message_type, interface_index, config); + let mut netlink_message = NetlinkMessage::from(message); + netlink_message.header.flags = NLM_F_REQUEST | NLM_F_ACK; + + let mut request = self + .connection + .request(netlink_message, SocketAddr::new(0, 0)) + .map_err(Error::SendRequestError)?; + + while let Some(response) = request.next().await { + if let NetlinkPayload::Error(err) = response.payload { + return Err(Error::WgSetConfError(err)); + } + } + Ok(()) + } +} + + +fn consume_netlink_error< + T, + I: NetlinkDeserializable<T> + Clone + Eq + std::fmt::Debug, + F: Fn(rtnetlink::Error) -> Error, +>( + message: NetlinkMessage<I>, + err_constructor: F, +) -> Result<(), Error> { + if let NetlinkPayload::Error(err) = message.payload { + return Err(err_constructor(rtnetlink::Error::NetlinkError(err))); + } + Ok(()) +} + +// the built-in support for adding addresses is too helpful, so a simple AddressMessage with a +// single Address nla is created +fn add_ip_addr_message(if_index: u32, addr: IpAddr) -> AddressMessage { + let prefix_len = if addr.is_ipv4() { 32 } else { 128 }; + let mut message = AddressMessage::default(); + message.header.prefix_len = prefix_len; + message.header.index = if_index; + message.header.scope = RT_SCOPE_UNIVERSE; + + match addr { + IpAddr::V4(ipv4) => { + message.header.family = libc::AF_INET as u8; + let ip_bytes = ipv4.octets().to_vec(); + + message.nlas.push(AddressNla::Address(ip_bytes.clone())); + message.nlas.push(AddressNla::Local(ip_bytes)); + } + IpAddr::V6(ipv6) => { + message.header.family = libc::AF_INET6 as u8; + message + .nlas + .push(AddressNla::Address(ipv6.octets().to_vec())); + } + }; + + message +} diff --git a/talpid-core/src/tunnel/wireguard/wireguard_kernel/nl_message.rs b/talpid-core/src/tunnel/wireguard/wireguard_kernel/nl_message.rs new file mode 100644 index 0000000000..7fc8be8304 --- /dev/null +++ b/talpid-core/src/tunnel/wireguard/wireguard_kernel/nl_message.rs @@ -0,0 +1,135 @@ +use super::parsers; +use byteorder::{ByteOrder, NativeEndian}; +use netlink_packet_core::{ + NetlinkDeserializable, NetlinkHeader, NetlinkPayload, NetlinkSerializable, +}; +use netlink_packet_utils::{ + nla::{Nla, NlaBuffer, NlasIterator}, + traits::{Emitable, Parseable}, + DecodeError, +}; +use std::{ffi::CString, io::Write, mem}; + + +#[derive(err_derive::Error, Debug)] +pub enum Error { + #[error(display = "Family name too long")] + FamilyNameTooLong, +} + +#[derive(Debug, PartialEq, Eq, Clone)] +pub struct NetlinkControlMessage { + cmd: u8, + version: u8, + pub nlas: Vec<ControlNla>, +} + +impl NetlinkControlMessage { + pub fn get_netlink_family_id(name: CString) -> Result<Self, Error> { + if name.as_bytes_with_nul().len() > (libc::GENL_NAMSIZ as usize) { + return Err(Error::FamilyNameTooLong); + } + Ok(Self { + nlas: vec![ControlNla::FamilyName(name)], + cmd: libc::CTRL_CMD_GETFAMILY as u8, + version: 1, + }) + } +} + + +impl NetlinkSerializable<NetlinkControlMessage> for NetlinkControlMessage { + fn message_type(&self) -> u16 { + libc::GENL_ID_CTRL as u16 + } + + fn buffer_len(&self) -> usize { + mem::size_of::<libc::genlmsghdr>() + self.nlas.as_slice().buffer_len() + } + + fn serialize(&self, mut buffer: &mut [u8]) { + let _ = buffer.write(&[self.cmd, self.version, 0u8, 0u8]).unwrap(); + self.nlas.as_slice().emit(&mut buffer); + } +} + +impl Into<NetlinkPayload<NetlinkControlMessage>> for NetlinkControlMessage { + fn into(self) -> NetlinkPayload<NetlinkControlMessage> { + NetlinkPayload::InnerMessage(self) + } +} + +impl NetlinkDeserializable<NetlinkControlMessage> for NetlinkControlMessage { + type Error = DecodeError; + fn deserialize( + _header: &NetlinkHeader, + payload: &[u8], + ) -> Result<NetlinkControlMessage, Self::Error> { + // skip the genlmsghdr + let (cmd, version) = parsers::parse_genlmsghdr(payload)?; + let nla_buffer = &payload[mem::size_of::<libc::genlmsghdr>()..]; + let nlas = NlasIterator::new(nla_buffer) + .map(|buffer| ControlNla::parse(&buffer?)) + .collect::<Result<Vec<_>, DecodeError>>()?; + + Ok(NetlinkControlMessage { nlas, cmd, version }) + } +} + +#[derive(Debug, PartialEq, Eq, Clone)] +pub enum ControlNla { + FamilyName(CString), + FamilyId(u16), + Unknown(u16, Vec<u8>), +} + +impl Nla for ControlNla { + fn value_len(&self) -> usize { + use ControlNla::*; + match self { + FamilyName(name) => name.as_bytes_with_nul().len(), + FamilyId(_id) => 2, + Unknown(_, buffer) => buffer.len(), + } + } + + fn kind(&self) -> u16 { + use ControlNla::*; + match self { + FamilyName(_) => libc::CTRL_ATTR_FAMILY_NAME as u16, + FamilyId(_) => libc::CTRL_ATTR_FAMILY_ID as u16, + Unknown(kind, _) => *kind, + } + } + + fn emit_value(&self, mut buffer: &mut [u8]) { + use ControlNla::*; + match self { + FamilyName(name) => { + let _ = buffer.write(name.as_bytes()).unwrap(); + } + FamilyId(id) => { + NativeEndian::write_u16(buffer, *id); + } + + Unknown(_, value) => { + let _ = buffer.write(value).unwrap(); + } + } + } +} + +impl<'a, T: AsRef<[u8]> + 'a + ?Sized + std::fmt::Debug> Parseable<NlaBuffer<&'a T>> + for ControlNla +{ + fn parse(buf: &NlaBuffer<&'a T>) -> Result<Self, DecodeError> { + let nla = match buf.kind() as i32 { + libc::CTRL_ATTR_FAMILY_NAME => { + ControlNla::FamilyName(parsers::parse_cstring(buf.value())?) + } + libc::CTRL_ATTR_FAMILY_ID => ControlNla::FamilyId(parsers::parse_u16(buf.value())?), + _unknown_kind => ControlNla::Unknown(buf.kind(), buf.value().to_vec()), + }; + Ok(nla) + } +} diff --git a/talpid-core/src/tunnel/wireguard/wireguard_kernel/parsers.rs b/talpid-core/src/tunnel/wireguard/wireguard_kernel/parsers.rs new file mode 100644 index 0000000000..b34c82d342 --- /dev/null +++ b/talpid-core/src/tunnel/wireguard/wireguard_kernel/parsers.rs @@ -0,0 +1,99 @@ +use byteorder::{ByteOrder, NativeEndian}; +use nix::sys::{socket::InetAddr, time::TimeSpec}; +use std::{ + ffi::{CStr, CString}, + mem, + net::IpAddr, +}; + +pub use netlink_packet_utils::parsers::*; +use netlink_packet_utils::DecodeError; + +pub fn parse_ip_addr(bytes: &[u8]) -> Result<IpAddr, DecodeError> { + if bytes.len() == 4 { + let mut ipv4_bytes = [0u8; 4]; + ipv4_bytes.copy_from_slice(bytes); + Ok(IpAddr::from(ipv4_bytes)) + } else if bytes.len() == 16 { + let mut ipv6_bytes = [0u8; 16]; + ipv6_bytes.copy_from_slice(bytes); + Ok(IpAddr::from(ipv6_bytes)) + } else { + log::error!("Expected either 4 or 16 bytes, got {} bytes", bytes.len()); + Err(format!("Invalid bytes for IP address: {:?}", bytes).into()) + } +} + +pub fn parse_wg_key(buffer: &[u8]) -> Result<[u8; 32], DecodeError> { + match buffer.len() { + 32 => { + let mut key = [0u8; 32]; + key.clone_from_slice(buffer); + Ok(key) + } + anything_else => Err(format!("Unexpected length of key: {}", anything_else).into()), + } +} + +pub fn parse_inet_sockaddr(buffer: &[u8]) -> Result<InetAddr, DecodeError> { + if buffer.len() != mem::size_of::<libc::sockaddr_in6>() + && buffer.len() != mem::size_of::<libc::sockaddr_in>() + { + return Err(format!( + "Unexpected length for sockaddr_in: {}, expected {} or {}", + buffer.len(), + mem::size_of::<libc::sockaddr_in6>(), + mem::size_of::<libc::sockaddr_in>() + ) + .into()); + } + let ptr = buffer.as_ptr(); + const AF_INET: u16 = libc::AF_INET as u16; + const AF_INET6: u16 = libc::AF_INET6 as u16; + + match NativeEndian::read_u16(buffer) { + AF_INET => unsafe { + let sockaddr: *const libc::sockaddr_in = ptr as *const _; + Ok(InetAddr::V4(*sockaddr).into()) + }, + AF_INET6 => unsafe { + let sockaddr: *const libc::sockaddr_in6 = ptr as *const _; + Ok(InetAddr::V6(*sockaddr)) + }, + unexpected_addr_family => { + Err(format!("Unexpected address family: {}", unexpected_addr_family).into()) + } + } +} + +pub fn parse_timespec(buffer: &[u8]) -> Result<TimeSpec, DecodeError> { + if buffer.len() != mem::size_of::<libc::timespec>() { + return Err(format!("Unexpected size for timespec: {}", buffer.len()).into()); + } + + Ok(TimeSpec::from(libc::timespec { + tv_sec: NativeEndian::read_i64(buffer), + // TODO: become compatible with 32-bit systems maybe? + tv_nsec: NativeEndian::read_i64(buffer), + })) +} + +pub fn parse_cstring(buffer: &[u8]) -> Result<CString, DecodeError> { + Ok(CStr::from_bytes_with_nul(buffer) + .map_err(|err| format!("{}", err))? + .into()) +} + +pub fn parse_genlmsghdr(buffer: &[u8]) -> Result<(u8, u8), DecodeError> { + const GENLMSGHDR_SIZE: usize = mem::size_of::<libc::genlmsghdr>(); + if buffer.len() < GENLMSGHDR_SIZE { + return Err(format!( + "Expected at least {}, got {}", + GENLMSGHDR_SIZE, + buffer.len() + ) + .into()); + } + + Ok((buffer[0], buffer[1])) +} diff --git a/talpid-core/src/tunnel/wireguard/wireguard_kernel/wg_message.rs b/talpid-core/src/tunnel/wireguard/wireguard_kernel/wg_message.rs new file mode 100644 index 0000000000..1e587712cc --- /dev/null +++ b/talpid-core/src/tunnel/wireguard/wireguard_kernel/wg_message.rs @@ -0,0 +1,899 @@ +use super::{super::config::Config, parsers, Error}; +use byteorder::{ByteOrder, NativeEndian}; +use ipnetwork::IpNetwork; +use netlink_packet_core::{ + NetlinkDeserializable, NetlinkHeader, NetlinkPayload, NetlinkSerializable, +}; +use netlink_packet_utils::{ + nla::{Nla, NlaBuffer, NlasIterator, NLA_F_NESTED}, + traits::{Emitable, Parseable}, + DecodeError, +}; +use nix::sys::{socket::InetAddr, time::TimeSpec}; +use std::{ffi::CString, io::Write, mem, net::IpAddr}; + +/// WireGuard netlink constants +mod constants { + #![allow(dead_code)] + pub const WG_GENL_VERSION: u8 = 1; + + /// Command constants + pub const WG_CMD_GET_DEVICE: u8 = 0; + pub const WG_CMD_SET_DEVICE: u8 = 1; + + // wgdevice_flag + pub const WGDEVICE_F_REPLACE_PEERS: u32 = 1 << 0; + + // wgdevice_attribute + pub const WGDEVICE_A_UNSPEC: u16 = 0; + pub const WGDEVICE_A_IFINDEX: u16 = 1; + pub const WGDEVICE_A_IFNAME: u16 = 2; + pub const WGDEVICE_A_PRIVATE_KEY: u16 = 3; + pub const WGDEVICE_A_PUBLIC_KEY: u16 = 4; + pub const WGDEVICE_A_FLAGS: u16 = 5; + pub const WGDEVICE_A_LISTEN_PORT: u16 = 6; + pub const WGDEVICE_A_FWMARK: u16 = 7; + pub const WGDEVICE_A_PEERS: u16 = 8; + + // wgpeer_flag + pub const WGPEER_F_REMOVE_ME: u32 = 1 << 0; + pub const WGPEER_F_REPLACE_ALLOWEDIPS: u32 = 1 << 1; + pub const WGPEER_F_UPDATE_ONLY: u32 = 1 << 2; + + // wgpeer_attribute + pub const WGPEER_A_UNSPEC: u16 = 0; + pub const WGPEER_A_PUBLIC_KEY: u16 = 1; + pub const WGPEER_A_PRESHARED_KEY: u16 = 2; + pub const WGPEER_A_FLAGS: u16 = 3; + pub const WGPEER_A_ENDPOINT: u16 = 4; + pub const WGPEER_A_PERSISTENT_KEEPALIVE_INTERVAL: u16 = 5; + pub const WGPEER_A_LAST_HANDSHAKE_TIME: u16 = 6; + pub const WGPEER_A_RX_BYTES: u16 = 7; + pub const WGPEER_A_TX_BYTES: u16 = 8; + pub const WGPEER_A_ALLOWEDIPS: u16 = 9; + pub const WGPEER_A_PROTOCOL_VERSION: u16 = 10; + + // wgallowedip_attribute + pub const WGALLOWEDIP_A_UNSPEC: u16 = 0; + pub const WGALLOWEDIP_A_FAMILY: u16 = 1; + pub const WGALLOWEDIP_A_IPADDR: u16 = 2; + pub const WGALLOWEDIP_A_CIDR_MASK: u16 = 3; +} + +use constants::*; +pub use constants::{WG_CMD_GET_DEVICE, WG_CMD_SET_DEVICE}; + +type PrivateKey = [u8; 32]; +type PublicKey = [u8; 32]; +type PresharedKey = [u8; 32]; + + +#[derive(Debug, PartialEq, Eq, Clone)] +pub struct DeviceMessage { + pub nlas: Vec<DeviceNla>, + pub message_type: u16, + pub command: u8, +} + +impl DeviceMessage { + pub fn reset_config(message_type: u16, interface_index: u32, config: &Config) -> DeviceMessage { + let mut peers = vec![]; + + for peer in config.peers.iter() { + let peer_endpoint = InetAddr::from_std(&peer.endpoint); + let allowed_ips = peer.allowed_ips.iter().map(From::from).collect(); + peers.push(PeerMessage(vec![ + PeerNla::PublicKey(*peer.public_key.as_bytes()), + PeerNla::Endpoint(peer_endpoint), + PeerNla::AllowedIps(allowed_ips), + PeerNla::Flags(WGPEER_F_REPLACE_ALLOWEDIPS), + ])); + } + + let nlas = vec![ + DeviceNla::IfIndex(interface_index), + DeviceNla::ListenPort(0), + DeviceNla::Fwmark(crate::linux::TUNNEL_FW_MARK), + DeviceNla::PrivateKey(config.tunnel.private_key.to_bytes()), + DeviceNla::Flags(WGDEVICE_F_REPLACE_PEERS), + DeviceNla::Peers(peers), + ]; + + + Self { + nlas, + message_type, + command: WG_CMD_SET_DEVICE, + } + } + + pub fn get_by_name(message_type: u16, name: String) -> Result<Self, Error> { + let c_name = CString::new(name).map_err(|_| Error::InterfaceNameError)?; + if c_name.as_bytes_with_nul().len() > libc::IFNAMSIZ { + return Err(Error::InterfaceNameError); + } + + Ok(Self { + message_type, + nlas: vec![DeviceNla::IfName(c_name)], + command: WG_CMD_GET_DEVICE, + }) + } + + pub fn get_by_index(message_type: u16, index: u32) -> Self { + Self { + message_type, + nlas: vec![DeviceNla::IfIndex(index)], + command: WG_CMD_GET_DEVICE, + } + } + + // All WireGuard netlink messages should start with a libc::genlmsghdr, for which the first + // byte contains the command. + fn read_genlmsghdr(buff: &[u8]) -> Result<u8, Error> { + if buff.len() < mem::size_of::<libc::genlmsghdr>() { + return Err(Error::Truncated); + } + + let cmd = buff[0]; + if cmd == WG_CMD_GET_DEVICE || cmd == WG_CMD_SET_DEVICE { + Ok(cmd) + } else { + Err(Error::UnnkownWireguardCommmand(cmd)) + } + } +} + +impl NetlinkSerializable<DeviceMessage> for DeviceMessage { + fn message_type(&self) -> u16 { + self.message_type + } + + fn buffer_len(&self) -> usize { + // add the genlmsghdr + mem::size_of::<libc::genlmsghdr>() + + // size of all of the NLAs + self.nlas.as_slice().buffer_len() + } + + fn serialize(&self, mut buffer: &mut [u8]) { + let command_buf = [self.command, WG_GENL_VERSION, 0u8, 0u8]; + let _ = buffer.write(&command_buf).unwrap(); + self.nlas.as_slice().emit(&mut buffer) + } +} +impl Into<NetlinkPayload<DeviceMessage>> for DeviceMessage { + fn into(self) -> NetlinkPayload<DeviceMessage> { + NetlinkPayload::InnerMessage(self) + } +} + +impl NetlinkDeserializable<DeviceMessage> for DeviceMessage { + type Error = Error; + fn deserialize(header: &NetlinkHeader, payload: &[u8]) -> Result<DeviceMessage, Self::Error> { + let command = Self::read_genlmsghdr(payload)?; + let new_payload = &payload[mem::size_of::<libc::genlmsghdr>()..]; + let mut nlas = vec![]; + for buf in NlasIterator::new(new_payload) { + nlas.push( + DeviceNla::parse(&buf.map_err(Error::DecodeError)?).map_err(Error::DecodeError)?, + ); + } + + Ok(DeviceMessage { + nlas, + command, + message_type: header.message_type, + }) + } +} + +#[derive(Debug, PartialEq, Eq, Clone)] +pub enum DeviceNla { + IfIndex(u32), + IfName(CString), + Flags(u32), + PrivateKey(PrivateKey), + PublicKey(PublicKey), + ListenPort(u16), + Fwmark(u32), + Peers(Vec<PeerMessage>), + Unspec(Vec<u8>), +} + +impl Nla for DeviceNla { + fn value_len(&self) -> usize { + use DeviceNla::*; + match self { + IfIndex(_) | Fwmark(_) | Flags(_) => 4, + IfName(name) => name.as_bytes_with_nul().len(), + PrivateKey(key) | PublicKey(key) => key.len(), + ListenPort(_) => 2, + Peers(peers) => peers.as_slice().buffer_len(), + Unspec(payload) => payload.len(), + } + } + + fn kind(&self) -> u16 { + use DeviceNla::*; + match self { + IfIndex(_) => WGDEVICE_A_IFINDEX, + IfName(_) => WGDEVICE_A_IFNAME, + PrivateKey(_) => WGDEVICE_A_PRIVATE_KEY, + PublicKey(_) => WGDEVICE_A_PUBLIC_KEY, + Flags(_) => WGDEVICE_A_FLAGS, + ListenPort(_) => WGDEVICE_A_LISTEN_PORT, + Fwmark(_) => WGDEVICE_A_FWMARK, + Peers(_) => WGDEVICE_A_PEERS | NLA_F_NESTED, + Unspec(_) => WGDEVICE_A_UNSPEC, + } + } + + fn emit_value(&self, mut buffer: &mut [u8]) { + use DeviceNla::*; + match self { + IfIndex(value) | Fwmark(value) | Flags(value) => { + NativeEndian::write_u32(buffer, *value) + } + IfName(interface_name) => { + let _ = buffer + .write(interface_name.as_bytes_with_nul()) + .expect("Failed to write interface name"); + } + PrivateKey(key) | PublicKey(key) => { + let _ = buffer.write(key).expect("Failed to write key"); + } + ListenPort(port) => NativeEndian::write_u16(buffer, *port), + Peers(peers) => { + peers.as_slice().emit(buffer); + } + Unspec(payload) => { + let _ = buffer.write(&payload).expect("Failed to write "); + } + } + } +} + +impl<'a, T: AsRef<[u8]> + 'a + ?Sized + core::fmt::Debug> Parseable<NlaBuffer<&'a T>> + for DeviceNla +{ + fn parse(buf: &NlaBuffer<&'a T>) -> Result<Self, DecodeError> { + use DeviceNla::*; + let value = buf.value(); + let kind = buf.kind(); + let nla = match kind { + WGDEVICE_A_IFINDEX => IfIndex(parsers::parse_u32(value)?), + WGDEVICE_A_IFNAME => IfName(parsers::parse_cstring(value)?), + WGDEVICE_A_PRIVATE_KEY => PrivateKey(parsers::parse_wg_key(value)?.into()), + WGDEVICE_A_PUBLIC_KEY => PublicKey(parsers::parse_wg_key(value)?.into()), + WGDEVICE_A_FLAGS => Flags(parsers::parse_u32(value)?), + WGDEVICE_A_LISTEN_PORT => ListenPort(parsers::parse_u16(value)?), + WGDEVICE_A_FWMARK => Fwmark(parsers::parse_u32(value)?), + WGDEVICE_A_PEERS => { + let peers = NlasIterator::new(value) + .map(|nla_bytes| { + let buf = nla_bytes?; + let val = buf.value(); + PeerMessage::parse(&val) + }) + .collect::<Result<Vec<PeerMessage>, DecodeError>>()?; + Peers(peers) + } + WGDEVICE_A_UNSPEC => Unspec(value.to_vec()), + _ => { + return Err(format!("Unexpected device attribute kind: {}", buf.kind()).into()); + } + }; + Ok(nla) + } +} + +#[derive(Debug, PartialEq, Eq, Clone)] +pub struct PeerMessage(pub Vec<PeerNla>); + +impl PeerMessage { + fn parse(payload: &[u8]) -> Result<Self, DecodeError> { + let mut nlas = vec![]; + + let nla_iter = NlasIterator::new(&payload); + for buffer in nla_iter { + nlas.push(PeerNla::parse(&buffer?)?) + } + Ok(Self(nlas)) + } +} + +impl<'a, T: AsRef<[u8]> + 'a + ?Sized> Parseable<NlaBuffer<&'a T>> for PeerMessage { + fn parse(payload: &NlaBuffer<&'a T>) -> Result<Self, DecodeError> { + Ok(Self( + NlasIterator::new(&payload.into_inner()) + .map(|buffer| PeerNla::parse(&buffer?)) + .collect::<Result<Vec<PeerNla>, DecodeError>>()?, + )) + } +} + +impl Nla for PeerMessage { + fn value_len(&self) -> usize { + self.0.as_slice().buffer_len() + } + + fn kind(&self) -> u16 { + NLA_F_NESTED + } + + fn emit_value(&self, buffer: &mut [u8]) { + self.0.as_slice().emit(buffer); + } +} + +#[derive(Debug, PartialEq, Eq, Clone)] +pub enum PeerNla { + Unspec(Vec<u8>), + PublicKey(PublicKey), + PresharedKey(PresharedKey), + Flags(u32), + Endpoint(InetAddr), + PersistentKeepaliveInterval(u16), + LastHandshakeTime(TimeSpec), + RxBytes(u64), + TxBytes(u64), + AllowedIps(Vec<AllowedIpMessage>), + ProtocolVersion(u32), +} + +impl Nla for PeerNla { + fn value_len(&self) -> usize { + use PeerNla::*; + match self { + PublicKey(key) | PresharedKey(key) => key.len(), + Endpoint(endpoint) => match &endpoint { + InetAddr::V4(_) => mem::size_of::<libc::sockaddr_in>(), + InetAddr::V6(_) => mem::size_of::<libc::sockaddr_in6>(), + }, + PersistentKeepaliveInterval(_) => 2, + LastHandshakeTime(_) => mem::size_of::<libc::timespec>(), + RxBytes(_) | TxBytes(_) => 8, + AllowedIps(ips) => ips.as_slice().buffer_len(), + Flags(_) | ProtocolVersion(_) => 4, + Unspec(payload) => payload.len(), + } + } + + fn kind(&self) -> u16 { + use PeerNla::*; + match self { + PublicKey(_) => WGPEER_A_PUBLIC_KEY, + PresharedKey(_) => WGPEER_A_PRESHARED_KEY, + Flags(_) => WGPEER_A_FLAGS, + Endpoint(_) => WGPEER_A_ENDPOINT, + PersistentKeepaliveInterval(_) => WGPEER_A_PERSISTENT_KEEPALIVE_INTERVAL, + LastHandshakeTime(_) => WGPEER_A_LAST_HANDSHAKE_TIME, + RxBytes(_) => WGPEER_A_RX_BYTES, + TxBytes(_) => WGPEER_A_TX_BYTES, + AllowedIps(_) => WGPEER_A_ALLOWEDIPS | NLA_F_NESTED, + ProtocolVersion(_) => WGPEER_A_PROTOCOL_VERSION, + Unspec(_) => WGPEER_A_UNSPEC, + } + } + + fn emit_value(&self, mut buffer: &mut [u8]) { + use PeerNla::*; + match self { + PublicKey(key) | PresharedKey(key) => { + let _ = buffer.write(key).expect("Buffer too small for a key"); + } + Flags(value) | ProtocolVersion(value) => NativeEndian::write_u32(buffer, *value), + Endpoint(endpoint) => match &endpoint { + InetAddr::V4(sockaddr_in) => { + let slice = unsafe { struct_as_slice(sockaddr_in) }; + buffer + .write(slice) + .expect("Buffer too small for sockaddr_in"); + } + InetAddr::V6(sockaddr_in6) => { + buffer + .write(unsafe { struct_as_slice(sockaddr_in6) }) + .expect("Buffer too small for sockaddr_in6"); + } + }, + PersistentKeepaliveInterval(interval) => { + NativeEndian::write_u16(buffer, *interval); + } + LastHandshakeTime(last_handshake) => { + let timespec: &libc::timespec = last_handshake.as_ref(); + buffer + .write(unsafe { struct_as_slice(timespec) }) + .expect("Buffer too small for timespec"); + } + RxBytes(num_bytes) | TxBytes(num_bytes) => NativeEndian::write_u64(buffer, *num_bytes), + AllowedIps(ips) => ips.as_slice().emit(buffer), + Unspec(payload) => { + let _ = buffer + .write(&payload) + .expect("Buffer too small for unspecified payload"); + } + } + } +} + +impl<'a, T: AsRef<[u8]> + 'a + ?Sized> Parseable<NlaBuffer<&'a T>> for PeerNla { + fn parse(buf: &NlaBuffer<&'a T>) -> Result<Self, DecodeError> { + use PeerNla::*; + let value = buf.value(); + let nla = match buf.kind() { + WGPEER_A_PUBLIC_KEY => PublicKey(parsers::parse_wg_key(value)?.into()), + WGPEER_A_PRESHARED_KEY => PresharedKey(parsers::parse_wg_key(value)?.into()), + WGPEER_A_FLAGS => Flags(parsers::parse_u32(value)?), + WGPEER_A_ENDPOINT => Endpoint(parsers::parse_inet_sockaddr(value)?), + WGPEER_A_PERSISTENT_KEEPALIVE_INTERVAL => { + PersistentKeepaliveInterval(parsers::parse_u16(value)?) + } + + WGPEER_A_LAST_HANDSHAKE_TIME => LastHandshakeTime(parsers::parse_timespec(value)?), + WGPEER_A_RX_BYTES => RxBytes(parsers::parse_u64(value)?), + WGPEER_A_TX_BYTES => TxBytes(parsers::parse_u64(value)?), + WGPEER_A_ALLOWEDIPS => { + let nlas = NlasIterator::new(value) + .map(|nla_buffer| AllowedIpMessage::parse(&nla_buffer?)) + .collect::<Result<Vec<_>, DecodeError>>()?; + + AllowedIps(nlas) + } + WGPEER_A_PROTOCOL_VERSION => ProtocolVersion(parsers::parse_u32(value)?), + WGPEER_A_UNSPEC => Unspec(value.to_vec()), + _ => { + return Err(format!("Unexpected peer attribute kind: {}", buf.kind()).into()); + } + }; + Ok(nla) + } +} + +#[derive(Debug, PartialEq, Eq, Clone)] +pub struct AllowedIpMessage(Vec<AllowedIpNla>); + +impl From<&IpNetwork> for AllowedIpMessage { + fn from(ip: &IpNetwork) -> Self { + use AllowedIpNla::*; + let address_family = if ip.is_ipv4() { + libc::AF_INET + } else { + libc::AF_INET6 + }; + + AllowedIpMessage(vec![ + AddressFamily(address_family as u16), + CidrMask(ip.prefix()), + IpAddr(ip.ip().into()), + ]) + } +} + +impl Nla for AllowedIpMessage { + fn value_len(&self) -> usize { + self.0.as_slice().buffer_len() + } + + fn kind(&self) -> u16 { + NLA_F_NESTED + } + + fn emit_value(&self, buffer: &mut [u8]) { + self.0.as_slice().emit(buffer); + } +} + +impl<'a, T: AsRef<[u8]> + 'a + ?Sized> Parseable<NlaBuffer<&'a T>> for AllowedIpMessage { + fn parse(buf: &NlaBuffer<&'a T>) -> Result<Self, DecodeError> { + let nlas = NlasIterator::new(buf.value()) + .map(|buffer| AllowedIpNla::parse(&buffer?)) + .collect::<Result<Vec<_>, _>>()?; + Ok(AllowedIpMessage(nlas)) + } +} + + +#[derive(Debug, PartialEq, Eq, Clone)] +pub enum AllowedIpNla { + AddressFamily(u16), + IpAddr(IpAddr), + CidrMask(u8), + Unspec(Vec<u8>), +} + +impl Nla for AllowedIpNla { + fn value_len(&self) -> usize { + use AllowedIpNla::*; + match &self { + AddressFamily(_) => 2, + IpAddr(addr) => ip_addr_to_bytes(addr).len(), + CidrMask(_) => 1, + Unspec(payload) => payload.len(), + } + } + + fn kind(&self) -> u16 { + use AllowedIpNla::*; + match &self { + AddressFamily(_) => WGALLOWEDIP_A_FAMILY, + IpAddr(_) => WGALLOWEDIP_A_IPADDR, + CidrMask(_) => WGALLOWEDIP_A_CIDR_MASK, + Unspec(_) => WGALLOWEDIP_A_UNSPEC, + } + } + + fn emit_value(&self, mut buffer: &mut [u8]) { + use AllowedIpNla::*; + match self { + AddressFamily(af) => { + NativeEndian::write_u16(buffer, *af); + } + IpAddr(ip_addr) => { + buffer + .write(&ip_addr_to_bytes(ip_addr)) + .expect("Buffer too small for AllowedIpNla::IpAddr"); + } + CidrMask(cidr_mask) => buffer[0] = *cidr_mask, + Unspec(payload) => { + let _ = buffer + .write(&payload) + .expect("Buffer too small for unspec payload"); + } + } + } +} + +impl<'a, T: AsRef<[u8]> + 'a + ?Sized> Parseable<NlaBuffer<&'a T>> for AllowedIpNla { + fn parse(buf: &NlaBuffer<&'a T>) -> Result<Self, DecodeError> { + use AllowedIpNla::*; + let value = buf.value(); + let nla = match buf.kind() { + WGALLOWEDIP_A_FAMILY => AddressFamily(parsers::parse_u16(value)?), + WGALLOWEDIP_A_IPADDR => IpAddr(parsers::parse_ip_addr(value)?), + WGALLOWEDIP_A_CIDR_MASK => CidrMask(parsers::parse_u8(value)?), + WGALLOWEDIP_A_UNSPEC => Unspec(value.to_vec()), + _ => Err(format!( + "Unexpected allowed IP attribute kind: {}", + buf.kind() + ))?, + }; + Ok(nla) + } +} + +unsafe fn struct_as_slice<T: Sized>(t: &T) -> &[u8] { + let s = mem::size_of::<T>(); + let ptr = t as *const T as *const u8; + std::slice::from_raw_parts(ptr, s) +} + +fn ip_addr_to_bytes(addr: &IpAddr) -> Vec<u8> { + match addr { + IpAddr::V4(addr) => addr.octets().to_vec(), + IpAddr::V6(addr) => addr.octets().to_vec(), + } +} + +#[cfg(test)] +mod test { + use super::*; + use nix::sys::time::TimeValLike; + use std::net::Ipv4Addr; + + + #[test] + fn deserialize_netlink_message() { + #[rustfmt::skip] + let payload = vec![ + 0x00, 0x01, 0x00, 0x00, + // 6 bytes of WGDEVICE_A_LISTEN_PORT 51820 + 2 bytes of padding + 0x06, 0x00, 0x06, 0x00, 0x6c, 0xca, 0x00, 0x00, + // 8 bytes of WGDEVICE_A_FWMARK 0 + 0x08, 0x00, 0x07, 0x00, 0x00, 0x00, 0x00, 0x00, + // 8 bytes of WGDEVIEC_A_IFINDEX 320 + 0x08, 0x00, 0x01, 0x00, 0x40, 0x01, 0x00, 0x00, + // 12 bytes of WGDEVICE_A_IFNAME "wg-test\0" + 0x0c, 0x00, 0x02, 0x00, 0x77, 0x67, 0x2d, 0x74, 0x65, 0x73, 0x74, 0x00, + // 36 bytes of WGDEVICE_A_PRIVATE_KEY OEf0rWXfVRarrw8nNbTBxkk3NTu8GjRKrbMW1aFH/H0= + 0x24, 0x00, 0x03, 0x00, 0x38, 0x47, 0xf4, 0xad, 0x65, 0xdf, 0x55, 0x16, 0xab, 0xaf, + 0x0f, 0x27, 0x35, 0xb4, 0xc1, 0xc6, 0x49, 0x37, 0x35, 0x3b, 0xbc, 0x1a, 0x34, 0x4a, + 0xad, 0xb3, 0x16, 0xd5, 0xa1, 0x47, 0xfc, 0x7d, + // 36 bytes of WGDEVICE_A_PUBLIC_KEY Ztqy3r8VO1N8tHwpWwqGx1S6G9o12BRdy1JESr2OYzs= + 0x24, 0x00, 0x04, 0x00, 0x66, 0xda, 0xb2, 0xde, 0xbf, 0x15, 0x3b, 0x53, 0x7c, 0xb4, + 0x7c, 0x29, 0x5b, 0x0a, 0x86, 0xc7, 0x54, 0xba, 0x1b, 0xda, 0x35, 0xd8, 0x14, 0x5d, + 0xcb, 0x52, 0x44, 0x4a, 0xbd, 0x8e, 0x63, 0x3b, + // 380 bytes of WGDEVICE_A_PEERS + 0x7c, 0x01, 0x08, 0x80, + // 188 bytes of WGPEER attributes + 0xbc, 0x00, 0x00, 0x80, + // 36 bytes of WGPEER_A_PUBLIC_KEY IOBEBReIZ+XOOyLn14vW7FBRuweaxfskq5wwSZEvhjY= + 0x24, 0x00, 0x01, 0x00, 0x20, 0xe0, 0x44, 0x05, 0x17, 0x88, 0x67, 0xe5, + 0xce, 0x3b, 0x22, 0xe7, 0xd7, 0x8b, 0xd6, 0xec, 0x50, 0x51, 0xbb, 0x07, + 0x9a, 0xc5, 0xfb, 0x24, 0xab, 0x9c, 0x30, 0x49, 0x91, 0x2f, 0x86, 0x36, + // 36 bytes of WGPEER_A_PRESHARED_KEY (all zeroes) + 0x24, 0x00, 0x02, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + // 20 bytes of WGPEER_A_LAST_HANDSHAKE_TIME 0 + 0x14, 0x00, 0x06, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + // 6 bytes of WGPEER_A_PERSISTENT_KEEPALIVE_INTERVAL 0 + 0x06, 0x00, 0x05, 0x00, 0x00, 0x00, 0x00, 0x00, + // 12 bytes of WGPEER_A_TX_BYTES 0 + 0x0c, 0x00, 0x08, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + // 12 bytes of WGPEER_A_RX_BYTES 0 + 0x0c, 0x00, 0x07, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + // 8 bytes of WGPEER_A_PROTOCOL_VERSION 1 + 0x08, 0x00, 0x0a, 0x00, 0x01, 0x00, 0x00, 0x00, + // 20 bytes of WGPEER_A_ENDPOINT 192.168.39.2:9797 + 0x14, 0x00, 0x04, 0x00, 0x02, 0x00, 0x26, 0x45, 0xc0, 0xa8, 0x28, 0x01, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + // 32 bytes of WGPEER_A_ALLOWEDIPS + 0x20, 0x00, 0x09, 0x80, + // 28 bytes of WGALLOWDIP_A_* + 0x1c, 0x00,0x00, 0x80, + // 5 bytes of WGALLOWEDIP_A_CIDR_MASK + 3 bytes of padding 32 + 0x05, 0x00, 0x03, 0x00, 0x20, 0x00, 0x00, 0x00, + // 6 bytes of WGALLOWEDIP_A_FAMILY + 2 bytes of padding 2 (IPv4) + 0x06, 0x00, 0x01, 0x00, 0x02, 0x00, 0x00, 0x00, + // 8 bytes of WGALLOWEDIP_A_IPADDR 192.168.40.1 + 0x08, 0x00, 0x02, 0x00, 0xc0, 0xa8, 0x27, 0x01, + // 188 bvytes of WGPEER attributes + 0xbc, 0x00, 0x00, 0x80, + // 36 bytes of WGPEER_A_PUBLIC_KEY + 0x24, 0x00, 0x01, 0x00, 0xf4, 0x1c, 0xce, 0x0c, 0x4f, 0x24, 0x58, 0xb7, + 0xc2, 0x9d, 0x36, 0x26, 0x36, 0xb7, 0x7f, 0x20, 0x8e, 0x18, 0xfb, 0x9e, + 0xd9, 0x38, 0x0c, 0x92, 0xd0, 0x15, 0x84, 0x9d, 0xa2, 0x44, 0x02, 0x2c, + // 36 bytes of WGPEER_A_PRESHARED_KEY + 0x24, 0x00, 0x02, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + // 20 bytes of WGPEER_A_LAST_HANDSHAKE_TIME + 0x14, 0x00, 0x06, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + // 6 bytes of WGPEER_A_PERSISTENT_KEEPALIVE_INTERVAL + 2 bytes of padding + 0x06, 0x00, 0x05, 0x00, 0x00, 0x00, 0x00, 0x00, + // 12 bytes of WGPEER_A_TX_BYTES + 0x0c, 0x00, 0x08, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + // 12 bytes of WGPEER_A_RX_BYTES + 0x0c, 0x00, 0x07, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + // 8 bytes of WGPEER_A_PROTOCOL_VERSION + 0x08, 0x00, 0x0a, 0x00, 0x01, 0x00, 0x00, 0x00, + // 20 bytes of WGPEER_A_ENDPOINT + 0x14, 0x00, 0x04, 0x00, 0x02, 0x00, 0x26, 0x45, 0xc0, 0xa8, 0x28, 0x02, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + // 32 bytes of WGPEER_A_ALLOWEDIPS + 0x20, 0x00, 0x09, 0x80, + // 28 bytes of WGALLOWDIP_A_* + 0x1c, 0x00, 0x00, 0x80, + // 5 bytes of WGALLOWEDIP_A_CIDR_MASK + 3 bytes of padding 32 + 0x05, 0x00, 0x03, 0x00, 0x20, 0x00, 0x00, 0x00, + // 6 bytes of WGALLOWEDIP_A_FAMILY + 2 bytes of padding 2 (IPv4) + 0x06, 0x00, 0x01, 0x00, 0x02, 0x00, 0x00, 0x00, + // 8 bytes of WGALLOWEDIP_A_IPADDR 192.168.40.2 + 0x08, 0x00, 0x02, 0x00, 0xc0, 0xa8, 0x27, 0x02, + ]; + let header = NetlinkHeader { + length: payload.len() as u32, + message_type: 0, + flags: 0, + sequence_number: 0, + port_number: 0, + }; + let message = DeviceMessage::deserialize(&header, &payload).unwrap(); + + let mut serialized_message = vec![0u8; payload.len()]; + + message.serialize(&mut serialized_message); + + assert_eq!(message, sample_get_message()); + assert_eq!(&payload, &serialized_message) + } + + fn sample_get_message() -> DeviceMessage { + use AllowedIpNla::*; + use DeviceNla::*; + use PeerNla::*; + + let if_name = CString::new(b"wg-test".to_vec()).unwrap(); + + let peer_1 = PeerMessage( + [ + PeerNla::PublicKey([ + 32, 224, 68, 5, 23, 136, 103, 229, 206, 59, 34, 231, 215, 139, 214, 236, 80, + 81, 187, 7, 154, 197, 251, 36, 171, 156, 48, 73, 145, 47, 134, 54, + ]), + PeerNla::PresharedKey([ + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, + ]), + LastHandshakeTime(TimeSpec::seconds(0)), + PersistentKeepaliveInterval(0), + TxBytes(0), + RxBytes(0), + ProtocolVersion(1), + Endpoint(InetAddr::from_std(&"192.168.40.1:9797".parse().unwrap())), + AllowedIps( + [AllowedIpMessage( + [ + CidrMask(32), + AddressFamily(2), + IpAddr(Ipv4Addr::new(192, 168, 39, 1).into()), + ] + .to_vec(), + )] + .to_vec() + .to_vec(), + ), + ] + .to_vec(), + ); + + let peer_2 = PeerMessage( + [ + PeerNla::PublicKey([ + 244, 28, 206, 12, 79, 36, 88, 183, 194, 157, 54, 38, 54, 183, 127, 32, 142, 24, + 251, 158, 217, 56, 12, 146, 208, 21, 132, 157, 162, 68, 2, 44, + ]), + PresharedKey([ + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, + ]), + LastHandshakeTime(TimeSpec::seconds(0)), + PersistentKeepaliveInterval(0), + TxBytes(0), + RxBytes(0), + ProtocolVersion(1), + Endpoint(InetAddr::from_std(&"192.168.40.2:9797".parse().unwrap())), + AllowedIps( + [AllowedIpMessage( + vec![ + CidrMask(32), + AddressFamily(2), + IpAddr(Ipv4Addr::new(192, 168, 39, 2).into()), + ] + .to_vec(), + )] + .to_vec(), + ), + ] + .to_vec(), + ); + + DeviceMessage { + command: WG_CMD_GET_DEVICE, + message_type: 0, + nlas: [ + ListenPort(51820), + Fwmark(0), + IfIndex(320), + IfName(if_name), + PrivateKey([ + 56, 71, 244, 173, 101, 223, 85, 22, 171, 175, 15, 39, 53, 180, 193, 198, 73, + 55, 53, 59, 188, 26, 52, 74, 173, 179, 22, 213, 161, 71, 252, 125, + ]), + DeviceNla::PublicKey([ + 102, 218, 178, 222, 191, 21, 59, 83, 124, 180, 124, 41, 91, 10, 134, 199, 84, + 186, 27, 218, 53, 216, 20, 93, 203, 82, 68, 74, 189, 142, 99, 59, + ]), + Peers([peer_1, peer_2].to_vec()), + ] + .to_vec(), + } + } + + pub fn sample_set_message() -> DeviceMessage { + use AllowedIpNla::*; + use DeviceNla::*; + use PeerNla::*; + + let if_name = CString::new("wg-test".to_string()).unwrap(); + + let peer_1 = PeerMessage( + [ + PeerNla::PublicKey([ + 32, 224, 68, 5, 23, 136, 103, 229, 206, 59, 34, 231, 215, 139, 214, 236, 80, + 81, 187, 7, 154, 197, 251, 36, 171, 156, 48, 73, 145, 47, 134, 54, + ]), + Endpoint(InetAddr::from_std(&"192.168.40.1:9797".parse().unwrap())), + PeerNla::Flags(WGPEER_F_REPLACE_ALLOWEDIPS), + AllowedIps( + [AllowedIpMessage( + [ + AddressFamily(2), + IpAddr(Ipv4Addr::new(192, 168, 39, 1).into()), + CidrMask(32), + ] + .to_vec(), + )] + .to_vec() + .to_vec(), + ), + ] + .to_vec(), + ); + + let peer_2 = PeerMessage( + [ + PeerNla::PublicKey([ + 244, 28, 206, 12, 79, 36, 88, 183, 194, 157, 54, 38, 54, 183, 127, 32, 142, 24, + 251, 158, 217, 56, 12, 146, 208, 21, 132, 157, 162, 68, 2, 44, + ]), + Endpoint(InetAddr::from_std(&"192.168.40.2:9797".parse().unwrap())), + PeerNla::Flags(WGPEER_F_REPLACE_ALLOWEDIPS), + AllowedIps( + [AllowedIpMessage( + vec![ + AddressFamily(2), + IpAddr(Ipv4Addr::new(192, 168, 39, 2).into()), + CidrMask(32), + ] + .to_vec(), + )] + .to_vec(), + ), + ] + .to_vec(), + ); + + DeviceMessage { + command: WG_CMD_SET_DEVICE, + message_type: 0, + nlas: [ + IfName(if_name), + PrivateKey([ + 56, 71, 244, 173, 101, 223, 85, 22, 171, 175, 15, 39, 53, 180, 193, 198, 73, + 55, 53, 59, 188, 26, 52, 74, 173, 179, 22, 213, 161, 71, 252, 125, + ]), + ListenPort(51820), + Peers([peer_1, peer_2].to_vec()), + ] + .to_vec(), + } + } + + + #[test] + fn serialize_netlink_message() { + let expected_payload: &[u8] = &[ + 0x01, 0x01, 0x00, 0x00, 0x0c, 0x00, 0x02, 0x00, 0x77, 0x67, 0x2d, 0x74, 0x65, 0x73, + 0x74, 0x00, 0x24, 0x00, 0x03, 0x00, 0x38, 0x47, 0xf4, 0xad, 0x65, 0xdf, 0x55, 0x16, + 0xab, 0xaf, 0x0f, 0x27, 0x35, 0xb4, 0xc1, 0xc6, 0x49, 0x37, 0x35, 0x3b, 0xbc, 0x1a, + 0x34, 0x4a, 0xad, 0xb3, 0x16, 0xd5, 0xa1, 0x47, 0xfc, 0x7d, 0x06, 0x00, 0x06, 0x00, + 0x6c, 0xca, 0x00, 0x00, 0xcc, 0x00, 0x08, 0x80, 0x64, 0x00, 0x00, 0x80, 0x24, 0x00, + 0x01, 0x00, 0x20, 0xe0, 0x44, 0x05, 0x17, 0x88, 0x67, 0xe5, 0xce, 0x3b, 0x22, 0xe7, + 0xd7, 0x8b, 0xd6, 0xec, 0x50, 0x51, 0xbb, 0x07, 0x9a, 0xc5, 0xfb, 0x24, 0xab, 0x9c, + 0x30, 0x49, 0x91, 0x2f, 0x86, 0x36, 0x14, 0x00, 0x04, 0x00, 0x02, 0x00, 0x26, 0x45, + 0xc0, 0xa8, 0x28, 0x01, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x08, 0x00, + 0x03, 0x00, 0x02, 0x00, 0x00, 0x00, 0x20, 0x00, 0x09, 0x80, 0x1c, 0x00, 0x00, 0x80, + 0x06, 0x00, 0x01, 0x00, 0x02, 0x00, 0x00, 0x00, 0x08, 0x00, 0x02, 0x00, 0xc0, 0xa8, + 0x27, 0x01, 0x05, 0x00, 0x03, 0x00, 0x20, 0x00, 0x00, 0x00, 0x64, 0x00, 0x00, 0x80, + 0x24, 0x00, 0x01, 0x00, 0xf4, 0x1c, 0xce, 0x0c, 0x4f, 0x24, 0x58, 0xb7, 0xc2, 0x9d, + 0x36, 0x26, 0x36, 0xb7, 0x7f, 0x20, 0x8e, 0x18, 0xfb, 0x9e, 0xd9, 0x38, 0x0c, 0x92, + 0xd0, 0x15, 0x84, 0x9d, 0xa2, 0x44, 0x02, 0x2c, 0x14, 0x00, 0x04, 0x00, 0x02, 0x00, + 0x26, 0x45, 0xc0, 0xa8, 0x28, 0x02, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0x08, 0x00, 0x03, 0x00, 0x02, 0x00, 0x00, 0x00, 0x20, 0x00, 0x09, 0x80, 0x1c, 0x00, + 0x00, 0x80, 0x06, 0x00, 0x01, 0x00, 0x02, 0x00, 0x00, 0x00, 0x08, 0x00, 0x02, 0x00, + 0xc0, 0xa8, 0x27, 0x02, 0x05, 0x00, 0x03, 0x00, 0x20, 0x00, 0x00, 0x00, + ]; + + let mut message = sample_set_message(); + message.command = WG_CMD_SET_DEVICE; + + + let mut payload_buffer = vec![0u8; message.buffer_len()]; + message.serialize(&mut payload_buffer); + let header = NetlinkHeader { + length: payload_buffer.len() as u32, + message_type: 0, + flags: 0, + sequence_number: 0, + port_number: 0, + }; + let deserialized_device = DeviceMessage::deserialize(&header, &payload_buffer).unwrap(); + + assert_eq!(message, deserialized_device); + assert_eq!(payload_buffer, expected_payload); + } +} |
