summaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorEmīls <emils@mullvad.net>2021-09-02 15:53:51 +0100
committerEmīls <emils@mullvad.net>2021-09-02 15:53:51 +0100
commit55746a5d6a36d6fed7b84ff3c98778c62b5ee798 (patch)
tree46f91536fadccf99443a12a1bbc6e66c45327c99
parente9ad6159df5080846bd3852237948d0bfade87e3 (diff)
parentfe7861eb6c451c49a20dfcfb321f8ac2c94a7cea (diff)
downloadmullvadvpn-55746a5d6a36d6fed7b84ff3c98778c62b5ee798.tar.xz
mullvadvpn-55746a5d6a36d6fed7b84ff3c98778c62b5ee798.zip
Merge branch 'macos-fix-tcp-wg'
-rw-r--r--talpid-core/src/tunnel/wireguard/mod.rs140
1 files changed, 70 insertions, 70 deletions
diff --git a/talpid-core/src/tunnel/wireguard/mod.rs b/talpid-core/src/tunnel/wireguard/mod.rs
index a104c0037e..ed19f9bf69 100644
--- a/talpid-core/src/tunnel/wireguard/mod.rs
+++ b/talpid-core/src/tunnel/wireguard/mod.rs
@@ -11,8 +11,7 @@ use std::env;
#[cfg(windows)]
use std::io;
use std::{
- collections::HashSet,
- net::SocketAddr,
+ net::{IpAddr, SocketAddr},
path::Path,
sync::{mpsc, Arc, Mutex},
};
@@ -130,7 +129,6 @@ impl TcpProxy {
},
))
.map_err(Error::Udp2TcpError)?;
-
let local_addr = udp2tcp
.local_udp_addr()
.map_err(Error::GetLocalUdpAddress)?;
@@ -172,14 +170,15 @@ impl WireguardMonitor {
route_manager: &mut routing::RouteManager,
) -> Result<WireguardMonitor> {
let mut tcp_proxies = vec![];
+ let mut endpoint_addrs = vec![];
for peer in &mut config.peers {
+ endpoint_addrs.push(peer.endpoint.ip());
if peer.protocol == TransportProtocol::Tcp {
let udp2tcp = TcpProxy::new(&runtime, peer.endpoint.clone())?;
// Replace remote peer with proxy
peer.endpoint = udp2tcp.local_udp_addr();
-
tcp_proxies.push(udp2tcp);
}
}
@@ -277,8 +276,11 @@ impl WireguardMonitor {
.await
.map_err(Error::SetupRoutingError)?;
+ let routes = Self::get_in_tunnel_routes(&iface_name, &config)
+ .chain(Self::get_tunnel_traffic_routes(&endpoint_addrs));
+
route_handle
- .add_routes(Self::get_routes(&iface_name, &config))
+ .add_routes(routes.collect())
.await
.map_err(Error::SetupRoutingError)
})
@@ -364,7 +366,7 @@ impl WireguardMonitor {
&config,
log_path,
tun_provider,
- Self::get_tunnel_routes(config),
+ Self::get_tunnel_destinations(config),
)
.map_err(Error::TunnelError)?,
))
@@ -412,7 +414,7 @@ impl WireguardMonitor {
}
}
- fn get_tunnel_routes(config: &Config) -> impl Iterator<Item = ipnetwork::IpNetwork> + '_ {
+ fn get_tunnel_destinations(config: &Config) -> impl Iterator<Item = ipnetwork::IpNetwork> + '_ {
let routes = config
.peers
.iter()
@@ -439,88 +441,86 @@ impl WireguardMonitor {
}
#[cfg(target_os = "windows")]
- fn get_routes(iface_name: &str, config: &Config) -> HashSet<RequiredRoute> {
- let mut routes: HashSet<RequiredRoute> = {
- let node_v4 =
- routing::Node::new(config.ipv4_gateway.clone().into(), iface_name.to_string());
- let node_v6 = if let Some(ipv6_gateway) = config.ipv6_gateway.as_ref() {
- routing::Node::new(ipv6_gateway.clone().into(), iface_name.to_string())
- } else {
- routing::Node::device(iface_name.to_string())
- };
- Self::get_tunnel_routes(config)
- .map(|network| {
- if network.is_ipv4() {
- RequiredRoute::new(network, node_v4.clone())
- } else {
- RequiredRoute::new(network, node_v6.clone())
- }
- })
- .collect()
+ fn get_in_tunnel_routes<'a>(
+ iface_name: &str,
+ config: &'a Config,
+ ) -> impl Iterator<Item = RequiredRoute> + 'a {
+ let node_v4 =
+ routing::Node::new(config.ipv4_gateway.clone().into(), iface_name.to_string());
+ let node_v6 = if let Some(ipv6_gateway) = config.ipv6_gateway.as_ref() {
+ routing::Node::new(ipv6_gateway.clone().into(), iface_name.to_string())
+ } else {
+ routing::Node::device(iface_name.to_string())
};
+ Self::get_tunnel_destinations(config).map(move |network| {
+ if network.is_ipv4() {
+ RequiredRoute::new(network, node_v4.clone())
+ } else {
+ RequiredRoute::new(network, node_v6.clone())
+ }
+ })
+ }
- // route endpoints with specific routes
- for peer in config.peers.iter() {
- routes.insert(RequiredRoute::new(
- peer.endpoint.ip().into(),
- routing::NetNode::DefaultNode,
- ));
- }
- routes
+ /// On linux, there is no need
+ #[cfg(target_os = "linux")]
+ fn get_tunnel_traffic_routes<'a>(
+ _endpoints: &'a [IpAddr],
+ ) -> impl Iterator<Item = RequiredRoute> {
+ std::iter::empty()
+ }
+
+ #[cfg(not(target_os = "linux"))]
+ fn get_tunnel_traffic_routes<'a>(
+ endpoints: &'a [IpAddr],
+ ) -> impl Iterator<Item = RequiredRoute> + 'a {
+ endpoints.iter().map(|ip| {
+ RequiredRoute::new(
+ ipnetwork::IpNetwork::from(*ip),
+ routing::NetNode::DefaultNode,
+ )
+ })
}
#[cfg(target_os = "linux")]
- fn get_routes(iface_name: &str, config: &Config) -> HashSet<RequiredRoute> {
+ fn get_in_tunnel_routes<'a>(
+ iface_name: &str,
+ config: &'a Config,
+ ) -> impl Iterator<Item = RequiredRoute> + 'a {
use netlink_packet_route::rtnl::constants::RT_TABLE_MAIN;
let node = routing::Node::device(iface_name.to_string());
- let mut routes: HashSet<RequiredRoute> = Self::get_tunnel_routes(config)
- .map(|network| {
+ let v4_node = node.clone();
+ let v6_node = node.clone();
+ Self::get_tunnel_destinations(config)
+ .map(move |network| {
if network.prefix() == 0 {
RequiredRoute::new(network, node.clone())
} else {
RequiredRoute::new(network, node.clone()).table(u32::from(RT_TABLE_MAIN))
}
})
- .collect();
-
- // add routes for the gateway so that DNS requests can be made in the tunnel
- // using `mullvad-exclude`
- routes.insert(
- RequiredRoute::new(
- ipnetwork::Ipv4Network::from(config.ipv4_gateway).into(),
- node.clone(),
- )
- .table(u32::from(RT_TABLE_MAIN)),
- );
-
- if let Some(gateway) = config.ipv6_gateway {
- routes.insert(
- RequiredRoute::new(ipnetwork::Ipv6Network::from(gateway).into(), node.clone())
- .table(u32::from(RT_TABLE_MAIN)),
- );
- }
-
- routes
+ .chain(std::iter::once(
+ RequiredRoute::new(
+ ipnetwork::Ipv4Network::from(config.ipv4_gateway).into(),
+ v4_node,
+ )
+ .table(u32::from(RT_TABLE_MAIN)),
+ ))
+ .chain(config.ipv6_gateway.map(|gateway| {
+ RequiredRoute::new(ipnetwork::Ipv6Network::from(gateway).into(), v6_node)
+ .table(u32::from(RT_TABLE_MAIN))
+ }))
}
#[cfg(all(not(target_os = "linux"), not(windows)))]
- fn get_routes(iface_name: &str, config: &Config) -> HashSet<RequiredRoute> {
+ fn get_in_tunnel_routes<'a>(
+ iface_name: &str,
+ config: &'a Config,
+ ) -> impl Iterator<Item = RequiredRoute> + 'a {
let node = routing::Node::device(iface_name.to_string());
- let mut routes: HashSet<RequiredRoute> = Self::get_tunnel_routes(config)
- .map(|network| RequiredRoute::new(network, node.clone()))
- .collect();
-
- // route endpoints with specific routes
- for peer in config.peers.iter() {
- routes.insert(RequiredRoute::new(
- peer.endpoint.ip().into(),
- routing::NetNode::DefaultNode,
- ));
- }
-
- routes
+ Self::get_tunnel_destinations(config)
+ .map(move |network| RequiredRoute::new(network, node.clone()))
}
fn tunnel_metadata(interface_name: &str, config: &Config) -> TunnelMetadata {