diff options
| author | David Lönnhager <david.l@mullvad.net> | 2022-11-09 11:02:39 +0100 |
|---|---|---|
| committer | David Lönnhager <david.l@mullvad.net> | 2022-11-09 11:02:39 +0100 |
| commit | 2af40170885057f6fa6a40ffb071d2b998265a54 (patch) | |
| tree | d9eecb898dbc276764e77013ef69d1c6fb3a31c6 | |
| parent | c51cbe279921485e0ca19085d9e322c07508ced0 (diff) | |
| parent | b71abff37164d23ad8ebed5a8fee6e87d6b71608 (diff) | |
| download | mullvadvpn-2af40170885057f6fa6a40ffb071d2b998265a54.tar.xz mullvadvpn-2af40170885057f6fa6a40ffb071d2b998265a54.zip | |
Merge branch 'simplify-windows-net'
| -rw-r--r-- | talpid-routing/src/windows/get_best_default_route.rs | 65 | ||||
| -rw-r--r-- | talpid-routing/src/windows/route_manager.rs | 43 | ||||
| -rw-r--r-- | talpid-wireguard/src/wireguard_go.rs | 5 |
3 files changed, 49 insertions, 64 deletions
diff --git a/talpid-routing/src/windows/get_best_default_route.rs b/talpid-routing/src/windows/get_best_default_route.rs index 940fd93644..37baf8c44b 100644 --- a/talpid-routing/src/windows/get_best_default_route.rs +++ b/talpid-routing/src/windows/get_best_default_route.rs @@ -1,5 +1,5 @@ use super::{Error, Result}; -use std::{convert::TryInto, io, net::SocketAddr}; +use std::{io, net::SocketAddr, slice}; use talpid_windows_net::{ get_ip_interface_entry, try_socketaddr_from_inet_sockaddr, AddressFamily, }; @@ -22,7 +22,7 @@ const TUNNEL_INTERFACE_DESCS: [&WideCStr; 3] = [ widecstr!("Tunnel"), ]; -fn get_ipforward_rows(family: AddressFamily) -> Result<Vec<MIB_IPFORWARD_ROW2>> { +fn get_ip_forward_table(family: AddressFamily) -> Result<Vec<MIB_IPFORWARD_ROW2>> { let family = family.to_af_family(); let mut table_ptr = std::ptr::null_mut(); @@ -36,32 +36,23 @@ fn get_ipforward_rows(family: AddressFamily) -> Result<Vec<MIB_IPFORWARD_ROW2>> } // SAFETY: table_ptr is valid since GetIpForwardTable2 did not return an error - let num_entries = unsafe { *table_ptr }.NumEntries; - let mut vec = Vec::with_capacity(num_entries.try_into().unwrap_or_default()); + let num_entries = usize::try_from(unsafe { *table_ptr }.NumEntries).unwrap(); + assert!( + num_entries + .checked_mul(std::mem::size_of::<MIB_IPFORWARD_ROW2>()) + .unwrap() + <= usize::try_from(isize::MAX).unwrap() + ); + // SAFETY: num_entries * size_of(MIB_IPFORWARD_ROW2) is at most isize::MAX + let rows = unsafe { slice::from_raw_parts((*table_ptr).Table.as_ptr(), num_entries) }.to_vec(); - for i in 0..num_entries { - assert!( - usize::try_from(i).unwrap() * std::mem::size_of::<MIB_IPFORWARD_ROW2>() - < usize::try_from(isize::MAX).unwrap() - ); - - // SAFETY: table_ptr is valid since GetIpForwardTable2 did not return an error nor have we - // or will we modify the table - let ptr: *const MIB_IPFORWARD_ROW2 = unsafe { (*table_ptr).Table.as_ptr() }; - - // SAFETY: The assert guarantees that the amount of bytes we are jumping is not larger than - // isize::MAX. Win32 guarantees that the resulting pointer is aligned, non-null, - // init. - let row: &MIB_IPFORWARD_ROW2 = - unsafe { ptr.offset(i.try_into().unwrap()).as_ref() }.unwrap(); - vec.push(row.clone()); - } // SAFETY: FreeMibTable does not have clear safety rules but it deallocates the // MIB_IPFORWARD_TABLE2 This pointer is ONLY deallocated here so it is guaranteed to not // have been already deallocated. We have cloned all MIB_IPFORWARD_ROW2s and the rows do not // contain pointers to the table so they will not be dangling after this free. unsafe { FreeMibTable(table_ptr as *const _) } - Ok(vec) + + Ok(rows) } /// General type for passing interface and gateway @@ -81,7 +72,7 @@ impl PartialEq for InterfaceAndGateway { /// Get the best default route for the given address family or None if none exists. pub fn get_best_default_route(family: AddressFamily) -> Result<Option<InterfaceAndGateway>> { - let table = get_ipforward_rows(family)?; + let table = get_ip_forward_table(family)?; // Remove all candidates without a gateway and which are not on a physical interface. // Then get the annotated routes which are active. @@ -92,29 +83,29 @@ pub fn get_best_default_route(family: AddressFamily) -> Result<Option<InterfaceA && route_has_gateway(row) && is_route_on_physical_interface(row).unwrap_or(false) }) - .filter_map(|row| annotate_route(row)) + .filter_map(annotate_route) .collect(); - if annotated.is_empty() { - return Ok(None); - } - // We previously filtered out all inactive routes so we only need to sort by acending // effective_metric annotated.sort_by(|lhs, rhs| lhs.effective_metric.cmp(&rhs.effective_metric)); - Ok(Some(InterfaceAndGateway { - iface: annotated[0].route.InterfaceLuid, - gateway: try_socketaddr_from_inet_sockaddr(annotated[0].route.NextHop) - .map_err(|_| Error::InvalidSiFamily)?, - })) + annotated + .get(0) + .map(|annotated| { + Ok(InterfaceAndGateway { + iface: annotated.route.InterfaceLuid, + gateway: try_socketaddr_from_inet_sockaddr(annotated.route.NextHop) + .map_err(|_| Error::InvalidSiFamily)?, + }) + }) + .transpose() } pub fn route_has_gateway(route: &MIB_IPFORWARD_ROW2) -> bool { - match try_socketaddr_from_inet_sockaddr(route.NextHop) { - Ok(sock) => !sock.ip().is_unspecified(), - Err(_) => false, - } + try_socketaddr_from_inet_sockaddr(route.NextHop) + .map(|addr| !addr.ip().is_unspecified()) + .unwrap_or(false) } // TODO(Jon): It would be more correct to filter for devices that match the known LUID of the tunnel diff --git a/talpid-routing/src/windows/route_manager.rs b/talpid-routing/src/windows/route_manager.rs index ab4c9ef45b..691f39f453 100644 --- a/talpid-routing/src/windows/route_manager.rs +++ b/talpid-routing/src/windows/route_manager.rs @@ -528,26 +528,23 @@ impl RouteManagerInternal { // changed. So removing and adding again is the only option. // - match Self::delete_from_routing_table(&affected_route.registered_route) { - Ok(()) => (), - Err(e) => { - log::error!( - "Failed to delete route when refreshing existing routes: {}", - e - ); - continue; - } + if let Err(error) = Self::delete_from_routing_table(&affected_route.registered_route) { + log::error!( + "Failed to delete route when refreshing existing routes: {}", + error + ); + continue; } affected_route.registered_route.luid = route.iface; affected_route.registered_route.next_hop = route.gateway; - match Self::restore_into_routing_table(&affected_route.registered_route) { - Ok(()) => (), - Err(e) => { - log::error!("Failed to add route when refreshing existing routes: {}", e); - continue; - } + if let Err(error) = Self::restore_into_routing_table(&affected_route.registered_route) { + log::error!( + "Failed to add route when refreshing existing routes: {}", + error + ); + continue; } } } @@ -601,11 +598,6 @@ fn interface_luid_from_gateway(gateway: &SOCKADDR_INET) -> Result<NET_LUID_LH> { }) .collect(); - if matches.is_empty() { - log::error!("Unable to find network adapter with specified gateway"); - return Err(Error::DeviceGatewayNotFound); - } - // Sort matching interfaces ascending by metric. // @@ -621,8 +613,13 @@ fn interface_luid_from_gateway(gateway: &SOCKADDR_INET) -> Result<NET_LUID_LH> { // Select the interface with the best (lowest) metric. // - - Ok(matches[0].Luid) + matches + .get(0) + .map(|interface| interface.Luid) + .ok_or_else(|| { + log::error!("Unable to find network adapter with specified gateway"); + Error::DeviceGatewayNotFound + }) } /// SAFETY: adapter.FirstGatewayAddress must be dereferencable and must live as long as adapter @@ -741,7 +738,7 @@ impl Adapters { GetAdaptersAddresses( family, flags, - std::ptr::null_mut() as *mut _, + std::ptr::null_mut(), buffer_pointer as *mut IP_ADAPTER_ADDRESSES_LH, &mut buffer_size, ) diff --git a/talpid-wireguard/src/wireguard_go.rs b/talpid-wireguard/src/wireguard_go.rs index 4232d99e70..aa962ace7d 100644 --- a/talpid-wireguard/src/wireguard_go.rs +++ b/talpid-wireguard/src/wireguard_go.rs @@ -217,11 +217,8 @@ impl WgGoTunnel { let iface_idx: u32 = match event_type { Updated(default_route) => { let mut iface_idx = 0u32; - // TODO: Make sure unwrap is fine let iface_luid = default_route.iface; - let status = unsafe { - ConvertInterfaceLuidToIndex(&iface_luid as *const _, &mut iface_idx as *mut _) - }; + let status = unsafe { ConvertInterfaceLuidToIndex(&iface_luid, &mut iface_idx) }; if status != 0 { log::error!( "Failed to convert interface LUID to interface index: {}: {}", |
