diff options
| -rw-r--r-- | talpid-core/src/tunnel/wireguard/wireguard_nt.rs | 171 |
1 files changed, 151 insertions, 20 deletions
diff --git a/talpid-core/src/tunnel/wireguard/wireguard_nt.rs b/talpid-core/src/tunnel/wireguard/wireguard_nt.rs index 2314ee283f..e2b386b1ed 100644 --- a/talpid-core/src/tunnel/wireguard/wireguard_nt.rs +++ b/talpid-core/src/tunnel/wireguard/wireguard_nt.rs @@ -134,6 +134,10 @@ pub enum Error { #[error(display = "Failed to obtain interface name")] ObtainAliasError(#[error(source)] io::Error), + /// Failed to get WireGuard tunnel config for device + #[error(display = "Failed to get tunnel WireGuard config")] + GetWireGuardConfigError(#[error(source)] io::Error), + /// Failed to set WireGuard tunnel config on device #[error(display = "Failed to set tunnel WireGuard config")] SetWireGuardConfigError(#[error(source)] io::Error), @@ -157,6 +161,14 @@ pub enum Error { /// Failure to set up logging #[error(display = "Failed to set up logging")] InitLoggingError(#[error(source)] logging::Error), + + /// Invalid allowed IP + #[error(display = "Invalid CIDR prefix")] + InvalidAllowedIpCidr, + + /// Allowed IP contains non-zero host bits + #[error(display = "Allowed IP contains non-zero host bits")] + InvalidAllowedIpBits, } pub struct WgNtTunnel { @@ -185,6 +197,46 @@ struct WgAllowedIp { cidr: u8, } +impl WgAllowedIp { + fn new(address: WgIpAddr, address_family: ADDRESS_FAMILY, cidr: u8) -> Result<Self> { + Self::validate(&address, address_family, cidr)?; + Ok(Self { + address, + address_family, + cidr, + }) + } + + fn validate(address: &WgIpAddr, address_family: ADDRESS_FAMILY, cidr: u8) -> Result<()> { + match address_family as i32 { + AF_INET => { + if cidr > 32 { + return Err(Error::InvalidAllowedIpCidr); + } + let host_mask = u32::MAX.checked_shr(u32::from(cidr)).unwrap_or(0); + if host_mask & (unsafe { *(address.v4.S_un.S_addr()) }.to_be()) != 0 { + return Err(Error::InvalidAllowedIpBits); + } + } + AF_INET6 => { + if cidr > 128 { + return Err(Error::InvalidAllowedIpCidr); + } + let mut host_mask = u128::MAX.checked_shr(u32::from(cidr)).unwrap_or(0); + let bytes = unsafe { address.v6.u.Byte() }; + for byte in bytes.iter().rev() { + if byte & ((host_mask & 0xff) as u8) != 0 { + return Err(Error::InvalidAllowedIpBits); + } + host_mask = host_mask >> 8; + } + } + family => return Err(Error::UnknownAddressFamily(family)), + } + Ok(()) + } +} + impl PartialEq for WgAllowedIp { fn eq(&self, other: &Self) -> bool { if self.cidr != other.cidr { @@ -399,9 +451,7 @@ impl WgNtTunnel { error.display_chain_with_msg("Failed to set log state on WireGuard interface") ); } - device - .set_config(config) - .map_err(Error::SetWireGuardConfigError)?; + device.set_config(config)?; set_interface_mtu(&device.luid(), AF_INET as u16, u32::from(config.mtu)) .map_err(Error::SetTunnelIpv4MtuError)?; if config.tunnel.addresses.iter().any(|addr| addr.is_ipv6()) { @@ -515,16 +565,24 @@ impl WgNtAdapter { unsafe { self.dll_handle.get_adapter_luid(self.handle) } } - fn set_config(&self, config: &Config) -> io::Result<()> { - let config_buffer = serialize_config(config); + fn set_config(&self, config: &Config) -> Result<()> { + let config_buffer = serialize_config(config)?; unsafe { self.dll_handle .set_config(self.handle, config_buffer.as_ptr(), config_buffer.len()) + .map_err(Error::SetWireGuardConfigError) } } - fn get_config(&self) -> io::Result<(WgInterface, Vec<(WgPeer, Vec<WgAllowedIp>)>)> { - Ok(unsafe { deserialize_config(&self.dll_handle.get_config(self.handle)?) }) + fn get_config(&self) -> Result<(WgInterface, Vec<(WgPeer, Vec<WgAllowedIp>)>)> { + unsafe { + deserialize_config( + &self + .dll_handle + .get_config(self.handle) + .map_err(Error::GetWireGuardConfigError)?, + ) + } } fn set_state(&self, state: WgAdapterState) -> io::Result<()> { @@ -799,7 +857,7 @@ fn load_wg_nt_dll(resource_dir: &Path) -> Result<Arc<WgNtDll>> { } } -fn serialize_config(config: &Config) -> Vec<u8> { +fn serialize_config(config: &Config) -> Result<Vec<u8>> { let mut buffer = vec![]; let header = WgInterface { @@ -842,20 +900,19 @@ fn serialize_config(config: &Config) -> Vec<u8> { }, }; - let wg_allowed_ip = WgAllowedIp { - address, - address_family, - cidr: allowed_ip.prefix() as u8, - }; + let wg_allowed_ip = + WgAllowedIp::new(address, address_family, allowed_ip.prefix() as u8)?; buffer.extend_from_slice(unsafe { as_u8_slice(&wg_allowed_ip) }); } } - buffer + Ok(buffer) } -unsafe fn deserialize_config(config: &[u8]) -> (WgInterface, Vec<(WgPeer, Vec<WgAllowedIp>)>) { +unsafe fn deserialize_config( + config: &[u8], +) -> Result<(WgInterface, Vec<(WgPeer, Vec<WgAllowedIp>)>)> { let (head, mut tail) = config.split_at(mem::size_of::<WgInterface>()); let interface: WgInterface = *(head.as_ptr() as *const WgInterface); @@ -870,6 +927,11 @@ unsafe fn deserialize_config(config: &[u8]) -> (WgInterface, Vec<(WgPeer, Vec<Wg for _ in 0..peer.allowed_ips_count { let (allowed_ip_data, new_tail) = tail.split_at(mem::size_of::<WgAllowedIp>()); let allowed_ip: WgAllowedIp = *(allowed_ip_data.as_ptr() as *const WgAllowedIp); + WgAllowedIp::validate( + &allowed_ip.address, + allowed_ip.address_family, + allowed_ip.cidr, + )?; tail = new_tail; allowed_ips.push(allowed_ip); } @@ -877,7 +939,7 @@ unsafe fn deserialize_config(config: &[u8]) -> (WgInterface, Vec<(WgPeer, Vec<Wg peers.push((peer, allowed_ips)); } - (interface, peers) + Ok((interface, peers)) } fn convert_v4_address(addr: Ipv4Addr) -> IN_ADDR { @@ -1051,7 +1113,7 @@ mod tests { }, peers: vec![wireguard::PeerConfig { public_key: WG_PUBLIC_KEY.clone(), - allowed_ips: vec!["1.3.3.7/24".parse().unwrap()], + allowed_ips: vec!["1.3.3.0/24".parse().unwrap()], endpoint: "1.2.3.4:1234".parse().unwrap(), protocol: TransportProtocol::Udp, }], @@ -1083,7 +1145,7 @@ mod tests { }, p0_allowed_ip_0: WgAllowedIp { address: WgIpAddr { - v4: convert_v4_address("1.3.3.7".parse().unwrap()), + v4: convert_v4_address("1.3.3.0".parse().unwrap()), }, address_family: AF_INET as u16, cidr: 24, @@ -1125,7 +1187,7 @@ mod tests { #[test] fn test_config_serialization() { - let serialized_data = serialize_config(&*WG_CONFIG); + let serialized_data = serialize_config(&*WG_CONFIG).unwrap(); assert_eq!(mem::size_of::<Interface>(), serialized_data.len()); let serialized_iface = &unsafe { *(serialized_data.as_ptr() as *const Interface) }; assert_eq!(&*WG_STRUCT_CONFIG, serialized_iface); @@ -1133,7 +1195,8 @@ mod tests { #[test] fn test_config_deserialization() { - let (iface, peers) = unsafe { deserialize_config(as_u8_slice(&*WG_STRUCT_CONFIG)) }; + let (iface, peers) = + unsafe { deserialize_config(as_u8_slice(&*WG_STRUCT_CONFIG)) }.unwrap(); assert_eq!(iface, WG_STRUCT_CONFIG.interface); assert_eq!(peers.len(), 1); let (peer, allowed_ips) = &peers[0]; @@ -1141,4 +1204,72 @@ mod tests { assert_eq!(allowed_ips.len(), 1); assert_eq!(allowed_ips[0], WG_STRUCT_CONFIG.p0_allowed_ip_0); } + + #[test] + fn test_wg_allowed_ip_v4() { + // Valid: /32 prefix + let address_family = AF_INET as u16; + let address = WgIpAddr { + v4: convert_v4_address("127.0.0.1".parse().unwrap()), + }; + let cidr = 32; + WgAllowedIp::new(address, address_family, cidr).unwrap(); + + // Invalid host bits + let cidr = 24; + let address = WgIpAddr { + v4: convert_v4_address("0.0.0.1".parse().unwrap()), + }; + assert!(WgAllowedIp::new(address, address_family, cidr).is_err()); + + // Valid host bits + let cidr = 24; + let address = WgIpAddr { + v4: convert_v4_address("255.255.255.0".parse().unwrap()), + }; + WgAllowedIp::new(address, address_family, cidr).unwrap(); + + // 0.0.0.0/0 + let cidr = 0; + let address = WgIpAddr { + v4: convert_v4_address("0.0.0.0".parse().unwrap()), + }; + WgAllowedIp::new(address, address_family, cidr).unwrap(); + + // Invalid CIDR + let cidr = 33; + assert!(WgAllowedIp::new(address, address_family, cidr).is_err()); + } + + #[test] + fn test_wg_allowed_ip_v6() { + // Valid: /128 prefix + let address_family = AF_INET6 as u16; + let address = WgIpAddr { + v6: convert_v6_address("::1".parse().unwrap()), + }; + let cidr = 128; + WgAllowedIp::new(address, address_family, cidr).unwrap(); + + // Invalid host bits + let cidr = 127; + assert!(WgAllowedIp::new(address, address_family, cidr).is_err()); + + // Valid host bits + let address = WgIpAddr { + v6: convert_v6_address("ffff:ffff:ffff:ffff:ffff:ffff:ffff:fffe".parse().unwrap()), + }; + WgAllowedIp::new(address, address_family, cidr).unwrap(); + + // ::/0 + let cidr = 0; + let address = WgIpAddr { + v6: convert_v6_address("::".parse().unwrap()), + }; + WgAllowedIp::new(address, address_family, cidr).unwrap(); + + // Invalid CIDR + let cidr = 129; + assert!(WgAllowedIp::new(address, address_family, cidr).is_err()); + } } |
