diff options
| author | Sebastian Holmin <sebastian.holmin@mullvad.net> | 2024-05-02 11:54:13 +0200 |
|---|---|---|
| committer | Markus Pettersson <markus.pettersson@mullvad.net> | 2024-06-25 17:31:32 +0200 |
| commit | da95b2603470841b64518959ceac0d03aab0068a (patch) | |
| tree | 83028a23bd2439813a0dcf826eac123ea83d956b | |
| parent | 9dfdf2fa52422709ce3df7a50643e0abc6ade319 (diff) | |
| download | mullvadvpn-da95b2603470841b64518959ceac0d03aab0068a.tar.xz mullvadvpn-da95b2603470841b64518959ceac0d03aab0068a.zip | |
Add a safe FFI wrapper in `wireguard-go-rs`
- Add local wireguard go import
- Activate DAITA and add `wgActivateDaita` and `wgReceiveEvent` FFI
- Implement `start_daita` on Wireguard-go tunnel type
- Mention DAITA in `wireguard-go-rs` description
- Do not compile `wireguard-go-rs` on Windows
- Handle DAITA closed on `nil` event
- Handle daita action timeouts in libwg
- Remove noisy log lines
- Remove `maybenot_on_action` callback
- Remove unused link to `../build/lib` for `talpid-wireguard`
- Bump the `wireguard-go` submodule to a signed release tag in Mullvad's
`wireguard-go` fork.
- Update path to `libwg/go.sum` in verification script
Also:
- Use u64 instead of *mut void as log context
- Make Tunnel::set_config take a &mut self
- Use dyn Error instead of i32s for wg errors
Co-authored-by: Joakim Hulthe <joakim@hulthe.net>
37 files changed, 901 insertions, 314 deletions
diff --git a/.github/workflows/verify-locked-down-signatures.yml b/.github/workflows/verify-locked-down-signatures.yml index 0e36154a15..118e44914a 100644 --- a/.github/workflows/verify-locked-down-signatures.yml +++ b/.github/workflows/verify-locked-down-signatures.yml @@ -11,7 +11,7 @@ on: - deny.toml - test/deny.toml - gui/package-lock.json - - wireguard/libwg/go.sum + - wireguard-go-rs/libwg/go.sum - ci/keys/** - ci/verify-locked-down-signatures.sh - ios/MullvadVPN.xcodeproj/project.xcworkspace/xcshareddata/swiftpm/Package.resolved diff --git a/.gitignore b/.gitignore index 1efb0161fc..569d6ca706 100644 --- a/.gitignore +++ b/.gitignore @@ -26,7 +26,6 @@ /android/keystore.properties /android/local.properties /android/play-api-key.json -/wireguard/libwg/libwg.h /wireguard/libwg/libwg.exp /wireguard/libwg/exports.def **/.vs/ diff --git a/.gitmodules b/.gitmodules index e8c3ba4187..39aa8b722d 100644 --- a/.gitmodules +++ b/.gitmodules @@ -7,3 +7,6 @@ [submodule "windows/windows-libraries"] path = windows/windows-libraries url = https://github.com/mullvad/windows-libraries +[submodule "wireguard-go-rs/libwg/wireguard-go"] + path = wireguard-go-rs/libwg/wireguard-go + url = https://github.com/mullvad/wireguard-go/ diff --git a/Cargo.lock b/Cargo.lock index c6a1e328fb..8709388967 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -2026,9 +2026,9 @@ checksum = "0e7465ac9959cc2b1404e8e2367b43684a6d13790fe23056cc8c6c5a6b7bcb94" [[package]] name = "maybenot" -version = "1.1.0" +version = "1.1.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "94ed977e86fc65a7ffae967a6a973e6f7a90b5d747ebd755703d5718804f7c16" +checksum = "a7fe205734d700937dabf0b8687e290f8574fac996f8a9d04bd7a62d7c2c1dad" dependencies = [ "byteorder", "hex", @@ -4144,6 +4144,7 @@ dependencies = [ "tunnel-obfuscation", "widestring", "windows-sys 0.52.0", + "wireguard-go-rs", "zeroize", ] @@ -5079,6 +5080,15 @@ dependencies = [ ] [[package]] +name = "wireguard-go-rs" +version = "0.0.0" +dependencies = [ + "log", + "thiserror", + "zeroize", +] + +[[package]] name = "x25519-dalek" version = "2.0.1" source = "registry+https://github.com/rust-lang/crates.io-index" @@ -5112,9 +5122,9 @@ dependencies = [ [[package]] name = "zeroize" -version = "1.7.0" +version = "1.8.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "525b4ec142c6b68a2d10f01f7bbf6755599ca3f81ea53b8431b7dd348f5fdb2d" +checksum = "ced3678a2879b30306d323f4542626697a464a97c0a07c9aebf7ebca65cd4dde" dependencies = [ "zeroize_derive", ] diff --git a/Cargo.toml b/Cargo.toml index aa3b9145c8..e126c6a13e 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -39,6 +39,7 @@ members = [ "talpid-windows", "talpid-wireguard", "tunnel-obfuscation", + "wireguard-go-rs" ] # Keep all lints in sync with `test/Cargo.toml` @@ -213,11 +213,6 @@ function build { # Compile and link all binaries. ################################################################################ - if [[ "$(uname -s)" != "MINGW"* ]]; then - log_header "Building wireguard-go$for_target_string" - ./wireguard/build-wireguard-go.sh "$current_target" - fi - log_header "Building Rust code in $RUST_BUILD_MODE mode using $RUSTC_VERSION$for_target_string" local cargo_target_arg=() @@ -312,6 +307,7 @@ for t in "${TARGETS[@]:-""}"; do build "$t" done + ################################################################################ # Package app. ################################################################################ diff --git a/ci/check-rust.sh b/ci/check-rust.sh index d42784d36e..cb48fbed3f 100755 --- a/ci/check-rust.sh +++ b/ci/check-rust.sh @@ -4,11 +4,6 @@ set -eux export RUSTFLAGS="--deny warnings" -# Build WireGuard Go -if [[ "$(uname -s)" != "MINGW"* ]]; then - ./wireguard/build-wireguard-go.sh -fi - # Build Rust crates source env.sh time cargo build --locked --verbose diff --git a/talpid-wireguard/Cargo.toml b/talpid-wireguard/Cargo.toml index c1eb335766..09b7baa57f 100644 --- a/talpid-wireguard/Cargo.toml +++ b/talpid-wireguard/Cargo.toml @@ -30,6 +30,9 @@ tunnel-obfuscation = { path = "../tunnel-obfuscation" } rand = "0.8.5" surge-ping = "0.8.0" +[target.'cfg(not(windows))'.dependencies] +wireguard-go-rs = { path = "../wireguard-go-rs"} + [target.'cfg(target_os="android")'.dependencies] duct = "0.13" @@ -42,7 +45,6 @@ tokio-stream = { version = "0.1", features = ["io-util"] } [target.'cfg(unix)'.dependencies] nix = "0.23" -[target.'cfg(target_os = "linux")'.dependencies] rtnetlink = "0.11" netlink-packet-core = "0.4.2" netlink-packet-route = "0.13" diff --git a/talpid-wireguard/build.rs b/talpid-wireguard/build.rs index fe5b45b819..3abec6abe2 100644 --- a/talpid-wireguard/build.rs +++ b/talpid-wireguard/build.rs @@ -4,22 +4,6 @@ fn main() { let target_os = env::var("CARGO_CFG_TARGET_OS").expect("CARGO_CFG_TARGET_OS not set"); declare_libs_dir("../dist-assets/binaries"); - declare_libs_dir("../build/lib"); - - let link_type = match target_os.as_str() { - "android" => "", - "linux" | "macos" => "=static", - // We would like to avoid panicking on windows even if we can not link correctly - // because we would like to be able to run check and clippy. - // This does not allow for correct linking or buijding. - #[cfg(not(windows))] - "windows" => "", - #[cfg(windows)] - "windows" => "dylib", - _ => panic!("Unsupported platform: {target_os}"), - }; - - println!("cargo:rustc-link-lib{link_type}=wg"); add_wireguard_go_cfg(&target_os); } diff --git a/talpid-wireguard/src/config.rs b/talpid-wireguard/src/config.rs index f10a0e4859..e9059c6cca 100644 --- a/talpid-wireguard/src/config.rs +++ b/talpid-wireguard/src/config.rs @@ -97,9 +97,9 @@ impl Config { enable_ipv6: generic_options.enable_ipv6, obfuscator_config: obfuscator_config.to_owned(), quantum_resistant: wg_options.quantum_resistant, - #[cfg(target_os = "windows")] + #[cfg(any(target_os = "windows", target_os = "linux"))] daita: wg_options.daita, - #[cfg(not(target_os = "windows"))] + #[cfg(not(any(target_os = "windows", target_os = "linux")))] daita: false, }; diff --git a/talpid-wireguard/src/connectivity_check.rs b/talpid-wireguard/src/connectivity_check.rs index 70f88e6872..41ad0b5bf3 100644 --- a/talpid-wireguard/src/connectivity_check.rs +++ b/talpid-wireguard/src/connectivity_check.rs @@ -5,9 +5,10 @@ use crate::{ use std::{ cmp, net::Ipv4Addr, - sync::{mpsc, Mutex, Weak}, + sync::{mpsc, Weak}, time::{Duration, Instant}, }; +use tokio::sync::Mutex; use super::{Tunnel, TunnelError}; @@ -211,11 +212,12 @@ impl ConnectivityMonitor { /// If None is returned, then the underlying tunnel has already been closed and all subsequent /// calls will also return None. + /// + /// NOTE: will panic if called from within a tokio runtime. fn get_stats(&self) -> Option<Result<StatsMap, Error>> { self.tunnel_handle .upgrade()? - .lock() - .ok()? + .blocking_lock() .as_ref() .and_then(|tunnel| match tunnel.get_tunnel_stats() { Ok(stats) if stats.is_empty() => { @@ -550,7 +552,7 @@ mod test { rx_bytes: 0, }, ); - let peers = Mutex::new(map); + let peers = std::sync::Mutex::new(map); Self { on_get_stats: Box::new(move || { let mut peers = peers.lock().unwrap(); @@ -607,13 +609,13 @@ mod test { } fn set_config( - &self, + &mut self, _config: Config, ) -> Pin<Box<dyn Future<Output = std::result::Result<(), TunnelError>> + Send>> { Box::pin(async { Ok(()) }) } - #[cfg(target_os = "windows")] + #[cfg(any(target_os = "windows", target_os = "linux"))] fn start_daita(&mut self) -> std::result::Result<(), TunnelError> { Ok(()) } @@ -745,7 +747,7 @@ mod test { rx_bytes: 0, }, ); - let tunnel_stats = Mutex::new(map); + let tunnel_stats = std::sync::Mutex::new(map); let pinger = MockPinger::default(); let (_tunnel_anchor, tunnel) = MockTunnel::new(move || { diff --git a/talpid-wireguard/src/lib.rs b/talpid-wireguard/src/lib.rs index 7c01538b60..e8334f43d1 100644 --- a/talpid-wireguard/src/lib.rs +++ b/talpid-wireguard/src/lib.rs @@ -115,7 +115,7 @@ impl Error { Error::CreateObfuscatorError(_) => true, Error::ObfuscatorError(_) => true, Error::PskNegotiationError(_) => true, - Error::TunnelError(TunnelError::RecoverableStartWireguardError) => true, + Error::TunnelError(TunnelError::RecoverableStartWireguardError(..)) => true, Error::SetupRoutingError(error) => error.is_recoverable(), @@ -144,7 +144,7 @@ impl Error { pub struct WireguardMonitor { runtime: tokio::runtime::Handle, /// Tunnel implementation - tunnel: Arc<Mutex<Option<Box<dyn Tunnel>>>>, + tunnel: Arc<AsyncMutex<Option<Box<dyn Tunnel>>>>, /// Callback to signal tunnel events event_callback: EventCallback, close_msg_receiver: sync_mpsc::Receiver<CloseMsg>, @@ -306,7 +306,7 @@ impl WireguardMonitor { let (pinger_tx, pinger_rx) = sync_mpsc::channel(); let monitor = WireguardMonitor { runtime: args.runtime.clone(), - tunnel: Arc::new(Mutex::new(Some(tunnel))), + tunnel: Arc::new(AsyncMutex::new(Some(tunnel))), event_callback, close_msg_receiver: close_obfs_listener, pinger_stop_sender: pinger_tx, @@ -473,7 +473,7 @@ impl WireguardMonitor { #[allow(clippy::too_many_arguments)] async fn config_ephemeral_peers<F>( - tunnel: &Arc<Mutex<Option<Box<dyn Tunnel>>>>, + tunnel: &Arc<AsyncMutex<Option<Box<dyn Tunnel>>>>, config: &mut Config, retry_attempt: u32, on_event: F, @@ -576,10 +576,10 @@ impl WireguardMonitor { ) .await?; - #[cfg(target_os = "windows")] + #[cfg(any(target_os = "windows", target_os = "linux"))] if config.daita { // Start local DAITA machines - let mut tunnel = tunnel.lock().unwrap(); + let mut tunnel = tunnel.lock().await; if let Some(tunnel) = tunnel.as_mut() { tunnel .start_daita() @@ -601,7 +601,7 @@ impl WireguardMonitor { /// Reconfigures the tunnel to use the provided config while potentially modifying the config /// and restarting the obfuscation provider. Returns the new config used by the new tunnel. async fn reconfigure_tunnel( - tunnel: &Arc<Mutex<Option<Box<dyn Tunnel>>>>, + tunnel: &Arc<AsyncMutex<Option<Box<dyn Tunnel>>>>, mut config: Config, obfuscator: Arc<AsyncMutex<Option<ObfuscatorHandle>>>, close_obfs_sender: sync_mpsc::Sender<CloseMsg>, @@ -625,11 +625,12 @@ impl WireguardMonitor { } } + let mut tunnel = tunnel.lock().await; + let set_config_future = tunnel - .lock() - .unwrap() - .as_ref() + .as_mut() .map(|tunnel| tunnel.set_config(config.clone())); + if let Some(f) = set_config_future { f.await .map_err(Error::TunnelError) @@ -817,6 +818,7 @@ impl WireguardMonitor { log_path, tun_provider, routes, + resource_dir, ) .map_err(Error::TunnelError)?, )) @@ -843,8 +845,11 @@ impl WireguardMonitor { wait_result } + /// Tear down the tunnel. + /// + /// NOTE: will panic if called from within a tokio runtime. fn stop_tunnel(&mut self) { - match self.tunnel.lock().expect("Tunnel lock poisoned").take() { + match self.tunnel.blocking_lock().take() { Some(tunnel) => { if let Err(e) = tunnel.stop() { log::error!("{}", e.display_chain_with_msg("Failed to stop tunnel")); @@ -1025,11 +1030,12 @@ pub(crate) trait Tunnel: Send { fn get_interface_name(&self) -> String; fn stop(self: Box<Self>) -> std::result::Result<(), TunnelError>; fn get_tunnel_stats(&self) -> std::result::Result<stats::StatsMap, TunnelError>; - fn set_config( - &self, + fn set_config<'a>( + &'a mut self, _config: Config, - ) -> Pin<Box<dyn Future<Output = std::result::Result<(), TunnelError>> + Send>>; - #[cfg(target_os = "windows")] + ) -> Pin<Box<dyn Future<Output = std::result::Result<(), TunnelError>> + Send + 'a>>; + #[cfg(any(target_os = "windows", target_os = "linux"))] + /// A [`Tunnel`] capable of using DAITA. fn start_daita(&mut self) -> std::result::Result<(), TunnelError>; } @@ -1041,7 +1047,7 @@ pub enum TunnelError { /// This is an error returned by the implementation that indicates that trying to establish the /// tunnel again should work normally. The error encountered is known to be sporadic. #[error("Recoverable error while starting wireguard tunnel")] - RecoverableStartWireguardError, + RecoverableStartWireguardError(#[source] Box<dyn std::error::Error + Send>), /// An unrecoverable error occurred while starting the wireguard tunnel /// @@ -1049,14 +1055,11 @@ pub enum TunnelError { /// tunnel again will likely fail with the same error. An error was encountered during tunnel /// configuration which can't be dealt with gracefully. #[error("Failed to start wireguard tunnel")] - FatalStartWireguardError, + FatalStartWireguardError(#[source] Box<dyn std::error::Error + Send>), /// Failed to tear down wireguard tunnel. - #[error("Failed to stop wireguard tunnel. Status: {status}")] - StopWireguardError { - /// Returned error code - status: i32, - }, + #[error("Failed to tear down wireguard tunnel")] + StopWireguardError(#[source] Box<dyn std::error::Error + Send>), /// Error whilst trying to parse the WireGuard config to read the stats #[error("Reading tunnel stats failed")] @@ -1107,6 +1110,16 @@ pub enum TunnelError { /// Failure to set up logging #[error("Failed to set up logging")] LoggingError(#[source] logging::Error), + + /// Failed to receive DAITA event + #[cfg(any(target_os = "windows", target_os = "linux"))] + #[error("Failed to start DAITA")] + StartDaita(#[source] Box<dyn std::error::Error + Send>), + + /// This tunnel does not support DAITA. + #[cfg(any(target_os = "windows", target_os = "linux"))] + #[error("Failed to start DAITA - tunnel implemenation does not support DAITA")] + DaitaNotSupported, } #[cfg(target_os = "linux")] diff --git a/talpid-wireguard/src/logging.rs b/talpid-wireguard/src/logging.rs index 6d1d364342..a4d8c7f240 100644 --- a/talpid-wireguard/src/logging.rs +++ b/talpid-wireguard/src/logging.rs @@ -2,9 +2,13 @@ use once_cell::sync::Lazy; use parking_lot::Mutex; use std::{collections::HashMap, fmt, fs, io::Write, path::Path}; -static LOG_MUTEX: Lazy<Mutex<HashMap<u32, fs::File>>> = Lazy::new(|| Mutex::new(HashMap::new())); +static LOG_MUTEX: Lazy<Mutex<LogState>> = Lazy::new(|| Mutex::new(LogState::default())); -static mut LOG_CONTEXT_NEXT_ORDINAL: u32 = 0; +#[derive(Default)] +struct LogState { + map: HashMap<u64, fs::File>, + next_ordinal: u64, +} /// Errors encountered when initializing logging #[derive(thiserror::Error, Debug)] @@ -14,18 +18,15 @@ pub enum Error { PrepareLogFileError(#[from] std::io::Error), } -pub fn initialize_logging(log_path: Option<&Path>) -> Result<u32, Error> { +pub fn initialize_logging(log_path: Option<&Path>) -> Result<u64, Error> { let log_file = create_log_file(log_path)?; - let log_context_ordinal = unsafe { - let mut map = LOG_MUTEX.lock(); - let ordinal = LOG_CONTEXT_NEXT_ORDINAL; - LOG_CONTEXT_NEXT_ORDINAL += 1; - map.insert(ordinal, log_file); - ordinal - }; + let mut state = LOG_MUTEX.lock(); + let ordinal = state.next_ordinal; + state.next_ordinal += 1; + state.map.insert(ordinal, log_file); - Ok(log_context_ordinal) + Ok(ordinal) } #[cfg(target_os = "windows")] @@ -39,9 +40,9 @@ fn create_log_file(log_path: Option<&Path>) -> Result<fs::File, Error> { .map_err(Error::PrepareLogFileError) } -pub fn clean_up_logging(ordinal: u32) { - let mut map = LOG_MUTEX.lock(); - map.remove(&ordinal); +pub fn clean_up_logging(ordinal: u64) { + let mut state = LOG_MUTEX.lock(); + state.map.remove(&ordinal); } pub enum LogLevel { @@ -71,9 +72,9 @@ impl AsRef<str> for LogLevel { } } -pub fn log(context: u32, level: LogLevel, tag: &str, msg: &str) { - let mut map = LOG_MUTEX.lock(); - if let Some(logfile) = map.get_mut(&{ context }) { +pub fn log(context: u64, level: LogLevel, tag: &str, msg: &str) { + let mut state = LOG_MUTEX.lock(); + if let Some(logfile) = state.map.get_mut(&context) { log_inner(logfile, level, tag, msg); } } diff --git a/talpid-wireguard/src/wireguard_go.rs b/talpid-wireguard/src/wireguard_go/mod.rs index b08b241bb9..32181beaea 100644 --- a/talpid-wireguard/src/wireguard_go.rs +++ b/talpid-wireguard/src/wireguard_go/mod.rs @@ -1,35 +1,40 @@ -use super::{ - stats::{Stats, StatsMap}, - Config, Tunnel, TunnelError, -}; -use crate::logging::{clean_up_logging, initialize_logging}; use ipnetwork::IpNetwork; +#[cfg(any(target_os = "windows", target_os = "linux"))] +use once_cell::sync::OnceCell; +#[cfg(any(target_os = "windows", target_os = "linux"))] +use std::{ffi::CString, fs, path::PathBuf}; use std::{ - ffi::{c_char, c_void, CStr}, future::Future, + net::IpAddr, + os::unix::io::{AsRawFd, RawFd}, path::Path, pin::Pin, + sync::{Arc, Mutex}, }; -use talpid_tunnel::tun_provider::TunProvider; -use talpid_types::BoxedError; -use zeroize::Zeroize; - #[cfg(target_os = "android")] -use talpid_tunnel::tun_provider; +use talpid_tunnel::tun_provider::Error as TunProviderError; +use talpid_tunnel::tun_provider::{Tun, TunConfig, TunProvider}; +use talpid_types::BoxedError; -use std::{ - net::IpAddr, - os::unix::io::{AsRawFd, RawFd}, +use super::{ + stats::{Stats, StatsMap}, + Config, Tunnel, TunnelError, }; -use talpid_tunnel::tun_provider::{Tun, TunConfig}; +use crate::logging::{clean_up_logging, initialize_logging}; -type Result<T> = std::result::Result<T, TunnelError>; +const MAX_PREPARE_TUN_ATTEMPTS: usize = 4; -use std::sync::{Arc, Mutex}; +/// Maximum number of events that can be stored in the underlying buffer +#[cfg(any(target_os = "windows", target_os = "linux"))] +const DAITA_EVENTS_CAPACITY: u32 = 1000; -const MAX_PREPARE_TUN_ATTEMPTS: usize = 4; +/// Maximum number of actions that can be stored in the underlying buffer +#[cfg(any(target_os = "windows", target_os = "linux"))] +const DAITA_ACTIONS_CAPACITY: u32 = 1000; + +type Result<T> = std::result::Result<T, TunnelError>; -struct LoggingContext(u32); +struct LoggingContext(u64); impl Drop for LoggingContext { fn drop(&mut self) { @@ -39,7 +44,7 @@ impl Drop for LoggingContext { pub struct WgGoTunnel { interface_name: String, - handle: Option<i32>, + tunnel_handle: wireguard_go_rs::Tunnel, // holding on to the tunnel device and the log file ensures that the associated file handles // live long enough and get closed when the tunnel is stopped _tunnel_device: Tun, @@ -47,6 +52,9 @@ pub struct WgGoTunnel { _logging_context: LoggingContext, #[cfg(target_os = "android")] tun_provider: Arc<Mutex<TunProvider>>, + #[cfg(any(target_os = "windows", target_os = "linux"))] + resource_dir: PathBuf, + config: Config, } impl WgGoTunnel { @@ -55,6 +63,7 @@ impl WgGoTunnel { log_path: Option<&Path>, tun_provider: Arc<Mutex<TunProvider>>, routes: impl Iterator<Item = IpNetwork>, + resource_dir: &Path, ) -> Result<Self> { #[cfg(target_os = "android")] let tun_provider_clone = tun_provider.clone(); @@ -70,29 +79,30 @@ impl WgGoTunnel { #[cfg(not(target_os = "android"))] let mtu = config.mtu as isize; - let handle = unsafe { - wgTurnOn( - #[cfg(not(target_os = "android"))] - mtu, - wg_config_str.as_ptr() as _, - tunnel_fd, - Some(logging::wg_go_logging_callback), - logging_context.0 as *mut c_void, - ) - }; - check_wg_status(handle)?; + let handle = wireguard_go_rs::Tunnel::turn_on( + #[cfg(not(target_os = "android"))] + mtu, + &wg_config_str, + tunnel_fd, + Some(logging::wg_go_logging_callback), + logging_context.0, + ) + .map_err(|e| TunnelError::FatalStartWireguardError(Box::new(e)))?; #[cfg(target_os = "android")] - Self::bypass_tunnel_sockets(&mut tunnel_device, handle) + Self::bypass_tunnel_sockets(&handle, &mut tunnel_device) .map_err(TunnelError::BypassError)?; Ok(WgGoTunnel { interface_name, - handle: Some(handle), + tunnel_handle: handle, _tunnel_device: tunnel_device, _logging_context: logging_context, #[cfg(target_os = "android")] tun_provider: tun_provider_clone, + resource_dir: resource_dir.to_owned(), + #[cfg(any(target_os = "windows", target_os = "linux"))] + config: config.clone(), }) } @@ -130,11 +140,11 @@ impl WgGoTunnel { #[cfg(target_os = "android")] fn bypass_tunnel_sockets( + handle: &wireguard_go_rs::Tunnel, tunnel_device: &mut Tun, - handle: i32, - ) -> std::result::Result<(), tun_provider::Error> { - let socket_v4 = unsafe { wgGetSocketV4(handle) }; - let socket_v6 = unsafe { wgGetSocketV6(handle) }; + ) -> std::result::Result<(), TunProviderError> { + let socket_v4 = handle.get_socket_v4(); + let socket_v6 = handle.get_socket_v6(); tunnel_device.bypass(socket_v4)?; tunnel_device.bypass(socket_v6)?; @@ -142,16 +152,6 @@ impl WgGoTunnel { Ok(()) } - fn stop_tunnel(&mut self) -> Result<()> { - if let Some(handle) = self.handle.take() { - let status = unsafe { wgTurnOff(handle) }; - if status < 0 { - return Err(TunnelError::StopWireguardError { status }); - } - } - Ok(()) - } - fn get_tunnel( tun_provider: Arc<Mutex<TunProvider>>, config: &Config, @@ -187,71 +187,46 @@ impl WgGoTunnel { } } -impl Drop for WgGoTunnel { - fn drop(&mut self) { - if let Err(e) = self.stop_tunnel() { - log::error!("Failed to stop tunnel: {}", e); - } - } -} - impl Tunnel for WgGoTunnel { fn get_interface_name(&self) -> String { self.interface_name.clone() } fn get_tunnel_stats(&self) -> Result<StatsMap> { - let config_str = unsafe { - let ptr = wgGetConfig(self.handle.unwrap()); - if ptr.is_null() { - log::error!("Failed to get config !"); - return Err(TunnelError::GetConfigError); - } - - CStr::from_ptr(ptr) - }; - - let result = - Stats::parse_config_str(config_str.to_str().expect("Go strings are always UTF-8")) - .map_err(|error| TunnelError::StatsError(BoxedError::new(error))); - unsafe { - // Zeroing out config string to not leave private key in memory. - let slice = std::slice::from_raw_parts_mut( - config_str.as_ptr() as *mut c_char, - config_str.to_bytes().len(), - ); - slice.zeroize(); - - wgFreePtr(config_str.as_ptr() as *mut c_void); - } - - result + self.tunnel_handle + .get_config(|cstr| { + Stats::parse_config_str(cstr.to_str().expect("Go strings are always UTF-8")) + }) + .ok_or(TunnelError::GetConfigError)? + .map_err(|error| TunnelError::StatsError(BoxedError::new(error))) } - fn stop(mut self: Box<Self>) -> Result<()> { - self.stop_tunnel() + fn stop(self: Box<Self>) -> Result<()> { + self.tunnel_handle + .turn_off() + .map_err(|e| TunnelError::StopWireguardError(Box::new(e))) } fn set_config( - &self, + &mut self, config: Config, - ) -> Pin<Box<dyn Future<Output = std::result::Result<(), super::TunnelError>> + Send>> { - let wg_config_str = config.to_userspace_format(); - let handle = self.handle.unwrap(); - #[cfg(target_os = "android")] - let tun_provider = self.tun_provider.clone(); + ) -> Pin<Box<dyn Future<Output = Result<()>> + Send + '_>> { Box::pin(async move { - let status = unsafe { wgSetConfig(handle, wg_config_str.as_ptr() as _) }; - if status != 0 { - return Err(TunnelError::SetConfigError); - } + let wg_config_str = config.to_userspace_format(); + + self.tunnel_handle + .set_config(&wg_config_str) + .map_err(|_| TunnelError::SetConfigError)?; + + #[cfg(target_os = "android")] + let tun_provider = self.tun_provider.clone(); // When reapplying the config, the endpoint socket may be discarded // and needs to be excluded again #[cfg(target_os = "android")] { - let socket_v4 = unsafe { wgGetSocketV4(handle) }; - let socket_v6 = unsafe { wgGetSocketV6(handle) }; + let socket_v4 = self.tunnel_handle.get_socket_v4(); + let socket_v6 = self.tunnel_handle.get_socket_v6(); let mut provider = tun_provider.lock().unwrap(); provider .bypass(socket_v4) @@ -264,68 +239,32 @@ impl Tunnel for WgGoTunnel { Ok(()) }) } -} -fn check_wg_status(wg_code: i32) -> Result<()> { - match wg_code { - ERROR_GENERAL_FAILURE => Err(TunnelError::FatalStartWireguardError), - ERROR_INTERMITTENT_FAILURE => Err(TunnelError::RecoverableStartWireguardError), - 0.. => Ok(()), - _ => { - log::error!("Unknown status code returned from wireguard-go"); - Err(TunnelError::FatalStartWireguardError) - } - } -} - -pub type Fd = std::os::unix::io::RawFd; - -const ERROR_GENERAL_FAILURE: i32 = -1; -const ERROR_INTERMITTENT_FAILURE: i32 = -2; - -extern "C" { - /// Creates a new wireguard tunnel, uses the specific interface name, MTU and file descriptors - /// for the tunnel device and logging. - /// - /// Positive return values are tunnel handles for this specific wireguard tunnel instance. - /// Negative return values signify errors. All error codes are opaque. - #[cfg(not(target_os = "android"))] - fn wgTurnOn( - mtu: isize, - settings: *const i8, - fd: Fd, - logging_callback: Option<logging::LoggingCallback>, - logging_context: *mut c_void, - ) -> i32; - - // Android - #[cfg(target_os = "android")] - fn wgTurnOn( - settings: *const i8, - fd: Fd, - logging_callback: Option<logging::LoggingCallback>, - logging_context: *mut c_void, - ) -> i32; - - // Pass a handle that was created by wgTurnOn to stop a wireguard tunnel. - fn wgTurnOff(handle: i32) -> i32; + fn start_daita(&mut self) -> Result<()> { + static MAYBENOT_MACHINES: OnceCell<CString> = OnceCell::new(); + let machines = MAYBENOT_MACHINES.get_or_try_init(|| { + let path = self.resource_dir.join("maybenot_machines"); + log::debug!("Reading maybenot machines from {}", path.display()); - // Returns the file descriptor of the tunnel IPv4 socket. - fn wgGetConfig(handle: i32) -> *mut c_char; + // TODO: errors + let machines = fs::read_to_string(path).unwrap(); + let machines = CString::new(machines).unwrap(); + Ok(machines) + })?; - // Sets the config of the WireGuard interface. - fn wgSetConfig(handle: i32, settings: *const i8) -> i32; - - // Frees a pointer allocated by the go runtime - useful to free return value of wgGetConfig - fn wgFreePtr(ptr: *mut c_void); - - // Returns the file descriptor of the tunnel IPv4 socket. - #[cfg(target_os = "android")] - fn wgGetSocketV4(handle: i32) -> Fd; + log::info!("Initializing DAITA for wireguard device"); + let peer_public_key = &self.config.entry_peer.public_key; + self.tunnel_handle + .activate_daita( + peer_public_key.as_bytes(), + machines, + DAITA_EVENTS_CAPACITY, + DAITA_ACTIONS_CAPACITY, + ) + .map_err(|e| TunnelError::StartDaita(Box::new(e)))?; - // Returns the file descriptor of the tunnel IPv6 socket. - #[cfg(target_os = "android")] - fn wgGetSocketV6(handle: i32) -> Fd; + Ok(()) + } } mod stats { @@ -438,13 +377,13 @@ mod stats { mod logging { use super::super::logging::{log, LogLevel}; - use std::ffi::{c_char, c_void}; + use std::ffi::c_char; // Callback that receives messages from WireGuard pub unsafe extern "system" fn wg_go_logging_callback( level: WgLogLevel, msg: *const c_char, - context: *mut c_void, + context: u64, ) { let managed_msg = if !msg.is_null() { std::ffi::CStr::from_ptr(msg).to_string_lossy().to_string() @@ -457,7 +396,7 @@ mod logging { _ => LogLevel::Error, }; - log(context as u32, level, "wireguard-go", &managed_msg); + log(context, level, "wireguard-go", &managed_msg); } // wireguard-go supports log levels 0 through 3 with 3 being the most verbose @@ -466,7 +405,4 @@ mod logging { const WG_GO_LOG_VERBOSE: WgLogLevel = 2; pub type WgLogLevel = u32; - - pub type LoggingCallback = - unsafe extern "system" fn(level: WgLogLevel, msg: *const c_char, context: *mut c_void); } diff --git a/talpid-wireguard/src/wireguard_kernel/netlink_tunnel.rs b/talpid-wireguard/src/wireguard_kernel/netlink_tunnel.rs index 579bcde65a..bdc187e1bd 100644 --- a/talpid-wireguard/src/wireguard_kernel/netlink_tunnel.rs +++ b/talpid-wireguard/src/wireguard_kernel/netlink_tunnel.rs @@ -93,7 +93,7 @@ impl Tunnel for NetlinkTunnel { tokio_handle.block_on(async move { if let Err(err) = netlink_connections.delete_device(interface_index).await { log::error!("Failed to remove WireGuard device: {}", err); - Err(TunnelError::FatalStartWireguardError) + Err(TunnelError::FatalStartWireguardError(Box::new(err))) } else { Ok(()) } @@ -113,7 +113,7 @@ impl Tunnel for NetlinkTunnel { } fn set_config( - &self, + &mut self, config: Config, ) -> Pin<Box<dyn Future<Output = std::result::Result<(), TunnelError>> + Send + 'static>> { let mut wg = self.netlink_connections.wg_handle.clone(); @@ -127,4 +127,9 @@ impl Tunnel for NetlinkTunnel { }) }) } + + // TODO: We shouldn't force `NetlinkTunnel` to implement `start_daita` + fn start_daita(&mut self) -> std::result::Result<(), TunnelError> { + unreachable!("Netlink tunnel does not support DAITA") + } } diff --git a/talpid-wireguard/src/wireguard_kernel/nm_tunnel.rs b/talpid-wireguard/src/wireguard_kernel/nm_tunnel.rs index 7c24c42a70..9fc9351591 100644 --- a/talpid-wireguard/src/wireguard_kernel/nm_tunnel.rs +++ b/talpid-wireguard/src/wireguard_kernel/nm_tunnel.rs @@ -70,7 +70,7 @@ impl Tunnel for NetworkManagerTunnel { if let Some(tunnel) = self.tunnel.take() { if let Err(err) = self.network_manager.remove_tunnel(tunnel) { log::error!("Failed to remove WireGuard tunnel via NM: {}", err); - Err(TunnelError::StopWireguardError { status: 0 }) + Err(TunnelError::StopWireguardError(Box::new(err))) } else { Ok(()) } @@ -94,7 +94,7 @@ impl Tunnel for NetworkManagerTunnel { } fn set_config( - &self, + &mut self, config: Config, ) -> Pin<Box<dyn Future<Output = std::result::Result<(), TunnelError>> + Send>> { let interface_name = self.interface_name.clone(); @@ -110,6 +110,11 @@ impl Tunnel for NetworkManagerTunnel { }) }) } + + // TODO: We shouldn't force `NetworkManagerTunnel` tunnel to implement `start_daita` + fn start_daita(&mut self) -> std::result::Result<(), TunnelError> { + unreachable!("NetworkManager tunnel does not support DAITA") + } } fn convert_config_to_dbus(config: &Config) -> DeviceConfig { diff --git a/talpid-wireguard/src/wireguard_nt/mod.rs b/talpid-wireguard/src/wireguard_nt/mod.rs index 375f28844a..de32c8d83e 100644 --- a/talpid-wireguard/src/wireguard_nt/mod.rs +++ b/talpid-wireguard/src/wireguard_nt/mod.rs @@ -70,7 +70,6 @@ type WireGuardGetConfigurationFn = unsafe extern "stdcall" fn( type WireGuardSetStateFn = unsafe extern "stdcall" fn(adapter: RawHandle, state: WgAdapterState) -> BOOL; -#[cfg(windows)] #[repr(C)] #[allow(dead_code)] enum LogLevel { @@ -79,7 +78,6 @@ enum LogLevel { Err = 2, } -#[cfg(windows)] impl From<LogLevel> for logging::LogLevel { fn from(level: LogLevel) -> Self { match level { @@ -430,7 +428,7 @@ impl WgNtTunnel { match error { Error::CreateTunnelDevice(error) => super::TunnelError::SetupTunnelDevice(error), - _ => super::TunnelError::FatalStartWireguardError, + _ => super::TunnelError::FatalStartWireguardError(Box::new(error)), } }) } @@ -542,11 +540,11 @@ impl Drop for WgNtTunnel { } } -static LOG_CONTEXT: Lazy<Mutex<Option<u32>>> = Lazy::new(|| Mutex::new(None)); +static LOG_CONTEXT: Lazy<Mutex<Option<u64>>> = Lazy::new(|| Mutex::new(None)); struct LoggerHandle { dll: &'static WgNtDll, - context: u32, + context: u64, } impl LoggerHandle { @@ -1080,7 +1078,7 @@ impl Tunnel for WgNtTunnel { } fn set_config( - &self, + &mut self, config: Config, ) -> Pin<Box<dyn Future<Output = std::result::Result<(), super::TunnelError>> + Send>> { let device = self.device.clone(); @@ -1145,7 +1143,6 @@ mod tests { allowed_ips: vec!["1.3.3.0/24".parse().unwrap()], endpoint: "1.2.3.4:1234".parse().unwrap(), psk: None, - #[cfg(target_os = "windows")] constant_packet_size: false, }, exit_peer: None, @@ -1181,7 +1178,6 @@ mod tests { rx_bytes: 0, last_handshake: 0, allowed_ips_count: 1, - #[cfg(target_os = "windows")] constant_packet_size: 0, }, p0_allowed_ip_0: WgAllowedIp { diff --git a/wireguard-go-rs/Cargo.toml b/wireguard-go-rs/Cargo.toml new file mode 100644 index 0000000000..3725787bf6 --- /dev/null +++ b/wireguard-go-rs/Cargo.toml @@ -0,0 +1,10 @@ +[package] +name = "wireguard-go-rs" +description = "Rust bindings to wireguard-go with DAITA support" +edition = "2021" +license.workspace = true + +[dependencies] +thiserror.workspace = true +log.workspace = true +zeroize = "1.8.1" diff --git a/wireguard-go-rs/README.md b/wireguard-go-rs/README.md new file mode 100644 index 0000000000..24a08fa54b --- /dev/null +++ b/wireguard-go-rs/README.md @@ -0,0 +1,10 @@ +# `wireguard-go-rs` +This crate wraps `libwg`, which in turn wraps [Mullvad VPN's fork of wireguard-go](https://github.com/mullvad/wireguard-go) which extends `wireguard-go` with [DAITA](https://mullvad.net/en/blog/introducing-defense-against-ai-guided-traffic-analysis-daita). + +## Known limitation +To extend `wireguard-go` with DAITA capabilities, it statically links against [maybenot](https://github.com/maybenot-io/maybenot/), which at the time of writing will cause issues if it in turn is statically linked from another Rust crate: https://github.com/rust-lang/rust/issues/104707. +As such, `libwg` is built as a shared object which you have to link to dynamically. +To get rid of this limitation, you could compile `wireguard-go` without DAITA support. See [build-wireguard-go.sh](./build-wireguard-go.sh) for details. + +## Upgrading `wireguard-go` +Upgrading `wireguard-go` involves updating the git submodule found in `libwg/wireguard-go`. This module uses [Mullvad VPN's fork of wireguard-go](https://github.com/mullvad/wireguard-go). diff --git a/wireguard/build-wireguard-go.sh b/wireguard-go-rs/build-wireguard-go.sh index 7e3b11910d..e270fb8219 100755 --- a/wireguard/build-wireguard-go.sh +++ b/wireguard-go-rs/build-wireguard-go.sh @@ -1,30 +1,27 @@ #!/usr/bin/env bash # This script is used to build wireguard-go libraries for all the platforms. +# +# If "DAITA" support should be enabled, pass the `--daita` flag when invoking this script. set -eu -function is_android_build { - for arg in "$@" - do - case "$arg" in - "--android") - return 0 - esac - done - return 1 -} +# If Wireguard-go should be built with DAITA-support. +DAITA="false" +# If the target OS is Adnroid. +ANDROID="false" -function is_docker_build { - for arg in "$@" - do - case "$arg" in - "--no-docker") - return 1 - esac - done - return 0 -} +while [[ "$#" -gt 0 ]]; do + case $1 in + --android) ANDROID="true";; + --daita) DAITA="true";; + *) + log_error "Unknown parameter: $1" + exit 1 + ;; + esac + shift +done function unix_target_triple { local platform @@ -48,6 +45,7 @@ function unix_target_triple { function build_unix { + # TODO: consider using `log_header` here echo "Building wireguard-go for $1" # Flags for cross compiling @@ -80,33 +78,35 @@ function build_unix { fi fi - pushd libwg - target_triple_dir="../../build/lib/$1" - mkdir -p "$target_triple_dir" - go build -v -o "$target_triple_dir"/libwg.a -buildmode c-archive + # Build wiregaurd-go as a library + pushd libwg + if [[ "$DAITA" == "true" ]]; then + pushd wireguard-go + make libmaybenot.a LIBDEST="$OUT_DIR" + popd + go build -v --tags daita -o "$OUT_DIR"/libwg.a -buildmode c-archive + else + go build -v -o "$OUT_DIR"/libwg.a -buildmode c-archive + fi popd } function build_android { - echo "Building for android" + echo "Building wireguard-go for android" - if is_docker_build "$@"; then - ../building/container-run.sh android wireguard/libwg/build-android.sh - else - ./libwg/build-android.sh - fi + ./libwg/build-android.sh } function build_wireguard_go { - if is_android_build "$@"; then + if [[ "$ANDROID" == "true" ]]; then build_android "$@" return fi local platform platform="$(uname -s)"; - case "$platform" in + case "$platform" in Linux*|Darwin*) build_unix "${1:-$(unix_target_triple)}";; *) echo "Unsupported platform" diff --git a/wireguard-go-rs/build.rs b/wireguard-go-rs/build.rs new file mode 100644 index 0000000000..82a8ff254b --- /dev/null +++ b/wireguard-go-rs/build.rs @@ -0,0 +1,65 @@ +use core::{panic, str}; +use std::{env, path::PathBuf}; + +fn main() { + let out_dir = env::var("OUT_DIR").expect("Missing OUT_DIR"); + eprintln!("OUT_DIR: {out_dir}"); + + let target_os = env::var("CARGO_CFG_TARGET_OS").expect("Missing 'CARGO_CFG_TARGET_OS"); + let mut cmd = std::process::Command::new("bash"); + cmd.arg("./build-wireguard-go.sh"); + + match target_os.as_str() { + "linux" => { + // Enable DAITA & Tell rustc to link libmaybenot + println!("cargo::rustc-link-lib=static=maybenot"); + // Tell the build script to build wireguard-go with DAITA support + cmd.arg("--daita"); + } + "android" => { + cmd.arg("--android"); + } + "macos" => {} + // building wireguard-go-rs for windows is not implemented + _ => return, + } + + let output = cmd.output().expect("build-wireguard-go.sh failed"); + if !output.status.success() { + let stdout = str::from_utf8(&output.stdout).unwrap(); + let stderr = str::from_utf8(&output.stderr).unwrap(); + eprintln!("build-wireguard-go.sh failed."); + eprintln!("stdout:\n{stdout}"); + eprintln!("stderr:\n{stderr}"); + panic!(); + } + + if target_os == "android" { + // NOTE: Go programs does not support being statically linked on android + // so we need to dynamically link to libwg + println!("cargo::rustc-link-lib=wg"); + declare_libs_dir("../build/lib"); + } else { + // other platforms can statically link to libwg just fine + // TODO: consider doing dynamic linking everywhere, to keep things simpler + println!("cargo::rustc-link-lib=static=wg"); + println!("cargo::rustc-link-search={out_dir}"); + } + + println!("cargo::rerun-if-changed=libwg"); +} + +/// Tell linker to check `base`/$TARGET for shared libraries. +fn declare_libs_dir(base: &str) { + let target_triplet = env::var("TARGET").expect("TARGET is not set"); + let lib_dir = manifest_dir().join(base).join(target_triplet); + println!("cargo::rerun-if-changed={}", lib_dir.display()); + println!("cargo::rustc-link-search={}", lib_dir.display()); +} + +/// Get the directory containing `Cargo.toml` +fn manifest_dir() -> PathBuf { + env::var("CARGO_MANIFEST_DIR") + .map(PathBuf::from) + .expect("CARGO_MANIFEST_DIR env var not set") +} diff --git a/wireguard/libwg/Android.mk b/wireguard-go-rs/libwg/Android.mk index acf2f6fe88..acf2f6fe88 100644 --- a/wireguard/libwg/Android.mk +++ b/wireguard-go-rs/libwg/Android.mk diff --git a/wireguard/libwg/README.md b/wireguard-go-rs/libwg/README.md index 39ad48e3e0..39ad48e3e0 100644 --- a/wireguard/libwg/README.md +++ b/wireguard-go-rs/libwg/README.md diff --git a/wireguard/libwg/build-android.sh b/wireguard-go-rs/libwg/build-android.sh index 750438c503..750438c503 100755 --- a/wireguard/libwg/build-android.sh +++ b/wireguard-go-rs/libwg/build-android.sh diff --git a/wireguard/libwg/go.mod b/wireguard-go-rs/libwg/go.mod index 2f65c7bf7b..76627dcb7f 100644 --- a/wireguard/libwg/go.mod +++ b/wireguard-go-rs/libwg/go.mod @@ -10,5 +10,7 @@ require ( require ( golang.org/x/crypto v0.22.0 // indirect golang.org/x/net v0.24.0 // indirect - golang.zx2c4.com/wintun v0.0.0-20211104114900-415007cec224 // indirect + golang.zx2c4.com/wintun v0.0.0-20230126152724-0fa3db229ce2 // indirect ) + +replace golang.zx2c4.com/wireguard => ./wireguard-go diff --git a/wireguard/libwg/go.sum b/wireguard-go-rs/libwg/go.sum index c90f293b17..b41c5842d1 100644 --- a/wireguard/libwg/go.sum +++ b/wireguard-go-rs/libwg/go.sum @@ -4,7 +4,5 @@ golang.org/x/net v0.24.0 h1:1PcaxkF854Fu3+lvBIx5SYn9wRlBzzcnHZSiaFFAb0w= golang.org/x/net v0.24.0/go.mod h1:2Q7sJY5mzlzWjKtYUEXSlBWCdyaioyXzRB2RtU8KVE8= golang.org/x/sys v0.19.0 h1:q5f1RH2jigJ1MoAWp2KTp3gm5zAGFUTarQZ5U386+4o= golang.org/x/sys v0.19.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= -golang.zx2c4.com/wintun v0.0.0-20211104114900-415007cec224 h1:Ug9qvr1myri/zFN6xL17LSCBGFDnphBBhzmILHsM5TY= -golang.zx2c4.com/wintun v0.0.0-20211104114900-415007cec224/go.mod h1:deeaetjYA+DHMHg+sMSMI58GrEteJUUzzw7en6TJQcI= -golang.zx2c4.com/wireguard v0.0.0-20230223181233-21636207a675 h1:/J/RVnr7ng4fWPRH3xa4WtBJ1Jp+Auu4YNLmGiPv5QU= -golang.zx2c4.com/wireguard v0.0.0-20230223181233-21636207a675/go.mod h1:whfbyDBt09xhCYQWtO2+3UVjlaq6/9hDZrjg2ZE6SyA= +golang.zx2c4.com/wintun v0.0.0-20230126152724-0fa3db229ce2 h1:B82qJJgjvYKsXS9jeunTOisW56dUokqW/FOteYJJ/yg= +golang.zx2c4.com/wintun v0.0.0-20230126152724-0fa3db229ce2/go.mod h1:deeaetjYA+DHMHg+sMSMI58GrEteJUUzzw7en6TJQcI= diff --git a/wireguard-go-rs/libwg/goruntime-boottime-over-monotonic.diff b/wireguard-go-rs/libwg/goruntime-boottime-over-monotonic.diff new file mode 100644 index 0000000000..5d78242b13 --- /dev/null +++ b/wireguard-go-rs/libwg/goruntime-boottime-over-monotonic.diff @@ -0,0 +1,171 @@ +From 61f3ae8298d1c503cbc31539e0f3a73446c7db9d Mon Sep 17 00:00:00 2001 +From: "Jason A. Donenfeld" <Jason@zx2c4.com> +Date: Tue, 21 Mar 2023 15:33:56 +0100 +Subject: [PATCH] [release-branch.go1.20] runtime: use CLOCK_BOOTTIME in + nanotime on Linux + +This makes timers account for having expired while a computer was +asleep, which is quite common on mobile devices. Note that BOOTTIME is +identical to MONOTONIC, except that it takes into account time spent +in suspend. In Linux 4.17, the kernel will actually make MONOTONIC act +like BOOTTIME anyway, so this switch will additionally unify the +timer behavior across kernels. + +BOOTTIME was introduced into Linux 2.6.39-rc1 with 70a08cca1227d in +2011. + +Fixes #24595 + +Change-Id: I7b2a6ca0c5bc5fce57ec0eeafe7b68270b429321 +--- + src/runtime/sys_linux_386.s | 4 ++-- + src/runtime/sys_linux_amd64.s | 2 +- + src/runtime/sys_linux_arm.s | 4 ++-- + src/runtime/sys_linux_arm64.s | 4 ++-- + src/runtime/sys_linux_mips64x.s | 4 ++-- + src/runtime/sys_linux_mipsx.s | 2 +- + src/runtime/sys_linux_ppc64x.s | 2 +- + src/runtime/sys_linux_s390x.s | 2 +- + 8 files changed, 12 insertions(+), 12 deletions(-) + +diff --git a/src/runtime/sys_linux_386.s b/src/runtime/sys_linux_386.s +index 12a294153d..17e3524b40 100644 +--- a/src/runtime/sys_linux_386.s ++++ b/src/runtime/sys_linux_386.s +@@ -352,13 +352,13 @@ noswitch: + + LEAL 8(SP), BX // &ts (struct timespec) + MOVL BX, 4(SP) +- MOVL $1, 0(SP) // CLOCK_MONOTONIC ++ MOVL $7, 0(SP) // CLOCK_BOOTTIME + CALL AX + JMP finish + + fallback: + MOVL $SYS_clock_gettime, AX +- MOVL $1, BX // CLOCK_MONOTONIC ++ MOVL $7, BX // CLOCK_BOOTTIME + LEAL 8(SP), CX + INVOKE_SYSCALL + +diff --git a/src/runtime/sys_linux_amd64.s b/src/runtime/sys_linux_amd64.s +index c7a89ba536..01f0a6a26e 100644 +--- a/src/runtime/sys_linux_amd64.s ++++ b/src/runtime/sys_linux_amd64.s +@@ -255,7 +255,7 @@ noswitch: + SUBQ $16, SP // Space for results + ANDQ $~15, SP // Align for C code + +- MOVL $1, DI // CLOCK_MONOTONIC ++ MOVL $7, DI // CLOCK_BOOTTIME + LEAQ 0(SP), SI + MOVQ runtime·vdsoClockgettimeSym(SB), AX + CMPQ AX, $0 +diff --git a/src/runtime/sys_linux_arm.s b/src/runtime/sys_linux_arm.s +index 7b8c4f0e04..9798a1334e 100644 +--- a/src/runtime/sys_linux_arm.s ++++ b/src/runtime/sys_linux_arm.s +@@ -11,7 +11,7 @@ + #include "textflag.h" + + #define CLOCK_REALTIME 0 +-#define CLOCK_MONOTONIC 1 ++#define CLOCK_BOOTTIME 7 + + // for EABI, as we don't support OABI + #define SYS_BASE 0x0 +@@ -374,7 +374,7 @@ finish: + + // func nanotime1() int64 + TEXT runtime·nanotime1(SB),NOSPLIT,$12-8 +- MOVW $CLOCK_MONOTONIC, R0 ++ MOVW $CLOCK_BOOTTIME, R0 + MOVW $spec-12(SP), R1 // timespec + + MOVW runtime·vdsoClockgettimeSym(SB), R4 +diff --git a/src/runtime/sys_linux_arm64.s b/src/runtime/sys_linux_arm64.s +index 38ff6ac330..6b819c5441 100644 +--- a/src/runtime/sys_linux_arm64.s ++++ b/src/runtime/sys_linux_arm64.s +@@ -14,7 +14,7 @@ + #define AT_FDCWD -100 + + #define CLOCK_REALTIME 0 +-#define CLOCK_MONOTONIC 1 ++#define CLOCK_BOOTTIME 7 + + #define SYS_exit 93 + #define SYS_read 63 +@@ -338,7 +338,7 @@ noswitch: + BIC $15, R1 + MOVD R1, RSP + +- MOVW $CLOCK_MONOTONIC, R0 ++ MOVW $CLOCK_BOOTTIME, R0 + MOVD runtime·vdsoClockgettimeSym(SB), R2 + CBZ R2, fallback + +diff --git a/src/runtime/sys_linux_mips64x.s b/src/runtime/sys_linux_mips64x.s +index 47f2da524d..a8b387f193 100644 +--- a/src/runtime/sys_linux_mips64x.s ++++ b/src/runtime/sys_linux_mips64x.s +@@ -326,7 +326,7 @@ noswitch: + AND $~15, R1 // Align for C code + MOVV R1, R29 + +- MOVW $1, R4 // CLOCK_MONOTONIC ++ MOVW $7, R4 // CLOCK_BOOTTIME + MOVV $0(R29), R5 + + MOVV runtime·vdsoClockgettimeSym(SB), R25 +@@ -336,7 +336,7 @@ noswitch: + // see walltime for detail + BEQ R2, R0, finish + MOVV R0, runtime·vdsoClockgettimeSym(SB) +- MOVW $1, R4 // CLOCK_MONOTONIC ++ MOVW $7, R4 // CLOCK_BOOTTIME + MOVV $0(R29), R5 + JMP fallback + +diff --git a/src/runtime/sys_linux_mipsx.s b/src/runtime/sys_linux_mipsx.s +index 5e6b6c1504..7f5fd2a80e 100644 +--- a/src/runtime/sys_linux_mipsx.s ++++ b/src/runtime/sys_linux_mipsx.s +@@ -243,7 +243,7 @@ TEXT runtime·walltime(SB),NOSPLIT,$8-12 + RET + + TEXT runtime·nanotime1(SB),NOSPLIT,$8-8 +- MOVW $1, R4 // CLOCK_MONOTONIC ++ MOVW $7, R4 // CLOCK_BOOTTIME + MOVW $4(R29), R5 + MOVW $SYS_clock_gettime, R2 + SYSCALL +diff --git a/src/runtime/sys_linux_ppc64x.s b/src/runtime/sys_linux_ppc64x.s +index d0427a4807..05ee9fede9 100644 +--- a/src/runtime/sys_linux_ppc64x.s ++++ b/src/runtime/sys_linux_ppc64x.s +@@ -298,7 +298,7 @@ fallback: + JMP return + + TEXT runtime·nanotime1(SB),NOSPLIT,$16-8 +- MOVD $1, R3 // CLOCK_MONOTONIC ++ MOVD $7, R3 // CLOCK_BOOTTIME + + MOVD R1, R15 // R15 is unchanged by C code + MOVD g_m(g), R21 // R21 = m +diff --git a/src/runtime/sys_linux_s390x.s b/src/runtime/sys_linux_s390x.s +index 1448670b91..7d2ee3231c 100644 +--- a/src/runtime/sys_linux_s390x.s ++++ b/src/runtime/sys_linux_s390x.s +@@ -296,7 +296,7 @@ fallback: + RET + + TEXT runtime·nanotime1(SB),NOSPLIT,$32-8 +- MOVW $1, R2 // CLOCK_MONOTONIC ++ MOVW $7, R2 // CLOCK_BOOTTIME + + MOVD R15, R7 // Backup stack pointer + +-- +2.17.1 + diff --git a/wireguard/libwg/libwg.go b/wireguard-go-rs/libwg/libwg.go index e26ea7b7da..aaa03ef838 100644 --- a/wireguard/libwg/libwg.go +++ b/wireguard-go-rs/libwg/libwg.go @@ -6,7 +6,9 @@ package main +// #include <stdio.h> // #include <stdlib.h> +// #include <stdint.h> import "C" import ( @@ -17,11 +19,31 @@ import ( "unsafe" "github.com/mullvad/mullvadvpn-app/wireguard/libwg/tunnelcontainer" + "golang.zx2c4.com/wireguard/device" ) +// FFI integer result codes +// NOTE: Must be kept in sync with the Error enum in wireguard-go-rs const ( - ERROR_GENERAL_FAILURE = -1 - ERROR_INTERMITTENT_FAILURE = -2 + OK = C.int32_t(-iota) + + // Something went wrong. + ERROR_GENERAL_FAILURE + + // Something went wrong, but trying again might help. + ERROR_INTERMITTENT_FAILURE + + // A bad argument was provided to libwg. + ERROR_INVALID_ARGUMENT + + // The provided tunnel handle did not refer to an existing tunnel. + ERROR_UNKNOWN_TUNNEL + + // The provided public key did not refer to an existing peer. + ERROR_UNKNOWN_PEER + + // Something went wrong when enabling DAITA. + ERROR_ENABLE_DAITA ) var tunnels tunnelcontainer.Container @@ -30,6 +52,11 @@ func init() { tunnels = tunnelcontainer.New() } +type EventContext struct { + tunnelHandle int32 + peer device.NoisePublicKey +} + //export wgTurnOff func wgTurnOff(tunnelHandle int32) { { @@ -61,14 +88,14 @@ func wgGetConfig(tunnelHandle int32) *C.char { } //export wgSetConfig -func wgSetConfig(tunnelHandle int32, cSettings *C.char) int32 { +func wgSetConfig(tunnelHandle int32, cSettings *C.char) C.int32_t { tunnel, err := tunnels.Get(tunnelHandle) if err != nil { - return ERROR_GENERAL_FAILURE + return ERROR_UNKNOWN_TUNNEL } if cSettings == nil { tunnel.Logger.Errorf("cSettings is null\n") - return ERROR_GENERAL_FAILURE + return ERROR_INVALID_ARGUMENT } settings := C.GoString(cSettings) diff --git a/wireguard-go-rs/libwg/libwg.h b/wireguard-go-rs/libwg/libwg.h new file mode 100644 index 0000000000..23c87092cc --- /dev/null +++ b/wireguard-go-rs/libwg/libwg.h @@ -0,0 +1,8 @@ +#include <stdint.h> +#include <stdbool.h> + +/// Activate DAITA for the specified tunnel. +int32_t wgActivateDaita(int32_t tunnelHandle, uint8_t* noisePublic, char* machines, uint32_t eventsCapacity, uint32_t actionsCapacity); +char* wgGetConfig(int32_t tunnelHandle); +int32_t wgSetConfig(int32_t tunnelHandle, char* cSettings); +void wgFreePtr(void*); diff --git a/wireguard/libwg/libwg_android.go b/wireguard-go-rs/libwg/libwg_android.go index df54d4cc8b..86410721f5 100644 --- a/wireguard/libwg/libwg_android.go +++ b/wireguard-go-rs/libwg/libwg_android.go @@ -6,8 +6,10 @@ package main +// #include <stdint.h> +import "C" + import ( - "C" "bufio" "strings" "unsafe" @@ -25,15 +27,15 @@ import ( // Redefined here because otherwise the compiler doesn't realize it's a type alias for a type that's safe to export. // Taken from the contained logging package. type LogSink = unsafe.Pointer -type LogContext = unsafe.Pointer +type LogContext = C.uint64_t //export wgTurnOn -func wgTurnOn(cSettings *C.char, fd int, logSink LogSink, logContext LogContext) int32 { - logger := logging.NewLogger(logSink, logContext) +func wgTurnOn(cSettings *C.char, fd int, logSink LogSink, logContext LogContext) C.int32_t { + logger := logging.NewLogger(logSink, logging.LogContext(logContext)) if cSettings == nil { logger.Errorf("cSettings is null\n") - return ERROR_GENERAL_FAILURE + return ERROR_INVALID_ARGUMENT } settings := C.GoString(cSettings) @@ -71,33 +73,33 @@ func wgTurnOn(cSettings *C.char, fd int, logSink LogSink, logContext LogContext) return ERROR_GENERAL_FAILURE } - return handle + return C.int32_t(handle) } //export wgGetSocketV4 -func wgGetSocketV4(tunnelHandle int32) int32 { +func wgGetSocketV4(tunnelHandle int32) C.int32_t { tunnel, err := tunnels.Get(tunnelHandle) if err != nil { - return ERROR_GENERAL_FAILURE + return ERROR_UNKNOWN_TUNNEL } peek := tunnel.Device.Bind().(conn.PeekLookAtSocketFd) fd, err := peek.PeekLookAtSocketFd4() if err != nil { return ERROR_GENERAL_FAILURE } - return int32(fd) + return C.int32_t(fd) } //export wgGetSocketV6 -func wgGetSocketV6(tunnelHandle int32) int32 { +func wgGetSocketV6(tunnelHandle int32) C.int32_t { tunnel, err := tunnels.Get(tunnelHandle) if err != nil { - return ERROR_GENERAL_FAILURE + return ERROR_UNKNOWN_TUNNEL } peek := tunnel.Device.Bind().(conn.PeekLookAtSocketFd) fd, err := peek.PeekLookAtSocketFd6() if err != nil { return ERROR_GENERAL_FAILURE } - return int32(fd) + return C.int32_t(fd) } diff --git a/wireguard-go-rs/libwg/libwg_daita.go b/wireguard-go-rs/libwg/libwg_daita.go new file mode 100644 index 0000000000..e33de84e49 --- /dev/null +++ b/wireguard-go-rs/libwg/libwg_daita.go @@ -0,0 +1,41 @@ +//go:build daita +// +build daita + +package main + +// #include <stdio.h> +// #include <stdlib.h> +// #include <stdint.h> +import "C" + +import ( + "unsafe" + + "golang.zx2c4.com/wireguard/device" +) + +const maxPaddingBytes = 0.0 +const maxBlockingBytes = 0.0 + +//export wgActivateDaita +func wgActivateDaita(tunnelHandle C.int32_t, peerPubkey *C.uint8_t, machines *C.char, eventsCapacity C.uint32_t, actionsCapacity C.uint32_t) C.int32_t { + + tunnel, err := tunnels.Get(int32(tunnelHandle)) + if err != nil { + return ERROR_UNKNOWN_TUNNEL + } + + var publicKey device.NoisePublicKey + copy(publicKey[:], C.GoBytes(unsafe.Pointer(peerPubkey), device.NoisePublicKeySize)) + peer := tunnel.Device.LookupPeer(publicKey) + + if peer == nil { + return ERROR_UNKNOWN_PEER + } + + if !peer.EnableDaita(C.GoString((*C.char)(machines)), uint(eventsCapacity), uint(actionsCapacity), maxPaddingBytes, maxBlockingBytes) { + return ERROR_ENABLE_DAITA + } + + return OK +} diff --git a/wireguard/libwg/libwg_default.go b/wireguard-go-rs/libwg/libwg_default.go index 7282c0ca8a..6741de715f 100644 --- a/wireguard/libwg/libwg_default.go +++ b/wireguard-go-rs/libwg/libwg_default.go @@ -1,3 +1,4 @@ +//go:build (darwin || linux) && !android // +build darwin linux // +build !android @@ -10,6 +11,7 @@ package main // #include <stdlib.h> +// #include <stdint.h> import "C" import ( "bufio" @@ -28,16 +30,15 @@ import ( // Redefined here because otherwise the compiler doesn't realize it's a type alias for a type that's safe to export. // Taken from the contained logging package. type LogSink = unsafe.Pointer -type LogContext = unsafe.Pointer +type LogContext = C.uint64_t //export wgTurnOn -func wgTurnOn(mtu int, cSettings *C.char, fd int, logSink LogSink, logContext LogContext) int32 { - - logger := logging.NewLogger(logSink, logContext) +func wgTurnOn(mtu int, cSettings *C.char, fd int, logSink LogSink, logContext LogContext) C.int32_t { + logger := logging.NewLogger(logSink, logging.LogContext(logContext)) if cSettings == nil { logger.Errorf("cSettings is null\n") - return ERROR_GENERAL_FAILURE + return ERROR_INVALID_ARGUMENT } settings := C.GoString(cSettings) @@ -74,5 +75,5 @@ func wgTurnOn(mtu int, cSettings *C.char, fd int, logSink LogSink, logContext Lo return ERROR_GENERAL_FAILURE } - return handle + return C.int32_t(handle) } diff --git a/wireguard/libwg/logging/logging.go b/wireguard-go-rs/libwg/logging/logging.go index a917e96493..a6782ec39a 100644 --- a/wireguard/libwg/logging/logging.go +++ b/wireguard-go-rs/libwg/logging/logging.go @@ -7,12 +7,13 @@ package logging // #include <stdlib.h> +// #include <stdint.h> // #include <sys/types.h> // #ifndef WIN32 // #define __stdcall // #endif -// typedef void (__stdcall *LogSink)(unsigned int, const char *, void *); -// static void callLogSink(void *logSink, int level, const char *message, void *context) +// typedef void (__stdcall *LogSink)(unsigned int, const char *, uint64_t); +// static void callLogSink(void *logSink, int level, const char *message, uint64_t context) // { // ((LogSink)logSink)((unsigned int)level, message, context); // } @@ -27,7 +28,7 @@ import ( // Define type aliases. type LogSink = unsafe.Pointer -type LogContext = unsafe.Pointer +type LogContext = C.uint64_t type Logger struct { sink LogSink diff --git a/wireguard/libwg/tunnelcontainer/tunnelcontainer.go b/wireguard-go-rs/libwg/tunnelcontainer/tunnelcontainer.go index 91291dcf4b..91291dcf4b 100644 --- a/wireguard/libwg/tunnelcontainer/tunnelcontainer.go +++ b/wireguard-go-rs/libwg/tunnelcontainer/tunnelcontainer.go diff --git a/wireguard-go-rs/libwg/wireguard-go b/wireguard-go-rs/libwg/wireguard-go new file mode 160000 +Subproject f4bc3aefeb6c6ae50567a4bfc177593337eb9cb diff --git a/wireguard-go-rs/src/lib.rs b/wireguard-go-rs/src/lib.rs new file mode 100644 index 0000000000..f87b2586d0 --- /dev/null +++ b/wireguard-go-rs/src/lib.rs @@ -0,0 +1,288 @@ +//! This crate provides Rust bindings to wireguard-go with DAITA support. +//! +//! The bindings on the Go side are provided by `libwg`, which is a Go package that wraps +//! `wireguard-go` and provides a C FFI that we can use from Rust. On the Rust side, the FFI is +//! in the private `ffi` module below. It needs to be kept in sync with any changes to libwg. +//! +//! The [`Tunnel`] type provides a safe Rust wrapper around the C FFI. + +#![cfg(unix)] + +use core::slice; +use std::{ + ffi::{c_char, CStr}, + mem::ManuallyDrop, +}; +use util::OnDrop; +use zeroize::Zeroize; + +mod util; + +pub type Fd = std::os::unix::io::RawFd; + +pub type WgLogLevel = u32; + +pub type LoggingContext = u64; +pub type LoggingCallback = + unsafe extern "system" fn(level: WgLogLevel, msg: *const c_char, context: LoggingContext); + +/// A wireguard-go tunnel +pub struct Tunnel { + /// wireguard-go handle to the tunnel. + handle: i32, +} + +// NOTE: Must be kept in sync with libwg.go +// NOTE: must be kept in sync with `result_from_code` +// INVARIANT: Will always be represented as a negative i32 +#[repr(i32)] +#[non_exhaustive] +#[derive(Clone, Copy, Debug, thiserror::Error)] +pub enum Error { + #[error("Something went wrong.")] + GeneralFailure = -1, + + #[error("Something went wrong, but trying again might help.")] + IntermittentFailure = -2, + + #[error("An argument you provided was invalid.")] + InvalidArgument = -3, + + #[error("The tunnel handle did not refer to an existing tunnel.")] + UnknownTunnel = -4, + + #[error("The provided public key did not refer to an existing peer.")] + UnknownPeer = -5, + + #[error("Something went wrong when enabling DAITA.")] + EnableDaita = -6, + + #[error("`libwg` provided an unknown error code. This is a bug.")] + Other = i32::MIN, +} + +impl Tunnel { + /// Creates a new wireguard tunnel, uses the specific interface name, and file descriptors + /// for the tunnel device and logging. For targets other than android, this also takes an MTU + /// value. + /// + /// The `logging_callback` let's you provide a Rust function that receives any logging output + /// from wireguard-go. `logging_context` is a value that will be passed to each invocation of + /// `logging_callback`. + pub fn turn_on( + #[cfg(not(target_os = "android"))] mtu: isize, + settings: &CStr, + device: Fd, + logging_callback: Option<LoggingCallback>, + logging_context: LoggingContext, + ) -> Result<Self, Error> { + // SAFETY: pointer is valid for the the lifetime of this function + let code = unsafe { + ffi::wgTurnOn( + #[cfg(not(target_os = "android"))] + mtu, + settings.as_ptr(), + device, + logging_callback, + logging_context, + ) + }; + + result_from_code(code)?; + Ok(Tunnel { handle: code }) + } + + /// Stop the wireguard tunnel. This also happens automatically if the [`Tunnel`] is dropped. + pub fn turn_off(self) -> Result<(), Error> { + // we manually turn off the tunnel here, so wrap it in ManuallyDrop to prevent the Drop + // impl from doing the same. + let code = unsafe { ffi::wgTurnOff(self.handle) }; + let _ = ManuallyDrop::new(self); + result_from_code(code) + } + + /// Get the config of the WireGuard interface and make it available in the provided function. + /// + /// This takes a function to make sure the cstr get's zeroed and freed afterwards. + /// Returns `None` if the call to wgGetConfig returned nil. + /// + /// **NOTE:** You should take extra care to avoid copying any secrets from the config without + /// zeroizing them afterwards. + // NOTE: this could return a guard type with a custom Drop impl instead, but me lazy. + pub fn get_config<T>(&self, f: impl FnOnce(&CStr) -> T) -> Option<T> { + let ptr = unsafe { ffi::wgGetConfig(self.handle) }; + + if ptr.is_null() { + return None; + } + + // SAFETY: we checked for null, and wgGetConfig promises that this is a valid cstr + let config = unsafe { CStr::from_ptr(ptr) }; + let config_len = config.to_bytes().len(); + + // execute cleanup code on Drop to make sure that it happens even if `f` panics + let on_drop = OnDrop::new(|| { + { + // SAFETY: + // we checked for null, and wgGetConfig promises that this is a valid cstr. + // config_len comes from the CStr above, so it should be good. + let config_bytes = unsafe { slice::from_raw_parts_mut(ptr, config_len) }; + config_bytes.zeroize(); + } + + // SAFETY: the pointer was created by wgGetConfig, and we are no longer using it. + unsafe { ffi::wgFreePtr(ptr.cast()) }; + }); + + let t = f(config); + let _ = config; + drop(on_drop); + + Some(t) + } + + /// Set the config of the WireGuard interface. + pub fn set_config(&self, config: &CStr) -> Result<(), Error> { + // SAFETY: pointer is valid for the lifetime of this function. + let code = unsafe { ffi::wgSetConfig(self.handle, config.as_ptr()) }; + result_from_code(code) + } + + /// Activate DAITA for the specified peer. + /// + /// `machines` is a string containing LF-separated maybenot machines. + #[cfg(any(target_os = "windows", target_os = "linux"))] + pub fn activate_daita( + &self, + peer_public_key: &[u8; 32], + machines: &CStr, + events_capacity: u32, + actions_capacity: u32, + ) -> Result<(), Error> { + // SAFETY: pointers are valid for the lifetime of this function. + let code = unsafe { + ffi::wgActivateDaita( + self.handle, + peer_public_key.as_ptr(), + machines.as_ptr(), + events_capacity, + actions_capacity, + ) + }; + + result_from_code(code) + } + + /// Get the file descriptor of the tunnel IPv4 socket. + #[cfg(target_os = "android")] + pub fn get_socket_v4(&self) -> Fd { + unsafe { ffi::wgGetSocketV4(self.handle) } + } + + /// Get the file descriptor of the tunnel IPv6 socket. + #[cfg(target_os = "android")] + pub fn get_socket_v6(&self) -> Fd { + unsafe { ffi::wgGetSocketV6(self.handle) } + } +} + +impl Drop for Tunnel { + fn drop(&mut self) { + let code = unsafe { ffi::wgTurnOff(self.handle) }; + if let Err(e) = result_from_code(code) { + log::error!("Failed to stop wireguard-go tunnel,oerror_code={code} ({e:?})") + } + } +} + +fn result_from_code(code: i32) -> Result<(), Error> { + // NOTE: must be kept in sync with enum definition + Err(match code { + 0.. => return Ok(()), + -1 => Error::GeneralFailure, + -2 => Error::IntermittentFailure, + -3 => Error::UnknownTunnel, + -4 => Error::UnknownPeer, + -5 => Error::EnableDaita, + _ => Error::Other, + }) +} + +impl Error { + pub const fn as_raw(self) -> i32 { + self as i32 + } +} + +mod ffi { + use super::{Fd, LoggingCallback, LoggingContext}; + use core::ffi::{c_char, c_void}; + + extern "C" { + /// Creates a new wireguard tunnel, uses the specific interface name, and file descriptors + /// for the tunnel device and logging. For targets other than android, this also takes an + /// MTU value. + /// + /// Positive return values are tunnel handles for this specific wireguard tunnel instance. + /// Negative return values signify errors. + pub fn wgTurnOn( + #[cfg(not(target_os = "android"))] mtu: isize, + settings: *const c_char, + fd: Fd, + logging_callback: Option<LoggingCallback>, + logging_context: LoggingContext, + ) -> i32; + + /// Pass a handle that was created by wgTurnOn to stop a wireguard tunnel. + /// + /// Negative return values signify errors. + pub fn wgTurnOff(handle: i32) -> i32; + + /// Get the config of the WireGuard interface. Returns null in case of error. + /// + /// # Safety: + /// - The function returns an owned pointer to a null-terminated UTF-8 string. + /// - The pointer may only be freed using [wgFreePtr]. + pub fn wgGetConfig(handle: i32) -> *mut c_char; + + /// Set the config of the WireGuard interface. + /// + /// Negative return values signify errors. + /// + /// # Safety: + /// - `settings` must point to a null-terminated UTF-8 string. + /// - The pointer will not be read from after `wgActivateDaita` has returned. + pub fn wgSetConfig(handle: i32, settings: *const c_char) -> i32; + + /// Activate DAITA for the specified peer. + /// + /// `tunnel_handle` must come from [wgTurnOn]. `machines` is a string containing + /// LF-separated maybenot machines. + /// + /// Negative return values signify errors. + /// + /// # Safety: + /// - `peer_public_key` must point to a 32 byte array. + /// - `machines` must point to a null-terminated UTF-8 string. + /// - Neither pointer will be read from after `wgActivateDaita` has returned. + #[cfg(any(target_os = "windows", target_os = "linux"))] + pub fn wgActivateDaita( + tunnel_handle: i32, + peer_public_key: *const u8, + machines: *const c_char, + events_capacity: u32, + actions_capacity: u32, + ) -> i32; + + /// Free a pointer allocated by the go runtime - useful to free return value of wgGetConfig + pub fn wgFreePtr(ptr: *mut c_void); + + /// Get the file descriptor of the tunnel IPv4 socket. + #[cfg(target_os = "android")] + pub fn wgGetSocketV4(handle: i32) -> Fd; + + /// Get the file descriptor of the tunnel IPv6 socket. + #[cfg(target_os = "android")] + pub fn wgGetSocketV6(handle: i32) -> Fd; + } +} diff --git a/wireguard-go-rs/src/util.rs b/wireguard-go-rs/src/util.rs new file mode 100644 index 0000000000..4c2df2f9bb --- /dev/null +++ b/wireguard-go-rs/src/util.rs @@ -0,0 +1,15 @@ +pub struct OnDrop<F: FnOnce()>(Option<F>); + +impl<F: FnOnce()> OnDrop<F> { + pub fn new(f: F) -> Self { + OnDrop(Some(f)) + } +} + +impl<F: FnOnce()> Drop for OnDrop<F> { + fn drop(&mut self) { + if let Some(f) = self.0.take() { + f() + } + } +} |
