summaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorDavid Lönnhager <david.l@mullvad.net>2020-10-05 15:51:33 +0200
committerDavid Lönnhager <david.l@mullvad.net>2020-10-05 15:51:33 +0200
commit53e167439ea6b263a2a399591a82a22a2d7a6343 (patch)
tree44cf8bd0cbc434b55a2d035426dc79cfe06fa3aa
parent779e5267e3c03ff511c36518f0323347d675cce3 (diff)
parent00097409bf285e08de927d1f934e32208fd4bb5e (diff)
downloadmullvadvpn-53e167439ea6b263a2a399591a82a22a2d7a6343.tar.xz
mullvadvpn-53e167439ea6b263a2a399591a82a22a2d7a6343.zip
Merge branch 'refactor-tunnel-sm'
-rw-r--r--Cargo.lock1
-rw-r--r--talpid-core/Cargo.toml1
-rw-r--r--talpid-core/src/routing/linux.rs6
-rw-r--r--talpid-core/src/tunnel_state_machine/connected_state.rs110
-rw-r--r--talpid-core/src/tunnel_state_machine/connecting_state.rs127
-rw-r--r--talpid-core/src/tunnel_state_machine/disconnected_state.rs33
-rw-r--r--talpid-core/src/tunnel_state_machine/disconnecting_state.rs123
-rw-r--r--talpid-core/src/tunnel_state_machine/error_state.rs31
-rw-r--r--talpid-core/src/tunnel_state_machine/macros.rs21
-rw-r--r--talpid-core/src/tunnel_state_machine/mod.rs150
10 files changed, 274 insertions, 329 deletions
diff --git a/Cargo.lock b/Cargo.lock
index 41b4eba987..cefe826bf8 100644
--- a/Cargo.lock
+++ b/Cargo.lock
@@ -2309,7 +2309,6 @@ dependencies = [
"duct 0.13.4 (registry+https://github.com/rust-lang/crates.io-index)",
"err-derive 0.2.4 (registry+https://github.com/rust-lang/crates.io-index)",
"failure 0.1.8 (registry+https://github.com/rust-lang/crates.io-index)",
- "futures 0.1.29 (registry+https://github.com/rust-lang/crates.io-index)",
"futures 0.3.5 (registry+https://github.com/rust-lang/crates.io-index)",
"hex 0.4.2 (registry+https://github.com/rust-lang/crates.io-index)",
"ipnetwork 0.16.0 (registry+https://github.com/rust-lang/crates.io-index)",
diff --git a/talpid-core/Cargo.toml b/talpid-core/Cargo.toml
index 1f374152c6..60acfbe9d0 100644
--- a/talpid-core/Cargo.toml
+++ b/talpid-core/Cargo.toml
@@ -12,7 +12,6 @@ atty = "0.2"
cfg-if = "0.1"
duct = "0.13"
err-derive = "0.2.1"
-futures01 = { package = "futures", version = "0.1" }
futures = { package = "futures", version = "0.3", features = [ "compat" ]}
hex = "0.4"
ipnetwork = "0.16"
diff --git a/talpid-core/src/routing/linux.rs b/talpid-core/src/routing/linux.rs
index 025804d3e1..4516f19127 100644
--- a/talpid-core/src/routing/linux.rs
+++ b/talpid-core/src/routing/linux.rs
@@ -972,12 +972,6 @@ impl RouteManagerImpl {
}
}
-impl Drop for RouteManagerImpl {
- fn drop(&mut self) {
- futures::executor::block_on(self.destructor());
- }
-}
-
fn ip_to_bytes(addr: IpAddr) -> Vec<u8> {
match addr {
IpAddr::V4(addr) => addr.octets().to_vec(),
diff --git a/talpid-core/src/tunnel_state_machine/connected_state.rs b/talpid-core/src/tunnel_state_machine/connected_state.rs
index 33e1463167..a36093d51c 100644
--- a/talpid-core/src/tunnel_state_machine/connected_state.rs
+++ b/talpid-core/src/tunnel_state_machine/connected_state.rs
@@ -1,15 +1,13 @@
use super::{
AfterDisconnect, ConnectingState, DisconnectingState, ErrorState, EventConsequence,
- SharedTunnelStateValues, TunnelCommand, TunnelState, TunnelStateTransition, TunnelStateWrapper,
+ EventResult, SharedTunnelStateValues, TunnelCommand, TunnelCommandReceiver, TunnelState,
+ TunnelStateTransition, TunnelStateWrapper,
};
use crate::{
firewall::FirewallPolicy,
tunnel::{CloseHandle, TunnelEvent, TunnelMetadata},
};
-use futures01::{
- sync::{mpsc, oneshot},
- Async, Future, Stream,
-};
+use futures::{channel::mpsc, stream::Fuse, StreamExt};
use talpid_types::{
net::TunnelParameters,
tunnel::{ErrorStateCause, FirewallPolicyError},
@@ -19,21 +17,25 @@ use talpid_types::{
#[cfg(windows)]
use crate::tunnel::TunnelMonitor;
+use super::connecting_state::TunnelCloseEvent;
+
+pub(crate) type TunnelEventsReceiver = Fuse<mpsc::UnboundedReceiver<TunnelEvent>>;
+
pub struct ConnectedStateBootstrap {
pub metadata: TunnelMetadata,
- pub tunnel_events: mpsc::UnboundedReceiver<TunnelEvent>,
+ pub tunnel_events: TunnelEventsReceiver,
pub tunnel_parameters: TunnelParameters,
- pub tunnel_close_event: Option<oneshot::Receiver<Option<ErrorStateCause>>>,
+ pub tunnel_close_event: TunnelCloseEvent,
pub close_handle: Option<CloseHandle>,
}
/// The tunnel is up and working.
pub struct ConnectedState {
metadata: TunnelMetadata,
- tunnel_events: mpsc::UnboundedReceiver<TunnelEvent>,
+ tunnel_events: TunnelEventsReceiver,
tunnel_parameters: TunnelParameters,
- tunnel_close_event: Option<oneshot::Receiver<Option<ErrorStateCause>>>,
+ tunnel_close_event: TunnelCloseEvent,
close_handle: Option<CloseHandle>,
}
@@ -126,7 +128,7 @@ impl ConnectedState {
self,
shared_values: &mut SharedTunnelStateValues,
after_disconnect: AfterDisconnect,
- ) -> EventConsequence<Self> {
+ ) -> EventConsequence {
Self::reset_dns(shared_values);
Self::reset_routes(shared_values);
@@ -138,18 +140,18 @@ impl ConnectedState {
fn handle_commands(
self,
- commands: &mut mpsc::UnboundedReceiver<TunnelCommand>,
+ command: Option<TunnelCommand>,
shared_values: &mut SharedTunnelStateValues,
- ) -> EventConsequence<Self> {
+ ) -> EventConsequence {
use self::EventConsequence::*;
- match try_handle_event!(self, commands.poll()) {
- Ok(TunnelCommand::AllowLan(allow_lan)) => {
+ match command {
+ Some(TunnelCommand::AllowLan(allow_lan)) => {
if let Err(error_cause) = shared_values.set_allow_lan(allow_lan) {
self.disconnect(shared_values, AfterDisconnect::Block(error_cause))
} else {
match self.set_firewall_policy(shared_values) {
- Ok(()) => SameState(self),
+ Ok(()) => SameState(self.into()),
Err(error) => self.disconnect(
shared_values,
AfterDisconnect::Block(ErrorStateCause::SetFirewallPolicyError(error)),
@@ -157,11 +159,11 @@ impl ConnectedState {
}
}
}
- Ok(TunnelCommand::BlockWhenDisconnected(block_when_disconnected)) => {
+ Some(TunnelCommand::BlockWhenDisconnected(block_when_disconnected)) => {
shared_values.block_when_disconnected = block_when_disconnected;
- SameState(self)
+ SameState(self.into())
}
- Ok(TunnelCommand::IsOffline(is_offline)) => {
+ Some(TunnelCommand::IsOffline(is_offline)) => {
shared_values.is_offline = is_offline;
if is_offline {
self.disconnect(
@@ -169,56 +171,47 @@ impl ConnectedState {
AfterDisconnect::Block(ErrorStateCause::IsOffline),
)
} else {
- SameState(self)
+ SameState(self.into())
}
}
- Ok(TunnelCommand::Connect) => {
+ Some(TunnelCommand::Connect) => {
self.disconnect(shared_values, AfterDisconnect::Reconnect(0))
}
- Ok(TunnelCommand::Disconnect) | Err(_) => {
+ Some(TunnelCommand::Disconnect) | None => {
self.disconnect(shared_values, AfterDisconnect::Nothing)
}
- Ok(TunnelCommand::Block(reason)) => {
+ Some(TunnelCommand::Block(reason)) => {
self.disconnect(shared_values, AfterDisconnect::Block(reason))
}
}
}
fn handle_tunnel_events(
- mut self,
+ self,
+ event: Option<TunnelEvent>,
shared_values: &mut SharedTunnelStateValues,
- ) -> EventConsequence<Self> {
+ ) -> EventConsequence {
use self::EventConsequence::*;
- match try_handle_event!(self, self.tunnel_events.poll()) {
- Ok(TunnelEvent::Down) | Err(_) => {
+ match event {
+ Some(TunnelEvent::Down) | None => {
self.disconnect(shared_values, AfterDisconnect::Reconnect(0))
}
- Ok(_) => SameState(self),
+ Some(_) => SameState(self.into()),
}
}
fn handle_tunnel_close_event(
- mut self,
+ self,
+ block_reason: Option<ErrorStateCause>,
shared_values: &mut SharedTunnelStateValues,
- ) -> EventConsequence<Self> {
+ ) -> EventConsequence {
use self::EventConsequence::*;
- let poll_result = match &mut self.tunnel_close_event {
- Some(tunnel_close_event) => tunnel_close_event.poll(),
- None => Ok(Async::NotReady),
- };
-
- match poll_result {
- Ok(Async::Ready(block_reason)) => {
- if let Some(reason) = block_reason {
- Self::reset_dns(shared_values);
- Self::reset_routes(shared_values);
- return NewState(ErrorState::enter(shared_values, reason));
- }
- }
- Ok(Async::NotReady) => return NoEvents(self),
- Err(_cancelled) => log::warn!("Tunnel monitor thread has stopped unexpectedly"),
+ if let Some(block_reason) = block_reason {
+ Self::reset_dns(shared_values);
+ Self::reset_routes(shared_values);
+ return NewState(ErrorState::enter(shared_values, block_reason));
}
log::info!("Tunnel closed. Reconnecting.");
@@ -267,12 +260,29 @@ impl TunnelState for ConnectedState {
}
fn handle_event(
- self,
- commands: &mut mpsc::UnboundedReceiver<TunnelCommand>,
+ mut self,
+ runtime: &tokio::runtime::Handle,
+ commands: &mut TunnelCommandReceiver,
shared_values: &mut SharedTunnelStateValues,
- ) -> EventConsequence<Self> {
- self.handle_commands(commands, shared_values)
- .or_else(Self::handle_tunnel_events, shared_values)
- .or_else(Self::handle_tunnel_close_event, shared_values)
+ ) -> EventConsequence {
+ let result = runtime.block_on(async {
+ futures::select! {
+ command = commands.next() => EventResult::Command(command),
+ event = self.tunnel_events.next() => EventResult::Event(event),
+ result = &mut self.tunnel_close_event => EventResult::Close(result),
+ }
+ });
+
+ match result {
+ EventResult::Command(command) => self.handle_commands(command, shared_values),
+ EventResult::Event(event) => self.handle_tunnel_events(event, shared_values),
+ EventResult::Close(result) => {
+ if result.is_err() {
+ log::warn!("Tunnel monitor thread has stopped unexpectedly");
+ }
+ let block_reason = result.unwrap_or(None);
+ self.handle_tunnel_close_event(block_reason, shared_values)
+ }
+ }
}
}
diff --git a/talpid-core/src/tunnel_state_machine/connecting_state.rs b/talpid-core/src/tunnel_state_machine/connecting_state.rs
index 859ef52eb0..bfccac6572 100644
--- a/talpid-core/src/tunnel_state_machine/connecting_state.rs
+++ b/talpid-core/src/tunnel_state_machine/connecting_state.rs
@@ -1,7 +1,7 @@
use super::{
AfterDisconnect, ConnectedState, ConnectedStateBootstrap, DisconnectingState, ErrorState,
- EventConsequence, SharedTunnelStateValues, TunnelCommand, TunnelState, TunnelStateTransition,
- TunnelStateWrapper,
+ EventConsequence, EventResult, SharedTunnelStateValues, TunnelCommand, TunnelCommandReceiver,
+ TunnelState, TunnelStateTransition, TunnelStateWrapper,
};
use crate::{
firewall::FirewallPolicy,
@@ -10,9 +10,10 @@ use crate::{
self, tun_provider::TunProvider, CloseHandle, TunnelEvent, TunnelMetadata, TunnelMonitor,
},
};
-use futures01::{
- sync::{mpsc, oneshot},
- Async, Future, Stream,
+use futures::{
+ channel::{mpsc, oneshot},
+ future::Fuse,
+ FutureExt, StreamExt,
};
use log::{debug, error, info, trace, warn};
use std::{
@@ -30,15 +31,19 @@ use talpid_types::{
#[cfg(target_os = "android")]
use crate::tunnel::tun_provider;
+use super::connected_state::TunnelEventsReceiver;
+
+pub(crate) type TunnelCloseEvent = Fuse<oneshot::Receiver<Option<ErrorStateCause>>>;
+
#[cfg(target_os = "android")]
const MAX_ATTEMPTS_WITH_SAME_TUN: u32 = 5;
const MIN_TUNNEL_ALIVE_TIME: Duration = Duration::from_millis(1000);
/// The tunnel has been started, but it is not established/functional.
pub struct ConnectingState {
- tunnel_events: mpsc::UnboundedReceiver<TunnelEvent>,
+ tunnel_events: TunnelEventsReceiver,
tunnel_parameters: TunnelParameters,
- tunnel_close_event: Option<oneshot::Receiver<Option<ErrorStateCause>>>,
+ tunnel_close_event: TunnelCloseEvent,
close_handle: Option<CloseHandle>,
retry_attempt: u32,
}
@@ -102,7 +107,7 @@ impl ConnectingState {
let tunnel_close_event = Self::spawn_tunnel_monitor_wait_thread(monitor);
Ok(ConnectingState {
- tunnel_events: event_rx,
+ tunnel_events: event_rx.fuse(),
tunnel_parameters: parameters,
tunnel_close_event,
close_handle,
@@ -110,9 +115,7 @@ impl ConnectingState {
})
}
- fn spawn_tunnel_monitor_wait_thread(
- tunnel_monitor: TunnelMonitor,
- ) -> Option<oneshot::Receiver<Option<ErrorStateCause>>> {
+ fn spawn_tunnel_monitor_wait_thread(tunnel_monitor: TunnelMonitor) -> TunnelCloseEvent {
let (tunnel_close_event_tx, tunnel_close_event_rx) = oneshot::channel();
thread::spawn(move || {
@@ -137,7 +140,7 @@ impl ConnectingState {
trace!("Tunnel monitor thread exit");
});
- Some(tunnel_close_event_rx)
+ tunnel_close_event_rx.fuse()
}
fn wait_for_tunnel_monitor(tunnel_monitor: TunnelMonitor) -> Option<ErrorStateCause> {
@@ -194,7 +197,7 @@ impl ConnectingState {
self,
shared_values: &mut SharedTunnelStateValues,
after_disconnect: AfterDisconnect,
- ) -> EventConsequence<Self> {
+ ) -> EventConsequence {
Self::reset_routes(shared_values);
EventConsequence::NewState(DisconnectingState::enter(
@@ -205,18 +208,18 @@ impl ConnectingState {
fn handle_commands(
self,
- commands: &mut mpsc::UnboundedReceiver<TunnelCommand>,
+ command: Option<TunnelCommand>,
shared_values: &mut SharedTunnelStateValues,
- ) -> EventConsequence<Self> {
+ ) -> EventConsequence {
use self::EventConsequence::*;
- match try_handle_event!(self, commands.poll()) {
- Ok(TunnelCommand::AllowLan(allow_lan)) => {
+ match command {
+ Some(TunnelCommand::AllowLan(allow_lan)) => {
if let Err(error_cause) = shared_values.set_allow_lan(allow_lan) {
self.disconnect(shared_values, AfterDisconnect::Block(error_cause))
} else {
match Self::set_firewall_policy(shared_values, &self.tunnel_parameters) {
- Ok(()) => SameState(self),
+ Ok(()) => SameState(self.into()),
Err(error) => self.disconnect(
shared_values,
AfterDisconnect::Block(ErrorStateCause::SetFirewallPolicyError(error)),
@@ -224,11 +227,11 @@ impl ConnectingState {
}
}
}
- Ok(TunnelCommand::BlockWhenDisconnected(block_when_disconnected)) => {
+ Some(TunnelCommand::BlockWhenDisconnected(block_when_disconnected)) => {
shared_values.block_when_disconnected = block_when_disconnected;
- SameState(self)
+ SameState(self.into())
}
- Ok(TunnelCommand::IsOffline(is_offline)) => {
+ Some(TunnelCommand::IsOffline(is_offline)) => {
shared_values.is_offline = is_offline;
if is_offline {
self.disconnect(
@@ -236,39 +239,40 @@ impl ConnectingState {
AfterDisconnect::Block(ErrorStateCause::IsOffline),
)
} else {
- SameState(self)
+ SameState(self.into())
}
}
- Ok(TunnelCommand::Connect) => {
+ Some(TunnelCommand::Connect) => {
self.disconnect(shared_values, AfterDisconnect::Reconnect(0))
}
- Ok(TunnelCommand::Disconnect) | Err(_) => {
+ Some(TunnelCommand::Disconnect) | None => {
self.disconnect(shared_values, AfterDisconnect::Nothing)
}
- Ok(TunnelCommand::Block(reason)) => {
+ Some(TunnelCommand::Block(reason)) => {
self.disconnect(shared_values, AfterDisconnect::Block(reason))
}
}
}
-
fn handle_tunnel_events(
- mut self,
+ self,
+ event: Option<tunnel::TunnelEvent>,
shared_values: &mut SharedTunnelStateValues,
- ) -> EventConsequence<Self> {
+ ) -> EventConsequence {
use self::EventConsequence::*;
- match try_handle_event!(self, self.tunnel_events.poll()) {
- Ok(TunnelEvent::AuthFailed(reason)) => self.disconnect(
+ match event {
+ Some(TunnelEvent::AuthFailed(reason)) => self.disconnect(
shared_values,
AfterDisconnect::Block(ErrorStateCause::AuthFailed(reason)),
),
- Ok(TunnelEvent::Up(metadata)) => NewState(ConnectedState::enter(
+ Some(TunnelEvent::Up(metadata)) => NewState(ConnectedState::enter(
shared_values,
self.into_connected_state_bootstrap(metadata),
)),
- Ok(_) => SameState(self),
- Err(_) => {
+ Some(TunnelEvent::Down) => SameState(self.into()),
+ None => {
+ // The channel was closed
debug!("The tunnel disconnected unexpectedly");
let retry_attempt = self.retry_attempt + 1;
self.disconnect(shared_values, AfterDisconnect::Reconnect(retry_attempt))
@@ -277,23 +281,15 @@ impl ConnectingState {
}
fn handle_tunnel_close_event(
- mut self,
+ self,
+ block_reason: Option<ErrorStateCause>,
shared_values: &mut SharedTunnelStateValues,
- ) -> EventConsequence<Self> {
- let poll_result = match &mut self.tunnel_close_event {
- Some(tunnel_close_event) => tunnel_close_event.poll(),
- None => Ok(Async::NotReady),
- };
+ ) -> EventConsequence {
+ use self::EventConsequence::*;
- match poll_result {
- Ok(Async::Ready(block_reason)) => {
- if let Some(reason) = block_reason {
- Self::reset_routes(shared_values);
- return EventConsequence::NewState(ErrorState::enter(shared_values, reason));
- }
- }
- Ok(Async::NotReady) => return EventConsequence::NoEvents(self),
- Err(_cancelled) => warn!("Tunnel monitor thread has stopped unexpectedly"),
+ if let Some(block_reason) = block_reason {
+ Self::reset_routes(shared_values);
+ return NewState(ErrorState::enter(shared_values, block_reason));
}
info!(
@@ -395,7 +391,11 @@ impl TunnelState for ConnectingState {
);
DisconnectingState::enter(
shared_values,
- (None, None, AfterDisconnect::Reconnect(retry_attempt + 1)),
+ (
+ None,
+ Fuse::terminated(),
+ AfterDisconnect::Reconnect(retry_attempt + 1),
+ ),
)
} else {
log::error!(
@@ -435,13 +435,30 @@ impl TunnelState for ConnectingState {
}
fn handle_event(
- self,
- commands: &mut mpsc::UnboundedReceiver<TunnelCommand>,
+ mut self,
+ runtime: &tokio::runtime::Handle,
+ commands: &mut TunnelCommandReceiver,
shared_values: &mut SharedTunnelStateValues,
- ) -> EventConsequence<Self> {
- self.handle_commands(commands, shared_values)
- .or_else(Self::handle_tunnel_events, shared_values)
- .or_else(Self::handle_tunnel_close_event, shared_values)
+ ) -> EventConsequence {
+ let result = runtime.block_on(async {
+ futures::select! {
+ command = commands.next() => EventResult::Command(command),
+ event = self.tunnel_events.next() => EventResult::Event(event),
+ result = &mut self.tunnel_close_event => EventResult::Close(result),
+ }
+ });
+
+ match result {
+ EventResult::Command(command) => self.handle_commands(command, shared_values),
+ EventResult::Event(event) => self.handle_tunnel_events(event, shared_values),
+ EventResult::Close(result) => {
+ if result.is_err() {
+ log::warn!("Tunnel monitor thread has stopped unexpectedly");
+ }
+ let block_reason = result.unwrap_or(None);
+ self.handle_tunnel_close_event(block_reason, shared_values)
+ }
+ }
}
}
diff --git a/talpid-core/src/tunnel_state_machine/disconnected_state.rs b/talpid-core/src/tunnel_state_machine/disconnected_state.rs
index b9a13c7b16..faeac0a45f 100644
--- a/talpid-core/src/tunnel_state_machine/disconnected_state.rs
+++ b/talpid-core/src/tunnel_state_machine/disconnected_state.rs
@@ -1,9 +1,9 @@
use super::{
ConnectingState, ErrorState, EventConsequence, SharedTunnelStateValues, TunnelCommand,
- TunnelState, TunnelStateTransition, TunnelStateWrapper,
+ TunnelCommandReceiver, TunnelState, TunnelStateTransition, TunnelStateWrapper,
};
use crate::firewall::FirewallPolicy;
-use futures01::{sync::mpsc, Stream};
+use futures::StreamExt;
use talpid_types::ErrorExt;
/// No tunnel is running.
@@ -63,13 +63,14 @@ impl TunnelState for DisconnectedState {
fn handle_event(
self,
- commands: &mut mpsc::UnboundedReceiver<TunnelCommand>,
+ runtime: &tokio::runtime::Handle,
+ commands: &mut TunnelCommandReceiver,
shared_values: &mut SharedTunnelStateValues,
- ) -> EventConsequence<Self> {
+ ) -> EventConsequence {
use self::EventConsequence::*;
- match try_handle_event!(self, commands.poll()) {
- Ok(TunnelCommand::AllowLan(allow_lan)) => {
+ match runtime.block_on(commands.next()) {
+ Some(TunnelCommand::AllowLan(allow_lan)) => {
if shared_values.allow_lan != allow_lan {
// The only platform that can fail is Android, but Android doesn't support the
// "block when disconnected" option, so the following call never fails.
@@ -79,23 +80,25 @@ impl TunnelState for DisconnectedState {
Self::set_firewall_policy(shared_values, true);
}
- SameState(self)
+ SameState(self.into())
}
- Ok(TunnelCommand::BlockWhenDisconnected(block_when_disconnected)) => {
+ Some(TunnelCommand::BlockWhenDisconnected(block_when_disconnected)) => {
if shared_values.block_when_disconnected != block_when_disconnected {
shared_values.block_when_disconnected = block_when_disconnected;
Self::set_firewall_policy(shared_values, true);
}
- SameState(self)
+ SameState(self.into())
}
- Ok(TunnelCommand::IsOffline(is_offline)) => {
+ Some(TunnelCommand::IsOffline(is_offline)) => {
shared_values.is_offline = is_offline;
- SameState(self)
+ SameState(self.into())
}
- Ok(TunnelCommand::Connect) => NewState(ConnectingState::enter(shared_values, 0)),
- Ok(TunnelCommand::Block(reason)) => NewState(ErrorState::enter(shared_values, reason)),
- Ok(_) => SameState(self),
- Err(_) => Finished,
+ Some(TunnelCommand::Connect) => NewState(ConnectingState::enter(shared_values, 0)),
+ Some(TunnelCommand::Block(reason)) => {
+ NewState(ErrorState::enter(shared_values, reason))
+ }
+ Some(_) => SameState(self.into()),
+ None => Finished,
}
}
}
diff --git a/talpid-core/src/tunnel_state_machine/disconnecting_state.rs b/talpid-core/src/tunnel_state_machine/disconnecting_state.rs
index 634cae45f2..33a09ca31a 100644
--- a/talpid-core/src/tunnel_state_machine/disconnecting_state.rs
+++ b/talpid-core/src/tunnel_state_machine/disconnecting_state.rs
@@ -1,12 +1,10 @@
use super::{
- ConnectingState, DisconnectedState, ErrorState, EventConsequence, SharedTunnelStateValues,
- TunnelCommand, TunnelState, TunnelStateTransition, TunnelStateWrapper,
+ connecting_state::TunnelCloseEvent, ConnectingState, DisconnectedState, ErrorState,
+ EventConsequence, EventResult, SharedTunnelStateValues, TunnelCommand, TunnelCommandReceiver,
+ TunnelState, TunnelStateTransition, TunnelStateWrapper,
};
use crate::tunnel::CloseHandle;
-use futures01::{
- sync::{mpsc, oneshot},
- Async, Future, Stream,
-};
+use futures::{future::FusedFuture, StreamExt};
use std::thread;
use talpid_types::{
tunnel::{ActionAfterDisconnect, ErrorStateCause},
@@ -16,47 +14,46 @@ use talpid_types::{
/// This state is active from when we manually trigger a tunnel kill until the tunnel wait
/// operation (TunnelExit) returned.
pub struct DisconnectingState {
- exited: Option<oneshot::Receiver<Option<ErrorStateCause>>>,
+ tunnel_close_event: TunnelCloseEvent,
after_disconnect: AfterDisconnect,
}
impl DisconnectingState {
fn handle_commands(
mut self,
- commands: &mut mpsc::UnboundedReceiver<TunnelCommand>,
+ command: Option<TunnelCommand>,
shared_values: &mut SharedTunnelStateValues,
- ) -> EventConsequence<Self> {
- let event = try_handle_event!(self, commands.poll());
+ ) -> EventConsequence {
let after_disconnect = self.after_disconnect;
self.after_disconnect = match after_disconnect {
- AfterDisconnect::Nothing => match event {
- Ok(TunnelCommand::AllowLan(allow_lan)) => {
+ AfterDisconnect::Nothing => match command {
+ Some(TunnelCommand::AllowLan(allow_lan)) => {
let _ = shared_values.set_allow_lan(allow_lan);
AfterDisconnect::Nothing
}
- Ok(TunnelCommand::BlockWhenDisconnected(block_when_disconnected)) => {
+ Some(TunnelCommand::BlockWhenDisconnected(block_when_disconnected)) => {
shared_values.block_when_disconnected = block_when_disconnected;
AfterDisconnect::Nothing
}
- Ok(TunnelCommand::IsOffline(is_offline)) => {
+ Some(TunnelCommand::IsOffline(is_offline)) => {
shared_values.is_offline = is_offline;
AfterDisconnect::Nothing
}
- Ok(TunnelCommand::Connect) => AfterDisconnect::Reconnect(0),
- Ok(TunnelCommand::Block(reason)) => AfterDisconnect::Block(reason),
- _ => AfterDisconnect::Nothing,
+ Some(TunnelCommand::Connect) => AfterDisconnect::Reconnect(0),
+ Some(TunnelCommand::Disconnect) | None => AfterDisconnect::Nothing,
+ Some(TunnelCommand::Block(reason)) => AfterDisconnect::Block(reason),
},
- AfterDisconnect::Block(reason) => match event {
- Ok(TunnelCommand::AllowLan(allow_lan)) => {
+ AfterDisconnect::Block(reason) => match command {
+ Some(TunnelCommand::AllowLan(allow_lan)) => {
let _ = shared_values.set_allow_lan(allow_lan);
AfterDisconnect::Block(reason)
}
- Ok(TunnelCommand::BlockWhenDisconnected(block_when_disconnected)) => {
+ Some(TunnelCommand::BlockWhenDisconnected(block_when_disconnected)) => {
shared_values.block_when_disconnected = block_when_disconnected;
AfterDisconnect::Block(reason)
}
- Ok(TunnelCommand::IsOffline(is_offline)) => {
+ Some(TunnelCommand::IsOffline(is_offline)) => {
shared_values.is_offline = is_offline;
if !is_offline && reason == ErrorStateCause::IsOffline {
AfterDisconnect::Reconnect(0)
@@ -64,21 +61,21 @@ impl DisconnectingState {
AfterDisconnect::Block(reason)
}
}
- Ok(TunnelCommand::Connect) => AfterDisconnect::Reconnect(0),
- Ok(TunnelCommand::Disconnect) => AfterDisconnect::Nothing,
- Ok(TunnelCommand::Block(new_reason)) => AfterDisconnect::Block(new_reason),
- Err(_) => AfterDisconnect::Block(reason),
+ Some(TunnelCommand::Connect) => AfterDisconnect::Reconnect(0),
+ Some(TunnelCommand::Disconnect) => AfterDisconnect::Nothing,
+ Some(TunnelCommand::Block(new_reason)) => AfterDisconnect::Block(new_reason),
+ None => AfterDisconnect::Block(reason),
},
- AfterDisconnect::Reconnect(retry_attempt) => match event {
- Ok(TunnelCommand::AllowLan(allow_lan)) => {
+ AfterDisconnect::Reconnect(retry_attempt) => match command {
+ Some(TunnelCommand::AllowLan(allow_lan)) => {
let _ = shared_values.set_allow_lan(allow_lan);
AfterDisconnect::Reconnect(retry_attempt)
}
- Ok(TunnelCommand::BlockWhenDisconnected(block_when_disconnected)) => {
+ Some(TunnelCommand::BlockWhenDisconnected(block_when_disconnected)) => {
shared_values.block_when_disconnected = block_when_disconnected;
AfterDisconnect::Reconnect(retry_attempt)
}
- Ok(TunnelCommand::IsOffline(is_offline)) => {
+ Some(TunnelCommand::IsOffline(is_offline)) => {
shared_values.is_offline = is_offline;
if is_offline {
AfterDisconnect::Block(ErrorStateCause::IsOffline)
@@ -86,33 +83,13 @@ impl DisconnectingState {
AfterDisconnect::Reconnect(retry_attempt)
}
}
- Ok(TunnelCommand::Connect) => AfterDisconnect::Reconnect(retry_attempt),
- Ok(TunnelCommand::Disconnect) | Err(_) => AfterDisconnect::Nothing,
- Ok(TunnelCommand::Block(reason)) => AfterDisconnect::Block(reason),
+ Some(TunnelCommand::Connect) => AfterDisconnect::Reconnect(retry_attempt),
+ Some(TunnelCommand::Disconnect) | None => AfterDisconnect::Nothing,
+ Some(TunnelCommand::Block(reason)) => AfterDisconnect::Block(reason),
},
};
- EventConsequence::SameState(self)
- }
-
- fn handle_exit_event(
- mut self,
- shared_values: &mut SharedTunnelStateValues,
- ) -> EventConsequence<Self> {
- use self::EventConsequence::*;
-
- let poll_result = match &mut self.exited {
- Some(exited) => exited.poll(),
- None => Ok(Async::Ready(None)),
- };
-
- match poll_result {
- Ok(Async::NotReady) => NoEvents(self),
- Ok(Async::Ready(block_reason)) => {
- NewState(self.after_disconnect(block_reason, shared_values))
- }
- Err(_) => NewState(self.after_disconnect(None, shared_values)),
- }
+ EventConsequence::SameState(self.into())
}
fn after_disconnect(
@@ -135,15 +112,11 @@ impl DisconnectingState {
}
impl TunnelState for DisconnectingState {
- type Bootstrap = (
- Option<CloseHandle>,
- Option<oneshot::Receiver<Option<ErrorStateCause>>>,
- AfterDisconnect,
- );
+ type Bootstrap = (Option<CloseHandle>, TunnelCloseEvent, AfterDisconnect);
fn enter(
_: &mut SharedTunnelStateValues,
- (close_handle, exited, after_disconnect): Self::Bootstrap,
+ (close_handle, tunnel_close_event, after_disconnect): Self::Bootstrap,
) -> (TunnelStateWrapper, TunnelStateTransition) {
if let Some(close_handle) = close_handle {
thread::spawn(move || {
@@ -160,7 +133,7 @@ impl TunnelState for DisconnectingState {
(
TunnelStateWrapper::from(DisconnectingState {
- exited,
+ tunnel_close_event,
after_disconnect,
}),
TunnelStateTransition::Disconnecting(action_after_disconnect),
@@ -168,12 +141,32 @@ impl TunnelState for DisconnectingState {
}
fn handle_event(
- self,
- commands: &mut mpsc::UnboundedReceiver<TunnelCommand>,
+ mut self,
+ runtime: &tokio::runtime::Handle,
+ commands: &mut TunnelCommandReceiver,
shared_values: &mut SharedTunnelStateValues,
- ) -> EventConsequence<Self> {
- self.handle_commands(commands, shared_values)
- .or_else(Self::handle_exit_event, shared_values)
+ ) -> EventConsequence {
+ use self::EventConsequence::*;
+
+ if self.tunnel_close_event.is_terminated() {
+ return NewState(self.after_disconnect(None, shared_values));
+ }
+
+ let result = runtime.block_on(async {
+ futures::select! {
+ command = commands.next() => EventResult::Command(command),
+ result = &mut self.tunnel_close_event => EventResult::Close(result),
+ }
+ });
+
+ match result {
+ EventResult::Command(command) => self.handle_commands(command, shared_values),
+ EventResult::Close(result) => {
+ let block_reason = result.unwrap_or(None);
+ NewState(self.after_disconnect(block_reason, shared_values))
+ }
+ _ => unreachable!("unexpected event result"),
+ }
}
}
diff --git a/talpid-core/src/tunnel_state_machine/error_state.rs b/talpid-core/src/tunnel_state_machine/error_state.rs
index d6a434b055..a9861e788e 100644
--- a/talpid-core/src/tunnel_state_machine/error_state.rs
+++ b/talpid-core/src/tunnel_state_machine/error_state.rs
@@ -1,9 +1,9 @@
use super::{
ConnectingState, DisconnectedState, EventConsequence, SharedTunnelStateValues, TunnelCommand,
- TunnelState, TunnelStateTransition, TunnelStateWrapper,
+ TunnelCommandReceiver, TunnelState, TunnelStateTransition, TunnelStateWrapper,
};
use crate::firewall::FirewallPolicy;
-use futures01::{sync::mpsc, Stream};
+use futures::StreamExt;
use talpid_types::{
tunnel::{self as talpid_tunnel, ErrorStateCause, FirewallPolicyError},
ErrorExt,
@@ -87,37 +87,40 @@ impl TunnelState for ErrorState {
fn handle_event(
self,
- commands: &mut mpsc::UnboundedReceiver<TunnelCommand>,
+ runtime: &tokio::runtime::Handle,
+ commands: &mut TunnelCommandReceiver,
shared_values: &mut SharedTunnelStateValues,
- ) -> EventConsequence<Self> {
+ ) -> EventConsequence {
use self::EventConsequence::*;
- match try_handle_event!(self, commands.poll()) {
- Ok(TunnelCommand::AllowLan(allow_lan)) => {
+ match runtime.block_on(commands.next()) {
+ Some(TunnelCommand::AllowLan(allow_lan)) => {
if let Err(error_state_cause) = shared_values.set_allow_lan(allow_lan) {
NewState(Self::enter(shared_values, error_state_cause))
} else {
let _ = Self::set_firewall_policy(shared_values);
- SameState(self)
+ SameState(self.into())
}
}
- Ok(TunnelCommand::BlockWhenDisconnected(block_when_disconnected)) => {
+ Some(TunnelCommand::BlockWhenDisconnected(block_when_disconnected)) => {
shared_values.block_when_disconnected = block_when_disconnected;
- SameState(self)
+ SameState(self.into())
}
- Ok(TunnelCommand::IsOffline(is_offline)) => {
+ Some(TunnelCommand::IsOffline(is_offline)) => {
shared_values.is_offline = is_offline;
if !is_offline && self.block_reason == ErrorStateCause::IsOffline {
NewState(ConnectingState::enter(shared_values, 0))
} else {
- SameState(self)
+ SameState(self.into())
}
}
- Ok(TunnelCommand::Connect) => NewState(ConnectingState::enter(shared_values, 0)),
- Ok(TunnelCommand::Disconnect) | Err(_) => {
+ Some(TunnelCommand::Connect) => NewState(ConnectingState::enter(shared_values, 0)),
+ Some(TunnelCommand::Disconnect) | None => {
NewState(DisconnectedState::enter(shared_values, true))
}
- Ok(TunnelCommand::Block(reason)) => NewState(ErrorState::enter(shared_values, reason)),
+ Some(TunnelCommand::Block(reason)) => {
+ NewState(ErrorState::enter(shared_values, reason))
+ }
}
}
}
diff --git a/talpid-core/src/tunnel_state_machine/macros.rs b/talpid-core/src/tunnel_state_machine/macros.rs
deleted file mode 100644
index deccd29fdd..0000000000
--- a/talpid-core/src/tunnel_state_machine/macros.rs
+++ /dev/null
@@ -1,21 +0,0 @@
-/// Try to receive an event from a `Stream`'s asynchronous poll expression.
-///
-/// This macro is similar to the `try_ready!` macro provided in `futures`. If there is an event
-/// ready, it will be returned wrapped in a `Result`. If there are no events ready to be received,
-/// the outer function will return with a transition that indicates that no events were received,
-/// which is analogous to `Async::NotReady`.
-///
-/// When the asynchronous event indicates that the stream has finished or that it has failed, an
-/// error type is returned so that either close scenario can be handled in a similar way.
-macro_rules! try_handle_event {
- ($same_state:expr, $event:expr) => {
- match $event {
- Ok(futures01::Async::Ready(Some(event))) => Ok(event),
- Ok(futures01::Async::Ready(None)) => Err(None),
- Ok(futures01::Async::NotReady) => {
- return crate::tunnel_state_machine::EventConsequence::NoEvents($same_state);
- }
- Err(error) => Err(Some(error)),
- }
- };
-}
diff --git a/talpid-core/src/tunnel_state_machine/mod.rs b/talpid-core/src/tunnel_state_machine/mod.rs
index 8a3b588927..37a27ae58f 100644
--- a/talpid-core/src/tunnel_state_machine/mod.rs
+++ b/talpid-core/src/tunnel_state_machine/mod.rs
@@ -1,6 +1,3 @@
-#[macro_use]
-mod macros;
-
mod connected_state;
mod connecting_state;
mod disconnected_state;
@@ -20,14 +17,13 @@ use crate::{
mpsc::Sender,
offline,
routing::RouteManager,
- tunnel::tun_provider::TunProvider,
+ tunnel::{tun_provider::TunProvider, TunnelEvent},
};
use futures::{
channel::{mpsc, oneshot},
- StreamExt,
+ stream, StreamExt,
};
-use futures01::{sync::mpsc as old_mpsc, Async, Poll, Stream};
use std::{
collections::HashSet,
io,
@@ -87,7 +83,7 @@ pub async fn spawn(
reset_firewall: bool,
#[cfg(target_os = "android")] android_context: AndroidContext,
) -> Result<Arc<mpsc::UnboundedSender<TunnelCommand>>, Error> {
- let (command_tx, mut command_rx) = mpsc::unbounded();
+ let (command_tx, command_rx) = mpsc::unbounded();
let command_tx = Arc::new(command_tx);
let mut offline_monitor = offline::spawn_monitor(
Arc::downgrade(&command_tx),
@@ -105,15 +101,7 @@ pub async fn spawn(
allow_lan,
);
- // Hide internal 0.1 futures from the client
- let (command_adapter_tx, command_adapter_rx) = old_mpsc::unbounded();
- tokio::spawn(async move {
- while let Some(command) = command_rx.next().await {
- if command_adapter_tx.unbounded_send(command).is_err() {
- log::error!("Failed to forward daemon command");
- }
- }
- });
+ let runtime = tokio::runtime::Handle::current();
let (startup_result_tx, startup_result_rx) = sync_mpsc::channel();
std::thread::spawn(move || {
@@ -126,7 +114,7 @@ pub async fn spawn(
log_dir,
resource_dir,
cache_dir,
- command_adapter_rx,
+ command_rx,
reset_firewall,
);
let state_machine = match state_machine {
@@ -140,16 +128,8 @@ pub async fn spawn(
}
};
- let mut iter = state_machine.wait();
- while let Some(Ok(change_event)) = iter.next() {
- if let Err(error) = state_change_listener
- .send(change_event)
- .map_err(|_| Error::SendStateChange)
- {
- log::error!("{}", error);
- break;
- }
- }
+ state_machine.run(runtime, state_change_listener);
+
if shutdown_tx.send(()).is_err() {
log::error!("Can't send shutdown completion to daemon");
}
@@ -179,6 +159,14 @@ pub enum TunnelCommand {
Block(ErrorStateCause),
}
+type TunnelCommandReceiver = stream::Fuse<mpsc::UnboundedReceiver<TunnelCommand>>;
+
+enum EventResult {
+ Command(Option<TunnelCommand>),
+ Event(Option<TunnelEvent>),
+ Close(Result<Option<ErrorStateCause>, oneshot::Canceled>),
+}
+
/// Asynchronous handling of the tunnel state machine.
///
/// This type implements `Stream`, and attempts to advance the state machine based on the events
@@ -187,7 +175,7 @@ pub enum TunnelCommand {
/// by the stream.
struct TunnelStateMachine {
current_state: Option<TunnelStateWrapper>,
- commands: old_mpsc::UnboundedReceiver<TunnelCommand>,
+ commands: TunnelCommandReceiver,
shared_values: SharedTunnelStateValues,
}
@@ -201,7 +189,7 @@ impl TunnelStateMachine {
log_dir: Option<PathBuf>,
resource_dir: PathBuf,
cache_dir: impl AsRef<Path>,
- commands: old_mpsc::UnboundedReceiver<TunnelCommand>,
+ commands: mpsc::UnboundedReceiver<TunnelCommand>,
reset_firewall: bool,
) -> Result<Self, Error> {
let args = FirewallArguments {
@@ -230,62 +218,43 @@ impl TunnelStateMachine {
Ok(TunnelStateMachine {
current_state: Some(initial_state),
- commands,
+ commands: commands.fuse(),
shared_values,
})
}
-}
-impl Stream for TunnelStateMachine {
- type Item = TunnelStateTransition;
- type Error = Error;
+ fn run(
+ mut self,
+ runtime: tokio::runtime::Handle,
+ change_listener: impl Sender<TunnelStateTransition> + Send + 'static,
+ ) {
+ use EventConsequence::*;
- fn poll(&mut self) -> Poll<Option<Self::Item>, Self::Error> {
while let Some(state_wrapper) = self.current_state.take() {
- match state_wrapper.handle_event(&mut self.commands, &mut self.shared_values) {
- TunnelStateMachineAction::Repeat(repeat_state_wrapper) => {
- self.current_state = Some(repeat_state_wrapper);
+ match state_wrapper.handle_event(&runtime, &mut self.commands, &mut self.shared_values)
+ {
+ NewState((state, transition)) => {
+ self.current_state = Some(state);
+
+ if let Err(error) = change_listener
+ .send(transition)
+ .map_err(|_| Error::SendStateChange)
+ {
+ log::error!("{}", error);
+ break;
+ }
}
- TunnelStateMachineAction::Notify(state_wrapper, result) => {
- self.current_state = state_wrapper;
- return result;
+ SameState(state) => {
+ self.current_state = Some(state);
}
+ Finished => (),
}
}
- Ok(Async::Ready(None))
- }
-}
-
-/// Action the state machine should take, which is discovered base on an event consequence.
-///
-/// The action can be to execute another iteration or to notify that something happened. Executing
-/// another iteration happens when an event is received and ignored, which causes the tunnel state
-/// machine to stay in the same state. The state machine can notify its caller that a state
-/// transition has occurred, that it has finished, or that it has paused to wait for new events.
-enum TunnelStateMachineAction {
- Repeat(TunnelStateWrapper),
- Notify(
- Option<TunnelStateWrapper>,
- Poll<Option<TunnelStateTransition>, Error>,
- ),
-}
-impl<T: TunnelState> From<EventConsequence<T>> for TunnelStateMachineAction {
- fn from(event_consequence: EventConsequence<T>) -> Self {
- use self::{EventConsequence::*, TunnelStateMachineAction::*};
-
- match event_consequence {
- NewState((state_wrapper, transition)) => {
- Notify(Some(state_wrapper), Ok(Async::Ready(Some(transition))))
- }
- SameState(state) => Repeat(state.into()),
- NoEvents(state) => Notify(Some(state.into()), Ok(Async::NotReady)),
- Finished => Notify(None, Ok(Async::Ready(None))),
- }
+ log::debug!("Exiting tunnel state machine loop");
}
}
-
/// Trait for any type that can provide a stream of `TunnelParameters` to the `TunnelStateMachine`.
pub trait TunnelParametersGenerator: Send + 'static {
/// Given the number of consecutive failed retry attempts, it should yield a `TunnelParameters`
@@ -343,37 +312,15 @@ impl SharedTunnelStateValues {
}
/// Asynchronous result of an attempt to progress a state.
-enum EventConsequence<T: TunnelState> {
+enum EventConsequence {
/// Transition to a new state.
NewState((TunnelStateWrapper, TunnelStateTransition)),
/// An event was received, but it was ignored by the state so no transition is performed.
- SameState(T),
- /// No events were received, the event loop should block until one becomes available.
- NoEvents(T),
+ SameState(TunnelStateWrapper),
/// The state machine has finished its execution.
Finished,
}
-impl<T> EventConsequence<T>
-where
- T: TunnelState,
-{
- /// Helper method to chain handling multiple different event types.
- ///
- /// The `handle_event` is only called if no events were handled so far.
- pub fn or_else<F>(self, handle_event: F, shared_values: &mut SharedTunnelStateValues) -> Self
- where
- F: FnOnce(T, &mut SharedTunnelStateValues) -> Self,
- {
- use self::EventConsequence::*;
-
- match self {
- NoEvents(state) => handle_event(state, shared_values),
- consequence => consequence,
- }
- }
-}
-
/// Trait that contains the method all states should implement to handle an event and advance the
/// state machine.
trait TunnelState: Into<TunnelStateWrapper> + Sized {
@@ -401,9 +348,10 @@ trait TunnelState: Into<TunnelStateWrapper> + Sized {
/// [`EventConsequence`]: enum.EventConsequence.html
fn handle_event(
self,
- commands: &mut old_mpsc::UnboundedReceiver<TunnelCommand>,
+ runtime: &tokio::runtime::Handle,
+ commands: &mut TunnelCommandReceiver,
shared_values: &mut SharedTunnelStateValues,
- ) -> EventConsequence<Self>;
+ ) -> EventConsequence;
}
macro_rules! state_wrapper {
@@ -425,13 +373,13 @@ macro_rules! state_wrapper {
impl $wrapper_name {
fn handle_event(
self,
- commands: &mut old_mpsc::UnboundedReceiver<TunnelCommand>,
+ runtime: &tokio::runtime::Handle,
+ commands: &mut TunnelCommandReceiver,
shared_values: &mut SharedTunnelStateValues,
- ) -> TunnelStateMachineAction {
+ ) -> EventConsequence {
match self {
$($wrapper_name::$state_variant(state) => {
- let event_consequence = state.handle_event(commands, shared_values);
- TunnelStateMachineAction::from(event_consequence)
+ state.handle_event(runtime, commands, shared_values)
})*
}
}