summaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorDavid Lönnhager <david.l@mullvad.net>2021-12-01 20:22:19 +0100
committerDavid Lönnhager <david.l@mullvad.net>2021-12-02 19:12:30 +0100
commit9c67e77ced3d6b592483df1775823f0f7dc55ef0 (patch)
tree79bebc1c7b97a2d6efb4e41af2ea4c83aa08d4f9
parentde6e7094dceb8d31e0962eb809a9603ee4536ab1 (diff)
downloadmullvadvpn-9c67e77ced3d6b592483df1775823f0f7dc55ef0.tar.xz
mullvadvpn-9c67e77ced3d6b592483df1775823f0f7dc55ef0.zip
Make initial allowed endpoint mandatory
-rw-r--r--mullvad-setup/src/main.rs5
-rw-r--r--talpid-core/src/firewall/mod.rs17
-rw-r--r--talpid-core/src/firewall/windows.rs48
-rw-r--r--talpid-core/src/tunnel_state_machine/mod.rs11
4 files changed, 38 insertions, 43 deletions
diff --git a/mullvad-setup/src/main.rs b/mullvad-setup/src/main.rs
index 8b9df4b154..f59cc46a0b 100644
--- a/mullvad-setup/src/main.rs
+++ b/mullvad-setup/src/main.rs
@@ -4,7 +4,7 @@ use mullvad_rpc::MullvadRpcRuntime;
use mullvad_types::version::ParsedAppVersion;
use std::{path::PathBuf, process, time::Duration};
use talpid_core::{
- firewall::{self, Firewall, FirewallArguments},
+ firewall::{self, Firewall, FirewallArguments, InitialFirewallState},
future_retry::{constant_interval, retry_future_n},
};
use talpid_types::ErrorExt;
@@ -158,9 +158,8 @@ async fn reset_firewall() -> Result<(), Error> {
}
let mut firewall = Firewall::new(FirewallArguments {
- initialize_blocked: false,
+ initial_state: InitialFirewallState::None,
allow_lan: true,
- allowed_endpoint: None,
})
.map_err(Error::FirewallError)?;
diff --git a/talpid-core/src/firewall/mod.rs b/talpid-core/src/firewall/mod.rs
index 3a003b4c1e..380953f7c5 100644
--- a/talpid-core/src/firewall/mod.rs
+++ b/talpid-core/src/firewall/mod.rs
@@ -214,12 +214,21 @@ pub struct Firewall {
/// Arguments required when first initializing the firewall.
pub struct FirewallArguments {
- /// Determines whether the firewall should atomically enter the blocked state during init.
- pub initialize_blocked: bool,
+ /// Initial firewall state to enter during init.
+ pub initial_state: InitialFirewallState,
/// This argument is required for the blocked state to configure the firewall correctly.
pub allow_lan: bool,
- /// This argument is required for the blocked state to configure the firewall correctly.
- pub allowed_endpoint: Option<Endpoint>,
+}
+
+/// State to enter during firewall init.
+pub enum InitialFirewallState {
+ /// Do not set any policy.
+ None,
+ /// Atomically enter the blocked state.
+ Blocked {
+ /// Host that should be reachable while in the blocked state.
+ allowed_endpoint: Endpoint,
+ },
}
impl Firewall {
diff --git a/talpid-core/src/firewall/windows.rs b/talpid-core/src/firewall/windows.rs
index eedcf1d1c9..3bb201507c 100644
--- a/talpid-core/src/firewall/windows.rs
+++ b/talpid-core/src/firewall/windows.rs
@@ -3,7 +3,7 @@ use crate::{logging::windows::log_sink, tunnel::TunnelMetadata};
use std::{net::IpAddr, path::Path, ptr};
use self::winfw::*;
-use super::{FirewallArguments, FirewallPolicy, FirewallT};
+use super::{FirewallArguments, FirewallPolicy, FirewallT, InitialFirewallState};
use crate::winnet;
use log::{debug, error, trace};
use talpid_types::{net::Endpoint, tunnel::FirewallPolicyError};
@@ -53,24 +53,19 @@ impl FirewallT for Firewall {
fn new(args: FirewallArguments) -> Result<Self, Self::Error> {
let logging_context = b"WinFw\0".as_ptr();
- if args.initialize_blocked {
+ if let InitialFirewallState::Blocked { allowed_endpoint } = args.initial_state {
let cfg = &WinFwSettings::new(args.allow_lan);
- let allowed_endpoint_ip = args
- .allowed_endpoint
- .map(|endpoint| (endpoint, widestring_ip(endpoint.address.ip())));
- let allowed_endpoint =
- allowed_endpoint_ip
- .as_ref()
- .map(|(endpoint, ip)| WinFwEndpoint {
- ip: ip.as_ptr(),
- port: endpoint.address.port(),
- protocol: WinFwProt::from(endpoint.protocol),
- });
+ let allowed_endpoint_ip = widestring_ip(allowed_endpoint.address.ip());
+ let winfw_allowed_endpoint = WinFwEndpoint {
+ ip: allowed_endpoint_ip.as_ptr(),
+ port: allowed_endpoint.address.port(),
+ protocol: WinFwProt::from(allowed_endpoint.protocol),
+ };
unsafe {
WinFw_InitializeBlocked(
WINFW_TIMEOUT_SECONDS,
&cfg,
- allowed_endpoint.as_ptr(),
+ &winfw_allowed_endpoint,
Some(log_sink),
logging_context,
)
@@ -165,11 +160,11 @@ impl Firewall {
let relay_client = WideCString::from_os_str_truncate(relay_client);
let allowed_endpoint_ip = widestring_ip(allowed_endpoint.address.ip());
- let winfw_allowed_endpoint = Some(WinFwEndpoint {
+ let winfw_allowed_endpoint = WinFwEndpoint {
ip: allowed_endpoint_ip.as_ptr(),
port: allowed_endpoint.address.port(),
protocol: WinFwProt::from(allowed_endpoint.protocol),
- });
+ };
let interface_wstr = tunnel_metadata
.as_ref()
@@ -186,7 +181,7 @@ impl Firewall {
&winfw_relay,
relay_client.as_ptr(),
interface_wstr_ptr,
- winfw_allowed_endpoint.as_ptr(),
+ &winfw_allowed_endpoint,
)
.into_result()
.map_err(Error::ApplyingConnectingPolicy)
@@ -261,33 +256,20 @@ impl Firewall {
trace!("Applying 'blocked' firewall policy");
let allowed_endpoint_ip = widestring_ip(allowed_endpoint.address.ip());
- let winfw_allowed_endpoint = Some(WinFwEndpoint {
+ let winfw_allowed_endpoint = WinFwEndpoint {
ip: allowed_endpoint_ip.as_ptr(),
port: allowed_endpoint.address.port(),
protocol: WinFwProt::from(allowed_endpoint.protocol),
- });
+ };
unsafe {
- WinFw_ApplyPolicyBlocked(winfw_settings, winfw_allowed_endpoint.as_ptr())
+ WinFw_ApplyPolicyBlocked(winfw_settings, &winfw_allowed_endpoint)
.into_result()
.map_err(Error::ApplyingBlockedPolicy)
}
}
}
-trait NullablePointer<T> {
- fn as_ptr(&self) -> *const T;
-}
-
-impl<T> NullablePointer<T> for Option<T> {
- fn as_ptr(&self) -> *const T {
- match self {
- Some(ref value) => value,
- None => ptr::null(),
- }
- }
-}
-
fn widestring_ip(ip: IpAddr) -> WideCString {
WideCString::from_str_truncate(ip.to_string())
}
diff --git a/talpid-core/src/tunnel_state_machine/mod.rs b/talpid-core/src/tunnel_state_machine/mod.rs
index ef80293399..671148fde0 100644
--- a/talpid-core/src/tunnel_state_machine/mod.rs
+++ b/talpid-core/src/tunnel_state_machine/mod.rs
@@ -15,7 +15,7 @@ use self::{
use crate::split_tunnel;
use crate::{
dns::DnsMonitor,
- firewall::{Firewall, FirewallArguments},
+ firewall::{Firewall, FirewallArguments, InitialFirewallState},
mpsc::Sender,
offline,
routing::RouteManager,
@@ -208,9 +208,14 @@ impl TunnelStateMachine {
.map_err(Error::InitSplitTunneling)?;
let args = FirewallArguments {
- initialize_blocked: settings.block_when_disconnected || !settings.reset_firewall,
+ initial_state: if settings.block_when_disconnected || !settings.reset_firewall {
+ InitialFirewallState::Blocked {
+ allowed_endpoint: settings.allowed_endpoint,
+ }
+ } else {
+ InitialFirewallState::None
+ },
allow_lan: settings.allow_lan,
- allowed_endpoint: Some(settings.allowed_endpoint),
};
let firewall = Firewall::new(args).map_err(Error::InitFirewallError)?;