diff options
| author | David Lönnhager <david.l@mullvad.net> | 2022-06-13 20:45:06 +0200 |
|---|---|---|
| committer | David Lönnhager <david.l@mullvad.net> | 2022-06-14 12:38:45 +0200 |
| commit | c43315be6090d5dfc7d85ae71b62fb13d66bb002 (patch) | |
| tree | f0ca29db3d1bc283bdb0c7a6d6e299fcbe54bf45 | |
| parent | db673926d5b1966218a907d1e56e2f4f715728ee (diff) | |
| download | mullvadvpn-c43315be6090d5dfc7d85ae71b62fb13d66bb002.tar.xz mullvadvpn-c43315be6090d5dfc7d85ae71b62fb13d66bb002.zip | |
Simplify WireGuard monitor constructor
| -rw-r--r-- | talpid-core/src/tunnel/wireguard/mod.rs | 231 | ||||
| -rw-r--r-- | talpid-core/src/winnet.rs | 2 |
2 files changed, 133 insertions, 100 deletions
diff --git a/talpid-core/src/tunnel/wireguard/mod.rs b/talpid-core/src/tunnel/wireguard/mod.rs index 7b3d2a4aa4..3d0de8ba5c 100644 --- a/talpid-core/src/tunnel/wireguard/mod.rs +++ b/talpid-core/src/tunnel/wireguard/mod.rs @@ -19,6 +19,7 @@ use std::env; #[cfg(windows)] use std::io; use std::{ + borrow::Cow, convert::Infallible, net::{IpAddr, SocketAddrV4}, path::Path, @@ -194,7 +195,7 @@ fn maybe_create_obfuscator( impl WireguardMonitor { /// Starts a WireGuard tunnel with the given config pub fn start< - F: (Fn(TunnelEvent) -> std::pin::Pin<Box<dyn std::future::Future<Output = ()> + Send>>) + F: (Fn(TunnelEvent) -> Pin<Box<dyn std::future::Future<Output = ()> + Send>>) + Send + Sync + Clone @@ -218,29 +219,11 @@ impl WireguardMonitor { let obfuscator = maybe_create_obfuscator(&runtime, &mut config, close_msg_sender.clone())?; #[cfg(target_os = "windows")] - let (setup_done_tx, mut setup_done_rx) = mpsc::channel(0); - - // Use allowed IPs to block anything but the v4 gateway, if PSK exchange is on. - let config_ref; - let mut patched_config; - if psk_negotiation.is_some() { - patched_config = config.clone(); - let gateway_net = ipnetwork::IpNetwork::from(IpAddr::from(config.ipv4_gateway)); - for peer in &mut patched_config.peers { - for allowed_ip in &mut peer.allowed_ips { - if allowed_ip.is_ipv4() && allowed_ip.prefix() == 0 { - *allowed_ip = gateway_net; - } - } - } - config_ref = &patched_config; - } else { - config_ref = &config; - } + let (setup_done_tx, setup_done_rx) = mpsc::channel(0); let tunnel = Self::open_tunnel( runtime.clone(), - config_ref, + &Self::patch_allowed_ips(&config, psk_negotiation.is_some()), log_path, resource_dir, tun_provider, @@ -275,26 +258,8 @@ impl WireguardMonitor { let tunnel_fut = async move { #[cfg(windows)] - { - setup_done_rx - .next() - .await - .ok_or_else(|| { - // Tunnel was shut down early - CloseMsg::SetupError(Error::IpInterfacesError) - })? - .map_err(|error| { - log::error!( - "{}", - error.display_chain_with_msg("Failed to configure tunnel interface") - ); - CloseMsg::SetupError(Error::IpInterfacesError) - })?; - - if !crate::winnet::add_device_ip_addresses(&iface_name, &config.tunnel.addresses) { - return Err(CloseMsg::SetupError(Error::SetIpAddressesError)); - } - } + Self::add_device_ip_addresses(&iface_name, &config.tunnel.addresses, setup_done_rx) + .await?; let allowed_traffic = if psk_negotiation.is_some() { AllowedTunnelTraffic::Only( @@ -302,10 +267,7 @@ impl WireguardMonitor { Protocol::Tcp, ) } else { - AllowedTunnelTraffic::Only( - SocketAddrV4::new(config.ipv4_gateway, 0).into(), - Protocol::IcmpV4, - ) + AllowedTunnelTraffic::All }; (on_event)(TunnelEvent::InterfaceUp(metadata.clone(), allowed_traffic)).await; @@ -327,60 +289,12 @@ impl WireguardMonitor { .map_err(CloseMsg::SetupError)?; if let Some(pubkey) = psk_negotiation { - log::debug!("Performing PQ-safe PSK exchange"); - - let timeout = std::cmp::min( - MAX_PSK_EXCHANGE_TIMEOUT, - INITIAL_PSK_EXCHANGE_TIMEOUT.saturating_mul( - PSK_EXCHANGE_TIMEOUT_MULTIPLIER.saturating_pow(retry_attempt), - ), - ); - - let (private_key, psk) = tokio::time::timeout( - timeout, - talpid_tunnel_config_client::push_pq_key( - IpAddr::V4(config.ipv4_gateway), - config.tunnel.private_key.public_key(), - ), - ) - .await - .map_err(|_timeout_err| { - log::warn!("Timeout while negotiating PSK"); - CloseMsg::PskNegotiationTimeout - })? - .map_err(Error::PskNegotiationError) - .map_err(CloseMsg::SetupError)?; - - config.tunnel.private_key = private_key; - - for peer in &mut config.peers { - if pubkey == peer.public_key { - peer.psk = Some(psk); - break; - } - } - - log::trace!( - "Ephemeral pubkey: {}", - config.tunnel.private_key.public_key() - ); - - let set_config_future = tunnel - .lock() - .unwrap() - .as_ref() - .map(|tunnel| tunnel.set_config(config.clone())); - if let Some(f) = set_config_future { - f.await - .map_err(Error::TunnelError) - .map_err(CloseMsg::SetupError)?; - } - - let allowed_traffic = AllowedTunnelTraffic::Only( - SocketAddrV4::new(config.ipv4_gateway, 0).into(), - Protocol::IcmpV4, - ); - (on_event)(TunnelEvent::InterfaceUp(metadata.clone(), allowed_traffic)).await; + Self::perform_psk_negotiation(tunnel, retry_attempt, pubkey, &mut config).await?; + (on_event)(TunnelEvent::InterfaceUp( + metadata.clone(), + AllowedTunnelTraffic::All, + )) + .await; } let mut connectivity_monitor = tokio::task::spawn_blocking(move || { @@ -442,6 +356,125 @@ impl WireguardMonitor { Ok(monitor) } + /// Replace `0.0.0.0/0`/`::/0` with the gateway IPs when `gateway_only` is true. + /// Used to block traffic to other destinations while connecting on Android. + fn patch_allowed_ips<'a>(config: &'a Config, gateway_only: bool) -> Cow<'a, Config> { + if gateway_only { + let mut patched_config = config.clone(); + let gateway_net_v4 = ipnetwork::IpNetwork::from(IpAddr::from(config.ipv4_gateway)); + let gateway_net_v6 = config + .ipv6_gateway + .map(|net| ipnetwork::IpNetwork::from(IpAddr::from(net))); + for peer in &mut patched_config.peers { + peer.allowed_ips = peer + .allowed_ips + .iter() + .cloned() + .filter_map(|mut allowed_ip| { + if allowed_ip.prefix() == 0 { + if allowed_ip.is_ipv4() { + allowed_ip = gateway_net_v4; + } else { + if let Some(net) = gateway_net_v6 { + allowed_ip = net; + } else { + return None; + } + } + } + Some(allowed_ip) + }) + .collect(); + } + Cow::Owned(patched_config) + } else { + Cow::Borrowed(config) + } + } + + #[cfg(windows)] + async fn add_device_ip_addresses( + iface_name: &str, + addresses: &[IpAddr], + mut setup_done_rx: mpsc::Receiver<std::result::Result<(), BoxedError>>, + ) -> std::result::Result<(), CloseMsg> { + setup_done_rx + .next() + .await + .ok_or_else(|| { + // Tunnel was shut down early + CloseMsg::SetupError(Error::IpInterfacesError) + })? + .map_err(|error| { + log::error!( + "{}", + error.display_chain_with_msg("Failed to configure tunnel interface") + ); + CloseMsg::SetupError(Error::IpInterfacesError) + })?; + if !crate::winnet::add_device_ip_addresses(iface_name, addresses) { + return Err(CloseMsg::SetupError(Error::SetIpAddressesError)); + } + Ok(()) + } + + async fn perform_psk_negotiation( + tunnel: Arc<Mutex<Option<Box<dyn Tunnel>>>>, + retry_attempt: u32, + current_pubkey: PublicKey, + config: &mut Config, + ) -> std::result::Result<(), CloseMsg> { + log::debug!("Performing PQ-safe PSK exchange"); + + let timeout = std::cmp::min( + MAX_PSK_EXCHANGE_TIMEOUT, + INITIAL_PSK_EXCHANGE_TIMEOUT + .saturating_mul(PSK_EXCHANGE_TIMEOUT_MULTIPLIER.saturating_pow(retry_attempt)), + ); + + let (private_key, psk) = tokio::time::timeout( + timeout, + talpid_tunnel_config_client::push_pq_key( + IpAddr::V4(config.ipv4_gateway), + config.tunnel.private_key.public_key(), + ), + ) + .await + .map_err(|_timeout_err| { + log::warn!("Timeout while negotiating PSK"); + CloseMsg::PskNegotiationTimeout + })? + .map_err(Error::PskNegotiationError) + .map_err(CloseMsg::SetupError)?; + + config.tunnel.private_key = private_key; + + for peer in &mut config.peers { + if current_pubkey == peer.public_key { + peer.psk = Some(psk); + break; + } + } + + log::trace!( + "Ephemeral pubkey: {}", + config.tunnel.private_key.public_key() + ); + + let set_config_future = tunnel + .lock() + .unwrap() + .as_ref() + .map(|tunnel| tunnel.set_config(config.clone())); + if let Some(f) = set_config_future { + f.await + .map_err(Error::TunnelError) + .map_err(CloseMsg::SetupError)?; + } + + Ok(()) + } + #[allow(unused_variables)] fn open_tunnel( runtime: tokio::runtime::Handle, diff --git a/talpid-core/src/winnet.rs b/talpid-core/src/winnet.rs index 7f31082541..3d0075257b 100644 --- a/talpid-core/src/winnet.rs +++ b/talpid-core/src/winnet.rs @@ -404,7 +404,7 @@ pub fn interface_luid_to_ip( } } -pub fn add_device_ip_addresses(iface: &String, addresses: &Vec<IpAddr>) -> bool { +pub fn add_device_ip_addresses(iface: &str, addresses: &[IpAddr]) -> bool { let raw_iface = WideCString::from_str(iface) .expect("Failed to convert UTF-8 string to null terminated UCS string") .into_raw(); |
