summaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
-rw-r--r--talpid-wireguard/src/wireguard_kernel/parsers.rs15
-rw-r--r--talpid-wireguard/src/wireguard_kernel/wg_message.rs41
2 files changed, 36 insertions, 20 deletions
diff --git a/talpid-wireguard/src/wireguard_kernel/parsers.rs b/talpid-wireguard/src/wireguard_kernel/parsers.rs
index a8c16d7d68..a21155bdf9 100644
--- a/talpid-wireguard/src/wireguard_kernel/parsers.rs
+++ b/talpid-wireguard/src/wireguard_kernel/parsers.rs
@@ -1,9 +1,12 @@
use byteorder::{ByteOrder, NativeEndian};
-use nix::sys::{socket::InetAddr, time::TimeSpec};
+use nix::sys::{
+ socket::{SockaddrIn, SockaddrIn6},
+ time::TimeSpec,
+};
use std::{
ffi::{CStr, CString},
mem::{self, transmute},
- net::IpAddr,
+ net::{IpAddr, SocketAddr},
};
pub use netlink_packet_utils::parsers::*;
@@ -35,7 +38,7 @@ pub fn parse_wg_key(buffer: &[u8]) -> Result<[u8; 32], DecodeError> {
}
}
-pub fn parse_inet_sockaddr(buffer: &[u8]) -> Result<InetAddr, DecodeError> {
+pub fn parse_inet_sockaddr(buffer: &[u8]) -> Result<SocketAddr, DecodeError> {
let wrong_len = || {
format!(
"Unexpected length for sockaddr_in: {}, expected {} or {}",
@@ -59,8 +62,9 @@ pub fn parse_inet_sockaddr(buffer: &[u8]) -> Result<InetAddr, DecodeError> {
// SAFETY: sockaddr_in has a defined repr(C) layout and is valid for all bit patterns
let sockaddr: libc::sockaddr_in = unsafe { transmute(*buffer) };
+ let sockaddr = SockaddrIn::from(sockaddr);
- Ok(InetAddr::V4(sockaddr))
+ Ok(SocketAddr::from(sockaddr))
}
AF_INET6 => {
let buffer: &[u8; size_of::<libc::sockaddr_in6>()] =
@@ -68,8 +72,9 @@ pub fn parse_inet_sockaddr(buffer: &[u8]) -> Result<InetAddr, DecodeError> {
// SAFETY: sockaddr_in6 has a defined repr(C) layout and is valid for all bit patterns
let sockaddr: libc::sockaddr_in6 = unsafe { transmute(*buffer) };
+ let sockaddr = SockaddrIn6::from(sockaddr);
- Ok(InetAddr::V6(sockaddr))
+ Ok(SocketAddr::from(sockaddr))
}
unexpected_addr_family => {
Err(format!("Unexpected address family: {unexpected_addr_family}").into())
diff --git a/talpid-wireguard/src/wireguard_kernel/wg_message.rs b/talpid-wireguard/src/wireguard_kernel/wg_message.rs
index e7f81fdc1f..1c9e206a68 100644
--- a/talpid-wireguard/src/wireguard_kernel/wg_message.rs
+++ b/talpid-wireguard/src/wireguard_kernel/wg_message.rs
@@ -9,8 +9,16 @@ use netlink_packet_utils::{
traits::{Emitable, Parseable},
DecodeError,
};
-use nix::sys::{socket::InetAddr, time::TimeSpec};
-use std::{ffi::CString, io::Write, mem, net::IpAddr};
+use nix::sys::{
+ socket::{SockaddrIn, SockaddrIn6},
+ time::TimeSpec,
+};
+use std::{
+ ffi::CString,
+ io::Write,
+ mem,
+ net::{IpAddr, SocketAddr},
+};
/// WireGuard netlink constants
mod constants {
@@ -78,11 +86,10 @@ impl DeviceMessage {
let mut peers = vec![];
for peer in config.peers() {
- let peer_endpoint = InetAddr::from_std(&peer.endpoint);
let allowed_ips = peer.allowed_ips.iter().map(From::from).collect();
let mut peer_nlas = vec![
PeerNla::PublicKey(*peer.public_key.as_bytes()),
- PeerNla::Endpoint(peer_endpoint),
+ PeerNla::Endpoint(peer.endpoint),
PeerNla::AllowedIps(allowed_ips),
PeerNla::Flags(WGPEER_F_REPLACE_ALLOWEDIPS),
];
@@ -333,7 +340,7 @@ pub enum PeerNla {
PublicKey(PublicKey),
PresharedKey(PresharedKey),
Flags(u32),
- Endpoint(InetAddr),
+ Endpoint(SocketAddr),
PersistentKeepaliveInterval(u16),
LastHandshakeTime(TimeSpec),
RxBytes(u64),
@@ -348,8 +355,8 @@ impl Nla for PeerNla {
match self {
PublicKey(key) | PresharedKey(key) => key.len(),
Endpoint(endpoint) => match &endpoint {
- InetAddr::V4(_) => mem::size_of::<libc::sockaddr_in>(),
- InetAddr::V6(_) => mem::size_of::<libc::sockaddr_in6>(),
+ SocketAddr::V4(_) => mem::size_of::<libc::sockaddr_in>(),
+ SocketAddr::V6(_) => mem::size_of::<libc::sockaddr_in6>(),
},
PersistentKeepaliveInterval(_) => 2,
LastHandshakeTime(_) => mem::size_of::<libc::timespec>(),
@@ -384,14 +391,18 @@ impl Nla for PeerNla {
let _ = buffer.write(key).expect("Buffer too small for a key");
}
Flags(value) | ProtocolVersion(value) => NativeEndian::write_u32(buffer, *value),
- Endpoint(endpoint) => match &endpoint {
- InetAddr::V4(sockaddr_in) => {
+ &Endpoint(endpoint) => match endpoint {
+ SocketAddr::V4(addr) => {
+ let sockaddr_in = SockaddrIn::from(addr);
+ let sockaddr_in: &libc::sockaddr_in = sockaddr_in.as_ref();
buffer
// SAFETY: `sockaddr_in` has no padding bytes
.write_all(unsafe { struct_as_slice(sockaddr_in) })
.expect("Buffer too small for sockaddr_in");
}
- InetAddr::V6(sockaddr_in6) => {
+ SocketAddr::V6(addr) => {
+ let sockaddr_in6 = SockaddrIn6::from(addr);
+ let sockaddr_in6: &libc::sockaddr_in6 = sockaddr_in6.as_ref();
buffer
// SAFETY: `sockaddr_in` has no padding bytes
.write_all(unsafe { struct_as_slice(sockaddr_in6) })
@@ -589,7 +600,7 @@ fn ip_addr_to_bytes(addr: &IpAddr) -> Vec<u8> {
mod test {
use super::*;
use nix::sys::time::TimeValLike;
- use std::net::Ipv4Addr;
+ use std::{net::Ipv4Addr, str::FromStr};
#[test]
fn deserialize_netlink_message() {
@@ -721,7 +732,7 @@ mod test {
TxBytes(0),
RxBytes(0),
ProtocolVersion(1),
- Endpoint(InetAddr::from_std(&"192.168.40.1:9797".parse().unwrap())),
+ Endpoint(SocketAddr::from_str("192.168.40.1:9797").unwrap()),
AllowedIps(vec![AllowedIpMessage(vec![
CidrMask(32),
AddressFamily(2),
@@ -743,7 +754,7 @@ mod test {
TxBytes(0),
RxBytes(0),
ProtocolVersion(1),
- Endpoint(InetAddr::from_std(&"192.168.40.2:9797".parse().unwrap())),
+ Endpoint(SocketAddr::from_str("192.168.40.2:9797").unwrap()),
AllowedIps(vec![AllowedIpMessage(vec![
CidrMask(32),
AddressFamily(2),
@@ -784,7 +795,7 @@ mod test {
32, 224, 68, 5, 23, 136, 103, 229, 206, 59, 34, 231, 215, 139, 214, 236, 80, 81,
187, 7, 154, 197, 251, 36, 171, 156, 48, 73, 145, 47, 134, 54,
]),
- Endpoint(InetAddr::from_std(&"192.168.40.1:9797".parse().unwrap())),
+ Endpoint(SocketAddr::from_str("192.168.40.1:9797").unwrap()),
PeerNla::Flags(WGPEER_F_REPLACE_ALLOWEDIPS),
AllowedIps(vec![AllowedIpMessage(vec![
AddressFamily(2),
@@ -798,7 +809,7 @@ mod test {
244, 28, 206, 12, 79, 36, 88, 183, 194, 157, 54, 38, 54, 183, 127, 32, 142, 24,
251, 158, 217, 56, 12, 146, 208, 21, 132, 157, 162, 68, 2, 44,
]),
- Endpoint(InetAddr::from_std(&"192.168.40.2:9797".parse().unwrap())),
+ Endpoint(SocketAddr::from_str("192.168.40.2:9797").unwrap()),
PeerNla::Flags(WGPEER_F_REPLACE_ALLOWEDIPS),
AllowedIps(vec![AllowedIpMessage(vec![
AddressFamily(2),