diff options
| author | David Lönnhager <david.l@mullvad.net> | 2021-12-03 18:51:08 +0100 |
|---|---|---|
| committer | David Lönnhager <david.l@mullvad.net> | 2021-12-08 10:56:17 +0100 |
| commit | ce34e99c43fdad50acbb768bc7d3de3d94201666 (patch) | |
| tree | fd3753f485a39e4868c30131e1cb541ded1e0475 /talpid-core/src | |
| parent | f5bdcee250d8d1eb6e243b617760e3228b24fd16 (diff) | |
| download | mullvadvpn-ce34e99c43fdad50acbb768bc7d3de3d94201666.tar.xz mullvadvpn-ce34e99c43fdad50acbb768bc7d3de3d94201666.zip | |
Rewrite construction of WinFwAllowedEndpoint using only safe code
Diffstat (limited to 'talpid-core/src')
| -rw-r--r-- | talpid-core/src/firewall/windows.rs | 102 |
1 files changed, 51 insertions, 51 deletions
diff --git a/talpid-core/src/firewall/windows.rs b/talpid-core/src/firewall/windows.rs index 2ecb425e37..989e453ff7 100644 --- a/talpid-core/src/firewall/windows.rs +++ b/talpid-core/src/firewall/windows.rs @@ -58,12 +58,12 @@ impl FirewallT for Firewall { if let InitialFirewallState::Blocked(allowed_endpoint) = args.initial_state { let cfg = &WinFwSettings::new(args.allow_lan); - let allowed_endpoint = WinFwAllowedEndpoint::from(allowed_endpoint); + let allowed_endpoint = WinFwAllowedEndpointContainer::from(allowed_endpoint); unsafe { WinFw_InitializeBlocked( WINFW_TIMEOUT_SECONDS, &cfg, - &allowed_endpoint, + &allowed_endpoint.as_endpoint(), Some(log_sink), logging_context, ) @@ -95,7 +95,7 @@ impl FirewallT for Firewall { &peer_endpoint, &cfg, &tunnel, - WinFwAllowedEndpoint::from(allowed_endpoint), + &WinFwAllowedEndpointContainer::from(allowed_endpoint).as_endpoint(), &relay_client, ) } @@ -114,7 +114,10 @@ impl FirewallT for Firewall { allowed_endpoint, } => { let cfg = &WinFwSettings::new(allow_lan); - self.set_blocked_state(&cfg, WinFwAllowedEndpoint::from(allowed_endpoint)) + self.set_blocked_state( + &cfg, + &WinFwAllowedEndpointContainer::from(allowed_endpoint).as_endpoint(), + ) } } } @@ -145,7 +148,7 @@ impl Firewall { endpoint: &Endpoint, winfw_settings: &WinFwSettings, tunnel_metadata: &Option<TunnelMetadata>, - allowed_endpoint: WinFwAllowedEndpoint, + allowed_endpoint: &WinFwAllowedEndpoint<'_>, relay_client: &Path, ) -> Result<(), Error> { trace!("Applying 'connecting' firewall policy"); @@ -173,7 +176,7 @@ impl Firewall { &winfw_relay, relay_client.as_ptr(), interface_wstr_ptr, - &allowed_endpoint, + allowed_endpoint, ) .into_result() .map_err(Error::ApplyingConnectingPolicy) @@ -243,11 +246,11 @@ impl Firewall { fn set_blocked_state( &mut self, winfw_settings: &WinFwSettings, - allowed_endpoint: WinFwAllowedEndpoint, + allowed_endpoint: &WinFwAllowedEndpoint<'_>, ) -> Result<(), Error> { trace!("Applying 'blocked' firewall policy"); unsafe { - WinFw_ApplyPolicyBlocked(winfw_settings, &allowed_endpoint) + WinFw_ApplyPolicyBlocked(winfw_settings, allowed_endpoint) .into_result() .map_err(Error::ApplyingBlockedPolicy) } @@ -265,63 +268,60 @@ mod winfw { use libc; use talpid_types::net::TransportProtocol; - #[repr(C)] - pub struct WinFwAllowedEndpoint { - num_clients: u32, - clients: *const *const libc::wchar_t, - endpoint: WinFwEndpoint, + pub struct WinFwAllowedEndpointContainer { + _clients: Box<[WideCString]>, + clients_ptrs: Box<[*const u16]>, + ip: WideCString, + port: u16, + protocol: WinFwProt, } - impl From<AllowedEndpoint> for WinFwAllowedEndpoint { + impl From<AllowedEndpoint> for WinFwAllowedEndpointContainer { fn from(endpoint: AllowedEndpoint) -> Self { - let allowed_endpoint_ip = widestring_ip(endpoint.endpoint.address.ip()); - let clients = endpoint .clients .iter() - .map(|client| WideCString::from_os_str_truncate(client).into_raw() as *const _) - .collect::<Vec<_>>(); + .map(|client| WideCString::from_os_str_truncate(client)) + .collect::<Box<_>>(); + let clients_ptrs = clients + .iter() + .map(|client| client.as_ptr()) + .collect::<Box<_>>(); + let ip = widestring_ip(endpoint.endpoint.address.ip()); - let (clients, num_clients) = vec_into_raw_parts(clients); + WinFwAllowedEndpointContainer { + _clients: clients, + clients_ptrs, + ip, + port: endpoint.endpoint.address.port(), + protocol: WinFwProt::from(endpoint.endpoint.protocol), + } + } + } + impl WinFwAllowedEndpointContainer { + pub fn as_endpoint(&self) -> WinFwAllowedEndpoint<'_> { WinFwAllowedEndpoint { - num_clients: num_clients as u32, - clients, + num_clients: self.clients_ptrs.len() as u32, + clients: self.clients_ptrs.as_ptr(), endpoint: WinFwEndpoint { - ip: allowed_endpoint_ip.into_raw(), - port: endpoint.endpoint.address.port(), - protocol: WinFwProt::from(endpoint.endpoint.protocol), + ip: self.ip.as_ptr(), + port: self.port, + protocol: self.protocol, }, - } - } - } - impl Drop for WinFwAllowedEndpoint { - fn drop(&mut self) { - // Drop paths - let clients: Vec<*mut u16> = - unsafe { vec_from_raw_parts(self.clients as *mut _, self.num_clients as usize) }; - for client in clients { - unsafe { WideCString::from_raw(client) }; + _phantom: std::marker::PhantomData, } - - // Drop address - unsafe { WideCString::from_raw(self.endpoint.ip as *mut _) }; } } - /// Deconstructs a vec into raw parts without capacity. - fn vec_into_raw_parts<T>(v: Vec<T>) -> (*mut T, usize) { - let mut raw = v.into_boxed_slice(); - let ptr = raw.as_mut_ptr(); - let len = raw.len(); - std::mem::forget(raw); - (ptr, len) - } + #[repr(C)] + pub struct WinFwAllowedEndpoint<'a> { + num_clients: u32, + clients: *const *const libc::wchar_t, + endpoint: WinFwEndpoint, - /// Constructs a vec from raw parts without capacity. - unsafe fn vec_from_raw_parts<T>(v: *mut T, len: usize) -> Vec<T> { - Box::from_raw(std::slice::from_raw_parts_mut(v, len)).into_vec() + _phantom: std::marker::PhantomData<&'a WinFwAllowedEndpointContainer>, } #[repr(C)] @@ -413,7 +413,7 @@ mod winfw { pub fn WinFw_InitializeBlocked( timeout: libc::c_uint, settings: &WinFwSettings, - allowed_endpoint: *const WinFwAllowedEndpoint, + allowed_endpoint: *const WinFwAllowedEndpoint<'_>, sink: Option<LogSink>, sink_context: *const u8, ) -> InitializationResult; @@ -427,7 +427,7 @@ mod winfw { relay: &WinFwEndpoint, relayClient: *const libc::wchar_t, tunnelIfaceAlias: *const libc::wchar_t, - allowed_endpoint: *const WinFwAllowedEndpoint, + allowed_endpoint: *const WinFwAllowedEndpoint<'_>, ) -> WinFwPolicyStatus; #[link_name = "WinFw_ApplyPolicyConnected"] @@ -445,7 +445,7 @@ mod winfw { #[link_name = "WinFw_ApplyPolicyBlocked"] pub fn WinFw_ApplyPolicyBlocked( settings: &WinFwSettings, - allowed_endpoint: *const WinFwAllowedEndpoint, + allowed_endpoint: *const WinFwAllowedEndpoint<'_>, ) -> WinFwPolicyStatus; #[link_name = "WinFw_Reset"] |
