summaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
-rw-r--r--talpid-core/src/tunnel_state_machine/connected_state.rs132
-rw-r--r--talpid-core/src/tunnel_state_machine/connecting_state.rs218
-rw-r--r--talpid-core/src/tunnel_state_machine/disconnected_state.rs90
-rw-r--r--talpid-core/src/tunnel_state_machine/disconnecting_state.rs45
-rw-r--r--talpid-core/src/tunnel_state_machine/error_state.rs120
-rw-r--r--talpid-core/src/tunnel_state_machine/mod.rs72
6 files changed, 295 insertions, 382 deletions
diff --git a/talpid-core/src/tunnel_state_machine/connected_state.rs b/talpid-core/src/tunnel_state_machine/connected_state.rs
index b231518838..687e941aa7 100644
--- a/talpid-core/src/tunnel_state_machine/connected_state.rs
+++ b/talpid-core/src/tunnel_state_machine/connected_state.rs
@@ -1,7 +1,7 @@
use super::{
AfterDisconnect, ConnectingState, DisconnectingState, ErrorState, EventConsequence,
EventResult, SharedTunnelStateValues, TunnelCommand, TunnelCommandReceiver, TunnelState,
- TunnelStateTransition, TunnelStateWrapper,
+ TunnelStateTransition,
};
use crate::{
firewall::FirewallPolicy,
@@ -27,14 +27,6 @@ use super::connecting_state::TunnelCloseEvent;
pub(crate) type TunnelEventsReceiver =
Fuse<mpsc::UnboundedReceiver<(TunnelEvent, oneshot::Sender<()>)>>;
-pub struct ConnectedStateBootstrap {
- pub metadata: TunnelMetadata,
- pub tunnel_events: TunnelEventsReceiver,
- pub tunnel_parameters: TunnelParameters,
- pub tunnel_close_event: TunnelCloseEvent,
- pub tunnel_close_tx: oneshot::Sender<()>,
-}
-
/// The tunnel is up and working.
pub struct ConnectedState {
metadata: TunnelMetadata,
@@ -45,13 +37,47 @@ pub struct ConnectedState {
}
impl ConnectedState {
- fn from(bootstrap: ConnectedStateBootstrap) -> Self {
- ConnectedState {
- metadata: bootstrap.metadata,
- tunnel_events: bootstrap.tunnel_events,
- tunnel_parameters: bootstrap.tunnel_parameters,
- tunnel_close_event: bootstrap.tunnel_close_event,
- tunnel_close_tx: bootstrap.tunnel_close_tx,
+ #[cfg_attr(target_os = "android", allow(unused_variables))]
+ pub(super) fn enter(
+ shared_values: &mut SharedTunnelStateValues,
+ metadata: TunnelMetadata,
+ tunnel_events: TunnelEventsReceiver,
+ tunnel_parameters: TunnelParameters,
+ tunnel_close_event: TunnelCloseEvent,
+ tunnel_close_tx: oneshot::Sender<()>,
+ ) -> (Box<dyn TunnelState>, TunnelStateTransition) {
+ let connected_state = ConnectedState {
+ metadata,
+ tunnel_events,
+ tunnel_parameters,
+ tunnel_close_event,
+ tunnel_close_tx,
+ };
+
+ let tunnel_interface = Some(connected_state.metadata.interface.clone());
+ let tunnel_endpoint = talpid_types::net::TunnelEndpoint {
+ tunnel_interface,
+ ..connected_state.tunnel_parameters.get_tunnel_endpoint()
+ };
+
+ if let Err(error) = connected_state.set_firewall_policy(shared_values) {
+ DisconnectingState::enter(
+ connected_state.tunnel_close_tx,
+ connected_state.tunnel_close_event,
+ AfterDisconnect::Block(ErrorStateCause::SetFirewallPolicyError(error)),
+ )
+ } else if let Err(error) = connected_state.set_dns(shared_values) {
+ log::error!("{}", error.display_chain_with_msg("Failed to set DNS"));
+ DisconnectingState::enter(
+ connected_state.tunnel_close_tx,
+ connected_state.tunnel_close_event,
+ AfterDisconnect::Block(ErrorStateCause::SetDnsError),
+ )
+ } else {
+ (
+ Box::new(connected_state),
+ TunnelStateTransition::Connected(tunnel_endpoint),
+ )
}
}
@@ -173,17 +199,14 @@ impl ConnectedState {
Self::reset_routes(shared_values);
EventConsequence::NewState(DisconnectingState::enter(
- shared_values,
- (
- self.tunnel_close_tx,
- self.tunnel_close_event,
- after_disconnect,
- ),
+ self.tunnel_close_tx,
+ self.tunnel_close_event,
+ after_disconnect,
))
}
fn handle_commands(
- self,
+ self: Box<Self>,
command: Option<TunnelCommand>,
shared_values: &mut SharedTunnelStateValues,
) -> EventConsequence {
@@ -199,7 +222,7 @@ impl ConnectedState {
if cfg!(target_os = "android") {
self.disconnect(shared_values, AfterDisconnect::Reconnect(0))
} else {
- SameState(self.into())
+ SameState(self)
}
}
Err(error) => self.disconnect(
@@ -212,7 +235,7 @@ impl ConnectedState {
Some(TunnelCommand::AllowEndpoint(endpoint, tx)) => {
shared_values.allowed_endpoint = endpoint;
let _ = tx.send(());
- SameState(self.into())
+ SameState(self)
}
Some(TunnelCommand::Dns(servers)) => match shared_values.set_dns_servers(servers) {
Ok(true) => {
@@ -227,7 +250,7 @@ impl ConnectedState {
#[cfg(target_os = "android")]
Ok(()) => self.disconnect(shared_values, AfterDisconnect::Reconnect(0)),
#[cfg(not(target_os = "android"))]
- Ok(()) => SameState(self.into()),
+ Ok(()) => SameState(self),
Err(error) => {
log::error!("{}", error.display_chain_with_msg("Failed to set DNS"));
self.disconnect(
@@ -237,14 +260,14 @@ impl ConnectedState {
}
}
}
- Ok(false) => SameState(self.into()),
+ Ok(false) => SameState(self),
Err(error_cause) => {
self.disconnect(shared_values, AfterDisconnect::Block(error_cause))
}
},
Some(TunnelCommand::BlockWhenDisconnected(block_when_disconnected)) => {
shared_values.block_when_disconnected = block_when_disconnected;
- SameState(self.into())
+ SameState(self)
}
Some(TunnelCommand::IsOffline(is_offline)) => {
shared_values.is_offline = is_offline;
@@ -254,7 +277,7 @@ impl ConnectedState {
AfterDisconnect::Block(ErrorStateCause::IsOffline),
)
} else {
- SameState(self.into())
+ SameState(self)
}
}
Some(TunnelCommand::Connect) => {
@@ -269,18 +292,18 @@ impl ConnectedState {
#[cfg(target_os = "android")]
Some(TunnelCommand::BypassSocket(fd, done_tx)) => {
shared_values.bypass_socket(fd, done_tx);
- SameState(self.into())
+ SameState(self)
}
#[cfg(windows)]
Some(TunnelCommand::SetExcludedApps(result_tx, paths)) => {
shared_values.split_tunnel.set_paths(&paths, result_tx);
- SameState(self.into())
+ SameState(self)
}
}
}
fn handle_tunnel_events(
- self,
+ self: Box<Self>,
event: Option<(TunnelEvent, oneshot::Sender<()>)>,
shared_values: &mut SharedTunnelStateValues,
) -> EventConsequence {
@@ -290,7 +313,7 @@ impl ConnectedState {
Some((TunnelEvent::Down, _)) | None => {
self.disconnect(shared_values, AfterDisconnect::Reconnect(0))
}
- Some(_) => SameState(self.into()),
+ Some(_) => SameState(self),
}
}
@@ -315,49 +338,8 @@ impl ConnectedState {
}
impl TunnelState for ConnectedState {
- type Bootstrap = ConnectedStateBootstrap;
-
- #[cfg_attr(target_os = "android", allow(unused_variables))]
- fn enter(
- shared_values: &mut SharedTunnelStateValues,
- bootstrap: Self::Bootstrap,
- ) -> (TunnelStateWrapper, TunnelStateTransition) {
- let connected_state = ConnectedState::from(bootstrap);
- let tunnel_interface = Some(connected_state.metadata.interface.clone());
- let tunnel_endpoint = talpid_types::net::TunnelEndpoint {
- tunnel_interface,
- ..connected_state.tunnel_parameters.get_tunnel_endpoint()
- };
-
- if let Err(error) = connected_state.set_firewall_policy(shared_values) {
- DisconnectingState::enter(
- shared_values,
- (
- connected_state.tunnel_close_tx,
- connected_state.tunnel_close_event,
- AfterDisconnect::Block(ErrorStateCause::SetFirewallPolicyError(error)),
- ),
- )
- } else if let Err(error) = connected_state.set_dns(shared_values) {
- log::error!("{}", error.display_chain_with_msg("Failed to set DNS"));
- DisconnectingState::enter(
- shared_values,
- (
- connected_state.tunnel_close_tx,
- connected_state.tunnel_close_event,
- AfterDisconnect::Block(ErrorStateCause::SetDnsError),
- ),
- )
- } else {
- (
- TunnelStateWrapper::from(connected_state),
- TunnelStateTransition::Connected(tunnel_endpoint),
- )
- }
- }
-
fn handle_event(
- mut self,
+ mut self: Box<Self>,
runtime: &tokio::runtime::Handle,
commands: &mut TunnelCommandReceiver,
shared_values: &mut SharedTunnelStateValues,
diff --git a/talpid-core/src/tunnel_state_machine/connecting_state.rs b/talpid-core/src/tunnel_state_machine/connecting_state.rs
index 28731d617d..2a728513ff 100644
--- a/talpid-core/src/tunnel_state_machine/connecting_state.rs
+++ b/talpid-core/src/tunnel_state_machine/connecting_state.rs
@@ -1,7 +1,7 @@
use super::{
- AfterDisconnect, ConnectedState, ConnectedStateBootstrap, DisconnectingState, ErrorState,
- EventConsequence, EventResult, SharedTunnelStateValues, TunnelCommand, TunnelCommandReceiver,
- TunnelState, TunnelStateTransition, TunnelStateWrapper,
+ AfterDisconnect, ConnectedState, DisconnectingState, ErrorState, EventConsequence, EventResult,
+ SharedTunnelStateValues, TunnelCommand, TunnelCommandReceiver, TunnelState,
+ TunnelStateTransition,
};
use crate::{
firewall::FirewallPolicy,
@@ -53,6 +53,84 @@ pub struct ConnectingState {
}
impl ConnectingState {
+ pub(super) fn enter(
+ shared_values: &mut SharedTunnelStateValues,
+ retry_attempt: u32,
+ ) -> (Box<dyn TunnelState>, TunnelStateTransition) {
+ if shared_values.is_offline {
+ // FIXME: Temporary: Nudge route manager to update the default interface
+ #[cfg(target_os = "macos")]
+ if let Ok(handle) = shared_values.route_manager.handle() {
+ log::debug!("Poking route manager to update default routes");
+ let _ = handle.refresh_routes();
+ }
+ return ErrorState::enter(shared_values, ErrorStateCause::IsOffline);
+ }
+ match shared_values.runtime.block_on(
+ shared_values
+ .tunnel_parameters_generator
+ .generate(retry_attempt),
+ ) {
+ Err(err) => {
+ ErrorState::enter(shared_values, ErrorStateCause::TunnelParameterError(err))
+ }
+ Ok(tunnel_parameters) => {
+ #[cfg(windows)]
+ if let Err(error) = shared_values.split_tunnel.set_tunnel_addresses(None) {
+ log::error!(
+ "{}",
+ error.display_chain_with_msg(
+ "Failed to reset addresses in split tunnel driver"
+ )
+ );
+
+ return ErrorState::enter(shared_values, ErrorStateCause::SplitTunnelError);
+ }
+
+ if let Err(error) = Self::set_firewall_policy(
+ shared_values,
+ &tunnel_parameters,
+ &None,
+ AllowedTunnelTraffic::None,
+ ) {
+ ErrorState::enter(
+ shared_values,
+ ErrorStateCause::SetFirewallPolicyError(error),
+ )
+ } else {
+ #[cfg(target_os = "android")]
+ {
+ if retry_attempt > 0 && retry_attempt % MAX_ATTEMPTS_WITH_SAME_TUN == 0 {
+ if let Err(error) =
+ { shared_values.tun_provider.lock().unwrap().create_tun() }
+ {
+ log::error!(
+ "{}",
+ error.display_chain_with_msg("Failed to recreate tun device")
+ );
+ }
+ }
+ }
+
+ let connecting_state = Self::start_tunnel(
+ shared_values.runtime.clone(),
+ tunnel_parameters,
+ &shared_values.log_dir,
+ &shared_values.resource_dir,
+ shared_values.tun_provider.clone(),
+ &shared_values.route_manager,
+ retry_attempt,
+ );
+ let params = connecting_state.tunnel_parameters.clone();
+ (
+ Box::new(connecting_state),
+ TunnelStateTransition::Connecting(params.get_tunnel_endpoint()),
+ )
+ }
+ }
+ }
+ }
+
fn set_firewall_policy(
shared_values: &mut SharedTunnelStateValues,
params: &TunnelParameters,
@@ -249,16 +327,6 @@ impl ConnectingState {
}
}
- fn into_connected_state_bootstrap(self, metadata: TunnelMetadata) -> ConnectedStateBootstrap {
- ConnectedStateBootstrap {
- metadata,
- tunnel_events: self.tunnel_events,
- tunnel_parameters: self.tunnel_parameters,
- tunnel_close_event: self.tunnel_close_event,
- tunnel_close_tx: self.tunnel_close_tx,
- }
- }
-
fn reset_routes(
#[cfg(target_os = "windows")] shared_values: &SharedTunnelStateValues,
#[cfg(not(target_os = "windows"))] shared_values: &mut SharedTunnelStateValues,
@@ -286,16 +354,16 @@ impl ConnectingState {
Self::reset_routes(shared_values);
EventConsequence::NewState(DisconnectingState::enter(
- shared_values,
- (
- self.tunnel_close_tx,
- self.tunnel_close_event,
- after_disconnect,
- ),
+ self.tunnel_close_tx,
+ self.tunnel_close_event,
+ after_disconnect,
))
}
- fn reset_firewall(self, shared_values: &mut SharedTunnelStateValues) -> EventConsequence {
+ fn reset_firewall(
+ self: Box<Self>,
+ shared_values: &mut SharedTunnelStateValues,
+ ) -> EventConsequence {
match Self::set_firewall_policy(
shared_values,
&self.tunnel_parameters,
@@ -306,7 +374,7 @@ impl ConnectingState {
if cfg!(target_os = "android") {
self.disconnect(shared_values, AfterDisconnect::Reconnect(0))
} else {
- EventConsequence::SameState(self.into())
+ EventConsequence::SameState(self)
}
}
Err(error) => self.disconnect(
@@ -317,7 +385,7 @@ impl ConnectingState {
}
fn handle_commands(
- self,
+ self: Box<Self>,
command: Option<TunnelCommand>,
shared_values: &mut SharedTunnelStateValues,
) -> EventConsequence {
@@ -348,17 +416,17 @@ impl ConnectingState {
}
}
let _ = tx.send(());
- SameState(self.into())
+ SameState(self)
}
Some(TunnelCommand::Dns(servers)) => match shared_values.set_dns_servers(servers) {
#[cfg(target_os = "android")]
Ok(true) => self.disconnect(shared_values, AfterDisconnect::Reconnect(0)),
- Ok(_) => SameState(self.into()),
+ Ok(_) => SameState(self),
Err(cause) => self.disconnect(shared_values, AfterDisconnect::Block(cause)),
},
Some(TunnelCommand::BlockWhenDisconnected(block_when_disconnected)) => {
shared_values.block_when_disconnected = block_when_disconnected;
- SameState(self.into())
+ SameState(self)
}
Some(TunnelCommand::IsOffline(is_offline)) => {
shared_values.is_offline = is_offline;
@@ -368,7 +436,7 @@ impl ConnectingState {
AfterDisconnect::Block(ErrorStateCause::IsOffline),
)
} else {
- SameState(self.into())
+ SameState(self)
}
}
Some(TunnelCommand::Connect) => {
@@ -383,18 +451,18 @@ impl ConnectingState {
#[cfg(target_os = "android")]
Some(TunnelCommand::BypassSocket(fd, done_tx)) => {
shared_values.bypass_socket(fd, done_tx);
- SameState(self.into())
+ SameState(self)
}
#[cfg(windows)]
Some(TunnelCommand::SetExcludedApps(result_tx, paths)) => {
shared_values.split_tunnel.set_paths(&paths, result_tx);
- SameState(self.into())
+ SameState(self)
}
}
}
fn handle_tunnel_events(
- mut self,
+ mut self: Box<Self>,
event: Option<(tunnel::TunnelEvent, oneshot::Sender<()>)>,
shared_values: &mut SharedTunnelStateValues,
) -> EventConsequence {
@@ -432,7 +500,7 @@ impl ConnectingState {
&self.tunnel_metadata,
self.allowed_tunnel_traffic.clone(),
) {
- Ok(()) => SameState(self.into()),
+ Ok(()) => SameState(self),
Err(error) => self.disconnect(
shared_values,
AfterDisconnect::Block(ErrorStateCause::SetFirewallPolicyError(error)),
@@ -441,7 +509,11 @@ impl ConnectingState {
}
Some((TunnelEvent::Up(metadata), _)) => NewState(ConnectedState::enter(
shared_values,
- self.into_connected_state_bootstrap(metadata),
+ metadata,
+ self.tunnel_events,
+ self.tunnel_parameters,
+ self.tunnel_close_event,
+ self.tunnel_close_tx,
)),
Some((TunnelEvent::Down, _)) => {
// It is important to reset this before the tunnel device is down,
@@ -450,7 +522,7 @@ impl ConnectingState {
self.allowed_tunnel_traffic = INITIAL_ALLOWED_TUNNEL_TRAFFIC;
self.tunnel_metadata = None;
- SameState(self.into())
+ SameState(self)
}
None => {
// The channel was closed
@@ -532,88 +604,8 @@ fn is_recoverable_routing_error(error: &talpid_routing::Error) -> bool {
}
impl TunnelState for ConnectingState {
- type Bootstrap = u32;
-
- fn enter(
- shared_values: &mut SharedTunnelStateValues,
- retry_attempt: u32,
- ) -> (TunnelStateWrapper, TunnelStateTransition) {
- if shared_values.is_offline {
- // FIXME: Temporary: Nudge route manager to update the default interface
- #[cfg(target_os = "macos")]
- if let Ok(handle) = shared_values.route_manager.handle() {
- log::debug!("Poking route manager to update default routes");
- let _ = handle.refresh_routes();
- }
- return ErrorState::enter(shared_values, ErrorStateCause::IsOffline);
- }
- match shared_values.runtime.block_on(
- shared_values
- .tunnel_parameters_generator
- .generate(retry_attempt),
- ) {
- Err(err) => {
- ErrorState::enter(shared_values, ErrorStateCause::TunnelParameterError(err))
- }
- Ok(tunnel_parameters) => {
- #[cfg(windows)]
- if let Err(error) = shared_values.split_tunnel.set_tunnel_addresses(None) {
- log::error!(
- "{}",
- error.display_chain_with_msg(
- "Failed to reset addresses in split tunnel driver"
- )
- );
-
- return ErrorState::enter(shared_values, ErrorStateCause::SplitTunnelError);
- }
-
- if let Err(error) = Self::set_firewall_policy(
- shared_values,
- &tunnel_parameters,
- &None,
- AllowedTunnelTraffic::None,
- ) {
- ErrorState::enter(
- shared_values,
- ErrorStateCause::SetFirewallPolicyError(error),
- )
- } else {
- #[cfg(target_os = "android")]
- {
- if retry_attempt > 0 && retry_attempt % MAX_ATTEMPTS_WITH_SAME_TUN == 0 {
- if let Err(error) =
- { shared_values.tun_provider.lock().unwrap().create_tun() }
- {
- log::error!(
- "{}",
- error.display_chain_with_msg("Failed to recreate tun device")
- );
- }
- }
- }
-
- let connecting_state = Self::start_tunnel(
- shared_values.runtime.clone(),
- tunnel_parameters,
- &shared_values.log_dir,
- &shared_values.resource_dir,
- shared_values.tun_provider.clone(),
- &shared_values.route_manager,
- retry_attempt,
- );
- let params = connecting_state.tunnel_parameters.clone();
- (
- TunnelStateWrapper::from(connecting_state),
- TunnelStateTransition::Connecting(params.get_tunnel_endpoint()),
- )
- }
- }
- }
- }
-
fn handle_event(
- mut self,
+ mut self: Box<Self>,
runtime: &tokio::runtime::Handle,
commands: &mut TunnelCommandReceiver,
shared_values: &mut SharedTunnelStateValues,
diff --git a/talpid-core/src/tunnel_state_machine/disconnected_state.rs b/talpid-core/src/tunnel_state_machine/disconnected_state.rs
index 92a0586245..5a2cf6fc4d 100644
--- a/talpid-core/src/tunnel_state_machine/disconnected_state.rs
+++ b/talpid-core/src/tunnel_state_machine/disconnected_state.rs
@@ -1,6 +1,6 @@
use super::{
ConnectingState, ErrorState, EventConsequence, SharedTunnelStateValues, TunnelCommand,
- TunnelCommandReceiver, TunnelState, TunnelStateTransition, TunnelStateWrapper,
+ TunnelCommandReceiver, TunnelState, TunnelStateTransition,
};
#[cfg(target_os = "macos")]
use crate::dns;
@@ -13,9 +13,42 @@ use talpid_types::tunnel::ErrorStateCause;
use talpid_types::ErrorExt;
/// No tunnel is running.
-pub struct DisconnectedState;
+pub struct DisconnectedState(());
impl DisconnectedState {
+ pub(super) fn enter(
+ shared_values: &mut SharedTunnelStateValues,
+ should_reset_firewall: bool,
+ ) -> (Box<dyn TunnelState>, TunnelStateTransition) {
+ #[cfg(target_os = "macos")]
+ if shared_values.block_when_disconnected {
+ if let Err(err) = Self::setup_local_dns_config(shared_values) {
+ log::error!(
+ "{}",
+ err.display_chain_with_msg("Failed to start filtering resolver:")
+ );
+ }
+ } else if let Err(error) = shared_values.dns_monitor.reset() {
+ log::error!(
+ "{}",
+ error.display_chain_with_msg("Unable to disable filtering resolver")
+ );
+ }
+
+ #[cfg(windows)]
+ Self::register_split_tunnel_addresses(shared_values, should_reset_firewall);
+ Self::set_firewall_policy(shared_values, should_reset_firewall);
+ #[cfg(target_os = "linux")]
+ shared_values.reset_connectivity_check();
+ #[cfg(target_os = "android")]
+ shared_values.tun_provider.lock().unwrap().close_tun();
+
+ (
+ Box::new(DisconnectedState(())),
+ TunnelStateTransition::Disconnected,
+ )
+ }
+
fn set_firewall_policy(
shared_values: &mut SharedTunnelStateValues,
should_reset_firewall: bool,
@@ -86,43 +119,8 @@ impl DisconnectedState {
}
impl TunnelState for DisconnectedState {
- type Bootstrap = bool;
-
- fn enter(
- shared_values: &mut SharedTunnelStateValues,
- should_reset_firewall: Self::Bootstrap,
- ) -> (TunnelStateWrapper, TunnelStateTransition) {
- #[cfg(target_os = "macos")]
- if shared_values.block_when_disconnected {
- if let Err(err) = Self::setup_local_dns_config(shared_values) {
- log::error!(
- "{}",
- err.display_chain_with_msg("Failed to start filtering resolver:")
- );
- }
- } else if let Err(error) = shared_values.dns_monitor.reset() {
- log::error!(
- "{}",
- error.display_chain_with_msg("Unable to disable filtering resolver")
- );
- }
-
- #[cfg(windows)]
- Self::register_split_tunnel_addresses(shared_values, should_reset_firewall);
- Self::set_firewall_policy(shared_values, should_reset_firewall);
- #[cfg(target_os = "linux")]
- shared_values.reset_connectivity_check();
- #[cfg(target_os = "android")]
- shared_values.tun_provider.lock().unwrap().close_tun();
-
- (
- TunnelStateWrapper::from(DisconnectedState),
- TunnelStateTransition::Disconnected,
- )
- }
-
fn handle_event(
- self,
+ self: Box<Self>,
runtime: &tokio::runtime::Handle,
commands: &mut TunnelCommandReceiver,
shared_values: &mut SharedTunnelStateValues,
@@ -140,7 +138,7 @@ impl TunnelState for DisconnectedState {
Self::set_firewall_policy(shared_values, false);
}
- SameState(self.into())
+ SameState(self)
}
Some(TunnelCommand::AllowEndpoint(endpoint, tx)) => {
if shared_values.allowed_endpoint != endpoint {
@@ -148,7 +146,7 @@ impl TunnelState for DisconnectedState {
Self::set_firewall_policy(shared_values, false);
}
let _ = tx.send(());
- SameState(self.into())
+ SameState(self)
}
Some(TunnelCommand::Dns(servers)) => {
// Same situation as allow LAN above.
@@ -156,7 +154,7 @@ impl TunnelState for DisconnectedState {
.set_dns_servers(servers)
.expect("Failed to reconnect after changing custom DNS servers");
- SameState(self.into())
+ SameState(self)
}
Some(TunnelCommand::BlockWhenDisconnected(block_when_disconnected)) => {
if shared_values.block_when_disconnected != block_when_disconnected {
@@ -180,11 +178,11 @@ impl TunnelState for DisconnectedState {
Self::reset_dns(shared_values);
}
}
- SameState(self.into())
+ SameState(self)
}
Some(TunnelCommand::IsOffline(is_offline)) => {
shared_values.is_offline = is_offline;
- SameState(self.into())
+ SameState(self)
}
Some(TunnelCommand::Connect) => NewState(ConnectingState::enter(shared_values, 0)),
Some(TunnelCommand::Block(reason)) => {
@@ -194,18 +192,18 @@ impl TunnelState for DisconnectedState {
#[cfg(target_os = "android")]
Some(TunnelCommand::BypassSocket(fd, done_tx)) => {
shared_values.bypass_socket(fd, done_tx);
- SameState(self.into())
+ SameState(self)
}
#[cfg(windows)]
Some(TunnelCommand::SetExcludedApps(result_tx, paths)) => {
shared_values.split_tunnel.set_paths(&paths, result_tx);
- SameState(self.into())
+ SameState(self)
}
None => {
Self::reset_dns(shared_values);
Finished
}
- Some(_) => SameState(self.into()),
+ Some(_) => SameState(self),
}
}
}
diff --git a/talpid-core/src/tunnel_state_machine/disconnecting_state.rs b/talpid-core/src/tunnel_state_machine/disconnecting_state.rs
index 954de76034..08248fbac2 100644
--- a/talpid-core/src/tunnel_state_machine/disconnecting_state.rs
+++ b/talpid-core/src/tunnel_state_machine/disconnecting_state.rs
@@ -1,7 +1,7 @@
use super::{
connecting_state::TunnelCloseEvent, ConnectingState, DisconnectedState, ErrorState,
EventConsequence, EventResult, SharedTunnelStateValues, TunnelCommand, TunnelCommandReceiver,
- TunnelState, TunnelStateTransition, TunnelStateWrapper,
+ TunnelState, TunnelStateTransition,
};
use futures::{channel::oneshot, future::FusedFuture, StreamExt};
use talpid_types::tunnel::{ActionAfterDisconnect, ErrorStateCause};
@@ -14,8 +14,25 @@ pub struct DisconnectingState {
}
impl DisconnectingState {
+ pub(super) fn enter(
+ tunnel_close_tx: oneshot::Sender<()>,
+ tunnel_close_event: TunnelCloseEvent,
+ after_disconnect: AfterDisconnect,
+ ) -> (Box<dyn TunnelState>, TunnelStateTransition) {
+ let _ = tunnel_close_tx.send(());
+ let action_after_disconnect = after_disconnect.action();
+
+ (
+ Box::new(DisconnectingState {
+ tunnel_close_event,
+ after_disconnect,
+ }),
+ TunnelStateTransition::Disconnecting(action_after_disconnect),
+ )
+ }
+
fn handle_commands(
- mut self,
+ mut self: Box<Self>,
command: Option<TunnelCommand>,
shared_values: &mut SharedTunnelStateValues,
) -> EventConsequence {
@@ -141,14 +158,14 @@ impl DisconnectingState {
},
};
- EventConsequence::SameState(self.into())
+ EventConsequence::SameState(self)
}
fn after_disconnect(
self,
block_reason: Option<ErrorStateCause>,
shared_values: &mut SharedTunnelStateValues,
- ) -> (TunnelStateWrapper, TunnelStateTransition) {
+ ) -> (Box<dyn TunnelState>, TunnelStateTransition) {
if let Some(reason) = block_reason {
return ErrorState::enter(shared_values, reason);
}
@@ -164,26 +181,8 @@ impl DisconnectingState {
}
impl TunnelState for DisconnectingState {
- type Bootstrap = (oneshot::Sender<()>, TunnelCloseEvent, AfterDisconnect);
-
- fn enter(
- _: &mut SharedTunnelStateValues,
- (tunnel_close_tx, tunnel_close_event, after_disconnect): Self::Bootstrap,
- ) -> (TunnelStateWrapper, TunnelStateTransition) {
- let _ = tunnel_close_tx.send(());
- let action_after_disconnect = after_disconnect.action();
-
- (
- TunnelStateWrapper::from(DisconnectingState {
- tunnel_close_event,
- after_disconnect,
- }),
- TunnelStateTransition::Disconnecting(action_after_disconnect),
- )
- }
-
fn handle_event(
- mut self,
+ mut self: Box<Self>,
runtime: &tokio::runtime::Handle,
commands: &mut TunnelCommandReceiver,
shared_values: &mut SharedTunnelStateValues,
diff --git a/talpid-core/src/tunnel_state_machine/error_state.rs b/talpid-core/src/tunnel_state_machine/error_state.rs
index 7fe95c9f67..11a805f7dc 100644
--- a/talpid-core/src/tunnel_state_machine/error_state.rs
+++ b/talpid-core/src/tunnel_state_machine/error_state.rs
@@ -1,6 +1,6 @@
use super::{
ConnectingState, DisconnectedState, EventConsequence, SharedTunnelStateValues, TunnelCommand,
- TunnelCommandReceiver, TunnelState, TunnelStateTransition, TunnelStateWrapper,
+ TunnelCommandReceiver, TunnelState, TunnelStateTransition,
};
use crate::firewall::FirewallPolicy;
use futures::StreamExt;
@@ -17,6 +17,56 @@ pub struct ErrorState {
}
impl ErrorState {
+ pub(super) fn enter(
+ shared_values: &mut SharedTunnelStateValues,
+ block_reason: ErrorStateCause,
+ ) -> (Box<dyn TunnelState>, TunnelStateTransition) {
+ #[cfg(windows)]
+ if let Err(error) = shared_values.split_tunnel.set_tunnel_addresses(None) {
+ log::error!(
+ "{}",
+ error.display_chain_with_msg(
+ "Failed to register addresses with split tunnel driver"
+ )
+ );
+ }
+
+ #[cfg(target_os = "macos")]
+ if !block_reason.prevents_filtering_resolver() {
+ if let Err(err) = shared_values
+ .dns_monitor
+ .set("lo", &[Ipv4Addr::LOCALHOST.into()])
+ {
+ log::error!(
+ "{}",
+ err.display_chain_with_msg(
+ "Failed to configure system to use filtering resolver"
+ )
+ );
+ return Self::enter(shared_values, ErrorStateCause::SetDnsError);
+ }
+ };
+
+ #[cfg(not(target_os = "android"))]
+ let block_failure = Self::set_firewall_policy(shared_values).err();
+
+ #[cfg(target_os = "android")]
+ let block_failure = if !Self::create_blocking_tun(shared_values) {
+ Some(FirewallPolicyError::Generic)
+ } else {
+ None
+ };
+ (
+ Box::new(ErrorState {
+ block_reason: block_reason.clone(),
+ }),
+ TunnelStateTransition::Error(talpid_tunnel::ErrorState::new(
+ block_reason,
+ block_failure,
+ )),
+ )
+ }
+
fn set_firewall_policy(
shared_values: &mut SharedTunnelStateValues,
) -> Result<(), FirewallPolicyError> {
@@ -78,61 +128,9 @@ impl ErrorState {
}
impl TunnelState for ErrorState {
- type Bootstrap = ErrorStateCause;
-
- fn enter(
- shared_values: &mut SharedTunnelStateValues,
- block_reason: Self::Bootstrap,
- ) -> (TunnelStateWrapper, TunnelStateTransition) {
- #[cfg(windows)]
- if let Err(error) = shared_values.split_tunnel.set_tunnel_addresses(None) {
- log::error!(
- "{}",
- error.display_chain_with_msg(
- "Failed to register addresses with split tunnel driver"
- )
- );
- }
-
- #[cfg(target_os = "macos")]
- if !block_reason.prevents_filtering_resolver() {
- if let Err(err) = shared_values
- .dns_monitor
- .set("lo", &[Ipv4Addr::LOCALHOST.into()])
- {
- log::error!(
- "{}",
- err.display_chain_with_msg(
- "Failed to configure system to use filtering resolver"
- )
- );
- return Self::enter(shared_values, ErrorStateCause::SetDnsError);
- }
- };
-
- #[cfg(not(target_os = "android"))]
- let block_failure = Self::set_firewall_policy(shared_values).err();
-
- #[cfg(target_os = "android")]
- let block_failure = if !Self::create_blocking_tun(shared_values) {
- Some(FirewallPolicyError::Generic)
- } else {
- None
- };
- (
- TunnelStateWrapper::from(ErrorState {
- block_reason: block_reason.clone(),
- }),
- TunnelStateTransition::Error(talpid_tunnel::ErrorState::new(
- block_reason,
- block_failure,
- )),
- )
- }
-
#[cfg_attr(not(target_os = "macos"), allow(unused_mut))]
fn handle_event(
- self,
+ self: Box<Self>,
runtime: &tokio::runtime::Handle,
commands: &mut TunnelCommandReceiver,
shared_values: &mut SharedTunnelStateValues,
@@ -145,7 +143,7 @@ impl TunnelState for ErrorState {
NewState(Self::enter(shared_values, error_state_cause))
} else {
let _ = Self::set_firewall_policy(shared_values);
- SameState(self.into())
+ SameState(self)
}
}
Some(TunnelCommand::AllowEndpoint(endpoint, tx)) => {
@@ -163,18 +161,18 @@ impl TunnelState for ErrorState {
}
}
let _ = tx.send(());
- SameState(self.into())
+ SameState(self)
}
Some(TunnelCommand::Dns(servers)) => {
if let Err(error_state_cause) = shared_values.set_dns_servers(servers) {
NewState(Self::enter(shared_values, error_state_cause))
} else {
- SameState(self.into())
+ SameState(self)
}
}
Some(TunnelCommand::BlockWhenDisconnected(block_when_disconnected)) => {
shared_values.block_when_disconnected = block_when_disconnected;
- SameState(self.into())
+ SameState(self)
}
Some(TunnelCommand::IsOffline(is_offline)) => {
shared_values.is_offline = is_offline;
@@ -182,7 +180,7 @@ impl TunnelState for ErrorState {
Self::reset_dns(shared_values);
NewState(ConnectingState::enter(shared_values, 0))
} else {
- SameState(self.into())
+ SameState(self)
}
}
Some(TunnelCommand::Connect) => {
@@ -202,12 +200,12 @@ impl TunnelState for ErrorState {
#[cfg(target_os = "android")]
Some(TunnelCommand::BypassSocket(fd, done_tx)) => {
shared_values.bypass_socket(fd, done_tx);
- SameState(self.into())
+ SameState(self)
}
#[cfg(windows)]
Some(TunnelCommand::SetExcludedApps(result_tx, paths)) => {
shared_values.split_tunnel.set_paths(&paths, result_tx);
- SameState(self.into())
+ SameState(self)
}
}
}
diff --git a/talpid-core/src/tunnel_state_machine/mod.rs b/talpid-core/src/tunnel_state_machine/mod.rs
index f4c58b849c..12bc4cfc86 100644
--- a/talpid-core/src/tunnel_state_machine/mod.rs
+++ b/talpid-core/src/tunnel_state_machine/mod.rs
@@ -5,7 +5,7 @@ mod disconnecting_state;
mod error_state;
use self::{
- connected_state::{ConnectedState, ConnectedStateBootstrap},
+ connected_state::ConnectedState,
connecting_state::ConnectingState,
disconnected_state::DisconnectedState,
disconnecting_state::{AfterDisconnect, DisconnectingState},
@@ -232,7 +232,7 @@ enum EventResult {
/// to. Every time it successfully advances the state machine a `TunnelStateTransition` is emitted
/// by the stream.
struct TunnelStateMachine {
- current_state: Option<TunnelStateWrapper>,
+ current_state: Option<Box<dyn TunnelState>>,
commands: TunnelCommandReceiver,
shared_values: SharedTunnelStateValues,
}
@@ -389,9 +389,8 @@ impl TunnelStateMachine {
let runtime = self.shared_values.runtime.clone();
- while let Some(state_wrapper) = self.current_state.take() {
- match state_wrapper.handle_event(&runtime, &mut self.commands, &mut self.shared_values)
- {
+ while let Some(state) = self.current_state.take() {
+ match state.handle_event(&runtime, &mut self.commands, &mut self.shared_values) {
NewState((state, transition)) => {
self.current_state = Some(state);
@@ -557,28 +556,16 @@ impl SharedTunnelStateValues {
/// Asynchronous result of an attempt to progress a state.
enum EventConsequence {
/// Transition to a new state.
- NewState((TunnelStateWrapper, TunnelStateTransition)),
+ NewState((Box<dyn TunnelState>, TunnelStateTransition)),
/// An event was received, but it was ignored by the state so no transition is performed.
- SameState(TunnelStateWrapper),
+ SameState(Box<dyn TunnelState>),
/// The state machine has finished its execution.
Finished,
}
/// Trait that contains the method all states should implement to handle an event and advance the
/// state machine.
-trait TunnelState: Into<TunnelStateWrapper> + Sized {
- /// Type representing extra information required for entering the state.
- type Bootstrap;
-
- /// Constructor function.
- ///
- /// This is the state entry point. It attempts to enter the state, and may fail by entering an
- /// error or fallback state instead.
- fn enter(
- shared_values: &mut SharedTunnelStateValues,
- bootstrap: Self::Bootstrap,
- ) -> (TunnelStateWrapper, TunnelStateTransition);
-
+trait TunnelState: Send {
/// Main state function.
///
/// This is state exit point. It consumes itself and returns the next state to advance to when
@@ -590,56 +577,13 @@ trait TunnelState: Into<TunnelStateWrapper> + Sized {
///
/// [`EventConsequence`]: enum.EventConsequence.html
fn handle_event(
- self,
+ self: Box<Self>,
runtime: &tokio::runtime::Handle,
commands: &mut TunnelCommandReceiver,
shared_values: &mut SharedTunnelStateValues,
) -> EventConsequence;
}
-macro_rules! state_wrapper {
- (enum $wrapper_name:ident { $($state_variant:ident($state_type:ident)),* $(,)* }) => {
- /// Valid states of the tunnel.
- ///
- /// All implementations must implement `TunnelState` so that they can handle events and
- /// commands in order to advance the state machine.
- enum $wrapper_name {
- $($state_variant($state_type),)*
- }
-
- $(impl From<$state_type> for $wrapper_name {
- fn from(state: $state_type) -> Self {
- $wrapper_name::$state_variant(state)
- }
- })*
-
- impl $wrapper_name {
- fn handle_event(
- self,
- runtime: &tokio::runtime::Handle,
- commands: &mut TunnelCommandReceiver,
- shared_values: &mut SharedTunnelStateValues,
- ) -> EventConsequence {
- match self {
- $($wrapper_name::$state_variant(state) => {
- state.handle_event(runtime, commands, shared_values)
- })*
- }
- }
- }
- }
-}
-
-state_wrapper! {
- enum TunnelStateWrapper {
- Disconnected(DisconnectedState),
- Connecting(ConnectingState),
- Connected(ConnectedState),
- Disconnecting(DisconnectingState),
- Error(ErrorState),
- }
-}
-
/// Handle used to control the tunnel state machine.
pub struct TunnelStateMachineHandle {
command_tx: Arc<mpsc::UnboundedSender<TunnelCommand>>,