diff options
| author | Emīls <emils@mullvad.net> | 2021-09-02 15:53:51 +0100 |
|---|---|---|
| committer | Emīls <emils@mullvad.net> | 2021-09-02 15:53:51 +0100 |
| commit | 55746a5d6a36d6fed7b84ff3c98778c62b5ee798 (patch) | |
| tree | 46f91536fadccf99443a12a1bbc6e66c45327c99 | |
| parent | e9ad6159df5080846bd3852237948d0bfade87e3 (diff) | |
| parent | fe7861eb6c451c49a20dfcfb321f8ac2c94a7cea (diff) | |
| download | mullvadvpn-55746a5d6a36d6fed7b84ff3c98778c62b5ee798.tar.xz mullvadvpn-55746a5d6a36d6fed7b84ff3c98778c62b5ee798.zip | |
Merge branch 'macos-fix-tcp-wg'
| -rw-r--r-- | talpid-core/src/tunnel/wireguard/mod.rs | 140 |
1 files changed, 70 insertions, 70 deletions
diff --git a/talpid-core/src/tunnel/wireguard/mod.rs b/talpid-core/src/tunnel/wireguard/mod.rs index a104c0037e..ed19f9bf69 100644 --- a/talpid-core/src/tunnel/wireguard/mod.rs +++ b/talpid-core/src/tunnel/wireguard/mod.rs @@ -11,8 +11,7 @@ use std::env; #[cfg(windows)] use std::io; use std::{ - collections::HashSet, - net::SocketAddr, + net::{IpAddr, SocketAddr}, path::Path, sync::{mpsc, Arc, Mutex}, }; @@ -130,7 +129,6 @@ impl TcpProxy { }, )) .map_err(Error::Udp2TcpError)?; - let local_addr = udp2tcp .local_udp_addr() .map_err(Error::GetLocalUdpAddress)?; @@ -172,14 +170,15 @@ impl WireguardMonitor { route_manager: &mut routing::RouteManager, ) -> Result<WireguardMonitor> { let mut tcp_proxies = vec![]; + let mut endpoint_addrs = vec![]; for peer in &mut config.peers { + endpoint_addrs.push(peer.endpoint.ip()); if peer.protocol == TransportProtocol::Tcp { let udp2tcp = TcpProxy::new(&runtime, peer.endpoint.clone())?; // Replace remote peer with proxy peer.endpoint = udp2tcp.local_udp_addr(); - tcp_proxies.push(udp2tcp); } } @@ -277,8 +276,11 @@ impl WireguardMonitor { .await .map_err(Error::SetupRoutingError)?; + let routes = Self::get_in_tunnel_routes(&iface_name, &config) + .chain(Self::get_tunnel_traffic_routes(&endpoint_addrs)); + route_handle - .add_routes(Self::get_routes(&iface_name, &config)) + .add_routes(routes.collect()) .await .map_err(Error::SetupRoutingError) }) @@ -364,7 +366,7 @@ impl WireguardMonitor { &config, log_path, tun_provider, - Self::get_tunnel_routes(config), + Self::get_tunnel_destinations(config), ) .map_err(Error::TunnelError)?, )) @@ -412,7 +414,7 @@ impl WireguardMonitor { } } - fn get_tunnel_routes(config: &Config) -> impl Iterator<Item = ipnetwork::IpNetwork> + '_ { + fn get_tunnel_destinations(config: &Config) -> impl Iterator<Item = ipnetwork::IpNetwork> + '_ { let routes = config .peers .iter() @@ -439,88 +441,86 @@ impl WireguardMonitor { } #[cfg(target_os = "windows")] - fn get_routes(iface_name: &str, config: &Config) -> HashSet<RequiredRoute> { - let mut routes: HashSet<RequiredRoute> = { - let node_v4 = - routing::Node::new(config.ipv4_gateway.clone().into(), iface_name.to_string()); - let node_v6 = if let Some(ipv6_gateway) = config.ipv6_gateway.as_ref() { - routing::Node::new(ipv6_gateway.clone().into(), iface_name.to_string()) - } else { - routing::Node::device(iface_name.to_string()) - }; - Self::get_tunnel_routes(config) - .map(|network| { - if network.is_ipv4() { - RequiredRoute::new(network, node_v4.clone()) - } else { - RequiredRoute::new(network, node_v6.clone()) - } - }) - .collect() + fn get_in_tunnel_routes<'a>( + iface_name: &str, + config: &'a Config, + ) -> impl Iterator<Item = RequiredRoute> + 'a { + let node_v4 = + routing::Node::new(config.ipv4_gateway.clone().into(), iface_name.to_string()); + let node_v6 = if let Some(ipv6_gateway) = config.ipv6_gateway.as_ref() { + routing::Node::new(ipv6_gateway.clone().into(), iface_name.to_string()) + } else { + routing::Node::device(iface_name.to_string()) }; + Self::get_tunnel_destinations(config).map(move |network| { + if network.is_ipv4() { + RequiredRoute::new(network, node_v4.clone()) + } else { + RequiredRoute::new(network, node_v6.clone()) + } + }) + } - // route endpoints with specific routes - for peer in config.peers.iter() { - routes.insert(RequiredRoute::new( - peer.endpoint.ip().into(), - routing::NetNode::DefaultNode, - )); - } - routes + /// On linux, there is no need + #[cfg(target_os = "linux")] + fn get_tunnel_traffic_routes<'a>( + _endpoints: &'a [IpAddr], + ) -> impl Iterator<Item = RequiredRoute> { + std::iter::empty() + } + + #[cfg(not(target_os = "linux"))] + fn get_tunnel_traffic_routes<'a>( + endpoints: &'a [IpAddr], + ) -> impl Iterator<Item = RequiredRoute> + 'a { + endpoints.iter().map(|ip| { + RequiredRoute::new( + ipnetwork::IpNetwork::from(*ip), + routing::NetNode::DefaultNode, + ) + }) } #[cfg(target_os = "linux")] - fn get_routes(iface_name: &str, config: &Config) -> HashSet<RequiredRoute> { + fn get_in_tunnel_routes<'a>( + iface_name: &str, + config: &'a Config, + ) -> impl Iterator<Item = RequiredRoute> + 'a { use netlink_packet_route::rtnl::constants::RT_TABLE_MAIN; let node = routing::Node::device(iface_name.to_string()); - let mut routes: HashSet<RequiredRoute> = Self::get_tunnel_routes(config) - .map(|network| { + let v4_node = node.clone(); + let v6_node = node.clone(); + Self::get_tunnel_destinations(config) + .map(move |network| { if network.prefix() == 0 { RequiredRoute::new(network, node.clone()) } else { RequiredRoute::new(network, node.clone()).table(u32::from(RT_TABLE_MAIN)) } }) - .collect(); - - // add routes for the gateway so that DNS requests can be made in the tunnel - // using `mullvad-exclude` - routes.insert( - RequiredRoute::new( - ipnetwork::Ipv4Network::from(config.ipv4_gateway).into(), - node.clone(), - ) - .table(u32::from(RT_TABLE_MAIN)), - ); - - if let Some(gateway) = config.ipv6_gateway { - routes.insert( - RequiredRoute::new(ipnetwork::Ipv6Network::from(gateway).into(), node.clone()) - .table(u32::from(RT_TABLE_MAIN)), - ); - } - - routes + .chain(std::iter::once( + RequiredRoute::new( + ipnetwork::Ipv4Network::from(config.ipv4_gateway).into(), + v4_node, + ) + .table(u32::from(RT_TABLE_MAIN)), + )) + .chain(config.ipv6_gateway.map(|gateway| { + RequiredRoute::new(ipnetwork::Ipv6Network::from(gateway).into(), v6_node) + .table(u32::from(RT_TABLE_MAIN)) + })) } #[cfg(all(not(target_os = "linux"), not(windows)))] - fn get_routes(iface_name: &str, config: &Config) -> HashSet<RequiredRoute> { + fn get_in_tunnel_routes<'a>( + iface_name: &str, + config: &'a Config, + ) -> impl Iterator<Item = RequiredRoute> + 'a { let node = routing::Node::device(iface_name.to_string()); - let mut routes: HashSet<RequiredRoute> = Self::get_tunnel_routes(config) - .map(|network| RequiredRoute::new(network, node.clone())) - .collect(); - - // route endpoints with specific routes - for peer in config.peers.iter() { - routes.insert(RequiredRoute::new( - peer.endpoint.ip().into(), - routing::NetNode::DefaultNode, - )); - } - - routes + Self::get_tunnel_destinations(config) + .map(move |network| RequiredRoute::new(network, node.clone())) } fn tunnel_metadata(interface_name: &str, config: &Config) -> TunnelMetadata { |
