summaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorDavid Lönnhager <david.l@mullvad.net>2021-09-03 10:25:13 +0200
committerDavid Lönnhager <david.l@mullvad.net>2021-09-28 12:41:48 +0200
commit3ec74812eb2f52c302009e867f2b30712caf72c4 (patch)
tree4f79d7a997a924330f0ebce44e20ed9fc4135eec
parent00973733707c669bc51c48c9eebc481c4e3f7b82 (diff)
downloadmullvadvpn-3ec74812eb2f52c302009e867f2b30712caf72c4.tar.xz
mullvadvpn-3ec74812eb2f52c302009e867f2b30712caf72c4.zip
Validate allowed IPs for wgnt
-rw-r--r--talpid-core/src/tunnel/wireguard/wireguard_nt.rs171
1 files changed, 151 insertions, 20 deletions
diff --git a/talpid-core/src/tunnel/wireguard/wireguard_nt.rs b/talpid-core/src/tunnel/wireguard/wireguard_nt.rs
index 2314ee283f..e2b386b1ed 100644
--- a/talpid-core/src/tunnel/wireguard/wireguard_nt.rs
+++ b/talpid-core/src/tunnel/wireguard/wireguard_nt.rs
@@ -134,6 +134,10 @@ pub enum Error {
#[error(display = "Failed to obtain interface name")]
ObtainAliasError(#[error(source)] io::Error),
+ /// Failed to get WireGuard tunnel config for device
+ #[error(display = "Failed to get tunnel WireGuard config")]
+ GetWireGuardConfigError(#[error(source)] io::Error),
+
/// Failed to set WireGuard tunnel config on device
#[error(display = "Failed to set tunnel WireGuard config")]
SetWireGuardConfigError(#[error(source)] io::Error),
@@ -157,6 +161,14 @@ pub enum Error {
/// Failure to set up logging
#[error(display = "Failed to set up logging")]
InitLoggingError(#[error(source)] logging::Error),
+
+ /// Invalid allowed IP
+ #[error(display = "Invalid CIDR prefix")]
+ InvalidAllowedIpCidr,
+
+ /// Allowed IP contains non-zero host bits
+ #[error(display = "Allowed IP contains non-zero host bits")]
+ InvalidAllowedIpBits,
}
pub struct WgNtTunnel {
@@ -185,6 +197,46 @@ struct WgAllowedIp {
cidr: u8,
}
+impl WgAllowedIp {
+ fn new(address: WgIpAddr, address_family: ADDRESS_FAMILY, cidr: u8) -> Result<Self> {
+ Self::validate(&address, address_family, cidr)?;
+ Ok(Self {
+ address,
+ address_family,
+ cidr,
+ })
+ }
+
+ fn validate(address: &WgIpAddr, address_family: ADDRESS_FAMILY, cidr: u8) -> Result<()> {
+ match address_family as i32 {
+ AF_INET => {
+ if cidr > 32 {
+ return Err(Error::InvalidAllowedIpCidr);
+ }
+ let host_mask = u32::MAX.checked_shr(u32::from(cidr)).unwrap_or(0);
+ if host_mask & (unsafe { *(address.v4.S_un.S_addr()) }.to_be()) != 0 {
+ return Err(Error::InvalidAllowedIpBits);
+ }
+ }
+ AF_INET6 => {
+ if cidr > 128 {
+ return Err(Error::InvalidAllowedIpCidr);
+ }
+ let mut host_mask = u128::MAX.checked_shr(u32::from(cidr)).unwrap_or(0);
+ let bytes = unsafe { address.v6.u.Byte() };
+ for byte in bytes.iter().rev() {
+ if byte & ((host_mask & 0xff) as u8) != 0 {
+ return Err(Error::InvalidAllowedIpBits);
+ }
+ host_mask = host_mask >> 8;
+ }
+ }
+ family => return Err(Error::UnknownAddressFamily(family)),
+ }
+ Ok(())
+ }
+}
+
impl PartialEq for WgAllowedIp {
fn eq(&self, other: &Self) -> bool {
if self.cidr != other.cidr {
@@ -399,9 +451,7 @@ impl WgNtTunnel {
error.display_chain_with_msg("Failed to set log state on WireGuard interface")
);
}
- device
- .set_config(config)
- .map_err(Error::SetWireGuardConfigError)?;
+ device.set_config(config)?;
set_interface_mtu(&device.luid(), AF_INET as u16, u32::from(config.mtu))
.map_err(Error::SetTunnelIpv4MtuError)?;
if config.tunnel.addresses.iter().any(|addr| addr.is_ipv6()) {
@@ -515,16 +565,24 @@ impl WgNtAdapter {
unsafe { self.dll_handle.get_adapter_luid(self.handle) }
}
- fn set_config(&self, config: &Config) -> io::Result<()> {
- let config_buffer = serialize_config(config);
+ fn set_config(&self, config: &Config) -> Result<()> {
+ let config_buffer = serialize_config(config)?;
unsafe {
self.dll_handle
.set_config(self.handle, config_buffer.as_ptr(), config_buffer.len())
+ .map_err(Error::SetWireGuardConfigError)
}
}
- fn get_config(&self) -> io::Result<(WgInterface, Vec<(WgPeer, Vec<WgAllowedIp>)>)> {
- Ok(unsafe { deserialize_config(&self.dll_handle.get_config(self.handle)?) })
+ fn get_config(&self) -> Result<(WgInterface, Vec<(WgPeer, Vec<WgAllowedIp>)>)> {
+ unsafe {
+ deserialize_config(
+ &self
+ .dll_handle
+ .get_config(self.handle)
+ .map_err(Error::GetWireGuardConfigError)?,
+ )
+ }
}
fn set_state(&self, state: WgAdapterState) -> io::Result<()> {
@@ -799,7 +857,7 @@ fn load_wg_nt_dll(resource_dir: &Path) -> Result<Arc<WgNtDll>> {
}
}
-fn serialize_config(config: &Config) -> Vec<u8> {
+fn serialize_config(config: &Config) -> Result<Vec<u8>> {
let mut buffer = vec![];
let header = WgInterface {
@@ -842,20 +900,19 @@ fn serialize_config(config: &Config) -> Vec<u8> {
},
};
- let wg_allowed_ip = WgAllowedIp {
- address,
- address_family,
- cidr: allowed_ip.prefix() as u8,
- };
+ let wg_allowed_ip =
+ WgAllowedIp::new(address, address_family, allowed_ip.prefix() as u8)?;
buffer.extend_from_slice(unsafe { as_u8_slice(&wg_allowed_ip) });
}
}
- buffer
+ Ok(buffer)
}
-unsafe fn deserialize_config(config: &[u8]) -> (WgInterface, Vec<(WgPeer, Vec<WgAllowedIp>)>) {
+unsafe fn deserialize_config(
+ config: &[u8],
+) -> Result<(WgInterface, Vec<(WgPeer, Vec<WgAllowedIp>)>)> {
let (head, mut tail) = config.split_at(mem::size_of::<WgInterface>());
let interface: WgInterface = *(head.as_ptr() as *const WgInterface);
@@ -870,6 +927,11 @@ unsafe fn deserialize_config(config: &[u8]) -> (WgInterface, Vec<(WgPeer, Vec<Wg
for _ in 0..peer.allowed_ips_count {
let (allowed_ip_data, new_tail) = tail.split_at(mem::size_of::<WgAllowedIp>());
let allowed_ip: WgAllowedIp = *(allowed_ip_data.as_ptr() as *const WgAllowedIp);
+ WgAllowedIp::validate(
+ &allowed_ip.address,
+ allowed_ip.address_family,
+ allowed_ip.cidr,
+ )?;
tail = new_tail;
allowed_ips.push(allowed_ip);
}
@@ -877,7 +939,7 @@ unsafe fn deserialize_config(config: &[u8]) -> (WgInterface, Vec<(WgPeer, Vec<Wg
peers.push((peer, allowed_ips));
}
- (interface, peers)
+ Ok((interface, peers))
}
fn convert_v4_address(addr: Ipv4Addr) -> IN_ADDR {
@@ -1051,7 +1113,7 @@ mod tests {
},
peers: vec![wireguard::PeerConfig {
public_key: WG_PUBLIC_KEY.clone(),
- allowed_ips: vec!["1.3.3.7/24".parse().unwrap()],
+ allowed_ips: vec!["1.3.3.0/24".parse().unwrap()],
endpoint: "1.2.3.4:1234".parse().unwrap(),
protocol: TransportProtocol::Udp,
}],
@@ -1083,7 +1145,7 @@ mod tests {
},
p0_allowed_ip_0: WgAllowedIp {
address: WgIpAddr {
- v4: convert_v4_address("1.3.3.7".parse().unwrap()),
+ v4: convert_v4_address("1.3.3.0".parse().unwrap()),
},
address_family: AF_INET as u16,
cidr: 24,
@@ -1125,7 +1187,7 @@ mod tests {
#[test]
fn test_config_serialization() {
- let serialized_data = serialize_config(&*WG_CONFIG);
+ let serialized_data = serialize_config(&*WG_CONFIG).unwrap();
assert_eq!(mem::size_of::<Interface>(), serialized_data.len());
let serialized_iface = &unsafe { *(serialized_data.as_ptr() as *const Interface) };
assert_eq!(&*WG_STRUCT_CONFIG, serialized_iface);
@@ -1133,7 +1195,8 @@ mod tests {
#[test]
fn test_config_deserialization() {
- let (iface, peers) = unsafe { deserialize_config(as_u8_slice(&*WG_STRUCT_CONFIG)) };
+ let (iface, peers) =
+ unsafe { deserialize_config(as_u8_slice(&*WG_STRUCT_CONFIG)) }.unwrap();
assert_eq!(iface, WG_STRUCT_CONFIG.interface);
assert_eq!(peers.len(), 1);
let (peer, allowed_ips) = &peers[0];
@@ -1141,4 +1204,72 @@ mod tests {
assert_eq!(allowed_ips.len(), 1);
assert_eq!(allowed_ips[0], WG_STRUCT_CONFIG.p0_allowed_ip_0);
}
+
+ #[test]
+ fn test_wg_allowed_ip_v4() {
+ // Valid: /32 prefix
+ let address_family = AF_INET as u16;
+ let address = WgIpAddr {
+ v4: convert_v4_address("127.0.0.1".parse().unwrap()),
+ };
+ let cidr = 32;
+ WgAllowedIp::new(address, address_family, cidr).unwrap();
+
+ // Invalid host bits
+ let cidr = 24;
+ let address = WgIpAddr {
+ v4: convert_v4_address("0.0.0.1".parse().unwrap()),
+ };
+ assert!(WgAllowedIp::new(address, address_family, cidr).is_err());
+
+ // Valid host bits
+ let cidr = 24;
+ let address = WgIpAddr {
+ v4: convert_v4_address("255.255.255.0".parse().unwrap()),
+ };
+ WgAllowedIp::new(address, address_family, cidr).unwrap();
+
+ // 0.0.0.0/0
+ let cidr = 0;
+ let address = WgIpAddr {
+ v4: convert_v4_address("0.0.0.0".parse().unwrap()),
+ };
+ WgAllowedIp::new(address, address_family, cidr).unwrap();
+
+ // Invalid CIDR
+ let cidr = 33;
+ assert!(WgAllowedIp::new(address, address_family, cidr).is_err());
+ }
+
+ #[test]
+ fn test_wg_allowed_ip_v6() {
+ // Valid: /128 prefix
+ let address_family = AF_INET6 as u16;
+ let address = WgIpAddr {
+ v6: convert_v6_address("::1".parse().unwrap()),
+ };
+ let cidr = 128;
+ WgAllowedIp::new(address, address_family, cidr).unwrap();
+
+ // Invalid host bits
+ let cidr = 127;
+ assert!(WgAllowedIp::new(address, address_family, cidr).is_err());
+
+ // Valid host bits
+ let address = WgIpAddr {
+ v6: convert_v6_address("ffff:ffff:ffff:ffff:ffff:ffff:ffff:fffe".parse().unwrap()),
+ };
+ WgAllowedIp::new(address, address_family, cidr).unwrap();
+
+ // ::/0
+ let cidr = 0;
+ let address = WgIpAddr {
+ v6: convert_v6_address("::".parse().unwrap()),
+ };
+ WgAllowedIp::new(address, address_family, cidr).unwrap();
+
+ // Invalid CIDR
+ let cidr = 129;
+ assert!(WgAllowedIp::new(address, address_family, cidr).is_err());
+ }
}