diff options
| author | David Lönnhager <david.l@mullvad.net> | 2021-10-25 19:40:08 +0200 |
|---|---|---|
| committer | David Lönnhager <david.l@mullvad.net> | 2021-10-25 19:40:08 +0200 |
| commit | 6a99850fa3a5d3766dfad18da396194c7c73c33c (patch) | |
| tree | 14f6afada72028cf004fddd91f6d6ae8c1dbf5bd | |
| parent | 721f89300d100da41b3507eb53f495e9a17a2f3e (diff) | |
| parent | 9f7ae3f10e885f98dd9317f13a80bd738d3ff434 (diff) | |
| download | mullvadvpn-6a99850fa3a5d3766dfad18da396194c7c73c33c.tar.xz mullvadvpn-6a99850fa3a5d3766dfad18da396194c7c73c33c.zip | |
Merge remote-tracking branch 'origin/win-refactor-use-socket2'
| -rw-r--r-- | talpid-core/src/tunnel/wireguard/wireguard_nt.rs | 72 | ||||
| -rw-r--r-- | talpid-core/src/windows.rs | 88 |
2 files changed, 67 insertions, 93 deletions
diff --git a/talpid-core/src/tunnel/wireguard/wireguard_nt.rs b/talpid-core/src/tunnel/wireguard/wireguard_nt.rs index b094416ae7..14097baedb 100644 --- a/talpid-core/src/tunnel/wireguard/wireguard_nt.rs +++ b/talpid-core/src/tunnel/wireguard/wireguard_nt.rs @@ -12,6 +12,7 @@ use std::{ ffi::CStr, fmt, io, iter, mem, mem::MaybeUninit, + net::{IpAddr, Ipv4Addr, Ipv6Addr}, os::windows::{ffi::OsStrExt, io::RawHandle}, path::Path, ptr, @@ -201,6 +202,31 @@ union WgIpAddr { v6: IN6_ADDR, } +impl From<IpAddr> for WgIpAddr { + fn from(address: IpAddr) -> Self { + match address { + IpAddr::V4(addr) => WgIpAddr::from(addr), + IpAddr::V6(addr) => WgIpAddr::from(addr), + } + } +} + +impl From<Ipv6Addr> for WgIpAddr { + fn from(address: Ipv6Addr) -> Self { + Self { + v6: windows::in6addr_from_ipaddr(address), + } + } +} + +impl From<Ipv4Addr> for WgIpAddr { + fn from(address: Ipv4Addr) -> Self { + Self { + v4: windows::inaddr_from_ipaddr(address), + } + } +} + /// See `WIREGUARD_ALLOWED_IP` at https://git.zx2c4.com/wireguard-nt/tree/api/wireguard.h. #[derive(Clone, Copy)] #[repr(C, align(8))] @@ -911,12 +937,8 @@ fn serialize_config(config: &Config) -> Result<Vec<MaybeUninit<u8>>> { IpNetwork::V6(_) => AF_INET6 as u16, }; let address = match allowed_ip { - IpNetwork::V4(v4_network) => WgIpAddr { - v4: windows::inaddr_from_ipaddr(v4_network.ip()), - }, - IpNetwork::V6(v6_network) => WgIpAddr { - v6: windows::in6addr_from_ipaddr(v6_network.ip()), - }, + IpNetwork::V4(v4_network) => WgIpAddr::from(v4_network.ip()), + IpNetwork::V6(v6_network) => WgIpAddr::from(v6_network.ip()), }; let wg_allowed_ip = @@ -1108,9 +1130,7 @@ mod tests { allowed_ips_count: 1, }, p0_allowed_ip_0: WgAllowedIp { - address: WgIpAddr { - v4: windows::inaddr_from_ipaddr("1.3.3.0".parse().unwrap()), - }, + address: WgIpAddr::from("1.3.3.0".parse::<Ipv4Addr>().unwrap()), address_family: AF_INET as u16, cidr: 24, }, @@ -1150,31 +1170,23 @@ mod tests { fn test_wg_allowed_ip_v4() { // Valid: /32 prefix let address_family = AF_INET as u16; - let address = WgIpAddr { - v4: windows::inaddr_from_ipaddr("127.0.0.1".parse().unwrap()), - }; + let address = WgIpAddr::from("127.0.0.1".parse::<Ipv4Addr>().unwrap()); let cidr = 32; WgAllowedIp::new(address, address_family, cidr).unwrap(); // Invalid host bits let cidr = 24; - let address = WgIpAddr { - v4: windows::inaddr_from_ipaddr("0.0.0.1".parse().unwrap()), - }; + let address = WgIpAddr::from("0.0.0.1".parse::<Ipv4Addr>().unwrap()); assert!(WgAllowedIp::new(address, address_family, cidr).is_err()); // Valid host bits let cidr = 24; - let address = WgIpAddr { - v4: windows::inaddr_from_ipaddr("255.255.255.0".parse().unwrap()), - }; + let address = WgIpAddr::from("255.255.255.0".parse::<Ipv4Addr>().unwrap()); WgAllowedIp::new(address, address_family, cidr).unwrap(); // 0.0.0.0/0 let cidr = 0; - let address = WgIpAddr { - v4: windows::inaddr_from_ipaddr("0.0.0.0".parse().unwrap()), - }; + let address = WgIpAddr::from("0.0.0.0".parse::<Ipv4Addr>().unwrap()); WgAllowedIp::new(address, address_family, cidr).unwrap(); // Invalid CIDR @@ -1186,9 +1198,7 @@ mod tests { fn test_wg_allowed_ip_v6() { // Valid: /128 prefix let address_family = AF_INET6 as u16; - let address = WgIpAddr { - v6: windows::in6addr_from_ipaddr("::1".parse().unwrap()), - }; + let address = WgIpAddr::from("::1".parse::<Ipv6Addr>().unwrap()); let cidr = 128; WgAllowedIp::new(address, address_family, cidr).unwrap(); @@ -1197,18 +1207,16 @@ mod tests { assert!(WgAllowedIp::new(address, address_family, cidr).is_err()); // Valid host bits - let address = WgIpAddr { - v6: windows::in6addr_from_ipaddr( - "ffff:ffff:ffff:ffff:ffff:ffff:ffff:fffe".parse().unwrap(), - ), - }; + let address = WgIpAddr::from( + "ffff:ffff:ffff:ffff:ffff:ffff:ffff:fffe" + .parse::<Ipv6Addr>() + .unwrap(), + ); WgAllowedIp::new(address, address_family, cidr).unwrap(); // ::/0 let cidr = 0; - let address = WgIpAddr { - v6: windows::in6addr_from_ipaddr("::".parse().unwrap()), - }; + let address = WgIpAddr::from("::".parse::<Ipv6Addr>().unwrap()); WgAllowedIp::new(address, address_family, cidr).unwrap(); // Invalid CIDR diff --git a/talpid-core/src/windows.rs b/talpid-core/src/windows.rs index 6236f32c4a..7648441a91 100644 --- a/talpid-core/src/windows.rs +++ b/talpid-core/src/windows.rs @@ -1,9 +1,9 @@ +use socket2::SockAddr; use std::{ ffi::OsStr, fmt, io, mem, net::{Ipv4Addr, Ipv6Addr, SocketAddr, SocketAddrV4, SocketAddrV6}, os::windows::{ffi::OsStrExt, io::RawHandle}, - ptr, sync::Mutex, time::{Duration, Instant}, }; @@ -20,8 +20,11 @@ use winapi::shared::{ nldef::{IpDadStatePreferred, IpDadStateTentative, NL_DAD_STATE}, ntdef::FALSE, winerror::{ERROR_NOT_FOUND, NO_ERROR}, - ws2def::{AF_INET, AF_INET6, AF_UNSPEC}, - ws2ipdef::SOCKADDR_INET, + ws2def::{ + AF_INET, AF_INET6, AF_UNSPEC, SOCKADDR_IN as sockaddr_in, + SOCKADDR_STORAGE as sockaddr_storage, + }, + ws2ipdef::{SOCKADDR_IN6_LH as sockaddr_in6, SOCKADDR_INET}, }; /// Result type for this module. @@ -364,35 +367,19 @@ fn af_family_from_family(family: Option<AddressFamily>) -> u16 { /// Converts an `Ipv4Addr` to `IN_ADDR` pub fn inaddr_from_ipaddr(addr: Ipv4Addr) -> IN_ADDR { - let mut in_addr: IN_ADDR = unsafe { mem::zeroed() }; - let addr_octets = addr.octets(); - unsafe { - ptr::copy_nonoverlapping( - &addr_octets as *const _, - in_addr.S_un.S_addr_mut() as *mut _ as *mut u8, - addr_octets.len(), - ); - } - in_addr + let sockaddr = SockAddr::from(SocketAddr::V4(SocketAddrV4::new(addr, 0))); + (&unsafe { *(sockaddr.as_ptr() as *const sockaddr_in) }).sin_addr } /// Converts an `Ipv6Addr` to `IN6_ADDR` pub fn in6addr_from_ipaddr(addr: Ipv6Addr) -> IN6_ADDR { - let mut in_addr: IN6_ADDR = unsafe { mem::zeroed() }; - let addr_octets = addr.octets(); - unsafe { - ptr::copy_nonoverlapping( - &addr_octets as *const _, - in_addr.u.Byte_mut() as *mut _, - addr_octets.len(), - ); - } - in_addr + let sockaddr = SockAddr::from(SocketAddr::V6(SocketAddrV6::new(addr, 0, 0, 0))); + (&unsafe { *(sockaddr.as_ptr() as *const sockaddr_in6) }).sin6_addr } /// Converts an `IN_ADDR` to `Ipv4Addr` pub fn ipaddr_from_inaddr(addr: IN_ADDR) -> Ipv4Addr { - Ipv4Addr::from(unsafe { *(addr.S_un.S_addr()) }.to_be()) + Ipv4Addr::from(unsafe { *(addr.S_un.S_addr()) }.to_ne_bytes()) } /// Converts an `IN6_ADDR` to `Ipv6Addr` @@ -403,52 +390,31 @@ pub fn ipaddr_from_in6addr(addr: IN6_ADDR) -> Ipv6Addr { /// Converts a `SocketAddr` to `SOCKADDR_INET` pub fn inet_sockaddr_from_socketaddr(addr: SocketAddr) -> SOCKADDR_INET { let mut sockaddr: SOCKADDR_INET = unsafe { mem::zeroed() }; - match addr { - SocketAddr::V4(v4_addr) => { - unsafe { - *sockaddr.si_family_mut() = AF_INET as u16; - } - - let mut v4sockaddr = unsafe { sockaddr.Ipv4_mut() }; - v4sockaddr.sin_family = AF_INET as u16; - v4sockaddr.sin_port = v4_addr.port().to_be(); - v4sockaddr.sin_addr = inaddr_from_ipaddr(*v4_addr.ip()); - } - SocketAddr::V6(v6_addr) => { - unsafe { - *sockaddr.si_family_mut() = AF_INET6 as u16; - } - - let mut v6sockaddr = unsafe { sockaddr.Ipv6_mut() }; - v6sockaddr.sin6_family = AF_INET6 as u16; - v6sockaddr.sin6_port = v6_addr.port().to_be(); - v6sockaddr.sin6_addr = in6addr_from_ipaddr(*v6_addr.ip()); - v6sockaddr.sin6_flowinfo = v6_addr.flowinfo(); - *unsafe { v6sockaddr.u.sin6_scope_id_mut() } = v6_addr.scope_id(); - } + // SAFETY: `*const sockaddr` may be treated as `*const sockaddr_in` since we know it's a v4 + // address. + SocketAddr::V4(_) => unsafe { + *sockaddr.Ipv4_mut() = *(SockAddr::from(addr).as_ptr() as *const _) + }, + // SAFETY: `*const sockaddr` may be treated as `*const sockaddr_in6` since we know it's a v6 + // address. + SocketAddr::V6(_) => unsafe { + *sockaddr.Ipv6_mut() = *(SockAddr::from(addr).as_ptr() as *const _) + }, } - sockaddr } /// Converts a `SOCKADDR_INET` to `SocketAddr`. Returns an error if the address family is invalid. pub fn try_socketaddr_from_inet_sockaddr(addr: SOCKADDR_INET) -> Result<SocketAddr> { + let family = unsafe { *addr.si_family() } as i32; unsafe { - match *addr.si_family() as i32 { - AF_INET => Ok(SocketAddr::V4(SocketAddrV4::new( - ipaddr_from_inaddr(addr.Ipv4().sin_addr), - u16::from_be(addr.Ipv4().sin_port), - ))), - AF_INET6 => Ok(SocketAddr::V6(SocketAddrV6::new( - ipaddr_from_in6addr(addr.Ipv6().sin6_addr), - u16::from_be(addr.Ipv6().sin6_port), - addr.Ipv6().sin6_flowinfo, - *addr.Ipv6().u.sin6_scope_id(), - ))), - family => Err(Error::UnknownAddressFamily(family)), - } + let mut storage: sockaddr_storage = mem::zeroed(); + *(&mut storage as *mut _ as *mut SOCKADDR_INET) = addr; + SockAddr::new(storage, mem::size_of_val(&addr) as i32) } + .as_socket() + .ok_or(Error::UnknownAddressFamily(family)) } /// Casts a struct to a slice of possibly uninitialized bytes. |
