summaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorDavid Lönnhager <david.l@mullvad.net>2021-12-03 18:51:08 +0100
committerDavid Lönnhager <david.l@mullvad.net>2021-12-08 10:56:17 +0100
commitce34e99c43fdad50acbb768bc7d3de3d94201666 (patch)
treefd3753f485a39e4868c30131e1cb541ded1e0475
parentf5bdcee250d8d1eb6e243b617760e3228b24fd16 (diff)
downloadmullvadvpn-ce34e99c43fdad50acbb768bc7d3de3d94201666.tar.xz
mullvadvpn-ce34e99c43fdad50acbb768bc7d3de3d94201666.zip
Rewrite construction of WinFwAllowedEndpoint using only safe code
-rw-r--r--talpid-core/src/firewall/windows.rs102
1 files changed, 51 insertions, 51 deletions
diff --git a/talpid-core/src/firewall/windows.rs b/talpid-core/src/firewall/windows.rs
index 2ecb425e37..989e453ff7 100644
--- a/talpid-core/src/firewall/windows.rs
+++ b/talpid-core/src/firewall/windows.rs
@@ -58,12 +58,12 @@ impl FirewallT for Firewall {
if let InitialFirewallState::Blocked(allowed_endpoint) = args.initial_state {
let cfg = &WinFwSettings::new(args.allow_lan);
- let allowed_endpoint = WinFwAllowedEndpoint::from(allowed_endpoint);
+ let allowed_endpoint = WinFwAllowedEndpointContainer::from(allowed_endpoint);
unsafe {
WinFw_InitializeBlocked(
WINFW_TIMEOUT_SECONDS,
&cfg,
- &allowed_endpoint,
+ &allowed_endpoint.as_endpoint(),
Some(log_sink),
logging_context,
)
@@ -95,7 +95,7 @@ impl FirewallT for Firewall {
&peer_endpoint,
&cfg,
&tunnel,
- WinFwAllowedEndpoint::from(allowed_endpoint),
+ &WinFwAllowedEndpointContainer::from(allowed_endpoint).as_endpoint(),
&relay_client,
)
}
@@ -114,7 +114,10 @@ impl FirewallT for Firewall {
allowed_endpoint,
} => {
let cfg = &WinFwSettings::new(allow_lan);
- self.set_blocked_state(&cfg, WinFwAllowedEndpoint::from(allowed_endpoint))
+ self.set_blocked_state(
+ &cfg,
+ &WinFwAllowedEndpointContainer::from(allowed_endpoint).as_endpoint(),
+ )
}
}
}
@@ -145,7 +148,7 @@ impl Firewall {
endpoint: &Endpoint,
winfw_settings: &WinFwSettings,
tunnel_metadata: &Option<TunnelMetadata>,
- allowed_endpoint: WinFwAllowedEndpoint,
+ allowed_endpoint: &WinFwAllowedEndpoint<'_>,
relay_client: &Path,
) -> Result<(), Error> {
trace!("Applying 'connecting' firewall policy");
@@ -173,7 +176,7 @@ impl Firewall {
&winfw_relay,
relay_client.as_ptr(),
interface_wstr_ptr,
- &allowed_endpoint,
+ allowed_endpoint,
)
.into_result()
.map_err(Error::ApplyingConnectingPolicy)
@@ -243,11 +246,11 @@ impl Firewall {
fn set_blocked_state(
&mut self,
winfw_settings: &WinFwSettings,
- allowed_endpoint: WinFwAllowedEndpoint,
+ allowed_endpoint: &WinFwAllowedEndpoint<'_>,
) -> Result<(), Error> {
trace!("Applying 'blocked' firewall policy");
unsafe {
- WinFw_ApplyPolicyBlocked(winfw_settings, &allowed_endpoint)
+ WinFw_ApplyPolicyBlocked(winfw_settings, allowed_endpoint)
.into_result()
.map_err(Error::ApplyingBlockedPolicy)
}
@@ -265,63 +268,60 @@ mod winfw {
use libc;
use talpid_types::net::TransportProtocol;
- #[repr(C)]
- pub struct WinFwAllowedEndpoint {
- num_clients: u32,
- clients: *const *const libc::wchar_t,
- endpoint: WinFwEndpoint,
+ pub struct WinFwAllowedEndpointContainer {
+ _clients: Box<[WideCString]>,
+ clients_ptrs: Box<[*const u16]>,
+ ip: WideCString,
+ port: u16,
+ protocol: WinFwProt,
}
- impl From<AllowedEndpoint> for WinFwAllowedEndpoint {
+ impl From<AllowedEndpoint> for WinFwAllowedEndpointContainer {
fn from(endpoint: AllowedEndpoint) -> Self {
- let allowed_endpoint_ip = widestring_ip(endpoint.endpoint.address.ip());
-
let clients = endpoint
.clients
.iter()
- .map(|client| WideCString::from_os_str_truncate(client).into_raw() as *const _)
- .collect::<Vec<_>>();
+ .map(|client| WideCString::from_os_str_truncate(client))
+ .collect::<Box<_>>();
+ let clients_ptrs = clients
+ .iter()
+ .map(|client| client.as_ptr())
+ .collect::<Box<_>>();
+ let ip = widestring_ip(endpoint.endpoint.address.ip());
- let (clients, num_clients) = vec_into_raw_parts(clients);
+ WinFwAllowedEndpointContainer {
+ _clients: clients,
+ clients_ptrs,
+ ip,
+ port: endpoint.endpoint.address.port(),
+ protocol: WinFwProt::from(endpoint.endpoint.protocol),
+ }
+ }
+ }
+ impl WinFwAllowedEndpointContainer {
+ pub fn as_endpoint(&self) -> WinFwAllowedEndpoint<'_> {
WinFwAllowedEndpoint {
- num_clients: num_clients as u32,
- clients,
+ num_clients: self.clients_ptrs.len() as u32,
+ clients: self.clients_ptrs.as_ptr(),
endpoint: WinFwEndpoint {
- ip: allowed_endpoint_ip.into_raw(),
- port: endpoint.endpoint.address.port(),
- protocol: WinFwProt::from(endpoint.endpoint.protocol),
+ ip: self.ip.as_ptr(),
+ port: self.port,
+ protocol: self.protocol,
},
- }
- }
- }
- impl Drop for WinFwAllowedEndpoint {
- fn drop(&mut self) {
- // Drop paths
- let clients: Vec<*mut u16> =
- unsafe { vec_from_raw_parts(self.clients as *mut _, self.num_clients as usize) };
- for client in clients {
- unsafe { WideCString::from_raw(client) };
+ _phantom: std::marker::PhantomData,
}
-
- // Drop address
- unsafe { WideCString::from_raw(self.endpoint.ip as *mut _) };
}
}
- /// Deconstructs a vec into raw parts without capacity.
- fn vec_into_raw_parts<T>(v: Vec<T>) -> (*mut T, usize) {
- let mut raw = v.into_boxed_slice();
- let ptr = raw.as_mut_ptr();
- let len = raw.len();
- std::mem::forget(raw);
- (ptr, len)
- }
+ #[repr(C)]
+ pub struct WinFwAllowedEndpoint<'a> {
+ num_clients: u32,
+ clients: *const *const libc::wchar_t,
+ endpoint: WinFwEndpoint,
- /// Constructs a vec from raw parts without capacity.
- unsafe fn vec_from_raw_parts<T>(v: *mut T, len: usize) -> Vec<T> {
- Box::from_raw(std::slice::from_raw_parts_mut(v, len)).into_vec()
+ _phantom: std::marker::PhantomData<&'a WinFwAllowedEndpointContainer>,
}
#[repr(C)]
@@ -413,7 +413,7 @@ mod winfw {
pub fn WinFw_InitializeBlocked(
timeout: libc::c_uint,
settings: &WinFwSettings,
- allowed_endpoint: *const WinFwAllowedEndpoint,
+ allowed_endpoint: *const WinFwAllowedEndpoint<'_>,
sink: Option<LogSink>,
sink_context: *const u8,
) -> InitializationResult;
@@ -427,7 +427,7 @@ mod winfw {
relay: &WinFwEndpoint,
relayClient: *const libc::wchar_t,
tunnelIfaceAlias: *const libc::wchar_t,
- allowed_endpoint: *const WinFwAllowedEndpoint,
+ allowed_endpoint: *const WinFwAllowedEndpoint<'_>,
) -> WinFwPolicyStatus;
#[link_name = "WinFw_ApplyPolicyConnected"]
@@ -445,7 +445,7 @@ mod winfw {
#[link_name = "WinFw_ApplyPolicyBlocked"]
pub fn WinFw_ApplyPolicyBlocked(
settings: &WinFwSettings,
- allowed_endpoint: *const WinFwAllowedEndpoint,
+ allowed_endpoint: *const WinFwAllowedEndpoint<'_>,
) -> WinFwPolicyStatus;
#[link_name = "WinFw_Reset"]