diff options
| -rw-r--r-- | talpid-core/src/tunnel/wireguard/mod.rs | 110 |
1 files changed, 46 insertions, 64 deletions
diff --git a/talpid-core/src/tunnel/wireguard/mod.rs b/talpid-core/src/tunnel/wireguard/mod.rs index 6cc2d85fdb..404b2dec68 100644 --- a/talpid-core/src/tunnel/wireguard/mod.rs +++ b/talpid-core/src/tunnel/wireguard/mod.rs @@ -3,11 +3,9 @@ use self::config::Config; use super::tun_provider; use super::{tun_provider::TunProvider, TunnelEvent, TunnelMetadata}; use crate::routing::{self, RequiredRoute}; -#[cfg(windows)] -use futures::channel::{mpsc, oneshot}; use futures::future::abortable; #[cfg(windows)] -use futures::{FutureExt, StreamExt}; +use futures::{channel::mpsc, StreamExt}; #[cfg(target_os = "linux")] use lazy_static::lazy_static; #[cfg(target_os = "linux")] @@ -92,8 +90,6 @@ pub struct WireguardMonitor { >, close_msg_sender: sync_mpsc::Sender<CloseMsg>, close_msg_receiver: sync_mpsc::Receiver<CloseMsg>, - #[cfg(target_os = "windows")] - stop_setup_tx: Option<oneshot::Sender<()>>, pinger_stop_sender: sync_mpsc::Sender<()>, _tcp_proxies: Vec<TcpProxy>, } @@ -191,7 +187,7 @@ impl WireguardMonitor { } #[cfg(target_os = "windows")] - let (setup_done_tx, setup_done_rx) = mpsc::channel(0); + let (setup_done_tx, mut setup_done_rx) = mpsc::channel(0); let tunnel = Self::open_tunnel( runtime.clone(), &config, @@ -207,16 +203,12 @@ impl WireguardMonitor { let event_callback = Box::new(on_event.clone()); let (close_msg_sender, close_msg_receiver) = sync_mpsc::channel(); let (pinger_tx, pinger_rx) = sync_mpsc::channel(); - #[cfg(target_os = "windows")] - let (stop_setup_tx, stop_setup_rx) = oneshot::channel(); let monitor = WireguardMonitor { runtime: runtime.clone(), tunnel: Arc::new(Mutex::new(Some(tunnel))), event_callback, close_msg_sender, close_msg_receiver, - #[cfg(target_os = "windows")] - stop_setup_tx: Some(stop_setup_tx), pinger_stop_sender: pinger_tx, _tcp_proxies: tcp_proxies, }; @@ -236,84 +228,78 @@ impl WireguardMonitor { let metadata = Self::tunnel_metadata(&iface_name, &config); - std::thread::spawn(move || { + tokio::spawn(async move { #[cfg(windows)] { - let mut done_rx = setup_done_rx.fuse(); let iface_close_sender = close_sender.clone(); - let result = runtime.block_on(async move { - futures::select! { - result = done_rx.next() => { - match result { - Some(result) => { - result.map_err(|error| { - log::error!("{}", error.display_chain_with_msg("Failed to configure tunnel interface")); - iface_close_sender.send(CloseMsg::SetupError( - Error::IpInterfacesError - )) - .unwrap_or(()) - }) - } - None => Err(()), - } - } - _ = stop_setup_rx.fuse() => Err(()), - } - }); + let result = match setup_done_rx.next().await { + Some(result) => result.map_err(|error| { + log::error!( + "{}", + error.display_chain_with_msg("Failed to configure tunnel interface") + ); + iface_close_sender + .send(CloseMsg::SetupError(Error::IpInterfacesError)) + .unwrap_or(()) + }), + None => Err(()), + }; if result.is_err() { return; } } - runtime.block_on((on_event)(TunnelEvent::InterfaceUp(metadata.clone()))); + (on_event)(TunnelEvent::InterfaceUp(metadata.clone())).await; - let setup_iface_routes = || -> Result<()> { + let setup_iface_routes = async move { #[cfg(target_os = "windows")] if !crate::winnet::add_device_ip_addresses(&iface_name, &config.tunnel.addresses) { return Err(Error::SetIpAddressesError); } - runtime.block_on(async move { - #[cfg(target_os = "linux")] - route_handle - .create_routing_rules(config.enable_ipv6) - .await - .map_err(Error::SetupRoutingError)?; + #[cfg(target_os = "linux")] + route_handle + .create_routing_rules(config.enable_ipv6) + .await + .map_err(Error::SetupRoutingError)?; - let routes = Self::get_in_tunnel_routes(&iface_name, &config) - .chain(Self::get_tunnel_traffic_routes(&endpoint_addrs)); + let routes = Self::get_in_tunnel_routes(&iface_name, &config) + .chain(Self::get_tunnel_traffic_routes(&endpoint_addrs)); - route_handle - .add_routes(routes.collect()) - .await - .map_err(Error::SetupRoutingError) - }) + route_handle + .add_routes(routes.collect()) + .await + .map_err(Error::SetupRoutingError) }; - if let Err(error) = setup_iface_routes() { + if let Err(error) = setup_iface_routes.await { let _ = close_sender.send(CloseMsg::SetupError(error)); return; } - match connectivity_monitor.establish_connectivity(retry_attempt) { - Ok(true) => { - runtime.block_on((on_event)(TunnelEvent::Up(metadata))); + tokio::task::spawn_blocking(move || { + match connectivity_monitor.establish_connectivity(retry_attempt) { + Ok(true) => { + tokio::spawn((on_event)(TunnelEvent::Up(metadata))); - if let Err(error) = connectivity_monitor.run() { + if let Err(error) = connectivity_monitor.run() { + log::error!( + "{}", + error.display_chain_with_msg("Connectivity monitor failed") + ); + } + } + Ok(false) => log::warn!("Timeout while checking tunnel connection"), + Err(error) => { log::error!( "{}", - error.display_chain_with_msg("Connectivity monitor failed") + error.display_chain_with_msg("Failed to check tunnel connection") ); } } - Ok(false) => log::warn!("Timeout while checking tunnel connection"), - Err(error) => { - log::error!( - "{}", - error.display_chain_with_msg("Failed to check tunnel connection") - ); - } - } + }) + .await + .expect("connectivity monitor thread panicked"); let _ = close_sender.send(CloseMsg::PingErr); }); @@ -422,10 +408,6 @@ impl WireguardMonitor { Err(_) => Ok(()), }; - #[cfg(windows)] - if let Some(stop_tx) = self.stop_setup_tx.take() { - let _ = stop_tx.send(()); - } let _ = self.pinger_stop_sender.send(()); self.stop_tunnel(); |
