diff options
| author | David Lönnhager <david.l@mullvad.net> | 2020-10-05 11:24:33 +0200 |
|---|---|---|
| committer | David Lönnhager <david.l@mullvad.net> | 2020-10-05 15:27:47 +0200 |
| commit | 3d584c8054cb185988b8a211156eb270d09a18fa (patch) | |
| tree | a00f4b35c340bbdcbabedfe11dbc7ea9fb89b22c | |
| parent | 68681c2e9a897b4d09353d6665ee7ece7a34236c (diff) | |
| download | mullvadvpn-3d584c8054cb185988b8a211156eb270d09a18fa.tar.xz mullvadvpn-3d584c8054cb185988b8a211156eb270d09a18fa.zip | |
Use tokio runtime for tunnel state machine
| -rw-r--r-- | Cargo.lock | 1 | ||||
| -rw-r--r-- | talpid-core/Cargo.toml | 1 | ||||
| -rw-r--r-- | talpid-core/src/dns/linux/network_manager.rs | 2 | ||||
| -rw-r--r-- | talpid-core/src/dns/linux/systemd_resolved.rs | 2 | ||||
| -rw-r--r-- | talpid-core/src/tunnel_state_machine/connected_state.rs | 26 | ||||
| -rw-r--r-- | talpid-core/src/tunnel_state_machine/connecting_state.rs | 28 | ||||
| -rw-r--r-- | talpid-core/src/tunnel_state_machine/disconnected_state.rs | 8 | ||||
| -rw-r--r-- | talpid-core/src/tunnel_state_machine/disconnecting_state.rs | 26 | ||||
| -rw-r--r-- | talpid-core/src/tunnel_state_machine/error_state.rs | 8 | ||||
| -rw-r--r-- | talpid-core/src/tunnel_state_machine/mod.rs | 32 |
10 files changed, 76 insertions, 58 deletions
diff --git a/Cargo.lock b/Cargo.lock index 8c347f3e8a..cefe826bf8 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -2301,7 +2301,6 @@ dependencies = [ name = "talpid-core" version = "0.1.0" dependencies = [ - "async-trait 0.1.40 (registry+https://github.com/rust-lang/crates.io-index)", "atty 0.2.14 (registry+https://github.com/rust-lang/crates.io-index)", "byteorder 1.3.4 (registry+https://github.com/rust-lang/crates.io-index)", "cfg-if 0.1.10 (registry+https://github.com/rust-lang/crates.io-index)", diff --git a/talpid-core/Cargo.toml b/talpid-core/Cargo.toml index 64a767a65e..60acfbe9d0 100644 --- a/talpid-core/Cargo.toml +++ b/talpid-core/Cargo.toml @@ -28,7 +28,6 @@ zeroize = "1" chrono = "0.4" tokio = { version = "0.2", features = [ "process", "rt-threaded", "stream" ] } rand = "0.7" -async-trait = "0.1.4" [target.'cfg(not(target_os="android"))'.dependencies] diff --git a/talpid-core/src/dns/linux/network_manager.rs b/talpid-core/src/dns/linux/network_manager.rs index d17353d6ac..f2b2f6d7fb 100644 --- a/talpid-core/src/dns/linux/network_manager.rs +++ b/talpid-core/src/dns/linux/network_manager.rs @@ -78,8 +78,6 @@ pub struct NetworkManager { settings_backup: Option<HashMap<String, HashMap<String, Variant<Box<dyn RefArg>>>>>, } -unsafe impl Send for NetworkManager {} - impl NetworkManager { pub fn new() -> Result<Self> { diff --git a/talpid-core/src/dns/linux/systemd_resolved.rs b/talpid-core/src/dns/linux/systemd_resolved.rs index 808ade1334..8039b9b112 100644 --- a/talpid-core/src/dns/linux/systemd_resolved.rs +++ b/talpid-core/src/dns/linux/systemd_resolved.rs @@ -83,8 +83,6 @@ pub struct SystemdResolved { interface_link: Option<(String, dbus::Path<'static>)>, } -unsafe impl Send for SystemdResolved {} - impl SystemdResolved { pub fn new() -> Result<Self> { let result = (|| { diff --git a/talpid-core/src/tunnel_state_machine/connected_state.rs b/talpid-core/src/tunnel_state_machine/connected_state.rs index 335e919c68..a36093d51c 100644 --- a/talpid-core/src/tunnel_state_machine/connected_state.rs +++ b/talpid-core/src/tunnel_state_machine/connected_state.rs @@ -1,6 +1,6 @@ use super::{ AfterDisconnect, ConnectingState, DisconnectingState, ErrorState, EventConsequence, - SharedTunnelStateValues, TunnelCommand, TunnelCommandReceiver, TunnelState, + EventResult, SharedTunnelStateValues, TunnelCommand, TunnelCommandReceiver, TunnelState, TunnelStateTransition, TunnelStateWrapper, }; use crate::{ @@ -221,7 +221,6 @@ impl ConnectedState { } } -#[async_trait::async_trait] impl TunnelState for ConnectedState { type Bootstrap = ConnectedStateBootstrap; @@ -260,25 +259,30 @@ impl TunnelState for ConnectedState { } } - async fn handle_event( + fn handle_event( mut self, + runtime: &tokio::runtime::Handle, commands: &mut TunnelCommandReceiver, shared_values: &mut SharedTunnelStateValues, ) -> EventConsequence { - return futures::select! { - command = commands.next() => { - self.handle_commands(command, shared_values) + let result = runtime.block_on(async { + futures::select! { + command = commands.next() => EventResult::Command(command), + event = self.tunnel_events.next() => EventResult::Event(event), + result = &mut self.tunnel_close_event => EventResult::Close(result), } - event = self.tunnel_events.next() => { - self.handle_tunnel_events(event, shared_values) - } - result = &mut self.tunnel_close_event => { + }); + + match result { + EventResult::Command(command) => self.handle_commands(command, shared_values), + EventResult::Event(event) => self.handle_tunnel_events(event, shared_values), + EventResult::Close(result) => { if result.is_err() { log::warn!("Tunnel monitor thread has stopped unexpectedly"); } let block_reason = result.unwrap_or(None); self.handle_tunnel_close_event(block_reason, shared_values) } - }; + } } } diff --git a/talpid-core/src/tunnel_state_machine/connecting_state.rs b/talpid-core/src/tunnel_state_machine/connecting_state.rs index 816dd51665..bfccac6572 100644 --- a/talpid-core/src/tunnel_state_machine/connecting_state.rs +++ b/talpid-core/src/tunnel_state_machine/connecting_state.rs @@ -1,7 +1,7 @@ use super::{ AfterDisconnect, ConnectedState, ConnectedStateBootstrap, DisconnectingState, ErrorState, - EventConsequence, SharedTunnelStateValues, TunnelCommand, TunnelCommandReceiver, TunnelState, - TunnelStateTransition, TunnelStateWrapper, + EventConsequence, EventResult, SharedTunnelStateValues, TunnelCommand, TunnelCommandReceiver, + TunnelState, TunnelStateTransition, TunnelStateWrapper, }; use crate::{ firewall::FirewallPolicy, @@ -322,7 +322,6 @@ fn should_retry(error: &tunnel::Error) -> bool { } } -#[async_trait::async_trait] impl TunnelState for ConnectingState { type Bootstrap = u32; @@ -435,26 +434,31 @@ impl TunnelState for ConnectingState { } } - async fn handle_event( + fn handle_event( mut self, + runtime: &tokio::runtime::Handle, commands: &mut TunnelCommandReceiver, shared_values: &mut SharedTunnelStateValues, ) -> EventConsequence { - return futures::select! { - command = commands.next() => { - self.handle_commands(command, shared_values) + let result = runtime.block_on(async { + futures::select! { + command = commands.next() => EventResult::Command(command), + event = self.tunnel_events.next() => EventResult::Event(event), + result = &mut self.tunnel_close_event => EventResult::Close(result), } - event = self.tunnel_events.next() => { - self.handle_tunnel_events(event, shared_values) - } - result = &mut self.tunnel_close_event => { + }); + + match result { + EventResult::Command(command) => self.handle_commands(command, shared_values), + EventResult::Event(event) => self.handle_tunnel_events(event, shared_values), + EventResult::Close(result) => { if result.is_err() { log::warn!("Tunnel monitor thread has stopped unexpectedly"); } let block_reason = result.unwrap_or(None); self.handle_tunnel_close_event(block_reason, shared_values) } - }; + } } } diff --git a/talpid-core/src/tunnel_state_machine/disconnected_state.rs b/talpid-core/src/tunnel_state_machine/disconnected_state.rs index 36199012b8..faeac0a45f 100644 --- a/talpid-core/src/tunnel_state_machine/disconnected_state.rs +++ b/talpid-core/src/tunnel_state_machine/disconnected_state.rs @@ -37,7 +37,6 @@ impl DisconnectedState { } } -#[async_trait::async_trait] impl TunnelState for DisconnectedState { type Bootstrap = bool; @@ -62,14 +61,15 @@ impl TunnelState for DisconnectedState { ) } - async fn handle_event( - mut self, + fn handle_event( + self, + runtime: &tokio::runtime::Handle, commands: &mut TunnelCommandReceiver, shared_values: &mut SharedTunnelStateValues, ) -> EventConsequence { use self::EventConsequence::*; - match commands.next().await { + match runtime.block_on(commands.next()) { Some(TunnelCommand::AllowLan(allow_lan)) => { if shared_values.allow_lan != allow_lan { // The only platform that can fail is Android, but Android doesn't support the diff --git a/talpid-core/src/tunnel_state_machine/disconnecting_state.rs b/talpid-core/src/tunnel_state_machine/disconnecting_state.rs index 68d567ed42..33a09ca31a 100644 --- a/talpid-core/src/tunnel_state_machine/disconnecting_state.rs +++ b/talpid-core/src/tunnel_state_machine/disconnecting_state.rs @@ -1,7 +1,7 @@ use super::{ connecting_state::TunnelCloseEvent, ConnectingState, DisconnectedState, ErrorState, - EventConsequence, SharedTunnelStateValues, TunnelCommand, TunnelCommandReceiver, TunnelState, - TunnelStateTransition, TunnelStateWrapper, + EventConsequence, EventResult, SharedTunnelStateValues, TunnelCommand, TunnelCommandReceiver, + TunnelState, TunnelStateTransition, TunnelStateWrapper, }; use crate::tunnel::CloseHandle; use futures::{future::FusedFuture, StreamExt}; @@ -111,7 +111,6 @@ impl DisconnectingState { } } -#[async_trait::async_trait] impl TunnelState for DisconnectingState { type Bootstrap = (Option<CloseHandle>, TunnelCloseEvent, AfterDisconnect); @@ -141,8 +140,9 @@ impl TunnelState for DisconnectingState { ) } - async fn handle_event( + fn handle_event( mut self, + runtime: &tokio::runtime::Handle, commands: &mut TunnelCommandReceiver, shared_values: &mut SharedTunnelStateValues, ) -> EventConsequence { @@ -152,15 +152,21 @@ impl TunnelState for DisconnectingState { return NewState(self.after_disconnect(None, shared_values)); } - return futures::select! { - command = commands.next() => { - self.handle_commands(command, shared_values) + let result = runtime.block_on(async { + futures::select! { + command = commands.next() => EventResult::Command(command), + result = &mut self.tunnel_close_event => EventResult::Close(result), } - block_reason = &mut self.tunnel_close_event => { - let block_reason = block_reason.unwrap_or(None); + }); + + match result { + EventResult::Command(command) => self.handle_commands(command, shared_values), + EventResult::Close(result) => { + let block_reason = result.unwrap_or(None); NewState(self.after_disconnect(block_reason, shared_values)) } - }; + _ => unreachable!("unexpected event result"), + } } } diff --git a/talpid-core/src/tunnel_state_machine/error_state.rs b/talpid-core/src/tunnel_state_machine/error_state.rs index b32871466b..a9861e788e 100644 --- a/talpid-core/src/tunnel_state_machine/error_state.rs +++ b/talpid-core/src/tunnel_state_machine/error_state.rs @@ -59,7 +59,6 @@ impl ErrorState { } } -#[async_trait::async_trait] impl TunnelState for ErrorState { type Bootstrap = ErrorStateCause; @@ -86,14 +85,15 @@ impl TunnelState for ErrorState { ) } - async fn handle_event( - mut self, + fn handle_event( + self, + runtime: &tokio::runtime::Handle, commands: &mut TunnelCommandReceiver, shared_values: &mut SharedTunnelStateValues, ) -> EventConsequence { use self::EventConsequence::*; - match commands.next().await { + match runtime.block_on(commands.next()) { Some(TunnelCommand::AllowLan(allow_lan)) => { if let Err(error_state_cause) = shared_values.set_allow_lan(allow_lan) { NewState(Self::enter(shared_values, error_state_cause)) diff --git a/talpid-core/src/tunnel_state_machine/mod.rs b/talpid-core/src/tunnel_state_machine/mod.rs index 51783654bd..37a27ae58f 100644 --- a/talpid-core/src/tunnel_state_machine/mod.rs +++ b/talpid-core/src/tunnel_state_machine/mod.rs @@ -17,7 +17,7 @@ use crate::{ mpsc::Sender, offline, routing::RouteManager, - tunnel::tun_provider::TunProvider, + tunnel::{tun_provider::TunProvider, TunnelEvent}, }; use futures::{ @@ -101,6 +101,8 @@ pub async fn spawn( allow_lan, ); + let runtime = tokio::runtime::Handle::current(); + let (startup_result_tx, startup_result_rx) = sync_mpsc::channel(); std::thread::spawn(move || { let state_machine = TunnelStateMachine::new( @@ -126,8 +128,7 @@ pub async fn spawn( } }; - // TODO: Spawn this on a tokio runtime, and share it with RouteManager, etc. - futures::executor::block_on(state_machine.run(state_change_listener)); + state_machine.run(runtime, state_change_listener); if shutdown_tx.send(()).is_err() { log::error!("Can't send shutdown completion to daemon"); @@ -160,6 +161,12 @@ pub enum TunnelCommand { type TunnelCommandReceiver = stream::Fuse<mpsc::UnboundedReceiver<TunnelCommand>>; +enum EventResult { + Command(Option<TunnelCommand>), + Event(Option<TunnelEvent>), + Close(Result<Option<ErrorStateCause>, oneshot::Canceled>), +} + /// Asynchronous handling of the tunnel state machine. /// /// This type implements `Stream`, and attempts to advance the state machine based on the events @@ -216,13 +223,15 @@ impl TunnelStateMachine { }) } - async fn run(mut self, change_listener: impl Sender<TunnelStateTransition> + Send + 'static) { + fn run( + mut self, + runtime: tokio::runtime::Handle, + change_listener: impl Sender<TunnelStateTransition> + Send + 'static, + ) { use EventConsequence::*; while let Some(state_wrapper) = self.current_state.take() { - match state_wrapper - .handle_event(&mut self.commands, &mut self.shared_values) - .await + match state_wrapper.handle_event(&runtime, &mut self.commands, &mut self.shared_values) { NewState((state, transition)) => { self.current_state = Some(state); @@ -314,7 +323,6 @@ enum EventConsequence { /// Trait that contains the method all states should implement to handle an event and advance the /// state machine. -#[async_trait::async_trait] trait TunnelState: Into<TunnelStateWrapper> + Sized { /// Type representing extra information required for entering the state. type Bootstrap; @@ -338,8 +346,9 @@ trait TunnelState: Into<TunnelStateWrapper> + Sized { /// events received through the provided `commands` stream. /// /// [`EventConsequence`]: enum.EventConsequence.html - async fn handle_event( + fn handle_event( self, + runtime: &tokio::runtime::Handle, commands: &mut TunnelCommandReceiver, shared_values: &mut SharedTunnelStateValues, ) -> EventConsequence; @@ -362,14 +371,15 @@ macro_rules! state_wrapper { })* impl $wrapper_name { - async fn handle_event( + fn handle_event( self, + runtime: &tokio::runtime::Handle, commands: &mut TunnelCommandReceiver, shared_values: &mut SharedTunnelStateValues, ) -> EventConsequence { match self { $($wrapper_name::$state_variant(state) => { - state.handle_event(commands, shared_values).await + state.handle_event(runtime, commands, shared_values) })* } } |
