summaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorJoakim Hulthe <joakim.hulthe@mullvad.net>2025-06-12 14:13:58 +0200
committerJoakim Hulthe <joakim.hulthe@mullvad.net>2025-06-12 15:41:13 +0200
commitb39d040d9f28cc06265f02c2bcf3910eb43d2ab9 (patch)
treedb801eca5f18defed3246cd0dfe1c522cd7b3abe
parentf0efcc68cfc310f6965443c28fbbf59455187165 (diff)
downloadmullvadvpn-b39d040d9f28cc06265f02c2bcf3910eb43d2ab9.tar.xz
mullvadvpn-b39d040d9f28cc06265f02c2bcf3910eb43d2ab9.zip
Fix tun file descriptor ownership
We accidentally borrowed the file descriptor when we should have moved it. This commit adds more `OwnedFd` and friends to help handle ownership correctly. Signed-off-by: Joakim Hulthe <joakim.hulthe@mullvad.net>
-rw-r--r--talpid-core/src/tunnel_state_machine/mod.rs2
-rw-r--r--talpid-tunnel/src/tun_provider/android/mod.rs8
-rw-r--r--talpid-wireguard/src/boringtun/mod.rs4
-rw-r--r--talpid-wireguard/src/obfuscation.rs2
-rw-r--r--talpid-wireguard/src/wireguard_go/mod.rs16
-rw-r--r--wireguard-go-rs/src/lib.rs65
6 files changed, 54 insertions, 43 deletions
diff --git a/talpid-core/src/tunnel_state_machine/mod.rs b/talpid-core/src/tunnel_state_machine/mod.rs
index 207291360c..059edd417d 100644
--- a/talpid-core/src/tunnel_state_machine/mod.rs
+++ b/talpid-core/src/tunnel_state_machine/mod.rs
@@ -604,7 +604,7 @@ impl SharedTunnelStateValues {
#[cfg(target_os = "android")]
pub fn bypass_socket(&mut self, fd: RawFd, tx: oneshot::Sender<()>) {
- if let Err(err) = self.tun_provider.lock().unwrap().bypass(fd) {
+ if let Err(err) = self.tun_provider.lock().unwrap().bypass(&fd) {
log::error!("Failed to bypass socket {}", err);
}
let _ = tx.send(());
diff --git a/talpid-tunnel/src/tun_provider/android/mod.rs b/talpid-tunnel/src/tun_provider/android/mod.rs
index 7eca3dbba2..f519d8be84 100644
--- a/talpid-tunnel/src/tun_provider/android/mod.rs
+++ b/talpid-tunnel/src/tun_provider/android/mod.rs
@@ -197,7 +197,7 @@ impl AndroidTunProvider {
}
/// Allow a socket to bypass the tunnel.
- pub fn bypass(&mut self, socket: RawFd) -> Result<(), Error> {
+ pub fn bypass(&mut self, socket: &impl AsRawFd) -> Result<(), Error> {
let env = JnixEnv::from(
self.jvm
.attach_current_thread_as_daemon()
@@ -212,7 +212,7 @@ impl AndroidTunProvider {
self.object.as_obj(),
create_tun_method,
JavaType::Primitive(Primitive::Boolean),
- &[JValue::Int(socket)],
+ &[JValue::Int(socket.as_raw_fd())],
)
.map_err(|cause| Error::CallMethod("bypass", cause))?;
@@ -404,7 +404,7 @@ impl VpnServiceTun {
}
/// Allow a socket to bypass the tunnel.
- pub fn bypass(&mut self, socket: RawFd) -> Result<(), Error> {
+ pub fn bypass(&mut self, socket: &impl AsFd) -> Result<(), Error> {
let env = JnixEnv::from(
self.jvm
.attach_current_thread_as_daemon()
@@ -419,7 +419,7 @@ impl VpnServiceTun {
self.object.as_obj(),
create_tun_method,
JavaType::Primitive(Primitive::Boolean),
- &[JValue::Int(socket)],
+ &[JValue::Int(socket.as_fd().as_raw_fd())],
)
.map_err(|cause| Error::CallMethod("bypass", cause))?;
diff --git a/talpid-wireguard/src/boringtun/mod.rs b/talpid-wireguard/src/boringtun/mod.rs
index 0c465cba49..44ba1eb1d8 100644
--- a/talpid-wireguard/src/boringtun/mod.rs
+++ b/talpid-wireguard/src/boringtun/mod.rs
@@ -90,9 +90,7 @@ pub async fn open_boringtun_tunnel(
let mut config = tun07::Configuration::default();
config.raw_fd(fd);
- boringtun_config.on_bind = Some(Box::new(move |socket| {
- tun.bypass(socket.as_raw_fd()).unwrap()
- }));
+ boringtun_config.on_bind = Some(Box::new(move |socket| tun.bypass(socket).unwrap()));
let device = tun07::Device::new(&config).unwrap();
tun07::AsyncDevice::new(device).unwrap()
diff --git a/talpid-wireguard/src/obfuscation.rs b/talpid-wireguard/src/obfuscation.rs
index fafe8c8860..b214e89ae0 100644
--- a/talpid-wireguard/src/obfuscation.rs
+++ b/talpid-wireguard/src/obfuscation.rs
@@ -117,7 +117,7 @@ async fn bypass_vpn(
// Exclude remote obfuscation socket or bridge
log::debug!("Excluding remote socket fd from the tunnel");
let _ = tokio::task::spawn_blocking(move || {
- if let Err(error) = tun_provider.lock().unwrap().bypass(remote_socket_fd) {
+ if let Err(error) = tun_provider.lock().unwrap().bypass(&remote_socket_fd) {
log::error!("Failed to exclude remote socket fd: {error}");
}
})
diff --git a/talpid-wireguard/src/wireguard_go/mod.rs b/talpid-wireguard/src/wireguard_go/mod.rs
index 5b172c4f47..bad0bf88cb 100644
--- a/talpid-wireguard/src/wireguard_go/mod.rs
+++ b/talpid-wireguard/src/wireguard_go/mod.rs
@@ -18,8 +18,6 @@ use std::borrow::Cow;
#[cfg(daita)]
use std::ffi::CString;
#[cfg(unix)]
-use std::os::unix::io::AsRawFd;
-#[cfg(unix)]
use std::sync::{Arc, Mutex};
use std::{
future::Future,
@@ -300,10 +298,10 @@ impl WgGoTunnelState {
let socket_v6 = self.tunnel_handle.get_socket_v6();
let mut provider = tun_provider.lock().unwrap();
provider
- .bypass(socket_v4)
+ .bypass(&socket_v4)
.map_err(super::TunnelError::BypassError)?;
provider
- .bypass(socket_v6)
+ .bypass(&socket_v6)
.map_err(super::TunnelError::BypassError)?;
}
@@ -334,7 +332,7 @@ impl WgGoTunnel {
let handle = wireguard_go_rs::Tunnel::turn_on(
mtu,
&wg_config_str,
- tunnel_fd.as_raw_fd(),
+ tunnel_fd,
Some(logging::wg_go_logging_callback),
logging_context.ordinal,
)
@@ -529,7 +527,7 @@ impl WgGoTunnel {
let handle = wireguard_go_rs::Tunnel::turn_on(
&wg_config_str,
- tunnel_fd.as_raw_fd(),
+ tunnel_fd,
Some(logging::wg_go_logging_callback),
logging_context.ordinal,
)
@@ -611,7 +609,7 @@ impl WgGoTunnel {
&exit_config_str,
&entry_config_str,
&private_ip,
- tunnel_fd.as_raw_fd(),
+ tunnel_fd,
Some(logging::wg_go_logging_callback),
logging_context.ordinal,
)
@@ -658,8 +656,8 @@ impl WgGoTunnel {
let socket_v4 = handle.get_socket_v4();
let socket_v6 = handle.get_socket_v6();
- tunnel_device.bypass(socket_v4)?;
- tunnel_device.bypass(socket_v6)?;
+ tunnel_device.bypass(&socket_v4)?;
+ tunnel_device.bypass(&socket_v6)?;
Ok(())
}
diff --git a/wireguard-go-rs/src/lib.rs b/wireguard-go-rs/src/lib.rs
index 8ffbecc747..a9381f5cab 100644
--- a/wireguard-go-rs/src/lib.rs
+++ b/wireguard-go-rs/src/lib.rs
@@ -8,18 +8,22 @@
use core::ffi::{c_char, CStr};
use core::mem::ManuallyDrop;
+use core::slice;
+use talpid_types::drop_guard::on_drop;
+use zeroize::Zeroize;
+
+#[cfg(target_os = "android")]
+use std::os::fd::BorrowedFd;
+
+#[cfg(not(target_os = "windows"))]
+use std::os::fd::{IntoRawFd, OwnedFd};
+
#[cfg(target_os = "windows")]
use core::mem::MaybeUninit;
-use core::slice;
#[cfg(target_os = "windows")]
use std::ffi::CString;
-use talpid_types::drop_guard::on_drop;
#[cfg(target_os = "windows")]
use windows_sys::Win32::NetworkManagement::Ndis::NET_LUID_LH;
-use zeroize::Zeroize;
-
-#[cfg(unix)]
-pub type Fd = std::os::unix::io::RawFd;
pub type WgLogLevel = u32;
@@ -84,17 +88,19 @@ impl Tunnel {
pub fn turn_on(
#[cfg(not(target_os = "android"))] mtu: isize,
settings: &CStr,
- device: Fd,
+ device: OwnedFd,
logging_callback: Option<LoggingCallback>,
logging_context: LoggingContext,
) -> Result<Self, Error> {
- // SAFETY: pointer is valid for the lifetime of this function
+ // SAFETY:
+ // - pointer is valid for the lifetime of `wgTurnOn`.
+ // - OwnedFd asserts that fd is open, and into_raw_fd will transfer ownership to Go.
let code = unsafe {
ffi::wgTurnOn(
#[cfg(not(target_os = "android"))]
mtu,
settings.as_ptr(),
- device,
+ device.into_raw_fd(), // Transfer ownership of the fd to Go
logging_callback,
logging_context,
)
@@ -181,17 +187,19 @@ impl Tunnel {
exit_settings: &CStr,
entry_settings: &CStr,
private_ip: &CStr,
- device: Fd,
+ device: OwnedFd,
logging_callback: Option<LoggingCallback>,
logging_context: LoggingContext,
) -> Result<Self, Error> {
- // SAFETY: pointer is valid for the lifetime of this function
+ // SAFETY:
+ // - pointers are valid for the lifetime of `wgTurnOnMultihop`.
+ // - OwnedFd asserts that fd is open, and into_raw_fd will transfer ownership to Go.
let code = unsafe {
ffi::wgTurnOnMultihop(
exit_settings.as_ptr(),
entry_settings.as_ptr(),
private_ip.as_ptr(),
- device,
+ device.into_raw_fd(), // Transfer ownership of the fd to Go
logging_callback,
logging_context,
)
@@ -279,16 +287,22 @@ impl Tunnel {
/// Get the file descriptor of the tunnel IPv4 socket.
#[cfg(target_os = "android")]
- pub fn get_socket_v4(&self) -> Fd {
- // SAFETY: self.handle is a valid pointer to an active wireguard-go tunnel.
- unsafe { ffi::wgGetSocketV4(self.handle) }
+ pub fn get_socket_v4(&self) -> BorrowedFd {
+ // SAFETY:
+ // - self.handle is a valid pointer to an active wireguard-go tunnel.
+ // - file descriptor won't be closed until wgTurnOff is called,
+ // which can't happen while `self` is borrowed.
+ unsafe { BorrowedFd::borrow_raw(ffi::wgGetSocketV4(self.handle)) }
}
/// Get the file descriptor of the tunnel IPv6 socket.
#[cfg(target_os = "android")]
- pub fn get_socket_v6(&self) -> Fd {
- // SAFETY: self.handle is a valid pointer to an active wireguard-go tunnel.
- unsafe { ffi::wgGetSocketV6(self.handle) }
+ pub fn get_socket_v6(&self) -> BorrowedFd {
+ // SAFETY:
+ // - self.handle is a valid pointer to an active wireguard-go tunnel.
+ // - file descriptor won't be closed until wgTurnOff is called,
+ // which can't happen while `self` is borrowed.
+ unsafe { BorrowedFd::borrow_raw(ffi::wgGetSocketV6(self.handle)) }
}
}
@@ -329,11 +343,12 @@ impl Error {
}
mod ffi {
- #[cfg(not(target_os = "windows"))]
- use super::Fd;
use super::{LoggingCallback, LoggingContext};
use core::ffi::{c_char, c_void};
+ #[cfg(not(target_os = "windows"))]
+ use std::os::fd::RawFd;
+
unsafe extern "C" {
/// Creates a new wireguard tunnel, uses the specific interface name, and file descriptors
/// for the tunnel device and logging. For targets other than android, this also takes an
@@ -345,7 +360,7 @@ mod ffi {
pub fn wgTurnOn(
mtu: isize,
settings: *const c_char,
- fd: Fd,
+ fd: RawFd,
logging_callback: Option<LoggingCallback>,
logging_context: LoggingContext,
) -> i32;
@@ -353,7 +368,7 @@ mod ffi {
#[cfg(target_os = "android")]
pub fn wgTurnOn(
settings: *const c_char,
- fd: Fd,
+ fd: RawFd,
logging_callback: Option<LoggingCallback>,
logging_context: LoggingContext,
) -> i32;
@@ -380,7 +395,7 @@ mod ffi {
exit_settings: *const c_char,
entry_settings: *const c_char,
private_ip: *const c_char,
- fd: Fd,
+ fd: RawFd,
logging_callback: Option<LoggingCallback>,
logging_context: LoggingContext,
) -> i32;
@@ -433,11 +448,11 @@ mod ffi {
/// Get the file descriptor of the tunnel IPv4 socket.
#[cfg(target_os = "android")]
- pub fn wgGetSocketV4(handle: i32) -> Fd;
+ pub fn wgGetSocketV4(handle: i32) -> RawFd;
/// Get the file descriptor of the tunnel IPv6 socket.
#[cfg(target_os = "android")]
- pub fn wgGetSocketV6(handle: i32) -> Fd;
+ pub fn wgGetSocketV6(handle: i32) -> RawFd;
/// Rebind endpoint sockets
#[cfg(target_os = "windows")]