summaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorDavid Lönnhager <david.l@mullvad.net>2021-10-25 19:40:08 +0200
committerDavid Lönnhager <david.l@mullvad.net>2021-10-25 19:40:08 +0200
commit6a99850fa3a5d3766dfad18da396194c7c73c33c (patch)
tree14f6afada72028cf004fddd91f6d6ae8c1dbf5bd
parent721f89300d100da41b3507eb53f495e9a17a2f3e (diff)
parent9f7ae3f10e885f98dd9317f13a80bd738d3ff434 (diff)
downloadmullvadvpn-6a99850fa3a5d3766dfad18da396194c7c73c33c.tar.xz
mullvadvpn-6a99850fa3a5d3766dfad18da396194c7c73c33c.zip
Merge remote-tracking branch 'origin/win-refactor-use-socket2'
-rw-r--r--talpid-core/src/tunnel/wireguard/wireguard_nt.rs72
-rw-r--r--talpid-core/src/windows.rs88
2 files changed, 67 insertions, 93 deletions
diff --git a/talpid-core/src/tunnel/wireguard/wireguard_nt.rs b/talpid-core/src/tunnel/wireguard/wireguard_nt.rs
index b094416ae7..14097baedb 100644
--- a/talpid-core/src/tunnel/wireguard/wireguard_nt.rs
+++ b/talpid-core/src/tunnel/wireguard/wireguard_nt.rs
@@ -12,6 +12,7 @@ use std::{
ffi::CStr,
fmt, io, iter, mem,
mem::MaybeUninit,
+ net::{IpAddr, Ipv4Addr, Ipv6Addr},
os::windows::{ffi::OsStrExt, io::RawHandle},
path::Path,
ptr,
@@ -201,6 +202,31 @@ union WgIpAddr {
v6: IN6_ADDR,
}
+impl From<IpAddr> for WgIpAddr {
+ fn from(address: IpAddr) -> Self {
+ match address {
+ IpAddr::V4(addr) => WgIpAddr::from(addr),
+ IpAddr::V6(addr) => WgIpAddr::from(addr),
+ }
+ }
+}
+
+impl From<Ipv6Addr> for WgIpAddr {
+ fn from(address: Ipv6Addr) -> Self {
+ Self {
+ v6: windows::in6addr_from_ipaddr(address),
+ }
+ }
+}
+
+impl From<Ipv4Addr> for WgIpAddr {
+ fn from(address: Ipv4Addr) -> Self {
+ Self {
+ v4: windows::inaddr_from_ipaddr(address),
+ }
+ }
+}
+
/// See `WIREGUARD_ALLOWED_IP` at https://git.zx2c4.com/wireguard-nt/tree/api/wireguard.h.
#[derive(Clone, Copy)]
#[repr(C, align(8))]
@@ -911,12 +937,8 @@ fn serialize_config(config: &Config) -> Result<Vec<MaybeUninit<u8>>> {
IpNetwork::V6(_) => AF_INET6 as u16,
};
let address = match allowed_ip {
- IpNetwork::V4(v4_network) => WgIpAddr {
- v4: windows::inaddr_from_ipaddr(v4_network.ip()),
- },
- IpNetwork::V6(v6_network) => WgIpAddr {
- v6: windows::in6addr_from_ipaddr(v6_network.ip()),
- },
+ IpNetwork::V4(v4_network) => WgIpAddr::from(v4_network.ip()),
+ IpNetwork::V6(v6_network) => WgIpAddr::from(v6_network.ip()),
};
let wg_allowed_ip =
@@ -1108,9 +1130,7 @@ mod tests {
allowed_ips_count: 1,
},
p0_allowed_ip_0: WgAllowedIp {
- address: WgIpAddr {
- v4: windows::inaddr_from_ipaddr("1.3.3.0".parse().unwrap()),
- },
+ address: WgIpAddr::from("1.3.3.0".parse::<Ipv4Addr>().unwrap()),
address_family: AF_INET as u16,
cidr: 24,
},
@@ -1150,31 +1170,23 @@ mod tests {
fn test_wg_allowed_ip_v4() {
// Valid: /32 prefix
let address_family = AF_INET as u16;
- let address = WgIpAddr {
- v4: windows::inaddr_from_ipaddr("127.0.0.1".parse().unwrap()),
- };
+ let address = WgIpAddr::from("127.0.0.1".parse::<Ipv4Addr>().unwrap());
let cidr = 32;
WgAllowedIp::new(address, address_family, cidr).unwrap();
// Invalid host bits
let cidr = 24;
- let address = WgIpAddr {
- v4: windows::inaddr_from_ipaddr("0.0.0.1".parse().unwrap()),
- };
+ let address = WgIpAddr::from("0.0.0.1".parse::<Ipv4Addr>().unwrap());
assert!(WgAllowedIp::new(address, address_family, cidr).is_err());
// Valid host bits
let cidr = 24;
- let address = WgIpAddr {
- v4: windows::inaddr_from_ipaddr("255.255.255.0".parse().unwrap()),
- };
+ let address = WgIpAddr::from("255.255.255.0".parse::<Ipv4Addr>().unwrap());
WgAllowedIp::new(address, address_family, cidr).unwrap();
// 0.0.0.0/0
let cidr = 0;
- let address = WgIpAddr {
- v4: windows::inaddr_from_ipaddr("0.0.0.0".parse().unwrap()),
- };
+ let address = WgIpAddr::from("0.0.0.0".parse::<Ipv4Addr>().unwrap());
WgAllowedIp::new(address, address_family, cidr).unwrap();
// Invalid CIDR
@@ -1186,9 +1198,7 @@ mod tests {
fn test_wg_allowed_ip_v6() {
// Valid: /128 prefix
let address_family = AF_INET6 as u16;
- let address = WgIpAddr {
- v6: windows::in6addr_from_ipaddr("::1".parse().unwrap()),
- };
+ let address = WgIpAddr::from("::1".parse::<Ipv6Addr>().unwrap());
let cidr = 128;
WgAllowedIp::new(address, address_family, cidr).unwrap();
@@ -1197,18 +1207,16 @@ mod tests {
assert!(WgAllowedIp::new(address, address_family, cidr).is_err());
// Valid host bits
- let address = WgIpAddr {
- v6: windows::in6addr_from_ipaddr(
- "ffff:ffff:ffff:ffff:ffff:ffff:ffff:fffe".parse().unwrap(),
- ),
- };
+ let address = WgIpAddr::from(
+ "ffff:ffff:ffff:ffff:ffff:ffff:ffff:fffe"
+ .parse::<Ipv6Addr>()
+ .unwrap(),
+ );
WgAllowedIp::new(address, address_family, cidr).unwrap();
// ::/0
let cidr = 0;
- let address = WgIpAddr {
- v6: windows::in6addr_from_ipaddr("::".parse().unwrap()),
- };
+ let address = WgIpAddr::from("::".parse::<Ipv6Addr>().unwrap());
WgAllowedIp::new(address, address_family, cidr).unwrap();
// Invalid CIDR
diff --git a/talpid-core/src/windows.rs b/talpid-core/src/windows.rs
index 6236f32c4a..7648441a91 100644
--- a/talpid-core/src/windows.rs
+++ b/talpid-core/src/windows.rs
@@ -1,9 +1,9 @@
+use socket2::SockAddr;
use std::{
ffi::OsStr,
fmt, io, mem,
net::{Ipv4Addr, Ipv6Addr, SocketAddr, SocketAddrV4, SocketAddrV6},
os::windows::{ffi::OsStrExt, io::RawHandle},
- ptr,
sync::Mutex,
time::{Duration, Instant},
};
@@ -20,8 +20,11 @@ use winapi::shared::{
nldef::{IpDadStatePreferred, IpDadStateTentative, NL_DAD_STATE},
ntdef::FALSE,
winerror::{ERROR_NOT_FOUND, NO_ERROR},
- ws2def::{AF_INET, AF_INET6, AF_UNSPEC},
- ws2ipdef::SOCKADDR_INET,
+ ws2def::{
+ AF_INET, AF_INET6, AF_UNSPEC, SOCKADDR_IN as sockaddr_in,
+ SOCKADDR_STORAGE as sockaddr_storage,
+ },
+ ws2ipdef::{SOCKADDR_IN6_LH as sockaddr_in6, SOCKADDR_INET},
};
/// Result type for this module.
@@ -364,35 +367,19 @@ fn af_family_from_family(family: Option<AddressFamily>) -> u16 {
/// Converts an `Ipv4Addr` to `IN_ADDR`
pub fn inaddr_from_ipaddr(addr: Ipv4Addr) -> IN_ADDR {
- let mut in_addr: IN_ADDR = unsafe { mem::zeroed() };
- let addr_octets = addr.octets();
- unsafe {
- ptr::copy_nonoverlapping(
- &addr_octets as *const _,
- in_addr.S_un.S_addr_mut() as *mut _ as *mut u8,
- addr_octets.len(),
- );
- }
- in_addr
+ let sockaddr = SockAddr::from(SocketAddr::V4(SocketAddrV4::new(addr, 0)));
+ (&unsafe { *(sockaddr.as_ptr() as *const sockaddr_in) }).sin_addr
}
/// Converts an `Ipv6Addr` to `IN6_ADDR`
pub fn in6addr_from_ipaddr(addr: Ipv6Addr) -> IN6_ADDR {
- let mut in_addr: IN6_ADDR = unsafe { mem::zeroed() };
- let addr_octets = addr.octets();
- unsafe {
- ptr::copy_nonoverlapping(
- &addr_octets as *const _,
- in_addr.u.Byte_mut() as *mut _,
- addr_octets.len(),
- );
- }
- in_addr
+ let sockaddr = SockAddr::from(SocketAddr::V6(SocketAddrV6::new(addr, 0, 0, 0)));
+ (&unsafe { *(sockaddr.as_ptr() as *const sockaddr_in6) }).sin6_addr
}
/// Converts an `IN_ADDR` to `Ipv4Addr`
pub fn ipaddr_from_inaddr(addr: IN_ADDR) -> Ipv4Addr {
- Ipv4Addr::from(unsafe { *(addr.S_un.S_addr()) }.to_be())
+ Ipv4Addr::from(unsafe { *(addr.S_un.S_addr()) }.to_ne_bytes())
}
/// Converts an `IN6_ADDR` to `Ipv6Addr`
@@ -403,52 +390,31 @@ pub fn ipaddr_from_in6addr(addr: IN6_ADDR) -> Ipv6Addr {
/// Converts a `SocketAddr` to `SOCKADDR_INET`
pub fn inet_sockaddr_from_socketaddr(addr: SocketAddr) -> SOCKADDR_INET {
let mut sockaddr: SOCKADDR_INET = unsafe { mem::zeroed() };
-
match addr {
- SocketAddr::V4(v4_addr) => {
- unsafe {
- *sockaddr.si_family_mut() = AF_INET as u16;
- }
-
- let mut v4sockaddr = unsafe { sockaddr.Ipv4_mut() };
- v4sockaddr.sin_family = AF_INET as u16;
- v4sockaddr.sin_port = v4_addr.port().to_be();
- v4sockaddr.sin_addr = inaddr_from_ipaddr(*v4_addr.ip());
- }
- SocketAddr::V6(v6_addr) => {
- unsafe {
- *sockaddr.si_family_mut() = AF_INET6 as u16;
- }
-
- let mut v6sockaddr = unsafe { sockaddr.Ipv6_mut() };
- v6sockaddr.sin6_family = AF_INET6 as u16;
- v6sockaddr.sin6_port = v6_addr.port().to_be();
- v6sockaddr.sin6_addr = in6addr_from_ipaddr(*v6_addr.ip());
- v6sockaddr.sin6_flowinfo = v6_addr.flowinfo();
- *unsafe { v6sockaddr.u.sin6_scope_id_mut() } = v6_addr.scope_id();
- }
+ // SAFETY: `*const sockaddr` may be treated as `*const sockaddr_in` since we know it's a v4
+ // address.
+ SocketAddr::V4(_) => unsafe {
+ *sockaddr.Ipv4_mut() = *(SockAddr::from(addr).as_ptr() as *const _)
+ },
+ // SAFETY: `*const sockaddr` may be treated as `*const sockaddr_in6` since we know it's a v6
+ // address.
+ SocketAddr::V6(_) => unsafe {
+ *sockaddr.Ipv6_mut() = *(SockAddr::from(addr).as_ptr() as *const _)
+ },
}
-
sockaddr
}
/// Converts a `SOCKADDR_INET` to `SocketAddr`. Returns an error if the address family is invalid.
pub fn try_socketaddr_from_inet_sockaddr(addr: SOCKADDR_INET) -> Result<SocketAddr> {
+ let family = unsafe { *addr.si_family() } as i32;
unsafe {
- match *addr.si_family() as i32 {
- AF_INET => Ok(SocketAddr::V4(SocketAddrV4::new(
- ipaddr_from_inaddr(addr.Ipv4().sin_addr),
- u16::from_be(addr.Ipv4().sin_port),
- ))),
- AF_INET6 => Ok(SocketAddr::V6(SocketAddrV6::new(
- ipaddr_from_in6addr(addr.Ipv6().sin6_addr),
- u16::from_be(addr.Ipv6().sin6_port),
- addr.Ipv6().sin6_flowinfo,
- *addr.Ipv6().u.sin6_scope_id(),
- ))),
- family => Err(Error::UnknownAddressFamily(family)),
- }
+ let mut storage: sockaddr_storage = mem::zeroed();
+ *(&mut storage as *mut _ as *mut SOCKADDR_INET) = addr;
+ SockAddr::new(storage, mem::size_of_val(&addr) as i32)
}
+ .as_socket()
+ .ok_or(Error::UnknownAddressFamily(family))
}
/// Casts a struct to a slice of possibly uninitialized bytes.