summaryrefslogtreecommitdiffhomepage
path: root/talpid-core/src
diff options
context:
space:
mode:
authorDavid Lönnhager <david.l@mullvad.net>2023-11-17 15:37:36 +0100
committerDavid Lönnhager <david.l@mullvad.net>2023-11-22 16:12:25 +0100
commiteed7a7d93e829198725977e15bb8e8de56eb2ac5 (patch)
tree51b43d28461a45d0950f6769511bb5152a45ab91 /talpid-core/src
parentac4d1eac84cbcd5fb6c47b97a6dc035da7de4d37 (diff)
downloadmullvadvpn-eed7a7d93e829198725977e15bb8e8de56eb2ac5.tar.xz
mullvadvpn-eed7a7d93e829198725977e15bb8e8de56eb2ac5.zip
Complete certain management interface commands when the tunnel state machine has actually handled the request
Diffstat (limited to 'talpid-core/src')
-rw-r--r--talpid-core/src/tunnel_state_machine/connected_state.rs70
-rw-r--r--talpid-core/src/tunnel_state_machine/connecting_state.rs27
-rw-r--r--talpid-core/src/tunnel_state_machine/disconnected_state.rs10
-rw-r--r--talpid-core/src/tunnel_state_machine/disconnecting_state.rs36
-rw-r--r--talpid-core/src/tunnel_state_machine/error_state.rs35
-rw-r--r--talpid-core/src/tunnel_state_machine/mod.rs6
6 files changed, 115 insertions, 69 deletions
diff --git a/talpid-core/src/tunnel_state_machine/connected_state.rs b/talpid-core/src/tunnel_state_machine/connected_state.rs
index 687e941aa7..21375f056b 100644
--- a/talpid-core/src/tunnel_state_machine/connected_state.rs
+++ b/talpid-core/src/tunnel_state_machine/connected_state.rs
@@ -213,8 +213,8 @@ impl ConnectedState {
use self::EventConsequence::*;
match command {
- Some(TunnelCommand::AllowLan(allow_lan)) => {
- if let Err(error_cause) = shared_values.set_allow_lan(allow_lan) {
+ Some(TunnelCommand::AllowLan(allow_lan, complete_tx)) => {
+ let consequence = if let Err(error_cause) = shared_values.set_allow_lan(allow_lan) {
self.disconnect(shared_values, AfterDisconnect::Block(error_cause))
} else {
match self.set_firewall_policy(shared_values) {
@@ -230,43 +230,55 @@ impl ConnectedState {
AfterDisconnect::Block(ErrorStateCause::SetFirewallPolicyError(error)),
),
}
- }
+ };
+ let _ = complete_tx.send(());
+ consequence
}
Some(TunnelCommand::AllowEndpoint(endpoint, tx)) => {
shared_values.allowed_endpoint = endpoint;
let _ = tx.send(());
SameState(self)
}
- Some(TunnelCommand::Dns(servers)) => match shared_values.set_dns_servers(servers) {
- Ok(true) => {
- if let Err(error) = self.set_firewall_policy(shared_values) {
- return self.disconnect(
- shared_values,
- AfterDisconnect::Block(ErrorStateCause::SetFirewallPolicyError(error)),
- );
- }
-
- match self.set_dns(shared_values) {
- #[cfg(target_os = "android")]
- Ok(()) => self.disconnect(shared_values, AfterDisconnect::Reconnect(0)),
- #[cfg(not(target_os = "android"))]
- Ok(()) => SameState(self),
- Err(error) => {
- log::error!("{}", error.display_chain_with_msg("Failed to set DNS"));
- self.disconnect(
+ Some(TunnelCommand::Dns(servers, complete_tx)) => {
+ let consequence = match shared_values.set_dns_servers(servers) {
+ Ok(true) => {
+ if let Err(error) = self.set_firewall_policy(shared_values) {
+ return self.disconnect(
shared_values,
- AfterDisconnect::Block(ErrorStateCause::SetDnsError),
- )
+ AfterDisconnect::Block(ErrorStateCause::SetFirewallPolicyError(
+ error,
+ )),
+ );
+ }
+
+ match self.set_dns(shared_values) {
+ #[cfg(target_os = "android")]
+ Ok(()) => self.disconnect(shared_values, AfterDisconnect::Reconnect(0)),
+ #[cfg(not(target_os = "android"))]
+ Ok(()) => SameState(self),
+ Err(error) => {
+ log::error!(
+ "{}",
+ error.display_chain_with_msg("Failed to set DNS")
+ );
+ self.disconnect(
+ shared_values,
+ AfterDisconnect::Block(ErrorStateCause::SetDnsError),
+ )
+ }
}
}
- }
- Ok(false) => SameState(self),
- Err(error_cause) => {
- self.disconnect(shared_values, AfterDisconnect::Block(error_cause))
- }
- },
- Some(TunnelCommand::BlockWhenDisconnected(block_when_disconnected)) => {
+ Ok(false) => SameState(self),
+ Err(error_cause) => {
+ self.disconnect(shared_values, AfterDisconnect::Block(error_cause))
+ }
+ };
+ let _ = complete_tx.send(());
+ consequence
+ }
+ Some(TunnelCommand::BlockWhenDisconnected(block_when_disconnected, complete_tx)) => {
shared_values.block_when_disconnected = block_when_disconnected;
+ let _ = complete_tx.send(());
SameState(self)
}
Some(TunnelCommand::IsOffline(is_offline)) => {
diff --git a/talpid-core/src/tunnel_state_machine/connecting_state.rs b/talpid-core/src/tunnel_state_machine/connecting_state.rs
index 2a728513ff..d7d93da9d4 100644
--- a/talpid-core/src/tunnel_state_machine/connecting_state.rs
+++ b/talpid-core/src/tunnel_state_machine/connecting_state.rs
@@ -392,12 +392,14 @@ impl ConnectingState {
use self::EventConsequence::*;
match command {
- Some(TunnelCommand::AllowLan(allow_lan)) => {
- if let Err(error_cause) = shared_values.set_allow_lan(allow_lan) {
+ Some(TunnelCommand::AllowLan(allow_lan, complete_tx)) => {
+ let consequence = if let Err(error_cause) = shared_values.set_allow_lan(allow_lan) {
self.disconnect(shared_values, AfterDisconnect::Block(error_cause))
} else {
self.reset_firewall(shared_values)
- }
+ };
+ let _ = complete_tx.send(());
+ consequence
}
Some(TunnelCommand::AllowEndpoint(endpoint, tx)) => {
if shared_values.allowed_endpoint != endpoint {
@@ -418,14 +420,19 @@ impl ConnectingState {
let _ = tx.send(());
SameState(self)
}
- Some(TunnelCommand::Dns(servers)) => match shared_values.set_dns_servers(servers) {
- #[cfg(target_os = "android")]
- Ok(true) => self.disconnect(shared_values, AfterDisconnect::Reconnect(0)),
- Ok(_) => SameState(self),
- Err(cause) => self.disconnect(shared_values, AfterDisconnect::Block(cause)),
- },
- Some(TunnelCommand::BlockWhenDisconnected(block_when_disconnected)) => {
+ Some(TunnelCommand::Dns(servers, complete_tx)) => {
+ let consequence = match shared_values.set_dns_servers(servers) {
+ #[cfg(target_os = "android")]
+ Ok(true) => self.disconnect(shared_values, AfterDisconnect::Reconnect(0)),
+ Ok(_) => SameState(self),
+ Err(cause) => self.disconnect(shared_values, AfterDisconnect::Block(cause)),
+ };
+ let _ = complete_tx.send(());
+ consequence
+ }
+ Some(TunnelCommand::BlockWhenDisconnected(block_when_disconnected, complete_tx)) => {
shared_values.block_when_disconnected = block_when_disconnected;
+ let _ = complete_tx.send(());
SameState(self)
}
Some(TunnelCommand::IsOffline(is_offline)) => {
diff --git a/talpid-core/src/tunnel_state_machine/disconnected_state.rs b/talpid-core/src/tunnel_state_machine/disconnected_state.rs
index 5a2cf6fc4d..d46f06e782 100644
--- a/talpid-core/src/tunnel_state_machine/disconnected_state.rs
+++ b/talpid-core/src/tunnel_state_machine/disconnected_state.rs
@@ -128,7 +128,7 @@ impl TunnelState for DisconnectedState {
use self::EventConsequence::*;
match runtime.block_on(commands.next()) {
- Some(TunnelCommand::AllowLan(allow_lan)) => {
+ Some(TunnelCommand::AllowLan(allow_lan, complete_tx)) => {
if shared_values.allow_lan != allow_lan {
// The only platform that can fail is Android, but Android doesn't support the
// "block when disconnected" option, so the following call never fails.
@@ -138,6 +138,7 @@ impl TunnelState for DisconnectedState {
Self::set_firewall_policy(shared_values, false);
}
+ let _ = complete_tx.send(());
SameState(self)
}
Some(TunnelCommand::AllowEndpoint(endpoint, tx)) => {
@@ -148,15 +149,15 @@ impl TunnelState for DisconnectedState {
let _ = tx.send(());
SameState(self)
}
- Some(TunnelCommand::Dns(servers)) => {
+ Some(TunnelCommand::Dns(servers, complete_tx)) => {
// Same situation as allow LAN above.
shared_values
.set_dns_servers(servers)
.expect("Failed to reconnect after changing custom DNS servers");
-
+ let _ = complete_tx.send(());
SameState(self)
}
- Some(TunnelCommand::BlockWhenDisconnected(block_when_disconnected)) => {
+ Some(TunnelCommand::BlockWhenDisconnected(block_when_disconnected, complete_tx)) => {
if shared_values.block_when_disconnected != block_when_disconnected {
shared_values.block_when_disconnected = block_when_disconnected;
Self::set_firewall_policy(shared_values, true);
@@ -178,6 +179,7 @@ impl TunnelState for DisconnectedState {
Self::reset_dns(shared_values);
}
}
+ let _ = complete_tx.send(());
SameState(self)
}
Some(TunnelCommand::IsOffline(is_offline)) => {
diff --git a/talpid-core/src/tunnel_state_machine/disconnecting_state.rs b/talpid-core/src/tunnel_state_machine/disconnecting_state.rs
index 08248fbac2..185d2f7d0a 100644
--- a/talpid-core/src/tunnel_state_machine/disconnecting_state.rs
+++ b/talpid-core/src/tunnel_state_machine/disconnecting_state.rs
@@ -40,8 +40,9 @@ impl DisconnectingState {
self.after_disconnect = match after_disconnect {
AfterDisconnect::Nothing => match command {
- Some(TunnelCommand::AllowLan(allow_lan)) => {
+ Some(TunnelCommand::AllowLan(allow_lan, complete_tx)) => {
let _ = shared_values.set_allow_lan(allow_lan);
+ let _ = complete_tx.send(());
AfterDisconnect::Nothing
}
Some(TunnelCommand::AllowEndpoint(endpoint, tx)) => {
@@ -49,12 +50,17 @@ impl DisconnectingState {
let _ = tx.send(());
AfterDisconnect::Nothing
}
- Some(TunnelCommand::Dns(servers)) => {
+ Some(TunnelCommand::Dns(servers, complete_tx)) => {
let _ = shared_values.set_dns_servers(servers);
+ let _ = complete_tx.send(());
AfterDisconnect::Nothing
}
- Some(TunnelCommand::BlockWhenDisconnected(block_when_disconnected)) => {
+ Some(TunnelCommand::BlockWhenDisconnected(
+ block_when_disconnected,
+ complete_tx,
+ )) => {
shared_values.block_when_disconnected = block_when_disconnected;
+ let _ = complete_tx.send(());
AfterDisconnect::Nothing
}
Some(TunnelCommand::IsOffline(is_offline)) => {
@@ -76,8 +82,9 @@ impl DisconnectingState {
}
},
AfterDisconnect::Block(reason) => match command {
- Some(TunnelCommand::AllowLan(allow_lan)) => {
+ Some(TunnelCommand::AllowLan(allow_lan, complete_tx)) => {
let _ = shared_values.set_allow_lan(allow_lan);
+ let _ = complete_tx.send(());
AfterDisconnect::Block(reason)
}
Some(TunnelCommand::AllowEndpoint(endpoint, tx)) => {
@@ -85,12 +92,17 @@ impl DisconnectingState {
let _ = tx.send(());
AfterDisconnect::Block(reason)
}
- Some(TunnelCommand::Dns(servers)) => {
+ Some(TunnelCommand::Dns(servers, complete_tx)) => {
let _ = shared_values.set_dns_servers(servers);
+ let _ = complete_tx.send(());
AfterDisconnect::Block(reason)
}
- Some(TunnelCommand::BlockWhenDisconnected(block_when_disconnected)) => {
+ Some(TunnelCommand::BlockWhenDisconnected(
+ block_when_disconnected,
+ complete_tx,
+ )) => {
shared_values.block_when_disconnected = block_when_disconnected;
+ let _ = complete_tx.send(());
AfterDisconnect::Block(reason)
}
Some(TunnelCommand::IsOffline(is_offline)) => {
@@ -117,8 +129,9 @@ impl DisconnectingState {
None => AfterDisconnect::Block(reason),
},
AfterDisconnect::Reconnect(retry_attempt) => match command {
- Some(TunnelCommand::AllowLan(allow_lan)) => {
+ Some(TunnelCommand::AllowLan(allow_lan, complete_tx)) => {
let _ = shared_values.set_allow_lan(allow_lan);
+ let _ = complete_tx.send(());
AfterDisconnect::Reconnect(retry_attempt)
}
Some(TunnelCommand::AllowEndpoint(endpoint, tx)) => {
@@ -126,12 +139,17 @@ impl DisconnectingState {
let _ = tx.send(());
AfterDisconnect::Reconnect(retry_attempt)
}
- Some(TunnelCommand::Dns(servers)) => {
+ Some(TunnelCommand::Dns(servers, complete_tx)) => {
let _ = shared_values.set_dns_servers(servers);
+ let _ = complete_tx.send(());
AfterDisconnect::Reconnect(retry_attempt)
}
- Some(TunnelCommand::BlockWhenDisconnected(block_when_disconnected)) => {
+ Some(TunnelCommand::BlockWhenDisconnected(
+ block_when_disconnected,
+ complete_tx,
+ )) => {
shared_values.block_when_disconnected = block_when_disconnected;
+ let _ = complete_tx.send(());
AfterDisconnect::Reconnect(retry_attempt)
}
Some(TunnelCommand::IsOffline(is_offline)) => {
diff --git a/talpid-core/src/tunnel_state_machine/error_state.rs b/talpid-core/src/tunnel_state_machine/error_state.rs
index 11a805f7dc..2f82cb4cf5 100644
--- a/talpid-core/src/tunnel_state_machine/error_state.rs
+++ b/talpid-core/src/tunnel_state_machine/error_state.rs
@@ -138,13 +138,16 @@ impl TunnelState for ErrorState {
use self::EventConsequence::*;
match runtime.block_on(commands.next()) {
- Some(TunnelCommand::AllowLan(allow_lan)) => {
- if let Err(error_state_cause) = shared_values.set_allow_lan(allow_lan) {
- NewState(Self::enter(shared_values, error_state_cause))
- } else {
- let _ = Self::set_firewall_policy(shared_values);
- SameState(self)
- }
+ Some(TunnelCommand::AllowLan(allow_lan, complete_tx)) => {
+ let consequence =
+ if let Err(error_state_cause) = shared_values.set_allow_lan(allow_lan) {
+ NewState(Self::enter(shared_values, error_state_cause))
+ } else {
+ let _ = Self::set_firewall_policy(shared_values);
+ SameState(self)
+ };
+ let _ = complete_tx.send(());
+ consequence
}
Some(TunnelCommand::AllowEndpoint(endpoint, tx)) => {
if shared_values.allowed_endpoint != endpoint {
@@ -163,15 +166,19 @@ impl TunnelState for ErrorState {
let _ = tx.send(());
SameState(self)
}
- Some(TunnelCommand::Dns(servers)) => {
- if let Err(error_state_cause) = shared_values.set_dns_servers(servers) {
- NewState(Self::enter(shared_values, error_state_cause))
- } else {
- SameState(self)
- }
+ Some(TunnelCommand::Dns(servers, complete_tx)) => {
+ let consequence =
+ if let Err(error_state_cause) = shared_values.set_dns_servers(servers) {
+ NewState(Self::enter(shared_values, error_state_cause))
+ } else {
+ SameState(self)
+ };
+ let _ = complete_tx.send(());
+ consequence
}
- Some(TunnelCommand::BlockWhenDisconnected(block_when_disconnected)) => {
+ Some(TunnelCommand::BlockWhenDisconnected(block_when_disconnected, complete_tx)) => {
shared_values.block_when_disconnected = block_when_disconnected;
+ let _ = complete_tx.send(());
SameState(self)
}
Some(TunnelCommand::IsOffline(is_offline)) => {
diff --git a/talpid-core/src/tunnel_state_machine/mod.rs b/talpid-core/src/tunnel_state_machine/mod.rs
index 12bc4cfc86..5957b2f731 100644
--- a/talpid-core/src/tunnel_state_machine/mod.rs
+++ b/talpid-core/src/tunnel_state_machine/mod.rs
@@ -189,15 +189,15 @@ pub async fn spawn(
/// Representation of external commands for the tunnel state machine.
pub enum TunnelCommand {
/// Enable or disable LAN access in the firewall.
- AllowLan(bool),
+ AllowLan(bool, oneshot::Sender<()>),
/// Endpoint that should never be blocked. `()` is sent to the
/// channel after attempting to set the firewall policy, regardless
/// of whether it succeeded.
AllowEndpoint(AllowedEndpoint, oneshot::Sender<()>),
/// Set DNS servers to use.
- Dns(Option<Vec<IpAddr>>),
+ Dns(Option<Vec<IpAddr>>, oneshot::Sender<()>),
/// Enable or disable the block_when_disconnected feature.
- BlockWhenDisconnected(bool),
+ BlockWhenDisconnected(bool, oneshot::Sender<()>),
/// Notify the state machine of the connectivity of the device.
IsOffline(bool),
/// Open tunnel connection.