summaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorDavid Lönnhager <david.l@mullvad.net>2022-02-08 10:39:11 +0100
committerDavid Lönnhager <david.l@mullvad.net>2022-02-14 11:00:00 +0100
commit6453491bb01bf4ad31a5377a9f16358bf185bc3b (patch)
tree8b24b86182165ae48e366f464b89b38c44262fe9
parent4100a6c46d51e7735a028a0dd4b6dff7c4201638 (diff)
downloadmullvadvpn-6453491bb01bf4ad31a5377a9f16358bf185bc3b.tar.xz
mullvadvpn-6453491bb01bf4ad31a5377a9f16358bf185bc3b.zip
Set up tunnel monitor in separate thread
-rw-r--r--talpid-core/src/tunnel/mod.rs81
-rw-r--r--talpid-core/src/tunnel/openvpn/mod.rs57
-rw-r--r--talpid-core/src/tunnel/wireguard/mod.rs30
-rw-r--r--talpid-core/src/tunnel/wireguard/wireguard_go.rs20
-rw-r--r--talpid-core/src/tunnel_state_machine/connected_state.rs18
-rw-r--r--talpid-core/src/tunnel_state_machine/connecting_state.rs206
-rw-r--r--talpid-core/src/tunnel_state_machine/disconnected_state.rs2
-rw-r--r--talpid-core/src/tunnel_state_machine/disconnecting_state.rs25
-rw-r--r--talpid-core/src/tunnel_state_machine/error_state.rs7
-rw-r--r--talpid-core/src/tunnel_state_machine/mod.rs25
10 files changed, 248 insertions, 223 deletions
diff --git a/talpid-core/src/tunnel/mod.rs b/talpid-core/src/tunnel/mod.rs
index 0ed39d7bee..ca6665114f 100644
--- a/talpid-core/src/tunnel/mod.rs
+++ b/talpid-core/src/tunnel/mod.rs
@@ -1,9 +1,10 @@
use self::tun_provider::TunProvider;
-use crate::{logging, routing::RouteManager};
+use crate::{logging, routing::RouteManagerHandle};
+use futures::channel::oneshot;
use std::{
- io,
net::{IpAddr, Ipv4Addr, Ipv6Addr},
path::{Path, PathBuf},
+ sync::{Arc, Mutex},
};
#[cfg(not(target_os = "android"))]
use talpid_types::net::openvpn as openvpn_types;
@@ -104,9 +105,10 @@ impl TunnelMonitor {
log_dir: &Option<PathBuf>,
resource_dir: &Path,
on_event: L,
- tun_provider: &mut TunProvider,
- route_manager: &mut RouteManager,
+ tun_provider: Arc<Mutex<TunProvider>>,
+ route_manager: RouteManagerHandle,
retry_attempt: u32,
+ tunnel_close_rx: oneshot::Receiver<()>,
) -> Result<Self>
where
L: (Fn(TunnelEvent) -> std::pin::Pin<Box<dyn std::future::Future<Output = ()> + Send>>)
@@ -120,9 +122,15 @@ impl TunnelMonitor {
match tunnel_parameters {
#[cfg(not(target_os = "android"))]
- TunnelParameters::OpenVpn(config) => {
- Self::start_openvpn_tunnel(&config, log_file, resource_dir, on_event, route_manager)
- }
+ TunnelParameters::OpenVpn(config) => Self::start_openvpn_tunnel(
+ &config,
+ log_file,
+ resource_dir,
+ on_event,
+ tunnel_close_rx,
+ #[cfg(target_os = "linux")]
+ route_manager,
+ ),
#[cfg(target_os = "android")]
TunnelParameters::OpenVpn(_) => Err(Error::UnsupportedPlatform),
@@ -135,6 +143,7 @@ impl TunnelMonitor {
tun_provider,
route_manager,
retry_attempt,
+ tunnel_close_rx,
),
}
}
@@ -165,9 +174,10 @@ impl TunnelMonitor {
log: Option<PathBuf>,
resource_dir: &Path,
on_event: L,
- tun_provider: &mut TunProvider,
- route_manager: &mut RouteManager,
+ tun_provider: Arc<Mutex<TunProvider>>,
+ route_manager: RouteManagerHandle,
retry_attempt: u32,
+ tunnel_close_rx: oneshot::Receiver<()>,
) -> Result<Self>
where
L: (Fn(TunnelEvent) -> std::pin::Pin<Box<dyn std::future::Future<Output = ()> + Send>>)
@@ -186,6 +196,7 @@ impl TunnelMonitor {
tun_provider,
route_manager,
retry_attempt,
+ tunnel_close_rx,
)?;
Ok(TunnelMonitor {
monitor: InternalTunnelMonitor::Wireguard(monitor),
@@ -198,7 +209,8 @@ impl TunnelMonitor {
log: Option<PathBuf>,
resource_dir: &Path,
on_event: L,
- route_manager: &mut RouteManager,
+ tunnel_close_rx: oneshot::Receiver<()>,
+ #[cfg(target_os = "linux")] route_manager: RouteManagerHandle,
) -> Result<Self>
where
L: (Fn(TunnelEvent) -> std::pin::Pin<Box<dyn std::future::Future<Output = ()> + Send>>)
@@ -206,8 +218,15 @@ impl TunnelMonitor {
+ Sync
+ 'static,
{
- let monitor =
- openvpn::OpenVpnMonitor::start(on_event, config, log, resource_dir, route_manager)?;
+ let monitor = openvpn::OpenVpnMonitor::start(
+ on_event,
+ config,
+ log,
+ resource_dir,
+ tunnel_close_rx,
+ #[cfg(target_os = "linux")]
+ route_manager,
+ )?;
Ok(TunnelMonitor {
monitor: InternalTunnelMonitor::OpenVpn(monitor),
})
@@ -263,42 +282,12 @@ impl TunnelMonitor {
}
}
- /// Creates a handle to this monitor, allowing the tunnel to be closed while some other
- /// thread
- /// is blocked in `wait`.
- pub fn close_handle(&self) -> CloseHandle {
- self.monitor.close_handle()
- }
-
/// Consumes the monitor and blocks until the tunnel exits or there is an error.
pub fn wait(self) -> Result<()> {
self.monitor.wait().map_err(Error::from)
}
}
-/// A handle to a `TunnelMonitor`
-pub enum CloseHandle {
- #[cfg(not(target_os = "android"))]
- /// OpenVpn close handle
- OpenVpn(openvpn::OpenVpnCloseHandle),
- /// Wireguard close handle
- Wireguard(wireguard::CloseHandle),
-}
-
-impl CloseHandle {
- /// Closes the underlying tunnel, making the `TunnelMonitor::wait` method return.
- pub fn close(self) -> io::Result<()> {
- match self {
- #[cfg(not(target_os = "android"))]
- CloseHandle::OpenVpn(handle) => handle.close(),
- CloseHandle::Wireguard(mut handle) => {
- handle.close();
- Ok(())
- }
- }
- }
-}
-
enum InternalTunnelMonitor {
#[cfg(not(target_os = "android"))]
OpenVpn(openvpn::OpenVpnMonitor),
@@ -306,14 +295,6 @@ enum InternalTunnelMonitor {
}
impl InternalTunnelMonitor {
- fn close_handle(&self) -> CloseHandle {
- match self {
- #[cfg(not(target_os = "android"))]
- InternalTunnelMonitor::OpenVpn(tun) => CloseHandle::OpenVpn(tun.close_handle()),
- InternalTunnelMonitor::Wireguard(tun) => CloseHandle::Wireguard(tun.close_handle()),
- }
- }
-
fn wait(self) -> Result<()> {
match self {
#[cfg(not(target_os = "android"))]
diff --git a/talpid-core/src/tunnel/openvpn/mod.rs b/talpid-core/src/tunnel/openvpn/mod.rs
index ffca275f68..446ce13d79 100644
--- a/talpid-core/src/tunnel/openvpn/mod.rs
+++ b/talpid-core/src/tunnel/openvpn/mod.rs
@@ -1,6 +1,6 @@
use super::TunnelEvent;
#[cfg(target_os = "linux")]
-use crate::routing::RequiredRoute;
+use crate::routing::{self, RequiredRoute};
use crate::{
mktemp,
process::{
@@ -8,8 +8,8 @@ use crate::{
stoppable_process::StoppableProcess,
},
proxy::{self, ProxyMonitor, ProxyResourceData},
- routing,
};
+use futures::channel::oneshot;
#[cfg(windows)]
use lazy_static::lazy_static;
#[cfg(target_os = "linux")]
@@ -65,11 +65,6 @@ pub enum Error {
#[error(display = "Failed to initialize the tokio runtime")]
RuntimeError(#[error(source)] io::Error),
- /// Failed to set up routing.
- #[cfg(target_os = "linux")]
- #[error(display = "Failed to setup routing")]
- SetupRoutingError(#[error(source)] routing::Error),
-
/// Unable to start, wait for or kill the OpenVPN process.
#[error(display = "Error in OpenVPN process management: {}", _0)]
ChildProcessError(&'static str, #[error(source)] io::Error),
@@ -254,8 +249,8 @@ impl OpenVpnMonitor<OpenVpnCommand> {
params: &openvpn::TunnelParameters,
log_path: Option<PathBuf>,
resource_dir: &Path,
- #[cfg(target_os = "linux")] route_manager: &mut routing::RouteManager,
- #[cfg(not(target_os = "linux"))] _route_manager: &mut routing::RouteManager,
+ tunnel_close_rx: oneshot::Receiver<()>,
+ #[cfg(target_os = "linux")] route_manager: routing::RouteManagerHandle,
) -> Result<Self>
where
L: (Fn(TunnelEvent) -> std::pin::Pin<Box<dyn std::future::Future<Output = ()> + Send>>)
@@ -323,8 +318,6 @@ impl OpenVpnMonitor<OpenVpnCommand> {
#[cfg(target_os = "linux")]
let ipv6_enabled = params.generic_options.enable_ipv6;
- #[cfg(target_os = "linux")]
- let route_manager_handle = route_manager.handle().map_err(Error::SetupRoutingError)?;
let (event_server_abort_tx, event_server_abort_rx) = triggered::trigger();
@@ -338,7 +331,7 @@ impl OpenVpnMonitor<OpenVpnCommand> {
proxy_auth_file_path: proxy_auth_file_path.clone(),
abort_server_tx: event_server_abort_tx,
#[cfg(target_os = "linux")]
- route_manager_handle,
+ route_manager_handle: route_manager,
#[cfg(target_os = "linux")]
ipv6_enabled,
},
@@ -347,6 +340,7 @@ impl OpenVpnMonitor<OpenVpnCommand> {
user_pass_file,
proxy_auth_file,
proxy_monitor,
+ tunnel_close_rx,
#[cfg(windows)]
Box::new(WintunContextImpl {
adapter: wintun_adapter,
@@ -379,6 +373,7 @@ impl<C: OpenVpnBuilder + Send + 'static> OpenVpnMonitor<C> {
user_pass_file: mktemp::TempFile,
proxy_auth_file: Option<mktemp::TempFile>,
proxy_monitor: Option<Box<dyn ProxyMonitor>>,
+ tunnel_close_rx: oneshot::Receiver<()>,
#[cfg(windows)] wintun: Box<dyn WintunContext>,
) -> Result<OpenVpnMonitor<C>>
where
@@ -424,7 +419,9 @@ impl<C: OpenVpnBuilder + Send + 'static> OpenVpnMonitor<C> {
));
let spawn_task = runtime.spawn(spawn_task);
- Ok(OpenVpnMonitor {
+ let handle = runtime.handle().clone();
+
+ let monitor = OpenVpnMonitor {
spawn_task: Some(spawn_task),
abort_spawn,
child: Arc::new(Mutex::new(None)),
@@ -439,7 +436,25 @@ impl<C: OpenVpnBuilder + Send + 'static> OpenVpnMonitor<C> {
#[cfg(windows)]
_wintun: wintun,
- })
+ };
+
+ let close_handle = monitor.close_handle();
+ handle.spawn(async move {
+ if tunnel_close_rx.await.is_ok() {
+ tokio::task::spawn_blocking(move || {
+ if let Err(error) = close_handle.close() {
+ log::error!(
+ "{}",
+ error.display_chain_with_msg("Failed to close the tunnel")
+ );
+ }
+ })
+ .await
+ .expect("close handle panic");
+ }
+ });
+
+ Ok(monitor)
}
async fn prepare_process(
@@ -457,7 +472,7 @@ impl<C: OpenVpnBuilder + Send + 'static> OpenVpnMonitor<C> {
/// Creates a handle to this monitor, allowing the tunnel to be closed while some other
/// thread is blocked in `wait`.
- pub fn close_handle(&self) -> OpenVpnCloseHandle<C::ProcessHandle> {
+ fn close_handle(&self) -> OpenVpnCloseHandle<C::ProcessHandle> {
OpenVpnCloseHandle {
child: self.child.clone(),
abort_spawn: self.abort_spawn.clone(),
@@ -1212,6 +1227,7 @@ mod tests {
fn sets_plugin() {
let builder = TestOpenVpnBuilder::default();
let (event_server_abort_tx, event_server_abort_rx) = triggered::trigger();
+ let (_close_tx, close_rx) = oneshot::channel();
let _ = OpenVpnMonitor::new_internal(
builder.clone(),
event_server_abort_tx,
@@ -1222,6 +1238,7 @@ mod tests {
TempFile::new(),
None,
None,
+ close_rx,
#[cfg(windows)]
Box::new(TestWintunContext {}),
);
@@ -1235,6 +1252,7 @@ mod tests {
fn sets_log() {
let builder = TestOpenVpnBuilder::default();
let (event_server_abort_tx, event_server_abort_rx) = triggered::trigger();
+ let (_close_tx, close_rx) = oneshot::channel();
let _ = OpenVpnMonitor::new_internal(
builder.clone(),
event_server_abort_tx,
@@ -1245,6 +1263,7 @@ mod tests {
TempFile::new(),
None,
None,
+ close_rx,
#[cfg(windows)]
Box::new(TestWintunContext {}),
);
@@ -1259,6 +1278,7 @@ mod tests {
let mut builder = TestOpenVpnBuilder::default();
builder.process_handle = Some(TestProcessHandle(0));
let (event_server_abort_tx, event_server_abort_rx) = triggered::trigger();
+ let (_close_tx, close_rx) = oneshot::channel();
let testee = OpenVpnMonitor::new_internal(
builder,
event_server_abort_tx,
@@ -1269,6 +1289,7 @@ mod tests {
TempFile::new(),
None,
None,
+ close_rx,
#[cfg(windows)]
Box::new(TestWintunContext {}),
)
@@ -1281,6 +1302,7 @@ mod tests {
let mut builder = TestOpenVpnBuilder::default();
builder.process_handle = Some(TestProcessHandle(1));
let (event_server_abort_tx, event_server_abort_rx) = triggered::trigger();
+ let (_close_tx, close_rx) = oneshot::channel();
let testee = OpenVpnMonitor::new_internal(
builder,
event_server_abort_tx,
@@ -1291,6 +1313,7 @@ mod tests {
TempFile::new(),
None,
None,
+ close_rx,
#[cfg(windows)]
Box::new(TestWintunContext {}),
)
@@ -1303,6 +1326,7 @@ mod tests {
let mut builder = TestOpenVpnBuilder::default();
builder.process_handle = Some(TestProcessHandle(1));
let (event_server_abort_tx, event_server_abort_rx) = triggered::trigger();
+ let (_close_tx, close_rx) = oneshot::channel();
let testee = OpenVpnMonitor::new_internal(
builder,
event_server_abort_tx,
@@ -1313,6 +1337,7 @@ mod tests {
TempFile::new(),
None,
None,
+ close_rx,
#[cfg(windows)]
Box::new(TestWintunContext {}),
)
@@ -1325,6 +1350,7 @@ mod tests {
fn failed_process_start() {
let builder = TestOpenVpnBuilder::default();
let (event_server_abort_tx, event_server_abort_rx) = triggered::trigger();
+ let (_close_tx, close_rx) = oneshot::channel();
let result = OpenVpnMonitor::new_internal(
builder,
event_server_abort_tx,
@@ -1335,6 +1361,7 @@ mod tests {
TempFile::new(),
None,
None,
+ close_rx,
#[cfg(windows)]
Box::new(TestWintunContext {}),
)
diff --git a/talpid-core/src/tunnel/wireguard/mod.rs b/talpid-core/src/tunnel/wireguard/mod.rs
index 404b2dec68..8b5e0cfa4a 100644
--- a/talpid-core/src/tunnel/wireguard/mod.rs
+++ b/talpid-core/src/tunnel/wireguard/mod.rs
@@ -2,10 +2,10 @@ use self::config::Config;
#[cfg(not(windows))]
use super::tun_provider;
use super::{tun_provider::TunProvider, TunnelEvent, TunnelMetadata};
-use crate::routing::{self, RequiredRoute};
-use futures::future::abortable;
+use crate::routing::{self, RequiredRoute, RouteManagerHandle};
#[cfg(windows)]
use futures::{channel::mpsc, StreamExt};
+use futures::{channel::oneshot, future::abortable};
#[cfg(target_os = "linux")]
use lazy_static::lazy_static;
#[cfg(target_os = "linux")]
@@ -168,9 +168,10 @@ impl WireguardMonitor {
log_path: Option<&Path>,
resource_dir: &Path,
on_event: F,
- tun_provider: &mut TunProvider,
- route_manager: &mut routing::RouteManager,
+ tun_provider: Arc<Mutex<TunProvider>>,
+ route_manager: RouteManagerHandle,
retry_attempt: u32,
+ tunnel_close_rx: oneshot::Receiver<()>,
) -> Result<WireguardMonitor> {
let mut tcp_proxies = vec![];
let mut endpoint_addrs = vec![];
@@ -194,7 +195,6 @@ impl WireguardMonitor {
log_path,
resource_dir,
tun_provider,
- route_manager,
#[cfg(target_os = "windows")]
setup_done_tx,
)?;
@@ -224,8 +224,6 @@ impl WireguardMonitor {
)
.map_err(Error::ConnectivityMonitorError)?;
- let route_handle = route_manager.handle().map_err(Error::SetupRoutingError)?;
-
let metadata = Self::tunnel_metadata(&iface_name, &config);
tokio::spawn(async move {
@@ -258,7 +256,7 @@ impl WireguardMonitor {
}
#[cfg(target_os = "linux")]
- route_handle
+ route_manager
.create_routing_rules(config.enable_ipv6)
.await
.map_err(Error::SetupRoutingError)?;
@@ -266,7 +264,7 @@ impl WireguardMonitor {
let routes = Self::get_in_tunnel_routes(&iface_name, &config)
.chain(Self::get_tunnel_traffic_routes(&endpoint_addrs));
- route_handle
+ route_manager
.add_routes(routes.collect())
.await
.map_err(Error::SetupRoutingError)
@@ -304,6 +302,13 @@ impl WireguardMonitor {
let _ = close_sender.send(CloseMsg::PingErr);
});
+ let mut close_handle = monitor.close_handle();
+ tokio::spawn(async move {
+ if tunnel_close_rx.await.is_ok() {
+ close_handle.close();
+ }
+ });
+
Ok(monitor)
}
@@ -313,8 +318,7 @@ impl WireguardMonitor {
config: &Config,
log_path: Option<&Path>,
resource_dir: &Path,
- tun_provider: &mut TunProvider,
- route_manager: &mut routing::RouteManager,
+ tun_provider: Arc<Mutex<TunProvider>>,
#[cfg(windows)] setup_done_tx: mpsc::Sender<std::result::Result<(), BoxedError>>,
) -> Result<Box<dyn Tunnel>> {
#[cfg(target_os = "linux")]
@@ -384,8 +388,6 @@ impl WireguardMonitor {
#[cfg(not(windows))]
Self::get_tunnel_destinations(config),
#[cfg(windows)]
- route_manager,
- #[cfg(windows)]
setup_done_tx,
)
.map_err(Error::TunnelError)?,
@@ -393,7 +395,7 @@ impl WireguardMonitor {
}
/// Returns a close handle for the tunnel
- pub fn close_handle(&self) -> CloseHandle {
+ fn close_handle(&self) -> CloseHandle {
CloseHandle {
chan: self.close_msg_sender.clone(),
}
diff --git a/talpid-core/src/tunnel/wireguard/wireguard_go.rs b/talpid-core/src/tunnel/wireguard/wireguard_go.rs
index e53b0cfbb3..28666b6506 100644
--- a/talpid-core/src/tunnel/wireguard/wireguard_go.rs
+++ b/talpid-core/src/tunnel/wireguard/wireguard_go.rs
@@ -2,8 +2,6 @@ use super::{
stats::{Stats, StatsMap},
Config, Tunnel, TunnelError,
};
-#[cfg(windows)]
-use crate::routing;
#[cfg(not(windows))]
use crate::tunnel::tun_provider::TunProvider;
use crate::tunnel::wireguard::logging::{
@@ -43,6 +41,9 @@ type Result<T> = std::result::Result<T, TunnelError>;
use crate::winnet;
#[cfg(not(target_os = "windows"))]
+use std::sync::{Arc, Mutex};
+
+#[cfg(not(target_os = "windows"))]
const MAX_PREPARE_TUN_ATTEMPTS: usize = 4;
struct LoggingContext(u32);
@@ -73,7 +74,7 @@ impl WgGoTunnel {
pub fn start_tunnel(
config: &Config,
log_path: Option<&Path>,
- tun_provider: &mut TunProvider,
+ tun_provider: Arc<Mutex<TunProvider>>,
routes: impl Iterator<Item = IpNetwork>,
) -> Result<Self> {
#[cfg_attr(not(target_os = "android"), allow(unused_mut))]
@@ -114,12 +115,13 @@ impl WgGoTunnel {
pub fn start_tunnel(
config: &Config,
log_path: Option<&Path>,
- route_manager: &mut routing::RouteManager,
mut done_tx: futures::channel::mpsc::Sender<std::result::Result<(), BoxedError>>,
) -> Result<Self> {
- let route_callback_handle = route_manager
- .add_default_route_callback(Some(WgGoTunnel::default_route_changed_callback), ())
- .ok();
+ let route_callback_handle = winnet::add_default_route_change_callback(
+ Some(WgGoTunnel::default_route_changed_callback),
+ (),
+ )
+ .ok();
if route_callback_handle.is_none() {
log::warn!("Failed to register default route callback");
}
@@ -275,13 +277,15 @@ impl WgGoTunnel {
#[cfg(not(target_os = "windows"))]
fn get_tunnel(
- tun_provider: &mut TunProvider,
+ tun_provider: Arc<Mutex<TunProvider>>,
config: &Config,
routes: impl Iterator<Item = IpNetwork>,
) -> Result<(Tun, RawFd)> {
let mut last_error = None;
let tunnel_config = Self::create_tunnel_config(config, routes);
+ let mut tun_provider = tun_provider.lock().unwrap();
+
for _ in 1..=MAX_PREPARE_TUN_ATTEMPTS {
let tunnel_device = tun_provider
.get_tun(tunnel_config.clone())
diff --git a/talpid-core/src/tunnel_state_machine/connected_state.rs b/talpid-core/src/tunnel_state_machine/connected_state.rs
index 546f9e92ab..52c410fdf6 100644
--- a/talpid-core/src/tunnel_state_machine/connected_state.rs
+++ b/talpid-core/src/tunnel_state_machine/connected_state.rs
@@ -5,7 +5,7 @@ use super::{
};
use crate::{
firewall::FirewallPolicy,
- tunnel::{CloseHandle, TunnelEvent, TunnelMetadata},
+ tunnel::{TunnelEvent, TunnelMetadata},
};
use cfg_if::cfg_if;
use futures::{
@@ -33,7 +33,7 @@ pub struct ConnectedStateBootstrap {
pub tunnel_events: TunnelEventsReceiver,
pub tunnel_parameters: TunnelParameters,
pub tunnel_close_event: TunnelCloseEvent,
- pub close_handle: Option<CloseHandle>,
+ pub tunnel_close_tx: oneshot::Sender<()>,
}
/// The tunnel is up and working.
@@ -42,7 +42,7 @@ pub struct ConnectedState {
tunnel_events: TunnelEventsReceiver,
tunnel_parameters: TunnelParameters,
tunnel_close_event: TunnelCloseEvent,
- close_handle: Option<CloseHandle>,
+ tunnel_close_tx: oneshot::Sender<()>,
}
impl ConnectedState {
@@ -52,7 +52,7 @@ impl ConnectedState {
tunnel_events: bootstrap.tunnel_events,
tunnel_parameters: bootstrap.tunnel_parameters,
tunnel_close_event: bootstrap.tunnel_close_event,
- close_handle: bootstrap.close_handle,
+ tunnel_close_tx: bootstrap.tunnel_close_tx,
}
}
@@ -173,7 +173,11 @@ impl ConnectedState {
EventConsequence::NewState(DisconnectingState::enter(
shared_values,
- (self.close_handle, self.tunnel_close_event, after_disconnect),
+ (
+ self.tunnel_close_tx,
+ self.tunnel_close_event,
+ after_disconnect,
+ ),
))
}
@@ -328,7 +332,7 @@ impl TunnelState for ConnectedState {
DisconnectingState::enter(
shared_values,
(
- connected_state.close_handle,
+ connected_state.tunnel_close_tx,
connected_state.tunnel_close_event,
AfterDisconnect::Block(ErrorStateCause::SetFirewallPolicyError(error)),
),
@@ -338,7 +342,7 @@ impl TunnelState for ConnectedState {
DisconnectingState::enter(
shared_values,
(
- connected_state.close_handle,
+ connected_state.tunnel_close_tx,
connected_state.tunnel_close_event,
AfterDisconnect::Block(ErrorStateCause::SetDnsError),
),
diff --git a/talpid-core/src/tunnel_state_machine/connecting_state.rs b/talpid-core/src/tunnel_state_machine/connecting_state.rs
index 2ae1924988..3c3be4c7f1 100644
--- a/talpid-core/src/tunnel_state_machine/connecting_state.rs
+++ b/talpid-core/src/tunnel_state_machine/connecting_state.rs
@@ -6,9 +6,7 @@ use super::{
use crate::{
firewall::FirewallPolicy,
routing::RouteManager,
- tunnel::{
- self, tun_provider::TunProvider, CloseHandle, TunnelEvent, TunnelMetadata, TunnelMonitor,
- },
+ tunnel::{self, tun_provider::TunProvider, TunnelEvent, TunnelMetadata, TunnelMonitor},
};
use cfg_if::cfg_if;
use futures::{
@@ -18,6 +16,7 @@ use futures::{
};
use std::{
path::{Path, PathBuf},
+ sync::{Arc, Mutex},
thread,
time::{Duration, Instant},
};
@@ -49,7 +48,7 @@ pub struct ConnectingState {
tunnel_parameters: TunnelParameters,
tunnel_metadata: Option<TunnelMetadata>,
tunnel_close_event: TunnelCloseEvent,
- close_handle: Option<CloseHandle>,
+ tunnel_close_tx: oneshot::Sender<()>,
retry_attempt: u32,
}
@@ -95,10 +94,10 @@ impl ConnectingState {
parameters: TunnelParameters,
log_dir: &Option<PathBuf>,
resource_dir: &Path,
- tun_provider: &mut TunProvider,
+ tun_provider: Arc<Mutex<TunProvider>>,
route_manager: &mut RouteManager,
retry_attempt: u32,
- ) -> crate::tunnel::Result<Self> {
+ ) -> Self {
let (event_tx, event_rx) = mpsc::unbounded();
let on_tunnel_event =
move |event| -> std::pin::Pin<Box<dyn std::future::Future<Output = ()> + Send>> {
@@ -109,45 +108,86 @@ impl ConnectingState {
})
};
- let monitor = TunnelMonitor::start(
- runtime,
- &parameters,
- log_dir,
- resource_dir,
- on_tunnel_event,
- tun_provider,
- route_manager,
- retry_attempt,
- )?;
- let close_handle = Some(monitor.close_handle());
- let tunnel_close_event =
- Self::spawn_tunnel_monitor_wait_thread(Some(monitor), retry_attempt);
-
- Ok(ConnectingState {
- tunnel_events: event_rx.fuse(),
- tunnel_parameters: parameters,
- tunnel_metadata: None,
- tunnel_close_event,
- close_handle,
- retry_attempt,
- })
- }
+ let route_manager_handle = route_manager.handle();
+ let log_dir = log_dir.clone();
+ let resource_dir = resource_dir.to_path_buf();
- fn spawn_tunnel_monitor_wait_thread(
- tunnel_monitor: Option<TunnelMonitor>,
- retry_attempt: u32,
- ) -> TunnelCloseEvent {
+ let (tunnel_close_tx, tunnel_close_rx) = oneshot::channel();
let (tunnel_close_event_tx, tunnel_close_event_rx) = oneshot::channel();
- thread::spawn(move || {
+ let tunnel_parameters = parameters.clone();
+
+ tokio::task::spawn_blocking(move || {
let start = Instant::now();
- let block_reason = if let Some(monitor) = tunnel_monitor {
- let reason = Self::wait_for_tunnel_monitor(monitor, retry_attempt);
- log::debug!("Tunnel monitor exited with block reason: {:?}", reason);
- reason
- } else {
- None
+ let route_manager_handle = match route_manager_handle {
+ Ok(handle) => handle,
+ Err(error) => {
+ if tunnel_close_event_tx
+ .send(Some(ErrorStateCause::StartTunnelError))
+ .is_err()
+ {
+ log::warn!(
+ "Tunnel state machine stopped before receiving tunnel closed event"
+ );
+ }
+ log::error!(
+ "{}",
+ error.display_chain_with_msg("Failed to obtain route monitor handle")
+ );
+ return;
+ }
+ };
+
+ let block_reason = match TunnelMonitor::start(
+ runtime,
+ &tunnel_parameters,
+ &log_dir,
+ &resource_dir,
+ on_tunnel_event,
+ tun_provider,
+ route_manager_handle,
+ retry_attempt,
+ tunnel_close_rx,
+ ) {
+ Ok(monitor) => {
+ let reason = Self::wait_for_tunnel_monitor(monitor, retry_attempt);
+ log::debug!("Tunnel monitor exited with block reason: {:?}", reason);
+ reason
+ }
+ Err(error) if should_retry(&error, retry_attempt) => {
+ log::warn!(
+ "{}",
+ error.display_chain_with_msg(
+ "Retrying to connect after failing to start tunnel"
+ )
+ );
+ None
+ }
+ Err(error) => {
+ log::error!("{}", error.display_chain_with_msg("Failed to start tunnel"));
+ let block_reason = match error {
+ tunnel::Error::EnableIpv6Error => ErrorStateCause::Ipv6Unavailable,
+ #[cfg(target_os = "android")]
+ tunnel::Error::WireguardTunnelMonitoringError(
+ tunnel::wireguard::Error::TunnelError(
+ tunnel::wireguard::TunnelError::SetupTunnelDeviceError(
+ tun_provider::Error::PermissionDenied,
+ ),
+ ),
+ ) => ErrorStateCause::VpnPermissionDenied,
+ #[cfg(target_os = "android")]
+ tunnel::Error::WireguardTunnelMonitoringError(
+ tunnel::wireguard::Error::TunnelError(
+ tunnel::wireguard::TunnelError::SetupTunnelDeviceError(
+ tun_provider::Error::InvalidDnsServers(addresses),
+ ),
+ ),
+ ) => ErrorStateCause::InvalidDnsServers(addresses),
+ _ => ErrorStateCause::StartTunnelError,
+ };
+ Some(block_reason)
+ }
};
if block_reason.is_none() {
@@ -163,7 +203,14 @@ impl ConnectingState {
log::trace!("Tunnel monitor thread exit");
});
- tunnel_close_event_rx.fuse()
+ ConnectingState {
+ tunnel_events: event_rx.fuse(),
+ tunnel_parameters: parameters,
+ tunnel_metadata: None,
+ tunnel_close_event: tunnel_close_event_rx.fuse(),
+ tunnel_close_tx,
+ retry_attempt,
+ }
}
fn wait_for_tunnel_monitor(
@@ -205,7 +252,7 @@ impl ConnectingState {
tunnel_events: self.tunnel_events,
tunnel_parameters: self.tunnel_parameters,
tunnel_close_event: self.tunnel_close_event,
- close_handle: self.close_handle,
+ tunnel_close_tx: self.tunnel_close_tx,
}
}
@@ -234,7 +281,11 @@ impl ConnectingState {
EventConsequence::NewState(DisconnectingState::enter(
shared_values,
- (self.close_handle, self.tunnel_close_event, after_disconnect),
+ (
+ self.tunnel_close_tx,
+ self.tunnel_close_event,
+ after_disconnect,
+ ),
))
}
@@ -512,7 +563,9 @@ impl TunnelState for ConnectingState {
#[cfg(target_os = "android")]
{
if retry_attempt > 0 && retry_attempt % MAX_ATTEMPTS_WITH_SAME_TUN == 0 {
- if let Err(error) = shared_values.tun_provider.create_tun() {
+ if let Err(error) =
+ { shared_values.tun_provider.lock().unwrap().create_tun() }
+ {
log::error!(
"{}",
error.display_chain_with_msg("Failed to recreate tun device")
@@ -521,69 +574,20 @@ impl TunnelState for ConnectingState {
}
}
- match Self::start_tunnel(
+ let connecting_state = Self::start_tunnel(
shared_values.runtime.clone(),
tunnel_parameters,
&shared_values.log_dir,
&shared_values.resource_dir,
- &mut shared_values.tun_provider,
+ shared_values.tun_provider.clone(),
&mut shared_values.route_manager,
retry_attempt,
- ) {
- Ok(connecting_state) => {
- let params = connecting_state.tunnel_parameters.clone();
- (
- TunnelStateWrapper::from(connecting_state),
- TunnelStateTransition::Connecting(params.get_tunnel_endpoint()),
- )
- }
- Err(error) => {
- if should_retry(&error, retry_attempt) {
- log::warn!(
- "{}",
- error.display_chain_with_msg(
- "Retrying to connect after failing to start tunnel"
- )
- );
- DisconnectingState::enter(
- shared_values,
- (
- None,
- Self::spawn_tunnel_monitor_wait_thread(None, retry_attempt),
- AfterDisconnect::Reconnect(retry_attempt + 1),
- ),
- )
- } else {
- log::error!(
- "{}",
- error.display_chain_with_msg("Failed to start tunnel")
- );
- let block_reason = match error {
- tunnel::Error::EnableIpv6Error => {
- ErrorStateCause::Ipv6Unavailable
- }
- #[cfg(target_os = "android")]
- tunnel::Error::WireguardTunnelMonitoringError(
- tunnel::wireguard::Error::TunnelError(
- tunnel::wireguard::TunnelError::SetupTunnelDeviceError(
- tun_provider::Error::PermissionDenied,
- ),
- ),
- ) => ErrorStateCause::VpnPermissionDenied,
- #[cfg(target_os = "android")]
- tunnel::Error::WireguardTunnelMonitoringError(
- tunnel::wireguard::Error::TunnelError(
- tunnel::wireguard::TunnelError::SetupTunnelDeviceError(
- tun_provider::Error::InvalidDnsServers(addresses),
- ),
- ),
- ) => ErrorStateCause::InvalidDnsServers(addresses),
- _ => ErrorStateCause::StartTunnelError,
- };
- ErrorState::enter(shared_values, block_reason)
- }
- }
- }
+ );
+ let params = connecting_state.tunnel_parameters.clone();
+ (
+ TunnelStateWrapper::from(connecting_state),
+ TunnelStateTransition::Connecting(params.get_tunnel_endpoint()),
+ )
}
}
}
diff --git a/talpid-core/src/tunnel_state_machine/disconnected_state.rs b/talpid-core/src/tunnel_state_machine/disconnected_state.rs
index 3682accc0b..6d0af09aee 100644
--- a/talpid-core/src/tunnel_state_machine/disconnected_state.rs
+++ b/talpid-core/src/tunnel_state_machine/disconnected_state.rs
@@ -120,7 +120,7 @@ impl TunnelState for DisconnectedState {
#[cfg(target_os = "linux")]
shared_values.reset_connectivity_check();
#[cfg(target_os = "android")]
- shared_values.tun_provider.close_tun();
+ shared_values.tun_provider.lock().unwrap().close_tun();
(
TunnelStateWrapper::from(DisconnectedState),
diff --git a/talpid-core/src/tunnel_state_machine/disconnecting_state.rs b/talpid-core/src/tunnel_state_machine/disconnecting_state.rs
index 8f6f6ae68b..2d3444f44a 100644
--- a/talpid-core/src/tunnel_state_machine/disconnecting_state.rs
+++ b/talpid-core/src/tunnel_state_machine/disconnecting_state.rs
@@ -3,13 +3,8 @@ use super::{
EventConsequence, EventResult, SharedTunnelStateValues, TunnelCommand, TunnelCommandReceiver,
TunnelState, TunnelStateTransition, TunnelStateWrapper,
};
-use crate::tunnel::CloseHandle;
-use futures::{future::FusedFuture, StreamExt};
-use std::thread;
-use talpid_types::{
- tunnel::{ActionAfterDisconnect, ErrorStateCause},
- ErrorExt,
-};
+use futures::{channel::oneshot, future::FusedFuture, StreamExt};
+use talpid_types::tunnel::{ActionAfterDisconnect, ErrorStateCause};
/// This state is active from when we manually trigger a tunnel kill until the tunnel wait
/// operation (TunnelExit) returned.
@@ -175,23 +170,13 @@ impl DisconnectingState {
}
impl TunnelState for DisconnectingState {
- type Bootstrap = (Option<CloseHandle>, TunnelCloseEvent, AfterDisconnect);
+ type Bootstrap = (oneshot::Sender<()>, TunnelCloseEvent, AfterDisconnect);
fn enter(
_: &mut SharedTunnelStateValues,
- (close_handle, tunnel_close_event, after_disconnect): Self::Bootstrap,
+ (tunnel_close_tx, tunnel_close_event, after_disconnect): Self::Bootstrap,
) -> (TunnelStateWrapper, TunnelStateTransition) {
- if let Some(close_handle) = close_handle {
- thread::spawn(move || {
- if let Err(error) = close_handle.close() {
- log::error!(
- "{}",
- error.display_chain_with_msg("Failed to close the tunnel")
- );
- }
- });
- }
-
+ let _ = tunnel_close_tx.send(());
let action_after_disconnect = after_disconnect.action();
(
diff --git a/talpid-core/src/tunnel_state_machine/error_state.rs b/talpid-core/src/tunnel_state_machine/error_state.rs
index a501b21f92..5464acac12 100644
--- a/talpid-core/src/tunnel_state_machine/error_state.rs
+++ b/talpid-core/src/tunnel_state_machine/error_state.rs
@@ -51,7 +51,12 @@ impl ErrorState {
/// Returns true if a new tunnel device was successfully created.
#[cfg(target_os = "android")]
fn create_blocking_tun(shared_values: &mut SharedTunnelStateValues) -> bool {
- match shared_values.tun_provider.create_blocking_tun() {
+ match shared_values
+ .tun_provider
+ .lock()
+ .unwrap()
+ .create_blocking_tun()
+ {
Ok(()) => true,
Err(error) => {
log::error!(
diff --git a/talpid-core/src/tunnel_state_machine/mod.rs b/talpid-core/src/tunnel_state_machine/mod.rs
index e22975a827..8c4446f3f3 100644
--- a/talpid-core/src/tunnel_state_machine/mod.rs
+++ b/talpid-core/src/tunnel_state_machine/mod.rs
@@ -30,7 +30,13 @@ use futures::{
};
#[cfg(target_os = "android")]
use std::os::unix::io::RawFd;
-use std::{collections::HashSet, io, net::IpAddr, path::PathBuf, sync::Arc};
+use std::{
+ collections::HashSet,
+ io,
+ net::IpAddr,
+ path::PathBuf,
+ sync::{Arc, Mutex},
+};
#[cfg(target_os = "android")]
use talpid_types::{android::AndroidContext, ErrorExt};
use talpid_types::{
@@ -294,7 +300,7 @@ impl TunnelStateMachine {
dns_servers: settings.dns_servers,
allowed_endpoint: settings.allowed_endpoint,
tunnel_parameters_generator: Box::new(tunnel_parameters_generator),
- tun_provider,
+ tun_provider: Arc::new(Mutex::new(tun_provider)),
log_dir,
resource_dir,
#[cfg(target_os = "linux")]
@@ -383,7 +389,7 @@ struct SharedTunnelStateValues {
/// The generator of new `TunnelParameter`s
tunnel_parameters_generator: Box<dyn TunnelParametersGenerator>,
/// The provider of tunnel devices.
- tun_provider: TunProvider,
+ tun_provider: Arc<Mutex<TunProvider>>,
/// Directory to store tunnel log file.
log_dir: Option<PathBuf>,
/// Resource directory path.
@@ -405,7 +411,7 @@ impl SharedTunnelStateValues {
#[cfg(target_os = "android")]
{
- if let Err(error) = self.tun_provider.set_allow_lan(allow_lan) {
+ if let Err(error) = self.tun_provider.lock().unwrap().set_allow_lan(allow_lan) {
log::error!(
"{}",
error.display_chain_with_msg(&format!(
@@ -425,6 +431,8 @@ impl SharedTunnelStateValues {
if self.allowed_endpoint != endpoint {
#[cfg(target_os = "android")]
self.tun_provider
+ .lock()
+ .unwrap()
.set_allowed_endpoint(endpoint.endpoint.address.ip());
self.allowed_endpoint = endpoint;
@@ -444,7 +452,12 @@ impl SharedTunnelStateValues {
#[cfg(target_os = "android")]
{
- if let Err(error) = self.tun_provider.set_dns_servers(dns_servers) {
+ if let Err(error) = self
+ .tun_provider
+ .lock()
+ .unwrap()
+ .set_dns_servers(dns_servers)
+ {
log::error!(
"{}",
error.display_chain_with_msg(
@@ -489,7 +502,7 @@ impl SharedTunnelStateValues {
#[cfg(target_os = "android")]
pub fn bypass_socket(&mut self, fd: RawFd, tx: oneshot::Sender<()>) {
- if let Err(err) = self.tun_provider.bypass(fd) {
+ if let Err(err) = self.tun_provider.lock().unwrap().bypass(fd) {
log::error!("Failed to bypass socket {}", err);
}
let _ = tx.send(());