diff options
| author | David Lönnhager <david.l@mullvad.net> | 2020-10-05 15:51:33 +0200 |
|---|---|---|
| committer | David Lönnhager <david.l@mullvad.net> | 2020-10-05 15:51:33 +0200 |
| commit | 53e167439ea6b263a2a399591a82a22a2d7a6343 (patch) | |
| tree | 44cf8bd0cbc434b55a2d035426dc79cfe06fa3aa | |
| parent | 779e5267e3c03ff511c36518f0323347d675cce3 (diff) | |
| parent | 00097409bf285e08de927d1f934e32208fd4bb5e (diff) | |
| download | mullvadvpn-53e167439ea6b263a2a399591a82a22a2d7a6343.tar.xz mullvadvpn-53e167439ea6b263a2a399591a82a22a2d7a6343.zip | |
Merge branch 'refactor-tunnel-sm'
| -rw-r--r-- | Cargo.lock | 1 | ||||
| -rw-r--r-- | talpid-core/Cargo.toml | 1 | ||||
| -rw-r--r-- | talpid-core/src/routing/linux.rs | 6 | ||||
| -rw-r--r-- | talpid-core/src/tunnel_state_machine/connected_state.rs | 110 | ||||
| -rw-r--r-- | talpid-core/src/tunnel_state_machine/connecting_state.rs | 127 | ||||
| -rw-r--r-- | talpid-core/src/tunnel_state_machine/disconnected_state.rs | 33 | ||||
| -rw-r--r-- | talpid-core/src/tunnel_state_machine/disconnecting_state.rs | 123 | ||||
| -rw-r--r-- | talpid-core/src/tunnel_state_machine/error_state.rs | 31 | ||||
| -rw-r--r-- | talpid-core/src/tunnel_state_machine/macros.rs | 21 | ||||
| -rw-r--r-- | talpid-core/src/tunnel_state_machine/mod.rs | 150 |
10 files changed, 274 insertions, 329 deletions
diff --git a/Cargo.lock b/Cargo.lock index 41b4eba987..cefe826bf8 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -2309,7 +2309,6 @@ dependencies = [ "duct 0.13.4 (registry+https://github.com/rust-lang/crates.io-index)", "err-derive 0.2.4 (registry+https://github.com/rust-lang/crates.io-index)", "failure 0.1.8 (registry+https://github.com/rust-lang/crates.io-index)", - "futures 0.1.29 (registry+https://github.com/rust-lang/crates.io-index)", "futures 0.3.5 (registry+https://github.com/rust-lang/crates.io-index)", "hex 0.4.2 (registry+https://github.com/rust-lang/crates.io-index)", "ipnetwork 0.16.0 (registry+https://github.com/rust-lang/crates.io-index)", diff --git a/talpid-core/Cargo.toml b/talpid-core/Cargo.toml index 1f374152c6..60acfbe9d0 100644 --- a/talpid-core/Cargo.toml +++ b/talpid-core/Cargo.toml @@ -12,7 +12,6 @@ atty = "0.2" cfg-if = "0.1" duct = "0.13" err-derive = "0.2.1" -futures01 = { package = "futures", version = "0.1" } futures = { package = "futures", version = "0.3", features = [ "compat" ]} hex = "0.4" ipnetwork = "0.16" diff --git a/talpid-core/src/routing/linux.rs b/talpid-core/src/routing/linux.rs index 025804d3e1..4516f19127 100644 --- a/talpid-core/src/routing/linux.rs +++ b/talpid-core/src/routing/linux.rs @@ -972,12 +972,6 @@ impl RouteManagerImpl { } } -impl Drop for RouteManagerImpl { - fn drop(&mut self) { - futures::executor::block_on(self.destructor()); - } -} - fn ip_to_bytes(addr: IpAddr) -> Vec<u8> { match addr { IpAddr::V4(addr) => addr.octets().to_vec(), diff --git a/talpid-core/src/tunnel_state_machine/connected_state.rs b/talpid-core/src/tunnel_state_machine/connected_state.rs index 33e1463167..a36093d51c 100644 --- a/talpid-core/src/tunnel_state_machine/connected_state.rs +++ b/talpid-core/src/tunnel_state_machine/connected_state.rs @@ -1,15 +1,13 @@ use super::{ AfterDisconnect, ConnectingState, DisconnectingState, ErrorState, EventConsequence, - SharedTunnelStateValues, TunnelCommand, TunnelState, TunnelStateTransition, TunnelStateWrapper, + EventResult, SharedTunnelStateValues, TunnelCommand, TunnelCommandReceiver, TunnelState, + TunnelStateTransition, TunnelStateWrapper, }; use crate::{ firewall::FirewallPolicy, tunnel::{CloseHandle, TunnelEvent, TunnelMetadata}, }; -use futures01::{ - sync::{mpsc, oneshot}, - Async, Future, Stream, -}; +use futures::{channel::mpsc, stream::Fuse, StreamExt}; use talpid_types::{ net::TunnelParameters, tunnel::{ErrorStateCause, FirewallPolicyError}, @@ -19,21 +17,25 @@ use talpid_types::{ #[cfg(windows)] use crate::tunnel::TunnelMonitor; +use super::connecting_state::TunnelCloseEvent; + +pub(crate) type TunnelEventsReceiver = Fuse<mpsc::UnboundedReceiver<TunnelEvent>>; + pub struct ConnectedStateBootstrap { pub metadata: TunnelMetadata, - pub tunnel_events: mpsc::UnboundedReceiver<TunnelEvent>, + pub tunnel_events: TunnelEventsReceiver, pub tunnel_parameters: TunnelParameters, - pub tunnel_close_event: Option<oneshot::Receiver<Option<ErrorStateCause>>>, + pub tunnel_close_event: TunnelCloseEvent, pub close_handle: Option<CloseHandle>, } /// The tunnel is up and working. pub struct ConnectedState { metadata: TunnelMetadata, - tunnel_events: mpsc::UnboundedReceiver<TunnelEvent>, + tunnel_events: TunnelEventsReceiver, tunnel_parameters: TunnelParameters, - tunnel_close_event: Option<oneshot::Receiver<Option<ErrorStateCause>>>, + tunnel_close_event: TunnelCloseEvent, close_handle: Option<CloseHandle>, } @@ -126,7 +128,7 @@ impl ConnectedState { self, shared_values: &mut SharedTunnelStateValues, after_disconnect: AfterDisconnect, - ) -> EventConsequence<Self> { + ) -> EventConsequence { Self::reset_dns(shared_values); Self::reset_routes(shared_values); @@ -138,18 +140,18 @@ impl ConnectedState { fn handle_commands( self, - commands: &mut mpsc::UnboundedReceiver<TunnelCommand>, + command: Option<TunnelCommand>, shared_values: &mut SharedTunnelStateValues, - ) -> EventConsequence<Self> { + ) -> EventConsequence { use self::EventConsequence::*; - match try_handle_event!(self, commands.poll()) { - Ok(TunnelCommand::AllowLan(allow_lan)) => { + match command { + Some(TunnelCommand::AllowLan(allow_lan)) => { if let Err(error_cause) = shared_values.set_allow_lan(allow_lan) { self.disconnect(shared_values, AfterDisconnect::Block(error_cause)) } else { match self.set_firewall_policy(shared_values) { - Ok(()) => SameState(self), + Ok(()) => SameState(self.into()), Err(error) => self.disconnect( shared_values, AfterDisconnect::Block(ErrorStateCause::SetFirewallPolicyError(error)), @@ -157,11 +159,11 @@ impl ConnectedState { } } } - Ok(TunnelCommand::BlockWhenDisconnected(block_when_disconnected)) => { + Some(TunnelCommand::BlockWhenDisconnected(block_when_disconnected)) => { shared_values.block_when_disconnected = block_when_disconnected; - SameState(self) + SameState(self.into()) } - Ok(TunnelCommand::IsOffline(is_offline)) => { + Some(TunnelCommand::IsOffline(is_offline)) => { shared_values.is_offline = is_offline; if is_offline { self.disconnect( @@ -169,56 +171,47 @@ impl ConnectedState { AfterDisconnect::Block(ErrorStateCause::IsOffline), ) } else { - SameState(self) + SameState(self.into()) } } - Ok(TunnelCommand::Connect) => { + Some(TunnelCommand::Connect) => { self.disconnect(shared_values, AfterDisconnect::Reconnect(0)) } - Ok(TunnelCommand::Disconnect) | Err(_) => { + Some(TunnelCommand::Disconnect) | None => { self.disconnect(shared_values, AfterDisconnect::Nothing) } - Ok(TunnelCommand::Block(reason)) => { + Some(TunnelCommand::Block(reason)) => { self.disconnect(shared_values, AfterDisconnect::Block(reason)) } } } fn handle_tunnel_events( - mut self, + self, + event: Option<TunnelEvent>, shared_values: &mut SharedTunnelStateValues, - ) -> EventConsequence<Self> { + ) -> EventConsequence { use self::EventConsequence::*; - match try_handle_event!(self, self.tunnel_events.poll()) { - Ok(TunnelEvent::Down) | Err(_) => { + match event { + Some(TunnelEvent::Down) | None => { self.disconnect(shared_values, AfterDisconnect::Reconnect(0)) } - Ok(_) => SameState(self), + Some(_) => SameState(self.into()), } } fn handle_tunnel_close_event( - mut self, + self, + block_reason: Option<ErrorStateCause>, shared_values: &mut SharedTunnelStateValues, - ) -> EventConsequence<Self> { + ) -> EventConsequence { use self::EventConsequence::*; - let poll_result = match &mut self.tunnel_close_event { - Some(tunnel_close_event) => tunnel_close_event.poll(), - None => Ok(Async::NotReady), - }; - - match poll_result { - Ok(Async::Ready(block_reason)) => { - if let Some(reason) = block_reason { - Self::reset_dns(shared_values); - Self::reset_routes(shared_values); - return NewState(ErrorState::enter(shared_values, reason)); - } - } - Ok(Async::NotReady) => return NoEvents(self), - Err(_cancelled) => log::warn!("Tunnel monitor thread has stopped unexpectedly"), + if let Some(block_reason) = block_reason { + Self::reset_dns(shared_values); + Self::reset_routes(shared_values); + return NewState(ErrorState::enter(shared_values, block_reason)); } log::info!("Tunnel closed. Reconnecting."); @@ -267,12 +260,29 @@ impl TunnelState for ConnectedState { } fn handle_event( - self, - commands: &mut mpsc::UnboundedReceiver<TunnelCommand>, + mut self, + runtime: &tokio::runtime::Handle, + commands: &mut TunnelCommandReceiver, shared_values: &mut SharedTunnelStateValues, - ) -> EventConsequence<Self> { - self.handle_commands(commands, shared_values) - .or_else(Self::handle_tunnel_events, shared_values) - .or_else(Self::handle_tunnel_close_event, shared_values) + ) -> EventConsequence { + let result = runtime.block_on(async { + futures::select! { + command = commands.next() => EventResult::Command(command), + event = self.tunnel_events.next() => EventResult::Event(event), + result = &mut self.tunnel_close_event => EventResult::Close(result), + } + }); + + match result { + EventResult::Command(command) => self.handle_commands(command, shared_values), + EventResult::Event(event) => self.handle_tunnel_events(event, shared_values), + EventResult::Close(result) => { + if result.is_err() { + log::warn!("Tunnel monitor thread has stopped unexpectedly"); + } + let block_reason = result.unwrap_or(None); + self.handle_tunnel_close_event(block_reason, shared_values) + } + } } } diff --git a/talpid-core/src/tunnel_state_machine/connecting_state.rs b/talpid-core/src/tunnel_state_machine/connecting_state.rs index 859ef52eb0..bfccac6572 100644 --- a/talpid-core/src/tunnel_state_machine/connecting_state.rs +++ b/talpid-core/src/tunnel_state_machine/connecting_state.rs @@ -1,7 +1,7 @@ use super::{ AfterDisconnect, ConnectedState, ConnectedStateBootstrap, DisconnectingState, ErrorState, - EventConsequence, SharedTunnelStateValues, TunnelCommand, TunnelState, TunnelStateTransition, - TunnelStateWrapper, + EventConsequence, EventResult, SharedTunnelStateValues, TunnelCommand, TunnelCommandReceiver, + TunnelState, TunnelStateTransition, TunnelStateWrapper, }; use crate::{ firewall::FirewallPolicy, @@ -10,9 +10,10 @@ use crate::{ self, tun_provider::TunProvider, CloseHandle, TunnelEvent, TunnelMetadata, TunnelMonitor, }, }; -use futures01::{ - sync::{mpsc, oneshot}, - Async, Future, Stream, +use futures::{ + channel::{mpsc, oneshot}, + future::Fuse, + FutureExt, StreamExt, }; use log::{debug, error, info, trace, warn}; use std::{ @@ -30,15 +31,19 @@ use talpid_types::{ #[cfg(target_os = "android")] use crate::tunnel::tun_provider; +use super::connected_state::TunnelEventsReceiver; + +pub(crate) type TunnelCloseEvent = Fuse<oneshot::Receiver<Option<ErrorStateCause>>>; + #[cfg(target_os = "android")] const MAX_ATTEMPTS_WITH_SAME_TUN: u32 = 5; const MIN_TUNNEL_ALIVE_TIME: Duration = Duration::from_millis(1000); /// The tunnel has been started, but it is not established/functional. pub struct ConnectingState { - tunnel_events: mpsc::UnboundedReceiver<TunnelEvent>, + tunnel_events: TunnelEventsReceiver, tunnel_parameters: TunnelParameters, - tunnel_close_event: Option<oneshot::Receiver<Option<ErrorStateCause>>>, + tunnel_close_event: TunnelCloseEvent, close_handle: Option<CloseHandle>, retry_attempt: u32, } @@ -102,7 +107,7 @@ impl ConnectingState { let tunnel_close_event = Self::spawn_tunnel_monitor_wait_thread(monitor); Ok(ConnectingState { - tunnel_events: event_rx, + tunnel_events: event_rx.fuse(), tunnel_parameters: parameters, tunnel_close_event, close_handle, @@ -110,9 +115,7 @@ impl ConnectingState { }) } - fn spawn_tunnel_monitor_wait_thread( - tunnel_monitor: TunnelMonitor, - ) -> Option<oneshot::Receiver<Option<ErrorStateCause>>> { + fn spawn_tunnel_monitor_wait_thread(tunnel_monitor: TunnelMonitor) -> TunnelCloseEvent { let (tunnel_close_event_tx, tunnel_close_event_rx) = oneshot::channel(); thread::spawn(move || { @@ -137,7 +140,7 @@ impl ConnectingState { trace!("Tunnel monitor thread exit"); }); - Some(tunnel_close_event_rx) + tunnel_close_event_rx.fuse() } fn wait_for_tunnel_monitor(tunnel_monitor: TunnelMonitor) -> Option<ErrorStateCause> { @@ -194,7 +197,7 @@ impl ConnectingState { self, shared_values: &mut SharedTunnelStateValues, after_disconnect: AfterDisconnect, - ) -> EventConsequence<Self> { + ) -> EventConsequence { Self::reset_routes(shared_values); EventConsequence::NewState(DisconnectingState::enter( @@ -205,18 +208,18 @@ impl ConnectingState { fn handle_commands( self, - commands: &mut mpsc::UnboundedReceiver<TunnelCommand>, + command: Option<TunnelCommand>, shared_values: &mut SharedTunnelStateValues, - ) -> EventConsequence<Self> { + ) -> EventConsequence { use self::EventConsequence::*; - match try_handle_event!(self, commands.poll()) { - Ok(TunnelCommand::AllowLan(allow_lan)) => { + match command { + Some(TunnelCommand::AllowLan(allow_lan)) => { if let Err(error_cause) = shared_values.set_allow_lan(allow_lan) { self.disconnect(shared_values, AfterDisconnect::Block(error_cause)) } else { match Self::set_firewall_policy(shared_values, &self.tunnel_parameters) { - Ok(()) => SameState(self), + Ok(()) => SameState(self.into()), Err(error) => self.disconnect( shared_values, AfterDisconnect::Block(ErrorStateCause::SetFirewallPolicyError(error)), @@ -224,11 +227,11 @@ impl ConnectingState { } } } - Ok(TunnelCommand::BlockWhenDisconnected(block_when_disconnected)) => { + Some(TunnelCommand::BlockWhenDisconnected(block_when_disconnected)) => { shared_values.block_when_disconnected = block_when_disconnected; - SameState(self) + SameState(self.into()) } - Ok(TunnelCommand::IsOffline(is_offline)) => { + Some(TunnelCommand::IsOffline(is_offline)) => { shared_values.is_offline = is_offline; if is_offline { self.disconnect( @@ -236,39 +239,40 @@ impl ConnectingState { AfterDisconnect::Block(ErrorStateCause::IsOffline), ) } else { - SameState(self) + SameState(self.into()) } } - Ok(TunnelCommand::Connect) => { + Some(TunnelCommand::Connect) => { self.disconnect(shared_values, AfterDisconnect::Reconnect(0)) } - Ok(TunnelCommand::Disconnect) | Err(_) => { + Some(TunnelCommand::Disconnect) | None => { self.disconnect(shared_values, AfterDisconnect::Nothing) } - Ok(TunnelCommand::Block(reason)) => { + Some(TunnelCommand::Block(reason)) => { self.disconnect(shared_values, AfterDisconnect::Block(reason)) } } } - fn handle_tunnel_events( - mut self, + self, + event: Option<tunnel::TunnelEvent>, shared_values: &mut SharedTunnelStateValues, - ) -> EventConsequence<Self> { + ) -> EventConsequence { use self::EventConsequence::*; - match try_handle_event!(self, self.tunnel_events.poll()) { - Ok(TunnelEvent::AuthFailed(reason)) => self.disconnect( + match event { + Some(TunnelEvent::AuthFailed(reason)) => self.disconnect( shared_values, AfterDisconnect::Block(ErrorStateCause::AuthFailed(reason)), ), - Ok(TunnelEvent::Up(metadata)) => NewState(ConnectedState::enter( + Some(TunnelEvent::Up(metadata)) => NewState(ConnectedState::enter( shared_values, self.into_connected_state_bootstrap(metadata), )), - Ok(_) => SameState(self), - Err(_) => { + Some(TunnelEvent::Down) => SameState(self.into()), + None => { + // The channel was closed debug!("The tunnel disconnected unexpectedly"); let retry_attempt = self.retry_attempt + 1; self.disconnect(shared_values, AfterDisconnect::Reconnect(retry_attempt)) @@ -277,23 +281,15 @@ impl ConnectingState { } fn handle_tunnel_close_event( - mut self, + self, + block_reason: Option<ErrorStateCause>, shared_values: &mut SharedTunnelStateValues, - ) -> EventConsequence<Self> { - let poll_result = match &mut self.tunnel_close_event { - Some(tunnel_close_event) => tunnel_close_event.poll(), - None => Ok(Async::NotReady), - }; + ) -> EventConsequence { + use self::EventConsequence::*; - match poll_result { - Ok(Async::Ready(block_reason)) => { - if let Some(reason) = block_reason { - Self::reset_routes(shared_values); - return EventConsequence::NewState(ErrorState::enter(shared_values, reason)); - } - } - Ok(Async::NotReady) => return EventConsequence::NoEvents(self), - Err(_cancelled) => warn!("Tunnel monitor thread has stopped unexpectedly"), + if let Some(block_reason) = block_reason { + Self::reset_routes(shared_values); + return NewState(ErrorState::enter(shared_values, block_reason)); } info!( @@ -395,7 +391,11 @@ impl TunnelState for ConnectingState { ); DisconnectingState::enter( shared_values, - (None, None, AfterDisconnect::Reconnect(retry_attempt + 1)), + ( + None, + Fuse::terminated(), + AfterDisconnect::Reconnect(retry_attempt + 1), + ), ) } else { log::error!( @@ -435,13 +435,30 @@ impl TunnelState for ConnectingState { } fn handle_event( - self, - commands: &mut mpsc::UnboundedReceiver<TunnelCommand>, + mut self, + runtime: &tokio::runtime::Handle, + commands: &mut TunnelCommandReceiver, shared_values: &mut SharedTunnelStateValues, - ) -> EventConsequence<Self> { - self.handle_commands(commands, shared_values) - .or_else(Self::handle_tunnel_events, shared_values) - .or_else(Self::handle_tunnel_close_event, shared_values) + ) -> EventConsequence { + let result = runtime.block_on(async { + futures::select! { + command = commands.next() => EventResult::Command(command), + event = self.tunnel_events.next() => EventResult::Event(event), + result = &mut self.tunnel_close_event => EventResult::Close(result), + } + }); + + match result { + EventResult::Command(command) => self.handle_commands(command, shared_values), + EventResult::Event(event) => self.handle_tunnel_events(event, shared_values), + EventResult::Close(result) => { + if result.is_err() { + log::warn!("Tunnel monitor thread has stopped unexpectedly"); + } + let block_reason = result.unwrap_or(None); + self.handle_tunnel_close_event(block_reason, shared_values) + } + } } } diff --git a/talpid-core/src/tunnel_state_machine/disconnected_state.rs b/talpid-core/src/tunnel_state_machine/disconnected_state.rs index b9a13c7b16..faeac0a45f 100644 --- a/talpid-core/src/tunnel_state_machine/disconnected_state.rs +++ b/talpid-core/src/tunnel_state_machine/disconnected_state.rs @@ -1,9 +1,9 @@ use super::{ ConnectingState, ErrorState, EventConsequence, SharedTunnelStateValues, TunnelCommand, - TunnelState, TunnelStateTransition, TunnelStateWrapper, + TunnelCommandReceiver, TunnelState, TunnelStateTransition, TunnelStateWrapper, }; use crate::firewall::FirewallPolicy; -use futures01::{sync::mpsc, Stream}; +use futures::StreamExt; use talpid_types::ErrorExt; /// No tunnel is running. @@ -63,13 +63,14 @@ impl TunnelState for DisconnectedState { fn handle_event( self, - commands: &mut mpsc::UnboundedReceiver<TunnelCommand>, + runtime: &tokio::runtime::Handle, + commands: &mut TunnelCommandReceiver, shared_values: &mut SharedTunnelStateValues, - ) -> EventConsequence<Self> { + ) -> EventConsequence { use self::EventConsequence::*; - match try_handle_event!(self, commands.poll()) { - Ok(TunnelCommand::AllowLan(allow_lan)) => { + match runtime.block_on(commands.next()) { + Some(TunnelCommand::AllowLan(allow_lan)) => { if shared_values.allow_lan != allow_lan { // The only platform that can fail is Android, but Android doesn't support the // "block when disconnected" option, so the following call never fails. @@ -79,23 +80,25 @@ impl TunnelState for DisconnectedState { Self::set_firewall_policy(shared_values, true); } - SameState(self) + SameState(self.into()) } - Ok(TunnelCommand::BlockWhenDisconnected(block_when_disconnected)) => { + Some(TunnelCommand::BlockWhenDisconnected(block_when_disconnected)) => { if shared_values.block_when_disconnected != block_when_disconnected { shared_values.block_when_disconnected = block_when_disconnected; Self::set_firewall_policy(shared_values, true); } - SameState(self) + SameState(self.into()) } - Ok(TunnelCommand::IsOffline(is_offline)) => { + Some(TunnelCommand::IsOffline(is_offline)) => { shared_values.is_offline = is_offline; - SameState(self) + SameState(self.into()) } - Ok(TunnelCommand::Connect) => NewState(ConnectingState::enter(shared_values, 0)), - Ok(TunnelCommand::Block(reason)) => NewState(ErrorState::enter(shared_values, reason)), - Ok(_) => SameState(self), - Err(_) => Finished, + Some(TunnelCommand::Connect) => NewState(ConnectingState::enter(shared_values, 0)), + Some(TunnelCommand::Block(reason)) => { + NewState(ErrorState::enter(shared_values, reason)) + } + Some(_) => SameState(self.into()), + None => Finished, } } } diff --git a/talpid-core/src/tunnel_state_machine/disconnecting_state.rs b/talpid-core/src/tunnel_state_machine/disconnecting_state.rs index 634cae45f2..33a09ca31a 100644 --- a/talpid-core/src/tunnel_state_machine/disconnecting_state.rs +++ b/talpid-core/src/tunnel_state_machine/disconnecting_state.rs @@ -1,12 +1,10 @@ use super::{ - ConnectingState, DisconnectedState, ErrorState, EventConsequence, SharedTunnelStateValues, - TunnelCommand, TunnelState, TunnelStateTransition, TunnelStateWrapper, + connecting_state::TunnelCloseEvent, ConnectingState, DisconnectedState, ErrorState, + EventConsequence, EventResult, SharedTunnelStateValues, TunnelCommand, TunnelCommandReceiver, + TunnelState, TunnelStateTransition, TunnelStateWrapper, }; use crate::tunnel::CloseHandle; -use futures01::{ - sync::{mpsc, oneshot}, - Async, Future, Stream, -}; +use futures::{future::FusedFuture, StreamExt}; use std::thread; use talpid_types::{ tunnel::{ActionAfterDisconnect, ErrorStateCause}, @@ -16,47 +14,46 @@ use talpid_types::{ /// This state is active from when we manually trigger a tunnel kill until the tunnel wait /// operation (TunnelExit) returned. pub struct DisconnectingState { - exited: Option<oneshot::Receiver<Option<ErrorStateCause>>>, + tunnel_close_event: TunnelCloseEvent, after_disconnect: AfterDisconnect, } impl DisconnectingState { fn handle_commands( mut self, - commands: &mut mpsc::UnboundedReceiver<TunnelCommand>, + command: Option<TunnelCommand>, shared_values: &mut SharedTunnelStateValues, - ) -> EventConsequence<Self> { - let event = try_handle_event!(self, commands.poll()); + ) -> EventConsequence { let after_disconnect = self.after_disconnect; self.after_disconnect = match after_disconnect { - AfterDisconnect::Nothing => match event { - Ok(TunnelCommand::AllowLan(allow_lan)) => { + AfterDisconnect::Nothing => match command { + Some(TunnelCommand::AllowLan(allow_lan)) => { let _ = shared_values.set_allow_lan(allow_lan); AfterDisconnect::Nothing } - Ok(TunnelCommand::BlockWhenDisconnected(block_when_disconnected)) => { + Some(TunnelCommand::BlockWhenDisconnected(block_when_disconnected)) => { shared_values.block_when_disconnected = block_when_disconnected; AfterDisconnect::Nothing } - Ok(TunnelCommand::IsOffline(is_offline)) => { + Some(TunnelCommand::IsOffline(is_offline)) => { shared_values.is_offline = is_offline; AfterDisconnect::Nothing } - Ok(TunnelCommand::Connect) => AfterDisconnect::Reconnect(0), - Ok(TunnelCommand::Block(reason)) => AfterDisconnect::Block(reason), - _ => AfterDisconnect::Nothing, + Some(TunnelCommand::Connect) => AfterDisconnect::Reconnect(0), + Some(TunnelCommand::Disconnect) | None => AfterDisconnect::Nothing, + Some(TunnelCommand::Block(reason)) => AfterDisconnect::Block(reason), }, - AfterDisconnect::Block(reason) => match event { - Ok(TunnelCommand::AllowLan(allow_lan)) => { + AfterDisconnect::Block(reason) => match command { + Some(TunnelCommand::AllowLan(allow_lan)) => { let _ = shared_values.set_allow_lan(allow_lan); AfterDisconnect::Block(reason) } - Ok(TunnelCommand::BlockWhenDisconnected(block_when_disconnected)) => { + Some(TunnelCommand::BlockWhenDisconnected(block_when_disconnected)) => { shared_values.block_when_disconnected = block_when_disconnected; AfterDisconnect::Block(reason) } - Ok(TunnelCommand::IsOffline(is_offline)) => { + Some(TunnelCommand::IsOffline(is_offline)) => { shared_values.is_offline = is_offline; if !is_offline && reason == ErrorStateCause::IsOffline { AfterDisconnect::Reconnect(0) @@ -64,21 +61,21 @@ impl DisconnectingState { AfterDisconnect::Block(reason) } } - Ok(TunnelCommand::Connect) => AfterDisconnect::Reconnect(0), - Ok(TunnelCommand::Disconnect) => AfterDisconnect::Nothing, - Ok(TunnelCommand::Block(new_reason)) => AfterDisconnect::Block(new_reason), - Err(_) => AfterDisconnect::Block(reason), + Some(TunnelCommand::Connect) => AfterDisconnect::Reconnect(0), + Some(TunnelCommand::Disconnect) => AfterDisconnect::Nothing, + Some(TunnelCommand::Block(new_reason)) => AfterDisconnect::Block(new_reason), + None => AfterDisconnect::Block(reason), }, - AfterDisconnect::Reconnect(retry_attempt) => match event { - Ok(TunnelCommand::AllowLan(allow_lan)) => { + AfterDisconnect::Reconnect(retry_attempt) => match command { + Some(TunnelCommand::AllowLan(allow_lan)) => { let _ = shared_values.set_allow_lan(allow_lan); AfterDisconnect::Reconnect(retry_attempt) } - Ok(TunnelCommand::BlockWhenDisconnected(block_when_disconnected)) => { + Some(TunnelCommand::BlockWhenDisconnected(block_when_disconnected)) => { shared_values.block_when_disconnected = block_when_disconnected; AfterDisconnect::Reconnect(retry_attempt) } - Ok(TunnelCommand::IsOffline(is_offline)) => { + Some(TunnelCommand::IsOffline(is_offline)) => { shared_values.is_offline = is_offline; if is_offline { AfterDisconnect::Block(ErrorStateCause::IsOffline) @@ -86,33 +83,13 @@ impl DisconnectingState { AfterDisconnect::Reconnect(retry_attempt) } } - Ok(TunnelCommand::Connect) => AfterDisconnect::Reconnect(retry_attempt), - Ok(TunnelCommand::Disconnect) | Err(_) => AfterDisconnect::Nothing, - Ok(TunnelCommand::Block(reason)) => AfterDisconnect::Block(reason), + Some(TunnelCommand::Connect) => AfterDisconnect::Reconnect(retry_attempt), + Some(TunnelCommand::Disconnect) | None => AfterDisconnect::Nothing, + Some(TunnelCommand::Block(reason)) => AfterDisconnect::Block(reason), }, }; - EventConsequence::SameState(self) - } - - fn handle_exit_event( - mut self, - shared_values: &mut SharedTunnelStateValues, - ) -> EventConsequence<Self> { - use self::EventConsequence::*; - - let poll_result = match &mut self.exited { - Some(exited) => exited.poll(), - None => Ok(Async::Ready(None)), - }; - - match poll_result { - Ok(Async::NotReady) => NoEvents(self), - Ok(Async::Ready(block_reason)) => { - NewState(self.after_disconnect(block_reason, shared_values)) - } - Err(_) => NewState(self.after_disconnect(None, shared_values)), - } + EventConsequence::SameState(self.into()) } fn after_disconnect( @@ -135,15 +112,11 @@ impl DisconnectingState { } impl TunnelState for DisconnectingState { - type Bootstrap = ( - Option<CloseHandle>, - Option<oneshot::Receiver<Option<ErrorStateCause>>>, - AfterDisconnect, - ); + type Bootstrap = (Option<CloseHandle>, TunnelCloseEvent, AfterDisconnect); fn enter( _: &mut SharedTunnelStateValues, - (close_handle, exited, after_disconnect): Self::Bootstrap, + (close_handle, tunnel_close_event, after_disconnect): Self::Bootstrap, ) -> (TunnelStateWrapper, TunnelStateTransition) { if let Some(close_handle) = close_handle { thread::spawn(move || { @@ -160,7 +133,7 @@ impl TunnelState for DisconnectingState { ( TunnelStateWrapper::from(DisconnectingState { - exited, + tunnel_close_event, after_disconnect, }), TunnelStateTransition::Disconnecting(action_after_disconnect), @@ -168,12 +141,32 @@ impl TunnelState for DisconnectingState { } fn handle_event( - self, - commands: &mut mpsc::UnboundedReceiver<TunnelCommand>, + mut self, + runtime: &tokio::runtime::Handle, + commands: &mut TunnelCommandReceiver, shared_values: &mut SharedTunnelStateValues, - ) -> EventConsequence<Self> { - self.handle_commands(commands, shared_values) - .or_else(Self::handle_exit_event, shared_values) + ) -> EventConsequence { + use self::EventConsequence::*; + + if self.tunnel_close_event.is_terminated() { + return NewState(self.after_disconnect(None, shared_values)); + } + + let result = runtime.block_on(async { + futures::select! { + command = commands.next() => EventResult::Command(command), + result = &mut self.tunnel_close_event => EventResult::Close(result), + } + }); + + match result { + EventResult::Command(command) => self.handle_commands(command, shared_values), + EventResult::Close(result) => { + let block_reason = result.unwrap_or(None); + NewState(self.after_disconnect(block_reason, shared_values)) + } + _ => unreachable!("unexpected event result"), + } } } diff --git a/talpid-core/src/tunnel_state_machine/error_state.rs b/talpid-core/src/tunnel_state_machine/error_state.rs index d6a434b055..a9861e788e 100644 --- a/talpid-core/src/tunnel_state_machine/error_state.rs +++ b/talpid-core/src/tunnel_state_machine/error_state.rs @@ -1,9 +1,9 @@ use super::{ ConnectingState, DisconnectedState, EventConsequence, SharedTunnelStateValues, TunnelCommand, - TunnelState, TunnelStateTransition, TunnelStateWrapper, + TunnelCommandReceiver, TunnelState, TunnelStateTransition, TunnelStateWrapper, }; use crate::firewall::FirewallPolicy; -use futures01::{sync::mpsc, Stream}; +use futures::StreamExt; use talpid_types::{ tunnel::{self as talpid_tunnel, ErrorStateCause, FirewallPolicyError}, ErrorExt, @@ -87,37 +87,40 @@ impl TunnelState for ErrorState { fn handle_event( self, - commands: &mut mpsc::UnboundedReceiver<TunnelCommand>, + runtime: &tokio::runtime::Handle, + commands: &mut TunnelCommandReceiver, shared_values: &mut SharedTunnelStateValues, - ) -> EventConsequence<Self> { + ) -> EventConsequence { use self::EventConsequence::*; - match try_handle_event!(self, commands.poll()) { - Ok(TunnelCommand::AllowLan(allow_lan)) => { + match runtime.block_on(commands.next()) { + Some(TunnelCommand::AllowLan(allow_lan)) => { if let Err(error_state_cause) = shared_values.set_allow_lan(allow_lan) { NewState(Self::enter(shared_values, error_state_cause)) } else { let _ = Self::set_firewall_policy(shared_values); - SameState(self) + SameState(self.into()) } } - Ok(TunnelCommand::BlockWhenDisconnected(block_when_disconnected)) => { + Some(TunnelCommand::BlockWhenDisconnected(block_when_disconnected)) => { shared_values.block_when_disconnected = block_when_disconnected; - SameState(self) + SameState(self.into()) } - Ok(TunnelCommand::IsOffline(is_offline)) => { + Some(TunnelCommand::IsOffline(is_offline)) => { shared_values.is_offline = is_offline; if !is_offline && self.block_reason == ErrorStateCause::IsOffline { NewState(ConnectingState::enter(shared_values, 0)) } else { - SameState(self) + SameState(self.into()) } } - Ok(TunnelCommand::Connect) => NewState(ConnectingState::enter(shared_values, 0)), - Ok(TunnelCommand::Disconnect) | Err(_) => { + Some(TunnelCommand::Connect) => NewState(ConnectingState::enter(shared_values, 0)), + Some(TunnelCommand::Disconnect) | None => { NewState(DisconnectedState::enter(shared_values, true)) } - Ok(TunnelCommand::Block(reason)) => NewState(ErrorState::enter(shared_values, reason)), + Some(TunnelCommand::Block(reason)) => { + NewState(ErrorState::enter(shared_values, reason)) + } } } } diff --git a/talpid-core/src/tunnel_state_machine/macros.rs b/talpid-core/src/tunnel_state_machine/macros.rs deleted file mode 100644 index deccd29fdd..0000000000 --- a/talpid-core/src/tunnel_state_machine/macros.rs +++ /dev/null @@ -1,21 +0,0 @@ -/// Try to receive an event from a `Stream`'s asynchronous poll expression. -/// -/// This macro is similar to the `try_ready!` macro provided in `futures`. If there is an event -/// ready, it will be returned wrapped in a `Result`. If there are no events ready to be received, -/// the outer function will return with a transition that indicates that no events were received, -/// which is analogous to `Async::NotReady`. -/// -/// When the asynchronous event indicates that the stream has finished or that it has failed, an -/// error type is returned so that either close scenario can be handled in a similar way. -macro_rules! try_handle_event { - ($same_state:expr, $event:expr) => { - match $event { - Ok(futures01::Async::Ready(Some(event))) => Ok(event), - Ok(futures01::Async::Ready(None)) => Err(None), - Ok(futures01::Async::NotReady) => { - return crate::tunnel_state_machine::EventConsequence::NoEvents($same_state); - } - Err(error) => Err(Some(error)), - } - }; -} diff --git a/talpid-core/src/tunnel_state_machine/mod.rs b/talpid-core/src/tunnel_state_machine/mod.rs index 8a3b588927..37a27ae58f 100644 --- a/talpid-core/src/tunnel_state_machine/mod.rs +++ b/talpid-core/src/tunnel_state_machine/mod.rs @@ -1,6 +1,3 @@ -#[macro_use] -mod macros; - mod connected_state; mod connecting_state; mod disconnected_state; @@ -20,14 +17,13 @@ use crate::{ mpsc::Sender, offline, routing::RouteManager, - tunnel::tun_provider::TunProvider, + tunnel::{tun_provider::TunProvider, TunnelEvent}, }; use futures::{ channel::{mpsc, oneshot}, - StreamExt, + stream, StreamExt, }; -use futures01::{sync::mpsc as old_mpsc, Async, Poll, Stream}; use std::{ collections::HashSet, io, @@ -87,7 +83,7 @@ pub async fn spawn( reset_firewall: bool, #[cfg(target_os = "android")] android_context: AndroidContext, ) -> Result<Arc<mpsc::UnboundedSender<TunnelCommand>>, Error> { - let (command_tx, mut command_rx) = mpsc::unbounded(); + let (command_tx, command_rx) = mpsc::unbounded(); let command_tx = Arc::new(command_tx); let mut offline_monitor = offline::spawn_monitor( Arc::downgrade(&command_tx), @@ -105,15 +101,7 @@ pub async fn spawn( allow_lan, ); - // Hide internal 0.1 futures from the client - let (command_adapter_tx, command_adapter_rx) = old_mpsc::unbounded(); - tokio::spawn(async move { - while let Some(command) = command_rx.next().await { - if command_adapter_tx.unbounded_send(command).is_err() { - log::error!("Failed to forward daemon command"); - } - } - }); + let runtime = tokio::runtime::Handle::current(); let (startup_result_tx, startup_result_rx) = sync_mpsc::channel(); std::thread::spawn(move || { @@ -126,7 +114,7 @@ pub async fn spawn( log_dir, resource_dir, cache_dir, - command_adapter_rx, + command_rx, reset_firewall, ); let state_machine = match state_machine { @@ -140,16 +128,8 @@ pub async fn spawn( } }; - let mut iter = state_machine.wait(); - while let Some(Ok(change_event)) = iter.next() { - if let Err(error) = state_change_listener - .send(change_event) - .map_err(|_| Error::SendStateChange) - { - log::error!("{}", error); - break; - } - } + state_machine.run(runtime, state_change_listener); + if shutdown_tx.send(()).is_err() { log::error!("Can't send shutdown completion to daemon"); } @@ -179,6 +159,14 @@ pub enum TunnelCommand { Block(ErrorStateCause), } +type TunnelCommandReceiver = stream::Fuse<mpsc::UnboundedReceiver<TunnelCommand>>; + +enum EventResult { + Command(Option<TunnelCommand>), + Event(Option<TunnelEvent>), + Close(Result<Option<ErrorStateCause>, oneshot::Canceled>), +} + /// Asynchronous handling of the tunnel state machine. /// /// This type implements `Stream`, and attempts to advance the state machine based on the events @@ -187,7 +175,7 @@ pub enum TunnelCommand { /// by the stream. struct TunnelStateMachine { current_state: Option<TunnelStateWrapper>, - commands: old_mpsc::UnboundedReceiver<TunnelCommand>, + commands: TunnelCommandReceiver, shared_values: SharedTunnelStateValues, } @@ -201,7 +189,7 @@ impl TunnelStateMachine { log_dir: Option<PathBuf>, resource_dir: PathBuf, cache_dir: impl AsRef<Path>, - commands: old_mpsc::UnboundedReceiver<TunnelCommand>, + commands: mpsc::UnboundedReceiver<TunnelCommand>, reset_firewall: bool, ) -> Result<Self, Error> { let args = FirewallArguments { @@ -230,62 +218,43 @@ impl TunnelStateMachine { Ok(TunnelStateMachine { current_state: Some(initial_state), - commands, + commands: commands.fuse(), shared_values, }) } -} -impl Stream for TunnelStateMachine { - type Item = TunnelStateTransition; - type Error = Error; + fn run( + mut self, + runtime: tokio::runtime::Handle, + change_listener: impl Sender<TunnelStateTransition> + Send + 'static, + ) { + use EventConsequence::*; - fn poll(&mut self) -> Poll<Option<Self::Item>, Self::Error> { while let Some(state_wrapper) = self.current_state.take() { - match state_wrapper.handle_event(&mut self.commands, &mut self.shared_values) { - TunnelStateMachineAction::Repeat(repeat_state_wrapper) => { - self.current_state = Some(repeat_state_wrapper); + match state_wrapper.handle_event(&runtime, &mut self.commands, &mut self.shared_values) + { + NewState((state, transition)) => { + self.current_state = Some(state); + + if let Err(error) = change_listener + .send(transition) + .map_err(|_| Error::SendStateChange) + { + log::error!("{}", error); + break; + } } - TunnelStateMachineAction::Notify(state_wrapper, result) => { - self.current_state = state_wrapper; - return result; + SameState(state) => { + self.current_state = Some(state); } + Finished => (), } } - Ok(Async::Ready(None)) - } -} - -/// Action the state machine should take, which is discovered base on an event consequence. -/// -/// The action can be to execute another iteration or to notify that something happened. Executing -/// another iteration happens when an event is received and ignored, which causes the tunnel state -/// machine to stay in the same state. The state machine can notify its caller that a state -/// transition has occurred, that it has finished, or that it has paused to wait for new events. -enum TunnelStateMachineAction { - Repeat(TunnelStateWrapper), - Notify( - Option<TunnelStateWrapper>, - Poll<Option<TunnelStateTransition>, Error>, - ), -} -impl<T: TunnelState> From<EventConsequence<T>> for TunnelStateMachineAction { - fn from(event_consequence: EventConsequence<T>) -> Self { - use self::{EventConsequence::*, TunnelStateMachineAction::*}; - - match event_consequence { - NewState((state_wrapper, transition)) => { - Notify(Some(state_wrapper), Ok(Async::Ready(Some(transition)))) - } - SameState(state) => Repeat(state.into()), - NoEvents(state) => Notify(Some(state.into()), Ok(Async::NotReady)), - Finished => Notify(None, Ok(Async::Ready(None))), - } + log::debug!("Exiting tunnel state machine loop"); } } - /// Trait for any type that can provide a stream of `TunnelParameters` to the `TunnelStateMachine`. pub trait TunnelParametersGenerator: Send + 'static { /// Given the number of consecutive failed retry attempts, it should yield a `TunnelParameters` @@ -343,37 +312,15 @@ impl SharedTunnelStateValues { } /// Asynchronous result of an attempt to progress a state. -enum EventConsequence<T: TunnelState> { +enum EventConsequence { /// Transition to a new state. NewState((TunnelStateWrapper, TunnelStateTransition)), /// An event was received, but it was ignored by the state so no transition is performed. - SameState(T), - /// No events were received, the event loop should block until one becomes available. - NoEvents(T), + SameState(TunnelStateWrapper), /// The state machine has finished its execution. Finished, } -impl<T> EventConsequence<T> -where - T: TunnelState, -{ - /// Helper method to chain handling multiple different event types. - /// - /// The `handle_event` is only called if no events were handled so far. - pub fn or_else<F>(self, handle_event: F, shared_values: &mut SharedTunnelStateValues) -> Self - where - F: FnOnce(T, &mut SharedTunnelStateValues) -> Self, - { - use self::EventConsequence::*; - - match self { - NoEvents(state) => handle_event(state, shared_values), - consequence => consequence, - } - } -} - /// Trait that contains the method all states should implement to handle an event and advance the /// state machine. trait TunnelState: Into<TunnelStateWrapper> + Sized { @@ -401,9 +348,10 @@ trait TunnelState: Into<TunnelStateWrapper> + Sized { /// [`EventConsequence`]: enum.EventConsequence.html fn handle_event( self, - commands: &mut old_mpsc::UnboundedReceiver<TunnelCommand>, + runtime: &tokio::runtime::Handle, + commands: &mut TunnelCommandReceiver, shared_values: &mut SharedTunnelStateValues, - ) -> EventConsequence<Self>; + ) -> EventConsequence; } macro_rules! state_wrapper { @@ -425,13 +373,13 @@ macro_rules! state_wrapper { impl $wrapper_name { fn handle_event( self, - commands: &mut old_mpsc::UnboundedReceiver<TunnelCommand>, + runtime: &tokio::runtime::Handle, + commands: &mut TunnelCommandReceiver, shared_values: &mut SharedTunnelStateValues, - ) -> TunnelStateMachineAction { + ) -> EventConsequence { match self { $($wrapper_name::$state_variant(state) => { - let event_consequence = state.handle_event(commands, shared_values); - TunnelStateMachineAction::from(event_consequence) + state.handle_event(runtime, commands, shared_values) })* } } |
