diff options
| author | Joakim Hulthe <joakim.hulthe@mullvad.net> | 2025-06-12 14:13:58 +0200 |
|---|---|---|
| committer | Joakim Hulthe <joakim.hulthe@mullvad.net> | 2025-06-12 15:41:13 +0200 |
| commit | b39d040d9f28cc06265f02c2bcf3910eb43d2ab9 (patch) | |
| tree | db801eca5f18defed3246cd0dfe1c522cd7b3abe | |
| parent | f0efcc68cfc310f6965443c28fbbf59455187165 (diff) | |
| download | mullvadvpn-b39d040d9f28cc06265f02c2bcf3910eb43d2ab9.tar.xz mullvadvpn-b39d040d9f28cc06265f02c2bcf3910eb43d2ab9.zip | |
Fix tun file descriptor ownership
We accidentally borrowed the file descriptor when we should have moved
it. This commit adds more `OwnedFd` and friends to help handle
ownership correctly.
Signed-off-by: Joakim Hulthe <joakim.hulthe@mullvad.net>
| -rw-r--r-- | talpid-core/src/tunnel_state_machine/mod.rs | 2 | ||||
| -rw-r--r-- | talpid-tunnel/src/tun_provider/android/mod.rs | 8 | ||||
| -rw-r--r-- | talpid-wireguard/src/boringtun/mod.rs | 4 | ||||
| -rw-r--r-- | talpid-wireguard/src/obfuscation.rs | 2 | ||||
| -rw-r--r-- | talpid-wireguard/src/wireguard_go/mod.rs | 16 | ||||
| -rw-r--r-- | wireguard-go-rs/src/lib.rs | 65 |
6 files changed, 54 insertions, 43 deletions
diff --git a/talpid-core/src/tunnel_state_machine/mod.rs b/talpid-core/src/tunnel_state_machine/mod.rs index 207291360c..059edd417d 100644 --- a/talpid-core/src/tunnel_state_machine/mod.rs +++ b/talpid-core/src/tunnel_state_machine/mod.rs @@ -604,7 +604,7 @@ impl SharedTunnelStateValues { #[cfg(target_os = "android")] pub fn bypass_socket(&mut self, fd: RawFd, tx: oneshot::Sender<()>) { - if let Err(err) = self.tun_provider.lock().unwrap().bypass(fd) { + if let Err(err) = self.tun_provider.lock().unwrap().bypass(&fd) { log::error!("Failed to bypass socket {}", err); } let _ = tx.send(()); diff --git a/talpid-tunnel/src/tun_provider/android/mod.rs b/talpid-tunnel/src/tun_provider/android/mod.rs index 7eca3dbba2..f519d8be84 100644 --- a/talpid-tunnel/src/tun_provider/android/mod.rs +++ b/talpid-tunnel/src/tun_provider/android/mod.rs @@ -197,7 +197,7 @@ impl AndroidTunProvider { } /// Allow a socket to bypass the tunnel. - pub fn bypass(&mut self, socket: RawFd) -> Result<(), Error> { + pub fn bypass(&mut self, socket: &impl AsRawFd) -> Result<(), Error> { let env = JnixEnv::from( self.jvm .attach_current_thread_as_daemon() @@ -212,7 +212,7 @@ impl AndroidTunProvider { self.object.as_obj(), create_tun_method, JavaType::Primitive(Primitive::Boolean), - &[JValue::Int(socket)], + &[JValue::Int(socket.as_raw_fd())], ) .map_err(|cause| Error::CallMethod("bypass", cause))?; @@ -404,7 +404,7 @@ impl VpnServiceTun { } /// Allow a socket to bypass the tunnel. - pub fn bypass(&mut self, socket: RawFd) -> Result<(), Error> { + pub fn bypass(&mut self, socket: &impl AsFd) -> Result<(), Error> { let env = JnixEnv::from( self.jvm .attach_current_thread_as_daemon() @@ -419,7 +419,7 @@ impl VpnServiceTun { self.object.as_obj(), create_tun_method, JavaType::Primitive(Primitive::Boolean), - &[JValue::Int(socket)], + &[JValue::Int(socket.as_fd().as_raw_fd())], ) .map_err(|cause| Error::CallMethod("bypass", cause))?; diff --git a/talpid-wireguard/src/boringtun/mod.rs b/talpid-wireguard/src/boringtun/mod.rs index 0c465cba49..44ba1eb1d8 100644 --- a/talpid-wireguard/src/boringtun/mod.rs +++ b/talpid-wireguard/src/boringtun/mod.rs @@ -90,9 +90,7 @@ pub async fn open_boringtun_tunnel( let mut config = tun07::Configuration::default(); config.raw_fd(fd); - boringtun_config.on_bind = Some(Box::new(move |socket| { - tun.bypass(socket.as_raw_fd()).unwrap() - })); + boringtun_config.on_bind = Some(Box::new(move |socket| tun.bypass(socket).unwrap())); let device = tun07::Device::new(&config).unwrap(); tun07::AsyncDevice::new(device).unwrap() diff --git a/talpid-wireguard/src/obfuscation.rs b/talpid-wireguard/src/obfuscation.rs index fafe8c8860..b214e89ae0 100644 --- a/talpid-wireguard/src/obfuscation.rs +++ b/talpid-wireguard/src/obfuscation.rs @@ -117,7 +117,7 @@ async fn bypass_vpn( // Exclude remote obfuscation socket or bridge log::debug!("Excluding remote socket fd from the tunnel"); let _ = tokio::task::spawn_blocking(move || { - if let Err(error) = tun_provider.lock().unwrap().bypass(remote_socket_fd) { + if let Err(error) = tun_provider.lock().unwrap().bypass(&remote_socket_fd) { log::error!("Failed to exclude remote socket fd: {error}"); } }) diff --git a/talpid-wireguard/src/wireguard_go/mod.rs b/talpid-wireguard/src/wireguard_go/mod.rs index 5b172c4f47..bad0bf88cb 100644 --- a/talpid-wireguard/src/wireguard_go/mod.rs +++ b/talpid-wireguard/src/wireguard_go/mod.rs @@ -18,8 +18,6 @@ use std::borrow::Cow; #[cfg(daita)] use std::ffi::CString; #[cfg(unix)] -use std::os::unix::io::AsRawFd; -#[cfg(unix)] use std::sync::{Arc, Mutex}; use std::{ future::Future, @@ -300,10 +298,10 @@ impl WgGoTunnelState { let socket_v6 = self.tunnel_handle.get_socket_v6(); let mut provider = tun_provider.lock().unwrap(); provider - .bypass(socket_v4) + .bypass(&socket_v4) .map_err(super::TunnelError::BypassError)?; provider - .bypass(socket_v6) + .bypass(&socket_v6) .map_err(super::TunnelError::BypassError)?; } @@ -334,7 +332,7 @@ impl WgGoTunnel { let handle = wireguard_go_rs::Tunnel::turn_on( mtu, &wg_config_str, - tunnel_fd.as_raw_fd(), + tunnel_fd, Some(logging::wg_go_logging_callback), logging_context.ordinal, ) @@ -529,7 +527,7 @@ impl WgGoTunnel { let handle = wireguard_go_rs::Tunnel::turn_on( &wg_config_str, - tunnel_fd.as_raw_fd(), + tunnel_fd, Some(logging::wg_go_logging_callback), logging_context.ordinal, ) @@ -611,7 +609,7 @@ impl WgGoTunnel { &exit_config_str, &entry_config_str, &private_ip, - tunnel_fd.as_raw_fd(), + tunnel_fd, Some(logging::wg_go_logging_callback), logging_context.ordinal, ) @@ -658,8 +656,8 @@ impl WgGoTunnel { let socket_v4 = handle.get_socket_v4(); let socket_v6 = handle.get_socket_v6(); - tunnel_device.bypass(socket_v4)?; - tunnel_device.bypass(socket_v6)?; + tunnel_device.bypass(&socket_v4)?; + tunnel_device.bypass(&socket_v6)?; Ok(()) } diff --git a/wireguard-go-rs/src/lib.rs b/wireguard-go-rs/src/lib.rs index 8ffbecc747..a9381f5cab 100644 --- a/wireguard-go-rs/src/lib.rs +++ b/wireguard-go-rs/src/lib.rs @@ -8,18 +8,22 @@ use core::ffi::{c_char, CStr}; use core::mem::ManuallyDrop; +use core::slice; +use talpid_types::drop_guard::on_drop; +use zeroize::Zeroize; + +#[cfg(target_os = "android")] +use std::os::fd::BorrowedFd; + +#[cfg(not(target_os = "windows"))] +use std::os::fd::{IntoRawFd, OwnedFd}; + #[cfg(target_os = "windows")] use core::mem::MaybeUninit; -use core::slice; #[cfg(target_os = "windows")] use std::ffi::CString; -use talpid_types::drop_guard::on_drop; #[cfg(target_os = "windows")] use windows_sys::Win32::NetworkManagement::Ndis::NET_LUID_LH; -use zeroize::Zeroize; - -#[cfg(unix)] -pub type Fd = std::os::unix::io::RawFd; pub type WgLogLevel = u32; @@ -84,17 +88,19 @@ impl Tunnel { pub fn turn_on( #[cfg(not(target_os = "android"))] mtu: isize, settings: &CStr, - device: Fd, + device: OwnedFd, logging_callback: Option<LoggingCallback>, logging_context: LoggingContext, ) -> Result<Self, Error> { - // SAFETY: pointer is valid for the lifetime of this function + // SAFETY: + // - pointer is valid for the lifetime of `wgTurnOn`. + // - OwnedFd asserts that fd is open, and into_raw_fd will transfer ownership to Go. let code = unsafe { ffi::wgTurnOn( #[cfg(not(target_os = "android"))] mtu, settings.as_ptr(), - device, + device.into_raw_fd(), // Transfer ownership of the fd to Go logging_callback, logging_context, ) @@ -181,17 +187,19 @@ impl Tunnel { exit_settings: &CStr, entry_settings: &CStr, private_ip: &CStr, - device: Fd, + device: OwnedFd, logging_callback: Option<LoggingCallback>, logging_context: LoggingContext, ) -> Result<Self, Error> { - // SAFETY: pointer is valid for the lifetime of this function + // SAFETY: + // - pointers are valid for the lifetime of `wgTurnOnMultihop`. + // - OwnedFd asserts that fd is open, and into_raw_fd will transfer ownership to Go. let code = unsafe { ffi::wgTurnOnMultihop( exit_settings.as_ptr(), entry_settings.as_ptr(), private_ip.as_ptr(), - device, + device.into_raw_fd(), // Transfer ownership of the fd to Go logging_callback, logging_context, ) @@ -279,16 +287,22 @@ impl Tunnel { /// Get the file descriptor of the tunnel IPv4 socket. #[cfg(target_os = "android")] - pub fn get_socket_v4(&self) -> Fd { - // SAFETY: self.handle is a valid pointer to an active wireguard-go tunnel. - unsafe { ffi::wgGetSocketV4(self.handle) } + pub fn get_socket_v4(&self) -> BorrowedFd { + // SAFETY: + // - self.handle is a valid pointer to an active wireguard-go tunnel. + // - file descriptor won't be closed until wgTurnOff is called, + // which can't happen while `self` is borrowed. + unsafe { BorrowedFd::borrow_raw(ffi::wgGetSocketV4(self.handle)) } } /// Get the file descriptor of the tunnel IPv6 socket. #[cfg(target_os = "android")] - pub fn get_socket_v6(&self) -> Fd { - // SAFETY: self.handle is a valid pointer to an active wireguard-go tunnel. - unsafe { ffi::wgGetSocketV6(self.handle) } + pub fn get_socket_v6(&self) -> BorrowedFd { + // SAFETY: + // - self.handle is a valid pointer to an active wireguard-go tunnel. + // - file descriptor won't be closed until wgTurnOff is called, + // which can't happen while `self` is borrowed. + unsafe { BorrowedFd::borrow_raw(ffi::wgGetSocketV6(self.handle)) } } } @@ -329,11 +343,12 @@ impl Error { } mod ffi { - #[cfg(not(target_os = "windows"))] - use super::Fd; use super::{LoggingCallback, LoggingContext}; use core::ffi::{c_char, c_void}; + #[cfg(not(target_os = "windows"))] + use std::os::fd::RawFd; + unsafe extern "C" { /// Creates a new wireguard tunnel, uses the specific interface name, and file descriptors /// for the tunnel device and logging. For targets other than android, this also takes an @@ -345,7 +360,7 @@ mod ffi { pub fn wgTurnOn( mtu: isize, settings: *const c_char, - fd: Fd, + fd: RawFd, logging_callback: Option<LoggingCallback>, logging_context: LoggingContext, ) -> i32; @@ -353,7 +368,7 @@ mod ffi { #[cfg(target_os = "android")] pub fn wgTurnOn( settings: *const c_char, - fd: Fd, + fd: RawFd, logging_callback: Option<LoggingCallback>, logging_context: LoggingContext, ) -> i32; @@ -380,7 +395,7 @@ mod ffi { exit_settings: *const c_char, entry_settings: *const c_char, private_ip: *const c_char, - fd: Fd, + fd: RawFd, logging_callback: Option<LoggingCallback>, logging_context: LoggingContext, ) -> i32; @@ -433,11 +448,11 @@ mod ffi { /// Get the file descriptor of the tunnel IPv4 socket. #[cfg(target_os = "android")] - pub fn wgGetSocketV4(handle: i32) -> Fd; + pub fn wgGetSocketV4(handle: i32) -> RawFd; /// Get the file descriptor of the tunnel IPv6 socket. #[cfg(target_os = "android")] - pub fn wgGetSocketV6(handle: i32) -> Fd; + pub fn wgGetSocketV6(handle: i32) -> RawFd; /// Rebind endpoint sockets #[cfg(target_os = "windows")] |
