diff options
| author | David Lönnhager <david.l@mullvad.net> | 2022-03-09 11:09:51 +0100 |
|---|---|---|
| committer | David Lönnhager <david.l@mullvad.net> | 2022-03-09 11:09:51 +0100 |
| commit | cc2f7c448d34a8befcc0bd89dcaf7d92a0b6f251 (patch) | |
| tree | 1203d757e88bba614190beafb3a5e8fdcb6980af | |
| parent | 9af7973d82ed9047e5d727ce3414c116d97f05cf (diff) | |
| parent | cb2fa8e313c1375da313747bc83916c3737a3fe0 (diff) | |
| download | mullvadvpn-cc2f7c448d34a8befcc0bd89dcaf7d92a0b6f251.tar.xz mullvadvpn-cc2f7c448d34a8befcc0bd89dcaf7d92a0b6f251.zip | |
Merge branch 'fix-multihop-connectivity-monitor'
| -rw-r--r-- | talpid-core/src/tunnel/wireguard/mod.rs | 210 |
1 files changed, 103 insertions, 107 deletions
diff --git a/talpid-core/src/tunnel/wireguard/mod.rs b/talpid-core/src/tunnel/wireguard/mod.rs index cc82527e62..263bbaeabb 100644 --- a/talpid-core/src/tunnel/wireguard/mod.rs +++ b/talpid-core/src/tunnel/wireguard/mod.rs @@ -251,9 +251,19 @@ impl WireguardMonitor { (on_event)(TunnelEvent::InterfaceUp(metadata.clone())).await; - // Add a specific gateway route for the connectivity monitor + // Add non-default routes before establishing the tunnel. + #[cfg(target_os = "linux")] + route_manager + .create_routing_rules(config.enable_ipv6) + .await + .map_err(Error::SetupRoutingError) + .map_err(CloseMsg::SetupError)?; + + let routes = Self::get_pre_tunnel_routes(&iface_name, &config) + .chain(Self::get_endpoint_routes(&endpoint_addrs)) + .collect(); route_manager - .add_routes(Self::gateway_route(&iface_name, &config).collect()) + .add_routes(routes) .await .map_err(Error::SetupRoutingError) .map_err(CloseMsg::SetupError)?; @@ -277,19 +287,9 @@ impl WireguardMonitor { .await .unwrap()?; - // Set up routes once tunnel is established - #[cfg(target_os = "linux")] + // Add any default route(s) that may exist. route_manager - .create_routing_rules(config.enable_ipv6) - .await - .map_err(Error::SetupRoutingError) - .map_err(CloseMsg::SetupError)?; - - let routes = Self::get_in_tunnel_routes(&iface_name, &config) - .chain(Self::get_tunnel_traffic_routes(&endpoint_addrs)); - - route_manager - .add_routes(routes.collect()) + .add_routes(Self::get_post_tunnel_routes(&iface_name, &config).collect()) .await .map_err(Error::SetupRoutingError) .map_err(CloseMsg::SetupError)?; @@ -440,124 +440,120 @@ impl WireguardMonitor { } } - fn get_tunnel_destinations(config: &Config) -> impl Iterator<Item = ipnetwork::IpNetwork> + '_ { - let routes = config - .peers - .iter() - .flat_map(|peer| peer.allowed_ips.iter()) - .cloned(); + /// Returns routes to the peer endpoints (through the physical interface). + #[cfg_attr(target_os = "linux", allow(unused_variables))] + fn get_endpoint_routes<'a>( + endpoints: &'a [IpAddr], + ) -> impl Iterator<Item = RequiredRoute> + 'a { #[cfg(target_os = "linux")] { - routes + // No need due to policy based routing. + std::iter::empty() } #[cfg(not(target_os = "linux"))] + endpoints.iter().map(|ip| { + RequiredRoute::new( + ipnetwork::IpNetwork::from(*ip), + routing::NetNode::DefaultNode, + ) + }) + } + + #[cfg_attr(not(target_os = "windows"), allow(unused_variables))] + fn get_tunnel_nodes(iface_name: &str, config: &Config) -> (routing::Node, routing::Node) { + #[cfg(windows)] { - routes.flat_map(|allowed_ip| { - if allowed_ip.prefix() == 0 { - if allowed_ip.is_ipv4() { - vec!["0.0.0.0/1".parse().unwrap(), "128.0.0.0/1".parse().unwrap()] - } else { - vec!["8000::/1".parse().unwrap(), "::/1".parse().unwrap()] - } - } else { - vec![allowed_ip] - } - }) + let v4 = routing::Node::new(config.ipv4_gateway.clone().into(), iface_name.to_string()); + let v6 = if let Some(ipv6_gateway) = config.ipv6_gateway.as_ref() { + routing::Node::new(ipv6_gateway.clone().into(), iface_name.to_string()) + } else { + routing::Node::device(iface_name.to_string()) + }; + (v4, v6) + } + + #[cfg(not(windows))] + { + let node = routing::Node::device(iface_name.to_string()); + (node.clone(), node) } } - #[cfg(target_os = "windows")] - fn get_in_tunnel_routes<'a>( + /// Return routes for all allowed IPs, as well as the gateway, except 0.0.0.0/0. + fn get_pre_tunnel_routes<'a>( iface_name: &str, config: &'a Config, ) -> impl Iterator<Item = RequiredRoute> + 'a { - let node_v4 = - routing::Node::new(config.ipv4_gateway.clone().into(), iface_name.to_string()); - let node_v6 = if let Some(ipv6_gateway) = config.ipv6_gateway.as_ref() { - routing::Node::new(ipv6_gateway.clone().into(), iface_name.to_string()) - } else { - routing::Node::device(iface_name.to_string()) - }; - Self::get_tunnel_destinations(config).map(move |network| { - if network.is_ipv4() { - RequiredRoute::new(network, node_v4.clone()) - } else { - RequiredRoute::new(network, node_v6.clone()) - } - }) - } + let gateway_node = routing::Node::device(iface_name.to_string()); + let gateway_routes = std::iter::once(RequiredRoute::new( + ipnetwork::Ipv4Network::from(config.ipv4_gateway).into(), + gateway_node.clone(), + )) + .chain(config.ipv6_gateway.map(|gateway| { + RequiredRoute::new(ipnetwork::Ipv6Network::from(gateway).into(), gateway_node) + })); - /// On linux, there is no need - #[cfg(target_os = "linux")] - fn get_tunnel_traffic_routes<'a>( - _endpoints: &'a [IpAddr], - ) -> impl Iterator<Item = RequiredRoute> { - std::iter::empty() - } + let (node_v4, node_v6) = Self::get_tunnel_nodes(iface_name, config); - #[cfg(not(target_os = "linux"))] - fn get_tunnel_traffic_routes<'a>( - endpoints: &'a [IpAddr], - ) -> impl Iterator<Item = RequiredRoute> + 'a { - endpoints.iter().map(|ip| { - RequiredRoute::new( - ipnetwork::IpNetwork::from(*ip), - routing::NetNode::DefaultNode, - ) - }) + let routes = gateway_routes.chain( + Self::get_tunnel_destinations(config) + .filter(|allowed_ip| allowed_ip.prefix() != 0) + .map(move |allowed_ip| { + if allowed_ip.is_ipv4() { + RequiredRoute::new(allowed_ip, node_v4.clone()) + } else { + RequiredRoute::new(allowed_ip, node_v6.clone()) + } + }), + ); + + // The gateway route, as well as the exit endpoint, need to be in the main table. + // Otherwise, DNS will not work for excluded apps, nor will the exit be reachable. + #[cfg(target_os = "linux")] + let routes = routes.map(|route| route.table(u32::from(RT_TABLE_MAIN))); + + routes } - #[cfg(target_os = "linux")] - fn get_in_tunnel_routes<'a>( + /// Return any 0.0.0.0/0 routes specified by the allowed IPs. + fn get_post_tunnel_routes<'a>( iface_name: &str, config: &'a Config, ) -> impl Iterator<Item = RequiredRoute> + 'a { - let node = routing::Node::device(iface_name.to_string()); - let v4_node = node.clone(); - let v6_node = node.clone(); + let (node_v4, node_v6) = Self::get_tunnel_nodes(iface_name, config); Self::get_tunnel_destinations(config) - .map(move |network| { - if network.prefix() == 0 { - RequiredRoute::new(network, node.clone()) + .filter(|allowed_ip| allowed_ip.prefix() == 0) + .map(move |allowed_ip| { + if allowed_ip.is_ipv4() { + RequiredRoute::new(allowed_ip, node_v4.clone()) } else { - RequiredRoute::new(network, node.clone()).table(u32::from(RT_TABLE_MAIN)) + RequiredRoute::new(allowed_ip, node_v6.clone()) } }) - .chain(std::iter::once( - RequiredRoute::new( - ipnetwork::Ipv4Network::from(config.ipv4_gateway).into(), - v4_node, - ) - .table(u32::from(RT_TABLE_MAIN)), - )) - .chain(config.ipv6_gateway.map(|gateway| { - RequiredRoute::new(ipnetwork::Ipv6Network::from(gateway).into(), v6_node) - .table(u32::from(RT_TABLE_MAIN)) - })) } - #[cfg(all(not(target_os = "linux"), not(windows)))] - fn get_in_tunnel_routes<'a>( - iface_name: &str, - config: &'a Config, - ) -> impl Iterator<Item = RequiredRoute> + 'a { - let node = routing::Node::device(iface_name.to_string()); - Self::get_tunnel_destinations(config) - .map(move |network| RequiredRoute::new(network, node.clone())) - } + /// Return routes for all allowed IPs. + fn get_tunnel_destinations(config: &Config) -> impl Iterator<Item = ipnetwork::IpNetwork> + '_ { + let routes = config + .peers + .iter() + .flat_map(|peer| peer.allowed_ips.iter()) + .cloned(); - fn gateway_route<'a>( - iface_name: &str, - config: &'a Config, - ) -> impl Iterator<Item = RequiredRoute> + 'a { - let node = routing::Node::device(iface_name.to_string()); - let r = RequiredRoute::new( - ipnetwork::Ipv4Network::from(config.ipv4_gateway).into(), - node, - ); - #[cfg(target_os = "linux")] - let r = r.table(u32::from(RT_TABLE_MAIN)); - std::iter::once(r) + #[cfg(not(target_os = "linux"))] + let routes = routes.flat_map(|allowed_ip| { + if allowed_ip.prefix() == 0 { + if allowed_ip.is_ipv4() { + vec!["0.0.0.0/1".parse().unwrap(), "128.0.0.0/1".parse().unwrap()] + } else { + vec!["8000::/1".parse().unwrap(), "::/1".parse().unwrap()] + } + } else { + vec![allowed_ip] + } + }); + + routes } fn tunnel_metadata(interface_name: &str, config: &Config) -> TunnelMetadata { |
