summaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorDavid Lönnhager <david.l@mullvad.net>2023-11-17 15:37:36 +0100
committerDavid Lönnhager <david.l@mullvad.net>2023-11-22 16:12:25 +0100
commiteed7a7d93e829198725977e15bb8e8de56eb2ac5 (patch)
tree51b43d28461a45d0950f6769511bb5152a45ab91
parentac4d1eac84cbcd5fb6c47b97a6dc035da7de4d37 (diff)
downloadmullvadvpn-eed7a7d93e829198725977e15bb8e8de56eb2ac5.tar.xz
mullvadvpn-eed7a7d93e829198725977e15bb8e8de56eb2ac5.zip
Complete certain management interface commands when the tunnel state machine has actually handled the request
-rw-r--r--mullvad-daemon/src/lib.rs48
-rw-r--r--talpid-core/src/tunnel_state_machine/connected_state.rs70
-rw-r--r--talpid-core/src/tunnel_state_machine/connecting_state.rs27
-rw-r--r--talpid-core/src/tunnel_state_machine/disconnected_state.rs10
-rw-r--r--talpid-core/src/tunnel_state_machine/disconnecting_state.rs36
-rw-r--r--talpid-core/src/tunnel_state_machine/error_state.rs35
-rw-r--r--talpid-core/src/tunnel_state_machine/mod.rs6
7 files changed, 156 insertions, 76 deletions
diff --git a/mullvad-daemon/src/lib.rs b/mullvad-daemon/src/lib.rs
index b8a21dda07..a540791838 100644
--- a/mullvad-daemon/src/lib.rs
+++ b/mullvad-daemon/src/lib.rs
@@ -1899,9 +1899,15 @@ where
.await
{
Ok(settings_changed) => {
- Self::oneshot_send(tx, Ok(()), "set_allow_lan response");
if settings_changed {
- self.send_tunnel_command(TunnelCommand::AllowLan(allow_lan));
+ self.send_tunnel_command(TunnelCommand::AllowLan(
+ allow_lan,
+ oneshot_map(tx, |tx, ()| {
+ Self::oneshot_send(tx, Ok(()), "set_allow_lan response");
+ }),
+ ));
+ } else {
+ Self::oneshot_send(tx, Ok(()), "set_allow_lan response");
}
}
Err(e) => {
@@ -1946,11 +1952,15 @@ where
.await
{
Ok(settings_changed) => {
- Self::oneshot_send(tx, Ok(()), "set_block_when_disconnected response");
if settings_changed {
self.send_tunnel_command(TunnelCommand::BlockWhenDisconnected(
block_when_disconnected,
+ oneshot_map(tx, |tx, ()| {
+ Self::oneshot_send(tx, Ok(()), "set_block_when_disconnected response");
+ }),
));
+ } else {
+ Self::oneshot_send(tx, Ok(()), "set_block_when_disconnected response");
}
}
Err(e) => {
@@ -2148,12 +2158,18 @@ where
.await
{
Ok(settings_changed) => {
- Self::oneshot_send(tx, Ok(()), "set_dns_options response");
if settings_changed {
let settings = self.settings.to_settings();
let resolvers =
dns::addresses_from_options(&settings.tunnel_options.dns_options);
- self.send_tunnel_command(TunnelCommand::Dns(resolvers));
+ self.send_tunnel_command(TunnelCommand::Dns(
+ resolvers,
+ oneshot_map(tx, |tx, ()| {
+ Self::oneshot_send(tx, Ok(()), "set_dns_options response");
+ }),
+ ));
+ } else {
+ Self::oneshot_send(tx, Ok(()), "set_dns_options response");
}
}
Err(e) => {
@@ -2396,7 +2412,8 @@ where
&& (*self.target_state == TargetState::Secured || self.settings.auto_connect)
{
log::debug!("Blocking firewall during shutdown since system is going down");
- self.send_tunnel_command(TunnelCommand::BlockWhenDisconnected(true));
+ let (tx, _rx) = oneshot::channel();
+ self.send_tunnel_command(TunnelCommand::BlockWhenDisconnected(true, tx));
}
self.state.shutdown(&self.tunnel_state);
@@ -2408,7 +2425,8 @@ where
// without causing the service to be restarted.
if *self.target_state == TargetState::Secured {
- self.send_tunnel_command(TunnelCommand::BlockWhenDisconnected(true));
+ let (tx, _rx) = oneshot::channel();
+ self.send_tunnel_command(TunnelCommand::BlockWhenDisconnected(true, tx));
}
self.target_state.lock();
}
@@ -2569,3 +2587,19 @@ fn new_selector_config(settings: &Settings) -> SelectorConfig {
relay_overrides: settings.relay_overrides.clone(),
}
}
+
+/// Consume a oneshot sender of `T1` and return a sender that takes a different type `T2`. `forwarder` should map `T1` back to `T2` and
+/// send the result back to the original receiver.
+fn oneshot_map<T1: Send + 'static, T2: Send + 'static>(
+ tx: oneshot::Sender<T1>,
+ forwarder: impl Fn(oneshot::Sender<T1>, T2) + Send + 'static,
+) -> oneshot::Sender<T2> {
+ let (new_tx, new_rx) = oneshot::channel();
+ tokio::spawn(async move {
+ match new_rx.await {
+ Ok(result) => forwarder(tx, result),
+ Err(oneshot::Canceled) => (),
+ }
+ });
+ new_tx
+}
diff --git a/talpid-core/src/tunnel_state_machine/connected_state.rs b/talpid-core/src/tunnel_state_machine/connected_state.rs
index 687e941aa7..21375f056b 100644
--- a/talpid-core/src/tunnel_state_machine/connected_state.rs
+++ b/talpid-core/src/tunnel_state_machine/connected_state.rs
@@ -213,8 +213,8 @@ impl ConnectedState {
use self::EventConsequence::*;
match command {
- Some(TunnelCommand::AllowLan(allow_lan)) => {
- if let Err(error_cause) = shared_values.set_allow_lan(allow_lan) {
+ Some(TunnelCommand::AllowLan(allow_lan, complete_tx)) => {
+ let consequence = if let Err(error_cause) = shared_values.set_allow_lan(allow_lan) {
self.disconnect(shared_values, AfterDisconnect::Block(error_cause))
} else {
match self.set_firewall_policy(shared_values) {
@@ -230,43 +230,55 @@ impl ConnectedState {
AfterDisconnect::Block(ErrorStateCause::SetFirewallPolicyError(error)),
),
}
- }
+ };
+ let _ = complete_tx.send(());
+ consequence
}
Some(TunnelCommand::AllowEndpoint(endpoint, tx)) => {
shared_values.allowed_endpoint = endpoint;
let _ = tx.send(());
SameState(self)
}
- Some(TunnelCommand::Dns(servers)) => match shared_values.set_dns_servers(servers) {
- Ok(true) => {
- if let Err(error) = self.set_firewall_policy(shared_values) {
- return self.disconnect(
- shared_values,
- AfterDisconnect::Block(ErrorStateCause::SetFirewallPolicyError(error)),
- );
- }
-
- match self.set_dns(shared_values) {
- #[cfg(target_os = "android")]
- Ok(()) => self.disconnect(shared_values, AfterDisconnect::Reconnect(0)),
- #[cfg(not(target_os = "android"))]
- Ok(()) => SameState(self),
- Err(error) => {
- log::error!("{}", error.display_chain_with_msg("Failed to set DNS"));
- self.disconnect(
+ Some(TunnelCommand::Dns(servers, complete_tx)) => {
+ let consequence = match shared_values.set_dns_servers(servers) {
+ Ok(true) => {
+ if let Err(error) = self.set_firewall_policy(shared_values) {
+ return self.disconnect(
shared_values,
- AfterDisconnect::Block(ErrorStateCause::SetDnsError),
- )
+ AfterDisconnect::Block(ErrorStateCause::SetFirewallPolicyError(
+ error,
+ )),
+ );
+ }
+
+ match self.set_dns(shared_values) {
+ #[cfg(target_os = "android")]
+ Ok(()) => self.disconnect(shared_values, AfterDisconnect::Reconnect(0)),
+ #[cfg(not(target_os = "android"))]
+ Ok(()) => SameState(self),
+ Err(error) => {
+ log::error!(
+ "{}",
+ error.display_chain_with_msg("Failed to set DNS")
+ );
+ self.disconnect(
+ shared_values,
+ AfterDisconnect::Block(ErrorStateCause::SetDnsError),
+ )
+ }
}
}
- }
- Ok(false) => SameState(self),
- Err(error_cause) => {
- self.disconnect(shared_values, AfterDisconnect::Block(error_cause))
- }
- },
- Some(TunnelCommand::BlockWhenDisconnected(block_when_disconnected)) => {
+ Ok(false) => SameState(self),
+ Err(error_cause) => {
+ self.disconnect(shared_values, AfterDisconnect::Block(error_cause))
+ }
+ };
+ let _ = complete_tx.send(());
+ consequence
+ }
+ Some(TunnelCommand::BlockWhenDisconnected(block_when_disconnected, complete_tx)) => {
shared_values.block_when_disconnected = block_when_disconnected;
+ let _ = complete_tx.send(());
SameState(self)
}
Some(TunnelCommand::IsOffline(is_offline)) => {
diff --git a/talpid-core/src/tunnel_state_machine/connecting_state.rs b/talpid-core/src/tunnel_state_machine/connecting_state.rs
index 2a728513ff..d7d93da9d4 100644
--- a/talpid-core/src/tunnel_state_machine/connecting_state.rs
+++ b/talpid-core/src/tunnel_state_machine/connecting_state.rs
@@ -392,12 +392,14 @@ impl ConnectingState {
use self::EventConsequence::*;
match command {
- Some(TunnelCommand::AllowLan(allow_lan)) => {
- if let Err(error_cause) = shared_values.set_allow_lan(allow_lan) {
+ Some(TunnelCommand::AllowLan(allow_lan, complete_tx)) => {
+ let consequence = if let Err(error_cause) = shared_values.set_allow_lan(allow_lan) {
self.disconnect(shared_values, AfterDisconnect::Block(error_cause))
} else {
self.reset_firewall(shared_values)
- }
+ };
+ let _ = complete_tx.send(());
+ consequence
}
Some(TunnelCommand::AllowEndpoint(endpoint, tx)) => {
if shared_values.allowed_endpoint != endpoint {
@@ -418,14 +420,19 @@ impl ConnectingState {
let _ = tx.send(());
SameState(self)
}
- Some(TunnelCommand::Dns(servers)) => match shared_values.set_dns_servers(servers) {
- #[cfg(target_os = "android")]
- Ok(true) => self.disconnect(shared_values, AfterDisconnect::Reconnect(0)),
- Ok(_) => SameState(self),
- Err(cause) => self.disconnect(shared_values, AfterDisconnect::Block(cause)),
- },
- Some(TunnelCommand::BlockWhenDisconnected(block_when_disconnected)) => {
+ Some(TunnelCommand::Dns(servers, complete_tx)) => {
+ let consequence = match shared_values.set_dns_servers(servers) {
+ #[cfg(target_os = "android")]
+ Ok(true) => self.disconnect(shared_values, AfterDisconnect::Reconnect(0)),
+ Ok(_) => SameState(self),
+ Err(cause) => self.disconnect(shared_values, AfterDisconnect::Block(cause)),
+ };
+ let _ = complete_tx.send(());
+ consequence
+ }
+ Some(TunnelCommand::BlockWhenDisconnected(block_when_disconnected, complete_tx)) => {
shared_values.block_when_disconnected = block_when_disconnected;
+ let _ = complete_tx.send(());
SameState(self)
}
Some(TunnelCommand::IsOffline(is_offline)) => {
diff --git a/talpid-core/src/tunnel_state_machine/disconnected_state.rs b/talpid-core/src/tunnel_state_machine/disconnected_state.rs
index 5a2cf6fc4d..d46f06e782 100644
--- a/talpid-core/src/tunnel_state_machine/disconnected_state.rs
+++ b/talpid-core/src/tunnel_state_machine/disconnected_state.rs
@@ -128,7 +128,7 @@ impl TunnelState for DisconnectedState {
use self::EventConsequence::*;
match runtime.block_on(commands.next()) {
- Some(TunnelCommand::AllowLan(allow_lan)) => {
+ Some(TunnelCommand::AllowLan(allow_lan, complete_tx)) => {
if shared_values.allow_lan != allow_lan {
// The only platform that can fail is Android, but Android doesn't support the
// "block when disconnected" option, so the following call never fails.
@@ -138,6 +138,7 @@ impl TunnelState for DisconnectedState {
Self::set_firewall_policy(shared_values, false);
}
+ let _ = complete_tx.send(());
SameState(self)
}
Some(TunnelCommand::AllowEndpoint(endpoint, tx)) => {
@@ -148,15 +149,15 @@ impl TunnelState for DisconnectedState {
let _ = tx.send(());
SameState(self)
}
- Some(TunnelCommand::Dns(servers)) => {
+ Some(TunnelCommand::Dns(servers, complete_tx)) => {
// Same situation as allow LAN above.
shared_values
.set_dns_servers(servers)
.expect("Failed to reconnect after changing custom DNS servers");
-
+ let _ = complete_tx.send(());
SameState(self)
}
- Some(TunnelCommand::BlockWhenDisconnected(block_when_disconnected)) => {
+ Some(TunnelCommand::BlockWhenDisconnected(block_when_disconnected, complete_tx)) => {
if shared_values.block_when_disconnected != block_when_disconnected {
shared_values.block_when_disconnected = block_when_disconnected;
Self::set_firewall_policy(shared_values, true);
@@ -178,6 +179,7 @@ impl TunnelState for DisconnectedState {
Self::reset_dns(shared_values);
}
}
+ let _ = complete_tx.send(());
SameState(self)
}
Some(TunnelCommand::IsOffline(is_offline)) => {
diff --git a/talpid-core/src/tunnel_state_machine/disconnecting_state.rs b/talpid-core/src/tunnel_state_machine/disconnecting_state.rs
index 08248fbac2..185d2f7d0a 100644
--- a/talpid-core/src/tunnel_state_machine/disconnecting_state.rs
+++ b/talpid-core/src/tunnel_state_machine/disconnecting_state.rs
@@ -40,8 +40,9 @@ impl DisconnectingState {
self.after_disconnect = match after_disconnect {
AfterDisconnect::Nothing => match command {
- Some(TunnelCommand::AllowLan(allow_lan)) => {
+ Some(TunnelCommand::AllowLan(allow_lan, complete_tx)) => {
let _ = shared_values.set_allow_lan(allow_lan);
+ let _ = complete_tx.send(());
AfterDisconnect::Nothing
}
Some(TunnelCommand::AllowEndpoint(endpoint, tx)) => {
@@ -49,12 +50,17 @@ impl DisconnectingState {
let _ = tx.send(());
AfterDisconnect::Nothing
}
- Some(TunnelCommand::Dns(servers)) => {
+ Some(TunnelCommand::Dns(servers, complete_tx)) => {
let _ = shared_values.set_dns_servers(servers);
+ let _ = complete_tx.send(());
AfterDisconnect::Nothing
}
- Some(TunnelCommand::BlockWhenDisconnected(block_when_disconnected)) => {
+ Some(TunnelCommand::BlockWhenDisconnected(
+ block_when_disconnected,
+ complete_tx,
+ )) => {
shared_values.block_when_disconnected = block_when_disconnected;
+ let _ = complete_tx.send(());
AfterDisconnect::Nothing
}
Some(TunnelCommand::IsOffline(is_offline)) => {
@@ -76,8 +82,9 @@ impl DisconnectingState {
}
},
AfterDisconnect::Block(reason) => match command {
- Some(TunnelCommand::AllowLan(allow_lan)) => {
+ Some(TunnelCommand::AllowLan(allow_lan, complete_tx)) => {
let _ = shared_values.set_allow_lan(allow_lan);
+ let _ = complete_tx.send(());
AfterDisconnect::Block(reason)
}
Some(TunnelCommand::AllowEndpoint(endpoint, tx)) => {
@@ -85,12 +92,17 @@ impl DisconnectingState {
let _ = tx.send(());
AfterDisconnect::Block(reason)
}
- Some(TunnelCommand::Dns(servers)) => {
+ Some(TunnelCommand::Dns(servers, complete_tx)) => {
let _ = shared_values.set_dns_servers(servers);
+ let _ = complete_tx.send(());
AfterDisconnect::Block(reason)
}
- Some(TunnelCommand::BlockWhenDisconnected(block_when_disconnected)) => {
+ Some(TunnelCommand::BlockWhenDisconnected(
+ block_when_disconnected,
+ complete_tx,
+ )) => {
shared_values.block_when_disconnected = block_when_disconnected;
+ let _ = complete_tx.send(());
AfterDisconnect::Block(reason)
}
Some(TunnelCommand::IsOffline(is_offline)) => {
@@ -117,8 +129,9 @@ impl DisconnectingState {
None => AfterDisconnect::Block(reason),
},
AfterDisconnect::Reconnect(retry_attempt) => match command {
- Some(TunnelCommand::AllowLan(allow_lan)) => {
+ Some(TunnelCommand::AllowLan(allow_lan, complete_tx)) => {
let _ = shared_values.set_allow_lan(allow_lan);
+ let _ = complete_tx.send(());
AfterDisconnect::Reconnect(retry_attempt)
}
Some(TunnelCommand::AllowEndpoint(endpoint, tx)) => {
@@ -126,12 +139,17 @@ impl DisconnectingState {
let _ = tx.send(());
AfterDisconnect::Reconnect(retry_attempt)
}
- Some(TunnelCommand::Dns(servers)) => {
+ Some(TunnelCommand::Dns(servers, complete_tx)) => {
let _ = shared_values.set_dns_servers(servers);
+ let _ = complete_tx.send(());
AfterDisconnect::Reconnect(retry_attempt)
}
- Some(TunnelCommand::BlockWhenDisconnected(block_when_disconnected)) => {
+ Some(TunnelCommand::BlockWhenDisconnected(
+ block_when_disconnected,
+ complete_tx,
+ )) => {
shared_values.block_when_disconnected = block_when_disconnected;
+ let _ = complete_tx.send(());
AfterDisconnect::Reconnect(retry_attempt)
}
Some(TunnelCommand::IsOffline(is_offline)) => {
diff --git a/talpid-core/src/tunnel_state_machine/error_state.rs b/talpid-core/src/tunnel_state_machine/error_state.rs
index 11a805f7dc..2f82cb4cf5 100644
--- a/talpid-core/src/tunnel_state_machine/error_state.rs
+++ b/talpid-core/src/tunnel_state_machine/error_state.rs
@@ -138,13 +138,16 @@ impl TunnelState for ErrorState {
use self::EventConsequence::*;
match runtime.block_on(commands.next()) {
- Some(TunnelCommand::AllowLan(allow_lan)) => {
- if let Err(error_state_cause) = shared_values.set_allow_lan(allow_lan) {
- NewState(Self::enter(shared_values, error_state_cause))
- } else {
- let _ = Self::set_firewall_policy(shared_values);
- SameState(self)
- }
+ Some(TunnelCommand::AllowLan(allow_lan, complete_tx)) => {
+ let consequence =
+ if let Err(error_state_cause) = shared_values.set_allow_lan(allow_lan) {
+ NewState(Self::enter(shared_values, error_state_cause))
+ } else {
+ let _ = Self::set_firewall_policy(shared_values);
+ SameState(self)
+ };
+ let _ = complete_tx.send(());
+ consequence
}
Some(TunnelCommand::AllowEndpoint(endpoint, tx)) => {
if shared_values.allowed_endpoint != endpoint {
@@ -163,15 +166,19 @@ impl TunnelState for ErrorState {
let _ = tx.send(());
SameState(self)
}
- Some(TunnelCommand::Dns(servers)) => {
- if let Err(error_state_cause) = shared_values.set_dns_servers(servers) {
- NewState(Self::enter(shared_values, error_state_cause))
- } else {
- SameState(self)
- }
+ Some(TunnelCommand::Dns(servers, complete_tx)) => {
+ let consequence =
+ if let Err(error_state_cause) = shared_values.set_dns_servers(servers) {
+ NewState(Self::enter(shared_values, error_state_cause))
+ } else {
+ SameState(self)
+ };
+ let _ = complete_tx.send(());
+ consequence
}
- Some(TunnelCommand::BlockWhenDisconnected(block_when_disconnected)) => {
+ Some(TunnelCommand::BlockWhenDisconnected(block_when_disconnected, complete_tx)) => {
shared_values.block_when_disconnected = block_when_disconnected;
+ let _ = complete_tx.send(());
SameState(self)
}
Some(TunnelCommand::IsOffline(is_offline)) => {
diff --git a/talpid-core/src/tunnel_state_machine/mod.rs b/talpid-core/src/tunnel_state_machine/mod.rs
index 12bc4cfc86..5957b2f731 100644
--- a/talpid-core/src/tunnel_state_machine/mod.rs
+++ b/talpid-core/src/tunnel_state_machine/mod.rs
@@ -189,15 +189,15 @@ pub async fn spawn(
/// Representation of external commands for the tunnel state machine.
pub enum TunnelCommand {
/// Enable or disable LAN access in the firewall.
- AllowLan(bool),
+ AllowLan(bool, oneshot::Sender<()>),
/// Endpoint that should never be blocked. `()` is sent to the
/// channel after attempting to set the firewall policy, regardless
/// of whether it succeeded.
AllowEndpoint(AllowedEndpoint, oneshot::Sender<()>),
/// Set DNS servers to use.
- Dns(Option<Vec<IpAddr>>),
+ Dns(Option<Vec<IpAddr>>, oneshot::Sender<()>),
/// Enable or disable the block_when_disconnected feature.
- BlockWhenDisconnected(bool),
+ BlockWhenDisconnected(bool, oneshot::Sender<()>),
/// Notify the state machine of the connectivity of the device.
IsOffline(bool),
/// Open tunnel connection.