summaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorDavid Lönnhager <david.l@mullvad.net>2022-06-13 20:45:06 +0200
committerDavid Lönnhager <david.l@mullvad.net>2022-06-14 12:38:45 +0200
commitc43315be6090d5dfc7d85ae71b62fb13d66bb002 (patch)
treef0ca29db3d1bc283bdb0c7a6d6e299fcbe54bf45
parentdb673926d5b1966218a907d1e56e2f4f715728ee (diff)
downloadmullvadvpn-c43315be6090d5dfc7d85ae71b62fb13d66bb002.tar.xz
mullvadvpn-c43315be6090d5dfc7d85ae71b62fb13d66bb002.zip
Simplify WireGuard monitor constructor
-rw-r--r--talpid-core/src/tunnel/wireguard/mod.rs231
-rw-r--r--talpid-core/src/winnet.rs2
2 files changed, 133 insertions, 100 deletions
diff --git a/talpid-core/src/tunnel/wireguard/mod.rs b/talpid-core/src/tunnel/wireguard/mod.rs
index 7b3d2a4aa4..3d0de8ba5c 100644
--- a/talpid-core/src/tunnel/wireguard/mod.rs
+++ b/talpid-core/src/tunnel/wireguard/mod.rs
@@ -19,6 +19,7 @@ use std::env;
#[cfg(windows)]
use std::io;
use std::{
+ borrow::Cow,
convert::Infallible,
net::{IpAddr, SocketAddrV4},
path::Path,
@@ -194,7 +195,7 @@ fn maybe_create_obfuscator(
impl WireguardMonitor {
/// Starts a WireGuard tunnel with the given config
pub fn start<
- F: (Fn(TunnelEvent) -> std::pin::Pin<Box<dyn std::future::Future<Output = ()> + Send>>)
+ F: (Fn(TunnelEvent) -> Pin<Box<dyn std::future::Future<Output = ()> + Send>>)
+ Send
+ Sync
+ Clone
@@ -218,29 +219,11 @@ impl WireguardMonitor {
let obfuscator = maybe_create_obfuscator(&runtime, &mut config, close_msg_sender.clone())?;
#[cfg(target_os = "windows")]
- let (setup_done_tx, mut setup_done_rx) = mpsc::channel(0);
-
- // Use allowed IPs to block anything but the v4 gateway, if PSK exchange is on.
- let config_ref;
- let mut patched_config;
- if psk_negotiation.is_some() {
- patched_config = config.clone();
- let gateway_net = ipnetwork::IpNetwork::from(IpAddr::from(config.ipv4_gateway));
- for peer in &mut patched_config.peers {
- for allowed_ip in &mut peer.allowed_ips {
- if allowed_ip.is_ipv4() && allowed_ip.prefix() == 0 {
- *allowed_ip = gateway_net;
- }
- }
- }
- config_ref = &patched_config;
- } else {
- config_ref = &config;
- }
+ let (setup_done_tx, setup_done_rx) = mpsc::channel(0);
let tunnel = Self::open_tunnel(
runtime.clone(),
- config_ref,
+ &Self::patch_allowed_ips(&config, psk_negotiation.is_some()),
log_path,
resource_dir,
tun_provider,
@@ -275,26 +258,8 @@ impl WireguardMonitor {
let tunnel_fut = async move {
#[cfg(windows)]
- {
- setup_done_rx
- .next()
- .await
- .ok_or_else(|| {
- // Tunnel was shut down early
- CloseMsg::SetupError(Error::IpInterfacesError)
- })?
- .map_err(|error| {
- log::error!(
- "{}",
- error.display_chain_with_msg("Failed to configure tunnel interface")
- );
- CloseMsg::SetupError(Error::IpInterfacesError)
- })?;
-
- if !crate::winnet::add_device_ip_addresses(&iface_name, &config.tunnel.addresses) {
- return Err(CloseMsg::SetupError(Error::SetIpAddressesError));
- }
- }
+ Self::add_device_ip_addresses(&iface_name, &config.tunnel.addresses, setup_done_rx)
+ .await?;
let allowed_traffic = if psk_negotiation.is_some() {
AllowedTunnelTraffic::Only(
@@ -302,10 +267,7 @@ impl WireguardMonitor {
Protocol::Tcp,
)
} else {
- AllowedTunnelTraffic::Only(
- SocketAddrV4::new(config.ipv4_gateway, 0).into(),
- Protocol::IcmpV4,
- )
+ AllowedTunnelTraffic::All
};
(on_event)(TunnelEvent::InterfaceUp(metadata.clone(), allowed_traffic)).await;
@@ -327,60 +289,12 @@ impl WireguardMonitor {
.map_err(CloseMsg::SetupError)?;
if let Some(pubkey) = psk_negotiation {
- log::debug!("Performing PQ-safe PSK exchange");
-
- let timeout = std::cmp::min(
- MAX_PSK_EXCHANGE_TIMEOUT,
- INITIAL_PSK_EXCHANGE_TIMEOUT.saturating_mul(
- PSK_EXCHANGE_TIMEOUT_MULTIPLIER.saturating_pow(retry_attempt),
- ),
- );
-
- let (private_key, psk) = tokio::time::timeout(
- timeout,
- talpid_tunnel_config_client::push_pq_key(
- IpAddr::V4(config.ipv4_gateway),
- config.tunnel.private_key.public_key(),
- ),
- )
- .await
- .map_err(|_timeout_err| {
- log::warn!("Timeout while negotiating PSK");
- CloseMsg::PskNegotiationTimeout
- })?
- .map_err(Error::PskNegotiationError)
- .map_err(CloseMsg::SetupError)?;
-
- config.tunnel.private_key = private_key;
-
- for peer in &mut config.peers {
- if pubkey == peer.public_key {
- peer.psk = Some(psk);
- break;
- }
- }
-
- log::trace!(
- "Ephemeral pubkey: {}",
- config.tunnel.private_key.public_key()
- );
-
- let set_config_future = tunnel
- .lock()
- .unwrap()
- .as_ref()
- .map(|tunnel| tunnel.set_config(config.clone()));
- if let Some(f) = set_config_future {
- f.await
- .map_err(Error::TunnelError)
- .map_err(CloseMsg::SetupError)?;
- }
-
- let allowed_traffic = AllowedTunnelTraffic::Only(
- SocketAddrV4::new(config.ipv4_gateway, 0).into(),
- Protocol::IcmpV4,
- );
- (on_event)(TunnelEvent::InterfaceUp(metadata.clone(), allowed_traffic)).await;
+ Self::perform_psk_negotiation(tunnel, retry_attempt, pubkey, &mut config).await?;
+ (on_event)(TunnelEvent::InterfaceUp(
+ metadata.clone(),
+ AllowedTunnelTraffic::All,
+ ))
+ .await;
}
let mut connectivity_monitor = tokio::task::spawn_blocking(move || {
@@ -442,6 +356,125 @@ impl WireguardMonitor {
Ok(monitor)
}
+ /// Replace `0.0.0.0/0`/`::/0` with the gateway IPs when `gateway_only` is true.
+ /// Used to block traffic to other destinations while connecting on Android.
+ fn patch_allowed_ips<'a>(config: &'a Config, gateway_only: bool) -> Cow<'a, Config> {
+ if gateway_only {
+ let mut patched_config = config.clone();
+ let gateway_net_v4 = ipnetwork::IpNetwork::from(IpAddr::from(config.ipv4_gateway));
+ let gateway_net_v6 = config
+ .ipv6_gateway
+ .map(|net| ipnetwork::IpNetwork::from(IpAddr::from(net)));
+ for peer in &mut patched_config.peers {
+ peer.allowed_ips = peer
+ .allowed_ips
+ .iter()
+ .cloned()
+ .filter_map(|mut allowed_ip| {
+ if allowed_ip.prefix() == 0 {
+ if allowed_ip.is_ipv4() {
+ allowed_ip = gateway_net_v4;
+ } else {
+ if let Some(net) = gateway_net_v6 {
+ allowed_ip = net;
+ } else {
+ return None;
+ }
+ }
+ }
+ Some(allowed_ip)
+ })
+ .collect();
+ }
+ Cow::Owned(patched_config)
+ } else {
+ Cow::Borrowed(config)
+ }
+ }
+
+ #[cfg(windows)]
+ async fn add_device_ip_addresses(
+ iface_name: &str,
+ addresses: &[IpAddr],
+ mut setup_done_rx: mpsc::Receiver<std::result::Result<(), BoxedError>>,
+ ) -> std::result::Result<(), CloseMsg> {
+ setup_done_rx
+ .next()
+ .await
+ .ok_or_else(|| {
+ // Tunnel was shut down early
+ CloseMsg::SetupError(Error::IpInterfacesError)
+ })?
+ .map_err(|error| {
+ log::error!(
+ "{}",
+ error.display_chain_with_msg("Failed to configure tunnel interface")
+ );
+ CloseMsg::SetupError(Error::IpInterfacesError)
+ })?;
+ if !crate::winnet::add_device_ip_addresses(iface_name, addresses) {
+ return Err(CloseMsg::SetupError(Error::SetIpAddressesError));
+ }
+ Ok(())
+ }
+
+ async fn perform_psk_negotiation(
+ tunnel: Arc<Mutex<Option<Box<dyn Tunnel>>>>,
+ retry_attempt: u32,
+ current_pubkey: PublicKey,
+ config: &mut Config,
+ ) -> std::result::Result<(), CloseMsg> {
+ log::debug!("Performing PQ-safe PSK exchange");
+
+ let timeout = std::cmp::min(
+ MAX_PSK_EXCHANGE_TIMEOUT,
+ INITIAL_PSK_EXCHANGE_TIMEOUT
+ .saturating_mul(PSK_EXCHANGE_TIMEOUT_MULTIPLIER.saturating_pow(retry_attempt)),
+ );
+
+ let (private_key, psk) = tokio::time::timeout(
+ timeout,
+ talpid_tunnel_config_client::push_pq_key(
+ IpAddr::V4(config.ipv4_gateway),
+ config.tunnel.private_key.public_key(),
+ ),
+ )
+ .await
+ .map_err(|_timeout_err| {
+ log::warn!("Timeout while negotiating PSK");
+ CloseMsg::PskNegotiationTimeout
+ })?
+ .map_err(Error::PskNegotiationError)
+ .map_err(CloseMsg::SetupError)?;
+
+ config.tunnel.private_key = private_key;
+
+ for peer in &mut config.peers {
+ if current_pubkey == peer.public_key {
+ peer.psk = Some(psk);
+ break;
+ }
+ }
+
+ log::trace!(
+ "Ephemeral pubkey: {}",
+ config.tunnel.private_key.public_key()
+ );
+
+ let set_config_future = tunnel
+ .lock()
+ .unwrap()
+ .as_ref()
+ .map(|tunnel| tunnel.set_config(config.clone()));
+ if let Some(f) = set_config_future {
+ f.await
+ .map_err(Error::TunnelError)
+ .map_err(CloseMsg::SetupError)?;
+ }
+
+ Ok(())
+ }
+
#[allow(unused_variables)]
fn open_tunnel(
runtime: tokio::runtime::Handle,
diff --git a/talpid-core/src/winnet.rs b/talpid-core/src/winnet.rs
index 7f31082541..3d0075257b 100644
--- a/talpid-core/src/winnet.rs
+++ b/talpid-core/src/winnet.rs
@@ -404,7 +404,7 @@ pub fn interface_luid_to_ip(
}
}
-pub fn add_device_ip_addresses(iface: &String, addresses: &Vec<IpAddr>) -> bool {
+pub fn add_device_ip_addresses(iface: &str, addresses: &[IpAddr]) -> bool {
let raw_iface = WideCString::from_str(iface)
.expect("Failed to convert UTF-8 string to null terminated UCS string")
.into_raw();