diff options
| author | David Lönnhager <david.l@mullvad.net> | 2023-11-17 15:37:36 +0100 |
|---|---|---|
| committer | David Lönnhager <david.l@mullvad.net> | 2023-11-22 16:12:25 +0100 |
| commit | eed7a7d93e829198725977e15bb8e8de56eb2ac5 (patch) | |
| tree | 51b43d28461a45d0950f6769511bb5152a45ab91 /talpid-core/src | |
| parent | ac4d1eac84cbcd5fb6c47b97a6dc035da7de4d37 (diff) | |
| download | mullvadvpn-eed7a7d93e829198725977e15bb8e8de56eb2ac5.tar.xz mullvadvpn-eed7a7d93e829198725977e15bb8e8de56eb2ac5.zip | |
Complete certain management interface commands when the tunnel state machine has actually handled the request
Diffstat (limited to 'talpid-core/src')
6 files changed, 115 insertions, 69 deletions
diff --git a/talpid-core/src/tunnel_state_machine/connected_state.rs b/talpid-core/src/tunnel_state_machine/connected_state.rs index 687e941aa7..21375f056b 100644 --- a/talpid-core/src/tunnel_state_machine/connected_state.rs +++ b/talpid-core/src/tunnel_state_machine/connected_state.rs @@ -213,8 +213,8 @@ impl ConnectedState { use self::EventConsequence::*; match command { - Some(TunnelCommand::AllowLan(allow_lan)) => { - if let Err(error_cause) = shared_values.set_allow_lan(allow_lan) { + Some(TunnelCommand::AllowLan(allow_lan, complete_tx)) => { + let consequence = if let Err(error_cause) = shared_values.set_allow_lan(allow_lan) { self.disconnect(shared_values, AfterDisconnect::Block(error_cause)) } else { match self.set_firewall_policy(shared_values) { @@ -230,43 +230,55 @@ impl ConnectedState { AfterDisconnect::Block(ErrorStateCause::SetFirewallPolicyError(error)), ), } - } + }; + let _ = complete_tx.send(()); + consequence } Some(TunnelCommand::AllowEndpoint(endpoint, tx)) => { shared_values.allowed_endpoint = endpoint; let _ = tx.send(()); SameState(self) } - Some(TunnelCommand::Dns(servers)) => match shared_values.set_dns_servers(servers) { - Ok(true) => { - if let Err(error) = self.set_firewall_policy(shared_values) { - return self.disconnect( - shared_values, - AfterDisconnect::Block(ErrorStateCause::SetFirewallPolicyError(error)), - ); - } - - match self.set_dns(shared_values) { - #[cfg(target_os = "android")] - Ok(()) => self.disconnect(shared_values, AfterDisconnect::Reconnect(0)), - #[cfg(not(target_os = "android"))] - Ok(()) => SameState(self), - Err(error) => { - log::error!("{}", error.display_chain_with_msg("Failed to set DNS")); - self.disconnect( + Some(TunnelCommand::Dns(servers, complete_tx)) => { + let consequence = match shared_values.set_dns_servers(servers) { + Ok(true) => { + if let Err(error) = self.set_firewall_policy(shared_values) { + return self.disconnect( shared_values, - AfterDisconnect::Block(ErrorStateCause::SetDnsError), - ) + AfterDisconnect::Block(ErrorStateCause::SetFirewallPolicyError( + error, + )), + ); + } + + match self.set_dns(shared_values) { + #[cfg(target_os = "android")] + Ok(()) => self.disconnect(shared_values, AfterDisconnect::Reconnect(0)), + #[cfg(not(target_os = "android"))] + Ok(()) => SameState(self), + Err(error) => { + log::error!( + "{}", + error.display_chain_with_msg("Failed to set DNS") + ); + self.disconnect( + shared_values, + AfterDisconnect::Block(ErrorStateCause::SetDnsError), + ) + } } } - } - Ok(false) => SameState(self), - Err(error_cause) => { - self.disconnect(shared_values, AfterDisconnect::Block(error_cause)) - } - }, - Some(TunnelCommand::BlockWhenDisconnected(block_when_disconnected)) => { + Ok(false) => SameState(self), + Err(error_cause) => { + self.disconnect(shared_values, AfterDisconnect::Block(error_cause)) + } + }; + let _ = complete_tx.send(()); + consequence + } + Some(TunnelCommand::BlockWhenDisconnected(block_when_disconnected, complete_tx)) => { shared_values.block_when_disconnected = block_when_disconnected; + let _ = complete_tx.send(()); SameState(self) } Some(TunnelCommand::IsOffline(is_offline)) => { diff --git a/talpid-core/src/tunnel_state_machine/connecting_state.rs b/talpid-core/src/tunnel_state_machine/connecting_state.rs index 2a728513ff..d7d93da9d4 100644 --- a/talpid-core/src/tunnel_state_machine/connecting_state.rs +++ b/talpid-core/src/tunnel_state_machine/connecting_state.rs @@ -392,12 +392,14 @@ impl ConnectingState { use self::EventConsequence::*; match command { - Some(TunnelCommand::AllowLan(allow_lan)) => { - if let Err(error_cause) = shared_values.set_allow_lan(allow_lan) { + Some(TunnelCommand::AllowLan(allow_lan, complete_tx)) => { + let consequence = if let Err(error_cause) = shared_values.set_allow_lan(allow_lan) { self.disconnect(shared_values, AfterDisconnect::Block(error_cause)) } else { self.reset_firewall(shared_values) - } + }; + let _ = complete_tx.send(()); + consequence } Some(TunnelCommand::AllowEndpoint(endpoint, tx)) => { if shared_values.allowed_endpoint != endpoint { @@ -418,14 +420,19 @@ impl ConnectingState { let _ = tx.send(()); SameState(self) } - Some(TunnelCommand::Dns(servers)) => match shared_values.set_dns_servers(servers) { - #[cfg(target_os = "android")] - Ok(true) => self.disconnect(shared_values, AfterDisconnect::Reconnect(0)), - Ok(_) => SameState(self), - Err(cause) => self.disconnect(shared_values, AfterDisconnect::Block(cause)), - }, - Some(TunnelCommand::BlockWhenDisconnected(block_when_disconnected)) => { + Some(TunnelCommand::Dns(servers, complete_tx)) => { + let consequence = match shared_values.set_dns_servers(servers) { + #[cfg(target_os = "android")] + Ok(true) => self.disconnect(shared_values, AfterDisconnect::Reconnect(0)), + Ok(_) => SameState(self), + Err(cause) => self.disconnect(shared_values, AfterDisconnect::Block(cause)), + }; + let _ = complete_tx.send(()); + consequence + } + Some(TunnelCommand::BlockWhenDisconnected(block_when_disconnected, complete_tx)) => { shared_values.block_when_disconnected = block_when_disconnected; + let _ = complete_tx.send(()); SameState(self) } Some(TunnelCommand::IsOffline(is_offline)) => { diff --git a/talpid-core/src/tunnel_state_machine/disconnected_state.rs b/talpid-core/src/tunnel_state_machine/disconnected_state.rs index 5a2cf6fc4d..d46f06e782 100644 --- a/talpid-core/src/tunnel_state_machine/disconnected_state.rs +++ b/talpid-core/src/tunnel_state_machine/disconnected_state.rs @@ -128,7 +128,7 @@ impl TunnelState for DisconnectedState { use self::EventConsequence::*; match runtime.block_on(commands.next()) { - Some(TunnelCommand::AllowLan(allow_lan)) => { + Some(TunnelCommand::AllowLan(allow_lan, complete_tx)) => { if shared_values.allow_lan != allow_lan { // The only platform that can fail is Android, but Android doesn't support the // "block when disconnected" option, so the following call never fails. @@ -138,6 +138,7 @@ impl TunnelState for DisconnectedState { Self::set_firewall_policy(shared_values, false); } + let _ = complete_tx.send(()); SameState(self) } Some(TunnelCommand::AllowEndpoint(endpoint, tx)) => { @@ -148,15 +149,15 @@ impl TunnelState for DisconnectedState { let _ = tx.send(()); SameState(self) } - Some(TunnelCommand::Dns(servers)) => { + Some(TunnelCommand::Dns(servers, complete_tx)) => { // Same situation as allow LAN above. shared_values .set_dns_servers(servers) .expect("Failed to reconnect after changing custom DNS servers"); - + let _ = complete_tx.send(()); SameState(self) } - Some(TunnelCommand::BlockWhenDisconnected(block_when_disconnected)) => { + Some(TunnelCommand::BlockWhenDisconnected(block_when_disconnected, complete_tx)) => { if shared_values.block_when_disconnected != block_when_disconnected { shared_values.block_when_disconnected = block_when_disconnected; Self::set_firewall_policy(shared_values, true); @@ -178,6 +179,7 @@ impl TunnelState for DisconnectedState { Self::reset_dns(shared_values); } } + let _ = complete_tx.send(()); SameState(self) } Some(TunnelCommand::IsOffline(is_offline)) => { diff --git a/talpid-core/src/tunnel_state_machine/disconnecting_state.rs b/talpid-core/src/tunnel_state_machine/disconnecting_state.rs index 08248fbac2..185d2f7d0a 100644 --- a/talpid-core/src/tunnel_state_machine/disconnecting_state.rs +++ b/talpid-core/src/tunnel_state_machine/disconnecting_state.rs @@ -40,8 +40,9 @@ impl DisconnectingState { self.after_disconnect = match after_disconnect { AfterDisconnect::Nothing => match command { - Some(TunnelCommand::AllowLan(allow_lan)) => { + Some(TunnelCommand::AllowLan(allow_lan, complete_tx)) => { let _ = shared_values.set_allow_lan(allow_lan); + let _ = complete_tx.send(()); AfterDisconnect::Nothing } Some(TunnelCommand::AllowEndpoint(endpoint, tx)) => { @@ -49,12 +50,17 @@ impl DisconnectingState { let _ = tx.send(()); AfterDisconnect::Nothing } - Some(TunnelCommand::Dns(servers)) => { + Some(TunnelCommand::Dns(servers, complete_tx)) => { let _ = shared_values.set_dns_servers(servers); + let _ = complete_tx.send(()); AfterDisconnect::Nothing } - Some(TunnelCommand::BlockWhenDisconnected(block_when_disconnected)) => { + Some(TunnelCommand::BlockWhenDisconnected( + block_when_disconnected, + complete_tx, + )) => { shared_values.block_when_disconnected = block_when_disconnected; + let _ = complete_tx.send(()); AfterDisconnect::Nothing } Some(TunnelCommand::IsOffline(is_offline)) => { @@ -76,8 +82,9 @@ impl DisconnectingState { } }, AfterDisconnect::Block(reason) => match command { - Some(TunnelCommand::AllowLan(allow_lan)) => { + Some(TunnelCommand::AllowLan(allow_lan, complete_tx)) => { let _ = shared_values.set_allow_lan(allow_lan); + let _ = complete_tx.send(()); AfterDisconnect::Block(reason) } Some(TunnelCommand::AllowEndpoint(endpoint, tx)) => { @@ -85,12 +92,17 @@ impl DisconnectingState { let _ = tx.send(()); AfterDisconnect::Block(reason) } - Some(TunnelCommand::Dns(servers)) => { + Some(TunnelCommand::Dns(servers, complete_tx)) => { let _ = shared_values.set_dns_servers(servers); + let _ = complete_tx.send(()); AfterDisconnect::Block(reason) } - Some(TunnelCommand::BlockWhenDisconnected(block_when_disconnected)) => { + Some(TunnelCommand::BlockWhenDisconnected( + block_when_disconnected, + complete_tx, + )) => { shared_values.block_when_disconnected = block_when_disconnected; + let _ = complete_tx.send(()); AfterDisconnect::Block(reason) } Some(TunnelCommand::IsOffline(is_offline)) => { @@ -117,8 +129,9 @@ impl DisconnectingState { None => AfterDisconnect::Block(reason), }, AfterDisconnect::Reconnect(retry_attempt) => match command { - Some(TunnelCommand::AllowLan(allow_lan)) => { + Some(TunnelCommand::AllowLan(allow_lan, complete_tx)) => { let _ = shared_values.set_allow_lan(allow_lan); + let _ = complete_tx.send(()); AfterDisconnect::Reconnect(retry_attempt) } Some(TunnelCommand::AllowEndpoint(endpoint, tx)) => { @@ -126,12 +139,17 @@ impl DisconnectingState { let _ = tx.send(()); AfterDisconnect::Reconnect(retry_attempt) } - Some(TunnelCommand::Dns(servers)) => { + Some(TunnelCommand::Dns(servers, complete_tx)) => { let _ = shared_values.set_dns_servers(servers); + let _ = complete_tx.send(()); AfterDisconnect::Reconnect(retry_attempt) } - Some(TunnelCommand::BlockWhenDisconnected(block_when_disconnected)) => { + Some(TunnelCommand::BlockWhenDisconnected( + block_when_disconnected, + complete_tx, + )) => { shared_values.block_when_disconnected = block_when_disconnected; + let _ = complete_tx.send(()); AfterDisconnect::Reconnect(retry_attempt) } Some(TunnelCommand::IsOffline(is_offline)) => { diff --git a/talpid-core/src/tunnel_state_machine/error_state.rs b/talpid-core/src/tunnel_state_machine/error_state.rs index 11a805f7dc..2f82cb4cf5 100644 --- a/talpid-core/src/tunnel_state_machine/error_state.rs +++ b/talpid-core/src/tunnel_state_machine/error_state.rs @@ -138,13 +138,16 @@ impl TunnelState for ErrorState { use self::EventConsequence::*; match runtime.block_on(commands.next()) { - Some(TunnelCommand::AllowLan(allow_lan)) => { - if let Err(error_state_cause) = shared_values.set_allow_lan(allow_lan) { - NewState(Self::enter(shared_values, error_state_cause)) - } else { - let _ = Self::set_firewall_policy(shared_values); - SameState(self) - } + Some(TunnelCommand::AllowLan(allow_lan, complete_tx)) => { + let consequence = + if let Err(error_state_cause) = shared_values.set_allow_lan(allow_lan) { + NewState(Self::enter(shared_values, error_state_cause)) + } else { + let _ = Self::set_firewall_policy(shared_values); + SameState(self) + }; + let _ = complete_tx.send(()); + consequence } Some(TunnelCommand::AllowEndpoint(endpoint, tx)) => { if shared_values.allowed_endpoint != endpoint { @@ -163,15 +166,19 @@ impl TunnelState for ErrorState { let _ = tx.send(()); SameState(self) } - Some(TunnelCommand::Dns(servers)) => { - if let Err(error_state_cause) = shared_values.set_dns_servers(servers) { - NewState(Self::enter(shared_values, error_state_cause)) - } else { - SameState(self) - } + Some(TunnelCommand::Dns(servers, complete_tx)) => { + let consequence = + if let Err(error_state_cause) = shared_values.set_dns_servers(servers) { + NewState(Self::enter(shared_values, error_state_cause)) + } else { + SameState(self) + }; + let _ = complete_tx.send(()); + consequence } - Some(TunnelCommand::BlockWhenDisconnected(block_when_disconnected)) => { + Some(TunnelCommand::BlockWhenDisconnected(block_when_disconnected, complete_tx)) => { shared_values.block_when_disconnected = block_when_disconnected; + let _ = complete_tx.send(()); SameState(self) } Some(TunnelCommand::IsOffline(is_offline)) => { diff --git a/talpid-core/src/tunnel_state_machine/mod.rs b/talpid-core/src/tunnel_state_machine/mod.rs index 12bc4cfc86..5957b2f731 100644 --- a/talpid-core/src/tunnel_state_machine/mod.rs +++ b/talpid-core/src/tunnel_state_machine/mod.rs @@ -189,15 +189,15 @@ pub async fn spawn( /// Representation of external commands for the tunnel state machine. pub enum TunnelCommand { /// Enable or disable LAN access in the firewall. - AllowLan(bool), + AllowLan(bool, oneshot::Sender<()>), /// Endpoint that should never be blocked. `()` is sent to the /// channel after attempting to set the firewall policy, regardless /// of whether it succeeded. AllowEndpoint(AllowedEndpoint, oneshot::Sender<()>), /// Set DNS servers to use. - Dns(Option<Vec<IpAddr>>), + Dns(Option<Vec<IpAddr>>, oneshot::Sender<()>), /// Enable or disable the block_when_disconnected feature. - BlockWhenDisconnected(bool), + BlockWhenDisconnected(bool, oneshot::Sender<()>), /// Notify the state machine of the connectivity of the device. IsOffline(bool), /// Open tunnel connection. |
