summaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorSebastian Holmin <sebastian.holmin@mullvad.net>2024-05-02 11:54:13 +0200
committerMarkus Pettersson <markus.pettersson@mullvad.net>2024-06-25 17:31:32 +0200
commitda95b2603470841b64518959ceac0d03aab0068a (patch)
tree83028a23bd2439813a0dcf826eac123ea83d956b
parent9dfdf2fa52422709ce3df7a50643e0abc6ade319 (diff)
downloadmullvadvpn-da95b2603470841b64518959ceac0d03aab0068a.tar.xz
mullvadvpn-da95b2603470841b64518959ceac0d03aab0068a.zip
Add a safe FFI wrapper in `wireguard-go-rs`
- Add local wireguard go import - Activate DAITA and add `wgActivateDaita` and `wgReceiveEvent` FFI - Implement `start_daita` on Wireguard-go tunnel type - Mention DAITA in `wireguard-go-rs` description - Do not compile `wireguard-go-rs` on Windows - Handle DAITA closed on `nil` event - Handle daita action timeouts in libwg - Remove noisy log lines - Remove `maybenot_on_action` callback - Remove unused link to `../build/lib` for `talpid-wireguard` - Bump the `wireguard-go` submodule to a signed release tag in Mullvad's `wireguard-go` fork. - Update path to `libwg/go.sum` in verification script Also: - Use u64 instead of *mut void as log context - Make Tunnel::set_config take a &mut self - Use dyn Error instead of i32s for wg errors Co-authored-by: Joakim Hulthe <joakim@hulthe.net>
-rw-r--r--.github/workflows/verify-locked-down-signatures.yml2
-rw-r--r--.gitignore1
-rw-r--r--.gitmodules3
-rw-r--r--Cargo.lock18
-rw-r--r--Cargo.toml1
-rwxr-xr-xbuild.sh6
-rwxr-xr-xci/check-rust.sh5
-rw-r--r--talpid-wireguard/Cargo.toml4
-rw-r--r--talpid-wireguard/build.rs16
-rw-r--r--talpid-wireguard/src/config.rs4
-rw-r--r--talpid-wireguard/src/connectivity_check.rs16
-rw-r--r--talpid-wireguard/src/lib.rs57
-rw-r--r--talpid-wireguard/src/logging.rs35
-rw-r--r--talpid-wireguard/src/wireguard_go/mod.rs (renamed from talpid-wireguard/src/wireguard_go.rs)252
-rw-r--r--talpid-wireguard/src/wireguard_kernel/netlink_tunnel.rs9
-rw-r--r--talpid-wireguard/src/wireguard_kernel/nm_tunnel.rs9
-rw-r--r--talpid-wireguard/src/wireguard_nt/mod.rs12
-rw-r--r--wireguard-go-rs/Cargo.toml10
-rw-r--r--wireguard-go-rs/README.md10
-rwxr-xr-xwireguard-go-rs/build-wireguard-go.sh (renamed from wireguard/build-wireguard-go.sh)64
-rw-r--r--wireguard-go-rs/build.rs65
-rw-r--r--wireguard-go-rs/libwg/Android.mk (renamed from wireguard/libwg/Android.mk)0
-rw-r--r--wireguard-go-rs/libwg/README.md (renamed from wireguard/libwg/README.md)0
-rwxr-xr-xwireguard-go-rs/libwg/build-android.sh (renamed from wireguard/libwg/build-android.sh)0
-rw-r--r--wireguard-go-rs/libwg/go.mod (renamed from wireguard/libwg/go.mod)4
-rw-r--r--wireguard-go-rs/libwg/go.sum (renamed from wireguard/libwg/go.sum)6
-rw-r--r--wireguard-go-rs/libwg/goruntime-boottime-over-monotonic.diff171
-rw-r--r--wireguard-go-rs/libwg/libwg.go (renamed from wireguard/libwg/libwg.go)37
-rw-r--r--wireguard-go-rs/libwg/libwg.h8
-rw-r--r--wireguard-go-rs/libwg/libwg_android.go (renamed from wireguard/libwg/libwg_android.go)26
-rw-r--r--wireguard-go-rs/libwg/libwg_daita.go41
-rw-r--r--wireguard-go-rs/libwg/libwg_default.go (renamed from wireguard/libwg/libwg_default.go)13
-rw-r--r--wireguard-go-rs/libwg/logging/logging.go (renamed from wireguard/libwg/logging/logging.go)7
-rw-r--r--wireguard-go-rs/libwg/tunnelcontainer/tunnelcontainer.go (renamed from wireguard/libwg/tunnelcontainer/tunnelcontainer.go)0
m---------wireguard-go-rs/libwg/wireguard-go0
-rw-r--r--wireguard-go-rs/src/lib.rs288
-rw-r--r--wireguard-go-rs/src/util.rs15
37 files changed, 901 insertions, 314 deletions
diff --git a/.github/workflows/verify-locked-down-signatures.yml b/.github/workflows/verify-locked-down-signatures.yml
index 0e36154a15..118e44914a 100644
--- a/.github/workflows/verify-locked-down-signatures.yml
+++ b/.github/workflows/verify-locked-down-signatures.yml
@@ -11,7 +11,7 @@ on:
- deny.toml
- test/deny.toml
- gui/package-lock.json
- - wireguard/libwg/go.sum
+ - wireguard-go-rs/libwg/go.sum
- ci/keys/**
- ci/verify-locked-down-signatures.sh
- ios/MullvadVPN.xcodeproj/project.xcworkspace/xcshareddata/swiftpm/Package.resolved
diff --git a/.gitignore b/.gitignore
index 1efb0161fc..569d6ca706 100644
--- a/.gitignore
+++ b/.gitignore
@@ -26,7 +26,6 @@
/android/keystore.properties
/android/local.properties
/android/play-api-key.json
-/wireguard/libwg/libwg.h
/wireguard/libwg/libwg.exp
/wireguard/libwg/exports.def
**/.vs/
diff --git a/.gitmodules b/.gitmodules
index e8c3ba4187..39aa8b722d 100644
--- a/.gitmodules
+++ b/.gitmodules
@@ -7,3 +7,6 @@
[submodule "windows/windows-libraries"]
path = windows/windows-libraries
url = https://github.com/mullvad/windows-libraries
+[submodule "wireguard-go-rs/libwg/wireguard-go"]
+ path = wireguard-go-rs/libwg/wireguard-go
+ url = https://github.com/mullvad/wireguard-go/
diff --git a/Cargo.lock b/Cargo.lock
index c6a1e328fb..8709388967 100644
--- a/Cargo.lock
+++ b/Cargo.lock
@@ -2026,9 +2026,9 @@ checksum = "0e7465ac9959cc2b1404e8e2367b43684a6d13790fe23056cc8c6c5a6b7bcb94"
[[package]]
name = "maybenot"
-version = "1.1.0"
+version = "1.1.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "94ed977e86fc65a7ffae967a6a973e6f7a90b5d747ebd755703d5718804f7c16"
+checksum = "a7fe205734d700937dabf0b8687e290f8574fac996f8a9d04bd7a62d7c2c1dad"
dependencies = [
"byteorder",
"hex",
@@ -4144,6 +4144,7 @@ dependencies = [
"tunnel-obfuscation",
"widestring",
"windows-sys 0.52.0",
+ "wireguard-go-rs",
"zeroize",
]
@@ -5079,6 +5080,15 @@ dependencies = [
]
[[package]]
+name = "wireguard-go-rs"
+version = "0.0.0"
+dependencies = [
+ "log",
+ "thiserror",
+ "zeroize",
+]
+
+[[package]]
name = "x25519-dalek"
version = "2.0.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
@@ -5112,9 +5122,9 @@ dependencies = [
[[package]]
name = "zeroize"
-version = "1.7.0"
+version = "1.8.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "525b4ec142c6b68a2d10f01f7bbf6755599ca3f81ea53b8431b7dd348f5fdb2d"
+checksum = "ced3678a2879b30306d323f4542626697a464a97c0a07c9aebf7ebca65cd4dde"
dependencies = [
"zeroize_derive",
]
diff --git a/Cargo.toml b/Cargo.toml
index aa3b9145c8..e126c6a13e 100644
--- a/Cargo.toml
+++ b/Cargo.toml
@@ -39,6 +39,7 @@ members = [
"talpid-windows",
"talpid-wireguard",
"tunnel-obfuscation",
+ "wireguard-go-rs"
]
# Keep all lints in sync with `test/Cargo.toml`
diff --git a/build.sh b/build.sh
index 7b34896ca6..a3cbb50e31 100755
--- a/build.sh
+++ b/build.sh
@@ -213,11 +213,6 @@ function build {
# Compile and link all binaries.
################################################################################
- if [[ "$(uname -s)" != "MINGW"* ]]; then
- log_header "Building wireguard-go$for_target_string"
- ./wireguard/build-wireguard-go.sh "$current_target"
- fi
-
log_header "Building Rust code in $RUST_BUILD_MODE mode using $RUSTC_VERSION$for_target_string"
local cargo_target_arg=()
@@ -312,6 +307,7 @@ for t in "${TARGETS[@]:-""}"; do
build "$t"
done
+
################################################################################
# Package app.
################################################################################
diff --git a/ci/check-rust.sh b/ci/check-rust.sh
index d42784d36e..cb48fbed3f 100755
--- a/ci/check-rust.sh
+++ b/ci/check-rust.sh
@@ -4,11 +4,6 @@ set -eux
export RUSTFLAGS="--deny warnings"
-# Build WireGuard Go
-if [[ "$(uname -s)" != "MINGW"* ]]; then
- ./wireguard/build-wireguard-go.sh
-fi
-
# Build Rust crates
source env.sh
time cargo build --locked --verbose
diff --git a/talpid-wireguard/Cargo.toml b/talpid-wireguard/Cargo.toml
index c1eb335766..09b7baa57f 100644
--- a/talpid-wireguard/Cargo.toml
+++ b/talpid-wireguard/Cargo.toml
@@ -30,6 +30,9 @@ tunnel-obfuscation = { path = "../tunnel-obfuscation" }
rand = "0.8.5"
surge-ping = "0.8.0"
+[target.'cfg(not(windows))'.dependencies]
+wireguard-go-rs = { path = "../wireguard-go-rs"}
+
[target.'cfg(target_os="android")'.dependencies]
duct = "0.13"
@@ -42,7 +45,6 @@ tokio-stream = { version = "0.1", features = ["io-util"] }
[target.'cfg(unix)'.dependencies]
nix = "0.23"
-[target.'cfg(target_os = "linux")'.dependencies]
rtnetlink = "0.11"
netlink-packet-core = "0.4.2"
netlink-packet-route = "0.13"
diff --git a/talpid-wireguard/build.rs b/talpid-wireguard/build.rs
index fe5b45b819..3abec6abe2 100644
--- a/talpid-wireguard/build.rs
+++ b/talpid-wireguard/build.rs
@@ -4,22 +4,6 @@ fn main() {
let target_os = env::var("CARGO_CFG_TARGET_OS").expect("CARGO_CFG_TARGET_OS not set");
declare_libs_dir("../dist-assets/binaries");
- declare_libs_dir("../build/lib");
-
- let link_type = match target_os.as_str() {
- "android" => "",
- "linux" | "macos" => "=static",
- // We would like to avoid panicking on windows even if we can not link correctly
- // because we would like to be able to run check and clippy.
- // This does not allow for correct linking or buijding.
- #[cfg(not(windows))]
- "windows" => "",
- #[cfg(windows)]
- "windows" => "dylib",
- _ => panic!("Unsupported platform: {target_os}"),
- };
-
- println!("cargo:rustc-link-lib{link_type}=wg");
add_wireguard_go_cfg(&target_os);
}
diff --git a/talpid-wireguard/src/config.rs b/talpid-wireguard/src/config.rs
index f10a0e4859..e9059c6cca 100644
--- a/talpid-wireguard/src/config.rs
+++ b/talpid-wireguard/src/config.rs
@@ -97,9 +97,9 @@ impl Config {
enable_ipv6: generic_options.enable_ipv6,
obfuscator_config: obfuscator_config.to_owned(),
quantum_resistant: wg_options.quantum_resistant,
- #[cfg(target_os = "windows")]
+ #[cfg(any(target_os = "windows", target_os = "linux"))]
daita: wg_options.daita,
- #[cfg(not(target_os = "windows"))]
+ #[cfg(not(any(target_os = "windows", target_os = "linux")))]
daita: false,
};
diff --git a/talpid-wireguard/src/connectivity_check.rs b/talpid-wireguard/src/connectivity_check.rs
index 70f88e6872..41ad0b5bf3 100644
--- a/talpid-wireguard/src/connectivity_check.rs
+++ b/talpid-wireguard/src/connectivity_check.rs
@@ -5,9 +5,10 @@ use crate::{
use std::{
cmp,
net::Ipv4Addr,
- sync::{mpsc, Mutex, Weak},
+ sync::{mpsc, Weak},
time::{Duration, Instant},
};
+use tokio::sync::Mutex;
use super::{Tunnel, TunnelError};
@@ -211,11 +212,12 @@ impl ConnectivityMonitor {
/// If None is returned, then the underlying tunnel has already been closed and all subsequent
/// calls will also return None.
+ ///
+ /// NOTE: will panic if called from within a tokio runtime.
fn get_stats(&self) -> Option<Result<StatsMap, Error>> {
self.tunnel_handle
.upgrade()?
- .lock()
- .ok()?
+ .blocking_lock()
.as_ref()
.and_then(|tunnel| match tunnel.get_tunnel_stats() {
Ok(stats) if stats.is_empty() => {
@@ -550,7 +552,7 @@ mod test {
rx_bytes: 0,
},
);
- let peers = Mutex::new(map);
+ let peers = std::sync::Mutex::new(map);
Self {
on_get_stats: Box::new(move || {
let mut peers = peers.lock().unwrap();
@@ -607,13 +609,13 @@ mod test {
}
fn set_config(
- &self,
+ &mut self,
_config: Config,
) -> Pin<Box<dyn Future<Output = std::result::Result<(), TunnelError>> + Send>> {
Box::pin(async { Ok(()) })
}
- #[cfg(target_os = "windows")]
+ #[cfg(any(target_os = "windows", target_os = "linux"))]
fn start_daita(&mut self) -> std::result::Result<(), TunnelError> {
Ok(())
}
@@ -745,7 +747,7 @@ mod test {
rx_bytes: 0,
},
);
- let tunnel_stats = Mutex::new(map);
+ let tunnel_stats = std::sync::Mutex::new(map);
let pinger = MockPinger::default();
let (_tunnel_anchor, tunnel) = MockTunnel::new(move || {
diff --git a/talpid-wireguard/src/lib.rs b/talpid-wireguard/src/lib.rs
index 7c01538b60..e8334f43d1 100644
--- a/talpid-wireguard/src/lib.rs
+++ b/talpid-wireguard/src/lib.rs
@@ -115,7 +115,7 @@ impl Error {
Error::CreateObfuscatorError(_) => true,
Error::ObfuscatorError(_) => true,
Error::PskNegotiationError(_) => true,
- Error::TunnelError(TunnelError::RecoverableStartWireguardError) => true,
+ Error::TunnelError(TunnelError::RecoverableStartWireguardError(..)) => true,
Error::SetupRoutingError(error) => error.is_recoverable(),
@@ -144,7 +144,7 @@ impl Error {
pub struct WireguardMonitor {
runtime: tokio::runtime::Handle,
/// Tunnel implementation
- tunnel: Arc<Mutex<Option<Box<dyn Tunnel>>>>,
+ tunnel: Arc<AsyncMutex<Option<Box<dyn Tunnel>>>>,
/// Callback to signal tunnel events
event_callback: EventCallback,
close_msg_receiver: sync_mpsc::Receiver<CloseMsg>,
@@ -306,7 +306,7 @@ impl WireguardMonitor {
let (pinger_tx, pinger_rx) = sync_mpsc::channel();
let monitor = WireguardMonitor {
runtime: args.runtime.clone(),
- tunnel: Arc::new(Mutex::new(Some(tunnel))),
+ tunnel: Arc::new(AsyncMutex::new(Some(tunnel))),
event_callback,
close_msg_receiver: close_obfs_listener,
pinger_stop_sender: pinger_tx,
@@ -473,7 +473,7 @@ impl WireguardMonitor {
#[allow(clippy::too_many_arguments)]
async fn config_ephemeral_peers<F>(
- tunnel: &Arc<Mutex<Option<Box<dyn Tunnel>>>>,
+ tunnel: &Arc<AsyncMutex<Option<Box<dyn Tunnel>>>>,
config: &mut Config,
retry_attempt: u32,
on_event: F,
@@ -576,10 +576,10 @@ impl WireguardMonitor {
)
.await?;
- #[cfg(target_os = "windows")]
+ #[cfg(any(target_os = "windows", target_os = "linux"))]
if config.daita {
// Start local DAITA machines
- let mut tunnel = tunnel.lock().unwrap();
+ let mut tunnel = tunnel.lock().await;
if let Some(tunnel) = tunnel.as_mut() {
tunnel
.start_daita()
@@ -601,7 +601,7 @@ impl WireguardMonitor {
/// Reconfigures the tunnel to use the provided config while potentially modifying the config
/// and restarting the obfuscation provider. Returns the new config used by the new tunnel.
async fn reconfigure_tunnel(
- tunnel: &Arc<Mutex<Option<Box<dyn Tunnel>>>>,
+ tunnel: &Arc<AsyncMutex<Option<Box<dyn Tunnel>>>>,
mut config: Config,
obfuscator: Arc<AsyncMutex<Option<ObfuscatorHandle>>>,
close_obfs_sender: sync_mpsc::Sender<CloseMsg>,
@@ -625,11 +625,12 @@ impl WireguardMonitor {
}
}
+ let mut tunnel = tunnel.lock().await;
+
let set_config_future = tunnel
- .lock()
- .unwrap()
- .as_ref()
+ .as_mut()
.map(|tunnel| tunnel.set_config(config.clone()));
+
if let Some(f) = set_config_future {
f.await
.map_err(Error::TunnelError)
@@ -817,6 +818,7 @@ impl WireguardMonitor {
log_path,
tun_provider,
routes,
+ resource_dir,
)
.map_err(Error::TunnelError)?,
))
@@ -843,8 +845,11 @@ impl WireguardMonitor {
wait_result
}
+ /// Tear down the tunnel.
+ ///
+ /// NOTE: will panic if called from within a tokio runtime.
fn stop_tunnel(&mut self) {
- match self.tunnel.lock().expect("Tunnel lock poisoned").take() {
+ match self.tunnel.blocking_lock().take() {
Some(tunnel) => {
if let Err(e) = tunnel.stop() {
log::error!("{}", e.display_chain_with_msg("Failed to stop tunnel"));
@@ -1025,11 +1030,12 @@ pub(crate) trait Tunnel: Send {
fn get_interface_name(&self) -> String;
fn stop(self: Box<Self>) -> std::result::Result<(), TunnelError>;
fn get_tunnel_stats(&self) -> std::result::Result<stats::StatsMap, TunnelError>;
- fn set_config(
- &self,
+ fn set_config<'a>(
+ &'a mut self,
_config: Config,
- ) -> Pin<Box<dyn Future<Output = std::result::Result<(), TunnelError>> + Send>>;
- #[cfg(target_os = "windows")]
+ ) -> Pin<Box<dyn Future<Output = std::result::Result<(), TunnelError>> + Send + 'a>>;
+ #[cfg(any(target_os = "windows", target_os = "linux"))]
+ /// A [`Tunnel`] capable of using DAITA.
fn start_daita(&mut self) -> std::result::Result<(), TunnelError>;
}
@@ -1041,7 +1047,7 @@ pub enum TunnelError {
/// This is an error returned by the implementation that indicates that trying to establish the
/// tunnel again should work normally. The error encountered is known to be sporadic.
#[error("Recoverable error while starting wireguard tunnel")]
- RecoverableStartWireguardError,
+ RecoverableStartWireguardError(#[source] Box<dyn std::error::Error + Send>),
/// An unrecoverable error occurred while starting the wireguard tunnel
///
@@ -1049,14 +1055,11 @@ pub enum TunnelError {
/// tunnel again will likely fail with the same error. An error was encountered during tunnel
/// configuration which can't be dealt with gracefully.
#[error("Failed to start wireguard tunnel")]
- FatalStartWireguardError,
+ FatalStartWireguardError(#[source] Box<dyn std::error::Error + Send>),
/// Failed to tear down wireguard tunnel.
- #[error("Failed to stop wireguard tunnel. Status: {status}")]
- StopWireguardError {
- /// Returned error code
- status: i32,
- },
+ #[error("Failed to tear down wireguard tunnel")]
+ StopWireguardError(#[source] Box<dyn std::error::Error + Send>),
/// Error whilst trying to parse the WireGuard config to read the stats
#[error("Reading tunnel stats failed")]
@@ -1107,6 +1110,16 @@ pub enum TunnelError {
/// Failure to set up logging
#[error("Failed to set up logging")]
LoggingError(#[source] logging::Error),
+
+ /// Failed to receive DAITA event
+ #[cfg(any(target_os = "windows", target_os = "linux"))]
+ #[error("Failed to start DAITA")]
+ StartDaita(#[source] Box<dyn std::error::Error + Send>),
+
+ /// This tunnel does not support DAITA.
+ #[cfg(any(target_os = "windows", target_os = "linux"))]
+ #[error("Failed to start DAITA - tunnel implemenation does not support DAITA")]
+ DaitaNotSupported,
}
#[cfg(target_os = "linux")]
diff --git a/talpid-wireguard/src/logging.rs b/talpid-wireguard/src/logging.rs
index 6d1d364342..a4d8c7f240 100644
--- a/talpid-wireguard/src/logging.rs
+++ b/talpid-wireguard/src/logging.rs
@@ -2,9 +2,13 @@ use once_cell::sync::Lazy;
use parking_lot::Mutex;
use std::{collections::HashMap, fmt, fs, io::Write, path::Path};
-static LOG_MUTEX: Lazy<Mutex<HashMap<u32, fs::File>>> = Lazy::new(|| Mutex::new(HashMap::new()));
+static LOG_MUTEX: Lazy<Mutex<LogState>> = Lazy::new(|| Mutex::new(LogState::default()));
-static mut LOG_CONTEXT_NEXT_ORDINAL: u32 = 0;
+#[derive(Default)]
+struct LogState {
+ map: HashMap<u64, fs::File>,
+ next_ordinal: u64,
+}
/// Errors encountered when initializing logging
#[derive(thiserror::Error, Debug)]
@@ -14,18 +18,15 @@ pub enum Error {
PrepareLogFileError(#[from] std::io::Error),
}
-pub fn initialize_logging(log_path: Option<&Path>) -> Result<u32, Error> {
+pub fn initialize_logging(log_path: Option<&Path>) -> Result<u64, Error> {
let log_file = create_log_file(log_path)?;
- let log_context_ordinal = unsafe {
- let mut map = LOG_MUTEX.lock();
- let ordinal = LOG_CONTEXT_NEXT_ORDINAL;
- LOG_CONTEXT_NEXT_ORDINAL += 1;
- map.insert(ordinal, log_file);
- ordinal
- };
+ let mut state = LOG_MUTEX.lock();
+ let ordinal = state.next_ordinal;
+ state.next_ordinal += 1;
+ state.map.insert(ordinal, log_file);
- Ok(log_context_ordinal)
+ Ok(ordinal)
}
#[cfg(target_os = "windows")]
@@ -39,9 +40,9 @@ fn create_log_file(log_path: Option<&Path>) -> Result<fs::File, Error> {
.map_err(Error::PrepareLogFileError)
}
-pub fn clean_up_logging(ordinal: u32) {
- let mut map = LOG_MUTEX.lock();
- map.remove(&ordinal);
+pub fn clean_up_logging(ordinal: u64) {
+ let mut state = LOG_MUTEX.lock();
+ state.map.remove(&ordinal);
}
pub enum LogLevel {
@@ -71,9 +72,9 @@ impl AsRef<str> for LogLevel {
}
}
-pub fn log(context: u32, level: LogLevel, tag: &str, msg: &str) {
- let mut map = LOG_MUTEX.lock();
- if let Some(logfile) = map.get_mut(&{ context }) {
+pub fn log(context: u64, level: LogLevel, tag: &str, msg: &str) {
+ let mut state = LOG_MUTEX.lock();
+ if let Some(logfile) = state.map.get_mut(&context) {
log_inner(logfile, level, tag, msg);
}
}
diff --git a/talpid-wireguard/src/wireguard_go.rs b/talpid-wireguard/src/wireguard_go/mod.rs
index b08b241bb9..32181beaea 100644
--- a/talpid-wireguard/src/wireguard_go.rs
+++ b/talpid-wireguard/src/wireguard_go/mod.rs
@@ -1,35 +1,40 @@
-use super::{
- stats::{Stats, StatsMap},
- Config, Tunnel, TunnelError,
-};
-use crate::logging::{clean_up_logging, initialize_logging};
use ipnetwork::IpNetwork;
+#[cfg(any(target_os = "windows", target_os = "linux"))]
+use once_cell::sync::OnceCell;
+#[cfg(any(target_os = "windows", target_os = "linux"))]
+use std::{ffi::CString, fs, path::PathBuf};
use std::{
- ffi::{c_char, c_void, CStr},
future::Future,
+ net::IpAddr,
+ os::unix::io::{AsRawFd, RawFd},
path::Path,
pin::Pin,
+ sync::{Arc, Mutex},
};
-use talpid_tunnel::tun_provider::TunProvider;
-use talpid_types::BoxedError;
-use zeroize::Zeroize;
-
#[cfg(target_os = "android")]
-use talpid_tunnel::tun_provider;
+use talpid_tunnel::tun_provider::Error as TunProviderError;
+use talpid_tunnel::tun_provider::{Tun, TunConfig, TunProvider};
+use talpid_types::BoxedError;
-use std::{
- net::IpAddr,
- os::unix::io::{AsRawFd, RawFd},
+use super::{
+ stats::{Stats, StatsMap},
+ Config, Tunnel, TunnelError,
};
-use talpid_tunnel::tun_provider::{Tun, TunConfig};
+use crate::logging::{clean_up_logging, initialize_logging};
-type Result<T> = std::result::Result<T, TunnelError>;
+const MAX_PREPARE_TUN_ATTEMPTS: usize = 4;
-use std::sync::{Arc, Mutex};
+/// Maximum number of events that can be stored in the underlying buffer
+#[cfg(any(target_os = "windows", target_os = "linux"))]
+const DAITA_EVENTS_CAPACITY: u32 = 1000;
-const MAX_PREPARE_TUN_ATTEMPTS: usize = 4;
+/// Maximum number of actions that can be stored in the underlying buffer
+#[cfg(any(target_os = "windows", target_os = "linux"))]
+const DAITA_ACTIONS_CAPACITY: u32 = 1000;
+
+type Result<T> = std::result::Result<T, TunnelError>;
-struct LoggingContext(u32);
+struct LoggingContext(u64);
impl Drop for LoggingContext {
fn drop(&mut self) {
@@ -39,7 +44,7 @@ impl Drop for LoggingContext {
pub struct WgGoTunnel {
interface_name: String,
- handle: Option<i32>,
+ tunnel_handle: wireguard_go_rs::Tunnel,
// holding on to the tunnel device and the log file ensures that the associated file handles
// live long enough and get closed when the tunnel is stopped
_tunnel_device: Tun,
@@ -47,6 +52,9 @@ pub struct WgGoTunnel {
_logging_context: LoggingContext,
#[cfg(target_os = "android")]
tun_provider: Arc<Mutex<TunProvider>>,
+ #[cfg(any(target_os = "windows", target_os = "linux"))]
+ resource_dir: PathBuf,
+ config: Config,
}
impl WgGoTunnel {
@@ -55,6 +63,7 @@ impl WgGoTunnel {
log_path: Option<&Path>,
tun_provider: Arc<Mutex<TunProvider>>,
routes: impl Iterator<Item = IpNetwork>,
+ resource_dir: &Path,
) -> Result<Self> {
#[cfg(target_os = "android")]
let tun_provider_clone = tun_provider.clone();
@@ -70,29 +79,30 @@ impl WgGoTunnel {
#[cfg(not(target_os = "android"))]
let mtu = config.mtu as isize;
- let handle = unsafe {
- wgTurnOn(
- #[cfg(not(target_os = "android"))]
- mtu,
- wg_config_str.as_ptr() as _,
- tunnel_fd,
- Some(logging::wg_go_logging_callback),
- logging_context.0 as *mut c_void,
- )
- };
- check_wg_status(handle)?;
+ let handle = wireguard_go_rs::Tunnel::turn_on(
+ #[cfg(not(target_os = "android"))]
+ mtu,
+ &wg_config_str,
+ tunnel_fd,
+ Some(logging::wg_go_logging_callback),
+ logging_context.0,
+ )
+ .map_err(|e| TunnelError::FatalStartWireguardError(Box::new(e)))?;
#[cfg(target_os = "android")]
- Self::bypass_tunnel_sockets(&mut tunnel_device, handle)
+ Self::bypass_tunnel_sockets(&handle, &mut tunnel_device)
.map_err(TunnelError::BypassError)?;
Ok(WgGoTunnel {
interface_name,
- handle: Some(handle),
+ tunnel_handle: handle,
_tunnel_device: tunnel_device,
_logging_context: logging_context,
#[cfg(target_os = "android")]
tun_provider: tun_provider_clone,
+ resource_dir: resource_dir.to_owned(),
+ #[cfg(any(target_os = "windows", target_os = "linux"))]
+ config: config.clone(),
})
}
@@ -130,11 +140,11 @@ impl WgGoTunnel {
#[cfg(target_os = "android")]
fn bypass_tunnel_sockets(
+ handle: &wireguard_go_rs::Tunnel,
tunnel_device: &mut Tun,
- handle: i32,
- ) -> std::result::Result<(), tun_provider::Error> {
- let socket_v4 = unsafe { wgGetSocketV4(handle) };
- let socket_v6 = unsafe { wgGetSocketV6(handle) };
+ ) -> std::result::Result<(), TunProviderError> {
+ let socket_v4 = handle.get_socket_v4();
+ let socket_v6 = handle.get_socket_v6();
tunnel_device.bypass(socket_v4)?;
tunnel_device.bypass(socket_v6)?;
@@ -142,16 +152,6 @@ impl WgGoTunnel {
Ok(())
}
- fn stop_tunnel(&mut self) -> Result<()> {
- if let Some(handle) = self.handle.take() {
- let status = unsafe { wgTurnOff(handle) };
- if status < 0 {
- return Err(TunnelError::StopWireguardError { status });
- }
- }
- Ok(())
- }
-
fn get_tunnel(
tun_provider: Arc<Mutex<TunProvider>>,
config: &Config,
@@ -187,71 +187,46 @@ impl WgGoTunnel {
}
}
-impl Drop for WgGoTunnel {
- fn drop(&mut self) {
- if let Err(e) = self.stop_tunnel() {
- log::error!("Failed to stop tunnel: {}", e);
- }
- }
-}
-
impl Tunnel for WgGoTunnel {
fn get_interface_name(&self) -> String {
self.interface_name.clone()
}
fn get_tunnel_stats(&self) -> Result<StatsMap> {
- let config_str = unsafe {
- let ptr = wgGetConfig(self.handle.unwrap());
- if ptr.is_null() {
- log::error!("Failed to get config !");
- return Err(TunnelError::GetConfigError);
- }
-
- CStr::from_ptr(ptr)
- };
-
- let result =
- Stats::parse_config_str(config_str.to_str().expect("Go strings are always UTF-8"))
- .map_err(|error| TunnelError::StatsError(BoxedError::new(error)));
- unsafe {
- // Zeroing out config string to not leave private key in memory.
- let slice = std::slice::from_raw_parts_mut(
- config_str.as_ptr() as *mut c_char,
- config_str.to_bytes().len(),
- );
- slice.zeroize();
-
- wgFreePtr(config_str.as_ptr() as *mut c_void);
- }
-
- result
+ self.tunnel_handle
+ .get_config(|cstr| {
+ Stats::parse_config_str(cstr.to_str().expect("Go strings are always UTF-8"))
+ })
+ .ok_or(TunnelError::GetConfigError)?
+ .map_err(|error| TunnelError::StatsError(BoxedError::new(error)))
}
- fn stop(mut self: Box<Self>) -> Result<()> {
- self.stop_tunnel()
+ fn stop(self: Box<Self>) -> Result<()> {
+ self.tunnel_handle
+ .turn_off()
+ .map_err(|e| TunnelError::StopWireguardError(Box::new(e)))
}
fn set_config(
- &self,
+ &mut self,
config: Config,
- ) -> Pin<Box<dyn Future<Output = std::result::Result<(), super::TunnelError>> + Send>> {
- let wg_config_str = config.to_userspace_format();
- let handle = self.handle.unwrap();
- #[cfg(target_os = "android")]
- let tun_provider = self.tun_provider.clone();
+ ) -> Pin<Box<dyn Future<Output = Result<()>> + Send + '_>> {
Box::pin(async move {
- let status = unsafe { wgSetConfig(handle, wg_config_str.as_ptr() as _) };
- if status != 0 {
- return Err(TunnelError::SetConfigError);
- }
+ let wg_config_str = config.to_userspace_format();
+
+ self.tunnel_handle
+ .set_config(&wg_config_str)
+ .map_err(|_| TunnelError::SetConfigError)?;
+
+ #[cfg(target_os = "android")]
+ let tun_provider = self.tun_provider.clone();
// When reapplying the config, the endpoint socket may be discarded
// and needs to be excluded again
#[cfg(target_os = "android")]
{
- let socket_v4 = unsafe { wgGetSocketV4(handle) };
- let socket_v6 = unsafe { wgGetSocketV6(handle) };
+ let socket_v4 = self.tunnel_handle.get_socket_v4();
+ let socket_v6 = self.tunnel_handle.get_socket_v6();
let mut provider = tun_provider.lock().unwrap();
provider
.bypass(socket_v4)
@@ -264,68 +239,32 @@ impl Tunnel for WgGoTunnel {
Ok(())
})
}
-}
-fn check_wg_status(wg_code: i32) -> Result<()> {
- match wg_code {
- ERROR_GENERAL_FAILURE => Err(TunnelError::FatalStartWireguardError),
- ERROR_INTERMITTENT_FAILURE => Err(TunnelError::RecoverableStartWireguardError),
- 0.. => Ok(()),
- _ => {
- log::error!("Unknown status code returned from wireguard-go");
- Err(TunnelError::FatalStartWireguardError)
- }
- }
-}
-
-pub type Fd = std::os::unix::io::RawFd;
-
-const ERROR_GENERAL_FAILURE: i32 = -1;
-const ERROR_INTERMITTENT_FAILURE: i32 = -2;
-
-extern "C" {
- /// Creates a new wireguard tunnel, uses the specific interface name, MTU and file descriptors
- /// for the tunnel device and logging.
- ///
- /// Positive return values are tunnel handles for this specific wireguard tunnel instance.
- /// Negative return values signify errors. All error codes are opaque.
- #[cfg(not(target_os = "android"))]
- fn wgTurnOn(
- mtu: isize,
- settings: *const i8,
- fd: Fd,
- logging_callback: Option<logging::LoggingCallback>,
- logging_context: *mut c_void,
- ) -> i32;
-
- // Android
- #[cfg(target_os = "android")]
- fn wgTurnOn(
- settings: *const i8,
- fd: Fd,
- logging_callback: Option<logging::LoggingCallback>,
- logging_context: *mut c_void,
- ) -> i32;
-
- // Pass a handle that was created by wgTurnOn to stop a wireguard tunnel.
- fn wgTurnOff(handle: i32) -> i32;
+ fn start_daita(&mut self) -> Result<()> {
+ static MAYBENOT_MACHINES: OnceCell<CString> = OnceCell::new();
+ let machines = MAYBENOT_MACHINES.get_or_try_init(|| {
+ let path = self.resource_dir.join("maybenot_machines");
+ log::debug!("Reading maybenot machines from {}", path.display());
- // Returns the file descriptor of the tunnel IPv4 socket.
- fn wgGetConfig(handle: i32) -> *mut c_char;
+ // TODO: errors
+ let machines = fs::read_to_string(path).unwrap();
+ let machines = CString::new(machines).unwrap();
+ Ok(machines)
+ })?;
- // Sets the config of the WireGuard interface.
- fn wgSetConfig(handle: i32, settings: *const i8) -> i32;
-
- // Frees a pointer allocated by the go runtime - useful to free return value of wgGetConfig
- fn wgFreePtr(ptr: *mut c_void);
-
- // Returns the file descriptor of the tunnel IPv4 socket.
- #[cfg(target_os = "android")]
- fn wgGetSocketV4(handle: i32) -> Fd;
+ log::info!("Initializing DAITA for wireguard device");
+ let peer_public_key = &self.config.entry_peer.public_key;
+ self.tunnel_handle
+ .activate_daita(
+ peer_public_key.as_bytes(),
+ machines,
+ DAITA_EVENTS_CAPACITY,
+ DAITA_ACTIONS_CAPACITY,
+ )
+ .map_err(|e| TunnelError::StartDaita(Box::new(e)))?;
- // Returns the file descriptor of the tunnel IPv6 socket.
- #[cfg(target_os = "android")]
- fn wgGetSocketV6(handle: i32) -> Fd;
+ Ok(())
+ }
}
mod stats {
@@ -438,13 +377,13 @@ mod stats {
mod logging {
use super::super::logging::{log, LogLevel};
- use std::ffi::{c_char, c_void};
+ use std::ffi::c_char;
// Callback that receives messages from WireGuard
pub unsafe extern "system" fn wg_go_logging_callback(
level: WgLogLevel,
msg: *const c_char,
- context: *mut c_void,
+ context: u64,
) {
let managed_msg = if !msg.is_null() {
std::ffi::CStr::from_ptr(msg).to_string_lossy().to_string()
@@ -457,7 +396,7 @@ mod logging {
_ => LogLevel::Error,
};
- log(context as u32, level, "wireguard-go", &managed_msg);
+ log(context, level, "wireguard-go", &managed_msg);
}
// wireguard-go supports log levels 0 through 3 with 3 being the most verbose
@@ -466,7 +405,4 @@ mod logging {
const WG_GO_LOG_VERBOSE: WgLogLevel = 2;
pub type WgLogLevel = u32;
-
- pub type LoggingCallback =
- unsafe extern "system" fn(level: WgLogLevel, msg: *const c_char, context: *mut c_void);
}
diff --git a/talpid-wireguard/src/wireguard_kernel/netlink_tunnel.rs b/talpid-wireguard/src/wireguard_kernel/netlink_tunnel.rs
index 579bcde65a..bdc187e1bd 100644
--- a/talpid-wireguard/src/wireguard_kernel/netlink_tunnel.rs
+++ b/talpid-wireguard/src/wireguard_kernel/netlink_tunnel.rs
@@ -93,7 +93,7 @@ impl Tunnel for NetlinkTunnel {
tokio_handle.block_on(async move {
if let Err(err) = netlink_connections.delete_device(interface_index).await {
log::error!("Failed to remove WireGuard device: {}", err);
- Err(TunnelError::FatalStartWireguardError)
+ Err(TunnelError::FatalStartWireguardError(Box::new(err)))
} else {
Ok(())
}
@@ -113,7 +113,7 @@ impl Tunnel for NetlinkTunnel {
}
fn set_config(
- &self,
+ &mut self,
config: Config,
) -> Pin<Box<dyn Future<Output = std::result::Result<(), TunnelError>> + Send + 'static>> {
let mut wg = self.netlink_connections.wg_handle.clone();
@@ -127,4 +127,9 @@ impl Tunnel for NetlinkTunnel {
})
})
}
+
+ // TODO: We shouldn't force `NetlinkTunnel` to implement `start_daita`
+ fn start_daita(&mut self) -> std::result::Result<(), TunnelError> {
+ unreachable!("Netlink tunnel does not support DAITA")
+ }
}
diff --git a/talpid-wireguard/src/wireguard_kernel/nm_tunnel.rs b/talpid-wireguard/src/wireguard_kernel/nm_tunnel.rs
index 7c24c42a70..9fc9351591 100644
--- a/talpid-wireguard/src/wireguard_kernel/nm_tunnel.rs
+++ b/talpid-wireguard/src/wireguard_kernel/nm_tunnel.rs
@@ -70,7 +70,7 @@ impl Tunnel for NetworkManagerTunnel {
if let Some(tunnel) = self.tunnel.take() {
if let Err(err) = self.network_manager.remove_tunnel(tunnel) {
log::error!("Failed to remove WireGuard tunnel via NM: {}", err);
- Err(TunnelError::StopWireguardError { status: 0 })
+ Err(TunnelError::StopWireguardError(Box::new(err)))
} else {
Ok(())
}
@@ -94,7 +94,7 @@ impl Tunnel for NetworkManagerTunnel {
}
fn set_config(
- &self,
+ &mut self,
config: Config,
) -> Pin<Box<dyn Future<Output = std::result::Result<(), TunnelError>> + Send>> {
let interface_name = self.interface_name.clone();
@@ -110,6 +110,11 @@ impl Tunnel for NetworkManagerTunnel {
})
})
}
+
+ // TODO: We shouldn't force `NetworkManagerTunnel` tunnel to implement `start_daita`
+ fn start_daita(&mut self) -> std::result::Result<(), TunnelError> {
+ unreachable!("NetworkManager tunnel does not support DAITA")
+ }
}
fn convert_config_to_dbus(config: &Config) -> DeviceConfig {
diff --git a/talpid-wireguard/src/wireguard_nt/mod.rs b/talpid-wireguard/src/wireguard_nt/mod.rs
index 375f28844a..de32c8d83e 100644
--- a/talpid-wireguard/src/wireguard_nt/mod.rs
+++ b/talpid-wireguard/src/wireguard_nt/mod.rs
@@ -70,7 +70,6 @@ type WireGuardGetConfigurationFn = unsafe extern "stdcall" fn(
type WireGuardSetStateFn =
unsafe extern "stdcall" fn(adapter: RawHandle, state: WgAdapterState) -> BOOL;
-#[cfg(windows)]
#[repr(C)]
#[allow(dead_code)]
enum LogLevel {
@@ -79,7 +78,6 @@ enum LogLevel {
Err = 2,
}
-#[cfg(windows)]
impl From<LogLevel> for logging::LogLevel {
fn from(level: LogLevel) -> Self {
match level {
@@ -430,7 +428,7 @@ impl WgNtTunnel {
match error {
Error::CreateTunnelDevice(error) => super::TunnelError::SetupTunnelDevice(error),
- _ => super::TunnelError::FatalStartWireguardError,
+ _ => super::TunnelError::FatalStartWireguardError(Box::new(error)),
}
})
}
@@ -542,11 +540,11 @@ impl Drop for WgNtTunnel {
}
}
-static LOG_CONTEXT: Lazy<Mutex<Option<u32>>> = Lazy::new(|| Mutex::new(None));
+static LOG_CONTEXT: Lazy<Mutex<Option<u64>>> = Lazy::new(|| Mutex::new(None));
struct LoggerHandle {
dll: &'static WgNtDll,
- context: u32,
+ context: u64,
}
impl LoggerHandle {
@@ -1080,7 +1078,7 @@ impl Tunnel for WgNtTunnel {
}
fn set_config(
- &self,
+ &mut self,
config: Config,
) -> Pin<Box<dyn Future<Output = std::result::Result<(), super::TunnelError>> + Send>> {
let device = self.device.clone();
@@ -1145,7 +1143,6 @@ mod tests {
allowed_ips: vec!["1.3.3.0/24".parse().unwrap()],
endpoint: "1.2.3.4:1234".parse().unwrap(),
psk: None,
- #[cfg(target_os = "windows")]
constant_packet_size: false,
},
exit_peer: None,
@@ -1181,7 +1178,6 @@ mod tests {
rx_bytes: 0,
last_handshake: 0,
allowed_ips_count: 1,
- #[cfg(target_os = "windows")]
constant_packet_size: 0,
},
p0_allowed_ip_0: WgAllowedIp {
diff --git a/wireguard-go-rs/Cargo.toml b/wireguard-go-rs/Cargo.toml
new file mode 100644
index 0000000000..3725787bf6
--- /dev/null
+++ b/wireguard-go-rs/Cargo.toml
@@ -0,0 +1,10 @@
+[package]
+name = "wireguard-go-rs"
+description = "Rust bindings to wireguard-go with DAITA support"
+edition = "2021"
+license.workspace = true
+
+[dependencies]
+thiserror.workspace = true
+log.workspace = true
+zeroize = "1.8.1"
diff --git a/wireguard-go-rs/README.md b/wireguard-go-rs/README.md
new file mode 100644
index 0000000000..24a08fa54b
--- /dev/null
+++ b/wireguard-go-rs/README.md
@@ -0,0 +1,10 @@
+# `wireguard-go-rs`
+This crate wraps `libwg`, which in turn wraps [Mullvad VPN's fork of wireguard-go](https://github.com/mullvad/wireguard-go) which extends `wireguard-go` with [DAITA](https://mullvad.net/en/blog/introducing-defense-against-ai-guided-traffic-analysis-daita).
+
+## Known limitation
+To extend `wireguard-go` with DAITA capabilities, it statically links against [maybenot](https://github.com/maybenot-io/maybenot/), which at the time of writing will cause issues if it in turn is statically linked from another Rust crate: https://github.com/rust-lang/rust/issues/104707.
+As such, `libwg` is built as a shared object which you have to link to dynamically.
+To get rid of this limitation, you could compile `wireguard-go` without DAITA support. See [build-wireguard-go.sh](./build-wireguard-go.sh) for details.
+
+## Upgrading `wireguard-go`
+Upgrading `wireguard-go` involves updating the git submodule found in `libwg/wireguard-go`. This module uses [Mullvad VPN's fork of wireguard-go](https://github.com/mullvad/wireguard-go).
diff --git a/wireguard/build-wireguard-go.sh b/wireguard-go-rs/build-wireguard-go.sh
index 7e3b11910d..e270fb8219 100755
--- a/wireguard/build-wireguard-go.sh
+++ b/wireguard-go-rs/build-wireguard-go.sh
@@ -1,30 +1,27 @@
#!/usr/bin/env bash
# This script is used to build wireguard-go libraries for all the platforms.
+#
+# If "DAITA" support should be enabled, pass the `--daita` flag when invoking this script.
set -eu
-function is_android_build {
- for arg in "$@"
- do
- case "$arg" in
- "--android")
- return 0
- esac
- done
- return 1
-}
+# If Wireguard-go should be built with DAITA-support.
+DAITA="false"
+# If the target OS is Adnroid.
+ANDROID="false"
-function is_docker_build {
- for arg in "$@"
- do
- case "$arg" in
- "--no-docker")
- return 1
- esac
- done
- return 0
-}
+while [[ "$#" -gt 0 ]]; do
+ case $1 in
+ --android) ANDROID="true";;
+ --daita) DAITA="true";;
+ *)
+ log_error "Unknown parameter: $1"
+ exit 1
+ ;;
+ esac
+ shift
+done
function unix_target_triple {
local platform
@@ -48,6 +45,7 @@ function unix_target_triple {
function build_unix {
+ # TODO: consider using `log_header` here
echo "Building wireguard-go for $1"
# Flags for cross compiling
@@ -80,33 +78,35 @@ function build_unix {
fi
fi
- pushd libwg
- target_triple_dir="../../build/lib/$1"
- mkdir -p "$target_triple_dir"
- go build -v -o "$target_triple_dir"/libwg.a -buildmode c-archive
+ # Build wiregaurd-go as a library
+ pushd libwg
+ if [[ "$DAITA" == "true" ]]; then
+ pushd wireguard-go
+ make libmaybenot.a LIBDEST="$OUT_DIR"
+ popd
+ go build -v --tags daita -o "$OUT_DIR"/libwg.a -buildmode c-archive
+ else
+ go build -v -o "$OUT_DIR"/libwg.a -buildmode c-archive
+ fi
popd
}
function build_android {
- echo "Building for android"
+ echo "Building wireguard-go for android"
- if is_docker_build "$@"; then
- ../building/container-run.sh android wireguard/libwg/build-android.sh
- else
- ./libwg/build-android.sh
- fi
+ ./libwg/build-android.sh
}
function build_wireguard_go {
- if is_android_build "$@"; then
+ if [[ "$ANDROID" == "true" ]]; then
build_android "$@"
return
fi
local platform
platform="$(uname -s)";
- case "$platform" in
+ case "$platform" in
Linux*|Darwin*) build_unix "${1:-$(unix_target_triple)}";;
*)
echo "Unsupported platform"
diff --git a/wireguard-go-rs/build.rs b/wireguard-go-rs/build.rs
new file mode 100644
index 0000000000..82a8ff254b
--- /dev/null
+++ b/wireguard-go-rs/build.rs
@@ -0,0 +1,65 @@
+use core::{panic, str};
+use std::{env, path::PathBuf};
+
+fn main() {
+ let out_dir = env::var("OUT_DIR").expect("Missing OUT_DIR");
+ eprintln!("OUT_DIR: {out_dir}");
+
+ let target_os = env::var("CARGO_CFG_TARGET_OS").expect("Missing 'CARGO_CFG_TARGET_OS");
+ let mut cmd = std::process::Command::new("bash");
+ cmd.arg("./build-wireguard-go.sh");
+
+ match target_os.as_str() {
+ "linux" => {
+ // Enable DAITA & Tell rustc to link libmaybenot
+ println!("cargo::rustc-link-lib=static=maybenot");
+ // Tell the build script to build wireguard-go with DAITA support
+ cmd.arg("--daita");
+ }
+ "android" => {
+ cmd.arg("--android");
+ }
+ "macos" => {}
+ // building wireguard-go-rs for windows is not implemented
+ _ => return,
+ }
+
+ let output = cmd.output().expect("build-wireguard-go.sh failed");
+ if !output.status.success() {
+ let stdout = str::from_utf8(&output.stdout).unwrap();
+ let stderr = str::from_utf8(&output.stderr).unwrap();
+ eprintln!("build-wireguard-go.sh failed.");
+ eprintln!("stdout:\n{stdout}");
+ eprintln!("stderr:\n{stderr}");
+ panic!();
+ }
+
+ if target_os == "android" {
+ // NOTE: Go programs does not support being statically linked on android
+ // so we need to dynamically link to libwg
+ println!("cargo::rustc-link-lib=wg");
+ declare_libs_dir("../build/lib");
+ } else {
+ // other platforms can statically link to libwg just fine
+ // TODO: consider doing dynamic linking everywhere, to keep things simpler
+ println!("cargo::rustc-link-lib=static=wg");
+ println!("cargo::rustc-link-search={out_dir}");
+ }
+
+ println!("cargo::rerun-if-changed=libwg");
+}
+
+/// Tell linker to check `base`/$TARGET for shared libraries.
+fn declare_libs_dir(base: &str) {
+ let target_triplet = env::var("TARGET").expect("TARGET is not set");
+ let lib_dir = manifest_dir().join(base).join(target_triplet);
+ println!("cargo::rerun-if-changed={}", lib_dir.display());
+ println!("cargo::rustc-link-search={}", lib_dir.display());
+}
+
+/// Get the directory containing `Cargo.toml`
+fn manifest_dir() -> PathBuf {
+ env::var("CARGO_MANIFEST_DIR")
+ .map(PathBuf::from)
+ .expect("CARGO_MANIFEST_DIR env var not set")
+}
diff --git a/wireguard/libwg/Android.mk b/wireguard-go-rs/libwg/Android.mk
index acf2f6fe88..acf2f6fe88 100644
--- a/wireguard/libwg/Android.mk
+++ b/wireguard-go-rs/libwg/Android.mk
diff --git a/wireguard/libwg/README.md b/wireguard-go-rs/libwg/README.md
index 39ad48e3e0..39ad48e3e0 100644
--- a/wireguard/libwg/README.md
+++ b/wireguard-go-rs/libwg/README.md
diff --git a/wireguard/libwg/build-android.sh b/wireguard-go-rs/libwg/build-android.sh
index 750438c503..750438c503 100755
--- a/wireguard/libwg/build-android.sh
+++ b/wireguard-go-rs/libwg/build-android.sh
diff --git a/wireguard/libwg/go.mod b/wireguard-go-rs/libwg/go.mod
index 2f65c7bf7b..76627dcb7f 100644
--- a/wireguard/libwg/go.mod
+++ b/wireguard-go-rs/libwg/go.mod
@@ -10,5 +10,7 @@ require (
require (
golang.org/x/crypto v0.22.0 // indirect
golang.org/x/net v0.24.0 // indirect
- golang.zx2c4.com/wintun v0.0.0-20211104114900-415007cec224 // indirect
+ golang.zx2c4.com/wintun v0.0.0-20230126152724-0fa3db229ce2 // indirect
)
+
+replace golang.zx2c4.com/wireguard => ./wireguard-go
diff --git a/wireguard/libwg/go.sum b/wireguard-go-rs/libwg/go.sum
index c90f293b17..b41c5842d1 100644
--- a/wireguard/libwg/go.sum
+++ b/wireguard-go-rs/libwg/go.sum
@@ -4,7 +4,5 @@ golang.org/x/net v0.24.0 h1:1PcaxkF854Fu3+lvBIx5SYn9wRlBzzcnHZSiaFFAb0w=
golang.org/x/net v0.24.0/go.mod h1:2Q7sJY5mzlzWjKtYUEXSlBWCdyaioyXzRB2RtU8KVE8=
golang.org/x/sys v0.19.0 h1:q5f1RH2jigJ1MoAWp2KTp3gm5zAGFUTarQZ5U386+4o=
golang.org/x/sys v0.19.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA=
-golang.zx2c4.com/wintun v0.0.0-20211104114900-415007cec224 h1:Ug9qvr1myri/zFN6xL17LSCBGFDnphBBhzmILHsM5TY=
-golang.zx2c4.com/wintun v0.0.0-20211104114900-415007cec224/go.mod h1:deeaetjYA+DHMHg+sMSMI58GrEteJUUzzw7en6TJQcI=
-golang.zx2c4.com/wireguard v0.0.0-20230223181233-21636207a675 h1:/J/RVnr7ng4fWPRH3xa4WtBJ1Jp+Auu4YNLmGiPv5QU=
-golang.zx2c4.com/wireguard v0.0.0-20230223181233-21636207a675/go.mod h1:whfbyDBt09xhCYQWtO2+3UVjlaq6/9hDZrjg2ZE6SyA=
+golang.zx2c4.com/wintun v0.0.0-20230126152724-0fa3db229ce2 h1:B82qJJgjvYKsXS9jeunTOisW56dUokqW/FOteYJJ/yg=
+golang.zx2c4.com/wintun v0.0.0-20230126152724-0fa3db229ce2/go.mod h1:deeaetjYA+DHMHg+sMSMI58GrEteJUUzzw7en6TJQcI=
diff --git a/wireguard-go-rs/libwg/goruntime-boottime-over-monotonic.diff b/wireguard-go-rs/libwg/goruntime-boottime-over-monotonic.diff
new file mode 100644
index 0000000000..5d78242b13
--- /dev/null
+++ b/wireguard-go-rs/libwg/goruntime-boottime-over-monotonic.diff
@@ -0,0 +1,171 @@
+From 61f3ae8298d1c503cbc31539e0f3a73446c7db9d Mon Sep 17 00:00:00 2001
+From: "Jason A. Donenfeld" <Jason@zx2c4.com>
+Date: Tue, 21 Mar 2023 15:33:56 +0100
+Subject: [PATCH] [release-branch.go1.20] runtime: use CLOCK_BOOTTIME in
+ nanotime on Linux
+
+This makes timers account for having expired while a computer was
+asleep, which is quite common on mobile devices. Note that BOOTTIME is
+identical to MONOTONIC, except that it takes into account time spent
+in suspend. In Linux 4.17, the kernel will actually make MONOTONIC act
+like BOOTTIME anyway, so this switch will additionally unify the
+timer behavior across kernels.
+
+BOOTTIME was introduced into Linux 2.6.39-rc1 with 70a08cca1227d in
+2011.
+
+Fixes #24595
+
+Change-Id: I7b2a6ca0c5bc5fce57ec0eeafe7b68270b429321
+---
+ src/runtime/sys_linux_386.s | 4 ++--
+ src/runtime/sys_linux_amd64.s | 2 +-
+ src/runtime/sys_linux_arm.s | 4 ++--
+ src/runtime/sys_linux_arm64.s | 4 ++--
+ src/runtime/sys_linux_mips64x.s | 4 ++--
+ src/runtime/sys_linux_mipsx.s | 2 +-
+ src/runtime/sys_linux_ppc64x.s | 2 +-
+ src/runtime/sys_linux_s390x.s | 2 +-
+ 8 files changed, 12 insertions(+), 12 deletions(-)
+
+diff --git a/src/runtime/sys_linux_386.s b/src/runtime/sys_linux_386.s
+index 12a294153d..17e3524b40 100644
+--- a/src/runtime/sys_linux_386.s
++++ b/src/runtime/sys_linux_386.s
+@@ -352,13 +352,13 @@ noswitch:
+
+ LEAL 8(SP), BX // &ts (struct timespec)
+ MOVL BX, 4(SP)
+- MOVL $1, 0(SP) // CLOCK_MONOTONIC
++ MOVL $7, 0(SP) // CLOCK_BOOTTIME
+ CALL AX
+ JMP finish
+
+ fallback:
+ MOVL $SYS_clock_gettime, AX
+- MOVL $1, BX // CLOCK_MONOTONIC
++ MOVL $7, BX // CLOCK_BOOTTIME
+ LEAL 8(SP), CX
+ INVOKE_SYSCALL
+
+diff --git a/src/runtime/sys_linux_amd64.s b/src/runtime/sys_linux_amd64.s
+index c7a89ba536..01f0a6a26e 100644
+--- a/src/runtime/sys_linux_amd64.s
++++ b/src/runtime/sys_linux_amd64.s
+@@ -255,7 +255,7 @@ noswitch:
+ SUBQ $16, SP // Space for results
+ ANDQ $~15, SP // Align for C code
+
+- MOVL $1, DI // CLOCK_MONOTONIC
++ MOVL $7, DI // CLOCK_BOOTTIME
+ LEAQ 0(SP), SI
+ MOVQ runtime·vdsoClockgettimeSym(SB), AX
+ CMPQ AX, $0
+diff --git a/src/runtime/sys_linux_arm.s b/src/runtime/sys_linux_arm.s
+index 7b8c4f0e04..9798a1334e 100644
+--- a/src/runtime/sys_linux_arm.s
++++ b/src/runtime/sys_linux_arm.s
+@@ -11,7 +11,7 @@
+ #include "textflag.h"
+
+ #define CLOCK_REALTIME 0
+-#define CLOCK_MONOTONIC 1
++#define CLOCK_BOOTTIME 7
+
+ // for EABI, as we don't support OABI
+ #define SYS_BASE 0x0
+@@ -374,7 +374,7 @@ finish:
+
+ // func nanotime1() int64
+ TEXT runtime·nanotime1(SB),NOSPLIT,$12-8
+- MOVW $CLOCK_MONOTONIC, R0
++ MOVW $CLOCK_BOOTTIME, R0
+ MOVW $spec-12(SP), R1 // timespec
+
+ MOVW runtime·vdsoClockgettimeSym(SB), R4
+diff --git a/src/runtime/sys_linux_arm64.s b/src/runtime/sys_linux_arm64.s
+index 38ff6ac330..6b819c5441 100644
+--- a/src/runtime/sys_linux_arm64.s
++++ b/src/runtime/sys_linux_arm64.s
+@@ -14,7 +14,7 @@
+ #define AT_FDCWD -100
+
+ #define CLOCK_REALTIME 0
+-#define CLOCK_MONOTONIC 1
++#define CLOCK_BOOTTIME 7
+
+ #define SYS_exit 93
+ #define SYS_read 63
+@@ -338,7 +338,7 @@ noswitch:
+ BIC $15, R1
+ MOVD R1, RSP
+
+- MOVW $CLOCK_MONOTONIC, R0
++ MOVW $CLOCK_BOOTTIME, R0
+ MOVD runtime·vdsoClockgettimeSym(SB), R2
+ CBZ R2, fallback
+
+diff --git a/src/runtime/sys_linux_mips64x.s b/src/runtime/sys_linux_mips64x.s
+index 47f2da524d..a8b387f193 100644
+--- a/src/runtime/sys_linux_mips64x.s
++++ b/src/runtime/sys_linux_mips64x.s
+@@ -326,7 +326,7 @@ noswitch:
+ AND $~15, R1 // Align for C code
+ MOVV R1, R29
+
+- MOVW $1, R4 // CLOCK_MONOTONIC
++ MOVW $7, R4 // CLOCK_BOOTTIME
+ MOVV $0(R29), R5
+
+ MOVV runtime·vdsoClockgettimeSym(SB), R25
+@@ -336,7 +336,7 @@ noswitch:
+ // see walltime for detail
+ BEQ R2, R0, finish
+ MOVV R0, runtime·vdsoClockgettimeSym(SB)
+- MOVW $1, R4 // CLOCK_MONOTONIC
++ MOVW $7, R4 // CLOCK_BOOTTIME
+ MOVV $0(R29), R5
+ JMP fallback
+
+diff --git a/src/runtime/sys_linux_mipsx.s b/src/runtime/sys_linux_mipsx.s
+index 5e6b6c1504..7f5fd2a80e 100644
+--- a/src/runtime/sys_linux_mipsx.s
++++ b/src/runtime/sys_linux_mipsx.s
+@@ -243,7 +243,7 @@ TEXT runtime·walltime(SB),NOSPLIT,$8-12
+ RET
+
+ TEXT runtime·nanotime1(SB),NOSPLIT,$8-8
+- MOVW $1, R4 // CLOCK_MONOTONIC
++ MOVW $7, R4 // CLOCK_BOOTTIME
+ MOVW $4(R29), R5
+ MOVW $SYS_clock_gettime, R2
+ SYSCALL
+diff --git a/src/runtime/sys_linux_ppc64x.s b/src/runtime/sys_linux_ppc64x.s
+index d0427a4807..05ee9fede9 100644
+--- a/src/runtime/sys_linux_ppc64x.s
++++ b/src/runtime/sys_linux_ppc64x.s
+@@ -298,7 +298,7 @@ fallback:
+ JMP return
+
+ TEXT runtime·nanotime1(SB),NOSPLIT,$16-8
+- MOVD $1, R3 // CLOCK_MONOTONIC
++ MOVD $7, R3 // CLOCK_BOOTTIME
+
+ MOVD R1, R15 // R15 is unchanged by C code
+ MOVD g_m(g), R21 // R21 = m
+diff --git a/src/runtime/sys_linux_s390x.s b/src/runtime/sys_linux_s390x.s
+index 1448670b91..7d2ee3231c 100644
+--- a/src/runtime/sys_linux_s390x.s
++++ b/src/runtime/sys_linux_s390x.s
+@@ -296,7 +296,7 @@ fallback:
+ RET
+
+ TEXT runtime·nanotime1(SB),NOSPLIT,$32-8
+- MOVW $1, R2 // CLOCK_MONOTONIC
++ MOVW $7, R2 // CLOCK_BOOTTIME
+
+ MOVD R15, R7 // Backup stack pointer
+
+--
+2.17.1
+
diff --git a/wireguard/libwg/libwg.go b/wireguard-go-rs/libwg/libwg.go
index e26ea7b7da..aaa03ef838 100644
--- a/wireguard/libwg/libwg.go
+++ b/wireguard-go-rs/libwg/libwg.go
@@ -6,7 +6,9 @@
package main
+// #include <stdio.h>
// #include <stdlib.h>
+// #include <stdint.h>
import "C"
import (
@@ -17,11 +19,31 @@ import (
"unsafe"
"github.com/mullvad/mullvadvpn-app/wireguard/libwg/tunnelcontainer"
+ "golang.zx2c4.com/wireguard/device"
)
+// FFI integer result codes
+// NOTE: Must be kept in sync with the Error enum in wireguard-go-rs
const (
- ERROR_GENERAL_FAILURE = -1
- ERROR_INTERMITTENT_FAILURE = -2
+ OK = C.int32_t(-iota)
+
+ // Something went wrong.
+ ERROR_GENERAL_FAILURE
+
+ // Something went wrong, but trying again might help.
+ ERROR_INTERMITTENT_FAILURE
+
+ // A bad argument was provided to libwg.
+ ERROR_INVALID_ARGUMENT
+
+ // The provided tunnel handle did not refer to an existing tunnel.
+ ERROR_UNKNOWN_TUNNEL
+
+ // The provided public key did not refer to an existing peer.
+ ERROR_UNKNOWN_PEER
+
+ // Something went wrong when enabling DAITA.
+ ERROR_ENABLE_DAITA
)
var tunnels tunnelcontainer.Container
@@ -30,6 +52,11 @@ func init() {
tunnels = tunnelcontainer.New()
}
+type EventContext struct {
+ tunnelHandle int32
+ peer device.NoisePublicKey
+}
+
//export wgTurnOff
func wgTurnOff(tunnelHandle int32) {
{
@@ -61,14 +88,14 @@ func wgGetConfig(tunnelHandle int32) *C.char {
}
//export wgSetConfig
-func wgSetConfig(tunnelHandle int32, cSettings *C.char) int32 {
+func wgSetConfig(tunnelHandle int32, cSettings *C.char) C.int32_t {
tunnel, err := tunnels.Get(tunnelHandle)
if err != nil {
- return ERROR_GENERAL_FAILURE
+ return ERROR_UNKNOWN_TUNNEL
}
if cSettings == nil {
tunnel.Logger.Errorf("cSettings is null\n")
- return ERROR_GENERAL_FAILURE
+ return ERROR_INVALID_ARGUMENT
}
settings := C.GoString(cSettings)
diff --git a/wireguard-go-rs/libwg/libwg.h b/wireguard-go-rs/libwg/libwg.h
new file mode 100644
index 0000000000..23c87092cc
--- /dev/null
+++ b/wireguard-go-rs/libwg/libwg.h
@@ -0,0 +1,8 @@
+#include <stdint.h>
+#include <stdbool.h>
+
+/// Activate DAITA for the specified tunnel.
+int32_t wgActivateDaita(int32_t tunnelHandle, uint8_t* noisePublic, char* machines, uint32_t eventsCapacity, uint32_t actionsCapacity);
+char* wgGetConfig(int32_t tunnelHandle);
+int32_t wgSetConfig(int32_t tunnelHandle, char* cSettings);
+void wgFreePtr(void*);
diff --git a/wireguard/libwg/libwg_android.go b/wireguard-go-rs/libwg/libwg_android.go
index df54d4cc8b..86410721f5 100644
--- a/wireguard/libwg/libwg_android.go
+++ b/wireguard-go-rs/libwg/libwg_android.go
@@ -6,8 +6,10 @@
package main
+// #include <stdint.h>
+import "C"
+
import (
- "C"
"bufio"
"strings"
"unsafe"
@@ -25,15 +27,15 @@ import (
// Redefined here because otherwise the compiler doesn't realize it's a type alias for a type that's safe to export.
// Taken from the contained logging package.
type LogSink = unsafe.Pointer
-type LogContext = unsafe.Pointer
+type LogContext = C.uint64_t
//export wgTurnOn
-func wgTurnOn(cSettings *C.char, fd int, logSink LogSink, logContext LogContext) int32 {
- logger := logging.NewLogger(logSink, logContext)
+func wgTurnOn(cSettings *C.char, fd int, logSink LogSink, logContext LogContext) C.int32_t {
+ logger := logging.NewLogger(logSink, logging.LogContext(logContext))
if cSettings == nil {
logger.Errorf("cSettings is null\n")
- return ERROR_GENERAL_FAILURE
+ return ERROR_INVALID_ARGUMENT
}
settings := C.GoString(cSettings)
@@ -71,33 +73,33 @@ func wgTurnOn(cSettings *C.char, fd int, logSink LogSink, logContext LogContext)
return ERROR_GENERAL_FAILURE
}
- return handle
+ return C.int32_t(handle)
}
//export wgGetSocketV4
-func wgGetSocketV4(tunnelHandle int32) int32 {
+func wgGetSocketV4(tunnelHandle int32) C.int32_t {
tunnel, err := tunnels.Get(tunnelHandle)
if err != nil {
- return ERROR_GENERAL_FAILURE
+ return ERROR_UNKNOWN_TUNNEL
}
peek := tunnel.Device.Bind().(conn.PeekLookAtSocketFd)
fd, err := peek.PeekLookAtSocketFd4()
if err != nil {
return ERROR_GENERAL_FAILURE
}
- return int32(fd)
+ return C.int32_t(fd)
}
//export wgGetSocketV6
-func wgGetSocketV6(tunnelHandle int32) int32 {
+func wgGetSocketV6(tunnelHandle int32) C.int32_t {
tunnel, err := tunnels.Get(tunnelHandle)
if err != nil {
- return ERROR_GENERAL_FAILURE
+ return ERROR_UNKNOWN_TUNNEL
}
peek := tunnel.Device.Bind().(conn.PeekLookAtSocketFd)
fd, err := peek.PeekLookAtSocketFd6()
if err != nil {
return ERROR_GENERAL_FAILURE
}
- return int32(fd)
+ return C.int32_t(fd)
}
diff --git a/wireguard-go-rs/libwg/libwg_daita.go b/wireguard-go-rs/libwg/libwg_daita.go
new file mode 100644
index 0000000000..e33de84e49
--- /dev/null
+++ b/wireguard-go-rs/libwg/libwg_daita.go
@@ -0,0 +1,41 @@
+//go:build daita
+// +build daita
+
+package main
+
+// #include <stdio.h>
+// #include <stdlib.h>
+// #include <stdint.h>
+import "C"
+
+import (
+ "unsafe"
+
+ "golang.zx2c4.com/wireguard/device"
+)
+
+const maxPaddingBytes = 0.0
+const maxBlockingBytes = 0.0
+
+//export wgActivateDaita
+func wgActivateDaita(tunnelHandle C.int32_t, peerPubkey *C.uint8_t, machines *C.char, eventsCapacity C.uint32_t, actionsCapacity C.uint32_t) C.int32_t {
+
+ tunnel, err := tunnels.Get(int32(tunnelHandle))
+ if err != nil {
+ return ERROR_UNKNOWN_TUNNEL
+ }
+
+ var publicKey device.NoisePublicKey
+ copy(publicKey[:], C.GoBytes(unsafe.Pointer(peerPubkey), device.NoisePublicKeySize))
+ peer := tunnel.Device.LookupPeer(publicKey)
+
+ if peer == nil {
+ return ERROR_UNKNOWN_PEER
+ }
+
+ if !peer.EnableDaita(C.GoString((*C.char)(machines)), uint(eventsCapacity), uint(actionsCapacity), maxPaddingBytes, maxBlockingBytes) {
+ return ERROR_ENABLE_DAITA
+ }
+
+ return OK
+}
diff --git a/wireguard/libwg/libwg_default.go b/wireguard-go-rs/libwg/libwg_default.go
index 7282c0ca8a..6741de715f 100644
--- a/wireguard/libwg/libwg_default.go
+++ b/wireguard-go-rs/libwg/libwg_default.go
@@ -1,3 +1,4 @@
+//go:build (darwin || linux) && !android
// +build darwin linux
// +build !android
@@ -10,6 +11,7 @@
package main
// #include <stdlib.h>
+// #include <stdint.h>
import "C"
import (
"bufio"
@@ -28,16 +30,15 @@ import (
// Redefined here because otherwise the compiler doesn't realize it's a type alias for a type that's safe to export.
// Taken from the contained logging package.
type LogSink = unsafe.Pointer
-type LogContext = unsafe.Pointer
+type LogContext = C.uint64_t
//export wgTurnOn
-func wgTurnOn(mtu int, cSettings *C.char, fd int, logSink LogSink, logContext LogContext) int32 {
-
- logger := logging.NewLogger(logSink, logContext)
+func wgTurnOn(mtu int, cSettings *C.char, fd int, logSink LogSink, logContext LogContext) C.int32_t {
+ logger := logging.NewLogger(logSink, logging.LogContext(logContext))
if cSettings == nil {
logger.Errorf("cSettings is null\n")
- return ERROR_GENERAL_FAILURE
+ return ERROR_INVALID_ARGUMENT
}
settings := C.GoString(cSettings)
@@ -74,5 +75,5 @@ func wgTurnOn(mtu int, cSettings *C.char, fd int, logSink LogSink, logContext Lo
return ERROR_GENERAL_FAILURE
}
- return handle
+ return C.int32_t(handle)
}
diff --git a/wireguard/libwg/logging/logging.go b/wireguard-go-rs/libwg/logging/logging.go
index a917e96493..a6782ec39a 100644
--- a/wireguard/libwg/logging/logging.go
+++ b/wireguard-go-rs/libwg/logging/logging.go
@@ -7,12 +7,13 @@
package logging
// #include <stdlib.h>
+// #include <stdint.h>
// #include <sys/types.h>
// #ifndef WIN32
// #define __stdcall
// #endif
-// typedef void (__stdcall *LogSink)(unsigned int, const char *, void *);
-// static void callLogSink(void *logSink, int level, const char *message, void *context)
+// typedef void (__stdcall *LogSink)(unsigned int, const char *, uint64_t);
+// static void callLogSink(void *logSink, int level, const char *message, uint64_t context)
// {
// ((LogSink)logSink)((unsigned int)level, message, context);
// }
@@ -27,7 +28,7 @@ import (
// Define type aliases.
type LogSink = unsafe.Pointer
-type LogContext = unsafe.Pointer
+type LogContext = C.uint64_t
type Logger struct {
sink LogSink
diff --git a/wireguard/libwg/tunnelcontainer/tunnelcontainer.go b/wireguard-go-rs/libwg/tunnelcontainer/tunnelcontainer.go
index 91291dcf4b..91291dcf4b 100644
--- a/wireguard/libwg/tunnelcontainer/tunnelcontainer.go
+++ b/wireguard-go-rs/libwg/tunnelcontainer/tunnelcontainer.go
diff --git a/wireguard-go-rs/libwg/wireguard-go b/wireguard-go-rs/libwg/wireguard-go
new file mode 160000
+Subproject f4bc3aefeb6c6ae50567a4bfc177593337eb9cb
diff --git a/wireguard-go-rs/src/lib.rs b/wireguard-go-rs/src/lib.rs
new file mode 100644
index 0000000000..f87b2586d0
--- /dev/null
+++ b/wireguard-go-rs/src/lib.rs
@@ -0,0 +1,288 @@
+//! This crate provides Rust bindings to wireguard-go with DAITA support.
+//!
+//! The bindings on the Go side are provided by `libwg`, which is a Go package that wraps
+//! `wireguard-go` and provides a C FFI that we can use from Rust. On the Rust side, the FFI is
+//! in the private `ffi` module below. It needs to be kept in sync with any changes to libwg.
+//!
+//! The [`Tunnel`] type provides a safe Rust wrapper around the C FFI.
+
+#![cfg(unix)]
+
+use core::slice;
+use std::{
+ ffi::{c_char, CStr},
+ mem::ManuallyDrop,
+};
+use util::OnDrop;
+use zeroize::Zeroize;
+
+mod util;
+
+pub type Fd = std::os::unix::io::RawFd;
+
+pub type WgLogLevel = u32;
+
+pub type LoggingContext = u64;
+pub type LoggingCallback =
+ unsafe extern "system" fn(level: WgLogLevel, msg: *const c_char, context: LoggingContext);
+
+/// A wireguard-go tunnel
+pub struct Tunnel {
+ /// wireguard-go handle to the tunnel.
+ handle: i32,
+}
+
+// NOTE: Must be kept in sync with libwg.go
+// NOTE: must be kept in sync with `result_from_code`
+// INVARIANT: Will always be represented as a negative i32
+#[repr(i32)]
+#[non_exhaustive]
+#[derive(Clone, Copy, Debug, thiserror::Error)]
+pub enum Error {
+ #[error("Something went wrong.")]
+ GeneralFailure = -1,
+
+ #[error("Something went wrong, but trying again might help.")]
+ IntermittentFailure = -2,
+
+ #[error("An argument you provided was invalid.")]
+ InvalidArgument = -3,
+
+ #[error("The tunnel handle did not refer to an existing tunnel.")]
+ UnknownTunnel = -4,
+
+ #[error("The provided public key did not refer to an existing peer.")]
+ UnknownPeer = -5,
+
+ #[error("Something went wrong when enabling DAITA.")]
+ EnableDaita = -6,
+
+ #[error("`libwg` provided an unknown error code. This is a bug.")]
+ Other = i32::MIN,
+}
+
+impl Tunnel {
+ /// Creates a new wireguard tunnel, uses the specific interface name, and file descriptors
+ /// for the tunnel device and logging. For targets other than android, this also takes an MTU
+ /// value.
+ ///
+ /// The `logging_callback` let's you provide a Rust function that receives any logging output
+ /// from wireguard-go. `logging_context` is a value that will be passed to each invocation of
+ /// `logging_callback`.
+ pub fn turn_on(
+ #[cfg(not(target_os = "android"))] mtu: isize,
+ settings: &CStr,
+ device: Fd,
+ logging_callback: Option<LoggingCallback>,
+ logging_context: LoggingContext,
+ ) -> Result<Self, Error> {
+ // SAFETY: pointer is valid for the the lifetime of this function
+ let code = unsafe {
+ ffi::wgTurnOn(
+ #[cfg(not(target_os = "android"))]
+ mtu,
+ settings.as_ptr(),
+ device,
+ logging_callback,
+ logging_context,
+ )
+ };
+
+ result_from_code(code)?;
+ Ok(Tunnel { handle: code })
+ }
+
+ /// Stop the wireguard tunnel. This also happens automatically if the [`Tunnel`] is dropped.
+ pub fn turn_off(self) -> Result<(), Error> {
+ // we manually turn off the tunnel here, so wrap it in ManuallyDrop to prevent the Drop
+ // impl from doing the same.
+ let code = unsafe { ffi::wgTurnOff(self.handle) };
+ let _ = ManuallyDrop::new(self);
+ result_from_code(code)
+ }
+
+ /// Get the config of the WireGuard interface and make it available in the provided function.
+ ///
+ /// This takes a function to make sure the cstr get's zeroed and freed afterwards.
+ /// Returns `None` if the call to wgGetConfig returned nil.
+ ///
+ /// **NOTE:** You should take extra care to avoid copying any secrets from the config without
+ /// zeroizing them afterwards.
+ // NOTE: this could return a guard type with a custom Drop impl instead, but me lazy.
+ pub fn get_config<T>(&self, f: impl FnOnce(&CStr) -> T) -> Option<T> {
+ let ptr = unsafe { ffi::wgGetConfig(self.handle) };
+
+ if ptr.is_null() {
+ return None;
+ }
+
+ // SAFETY: we checked for null, and wgGetConfig promises that this is a valid cstr
+ let config = unsafe { CStr::from_ptr(ptr) };
+ let config_len = config.to_bytes().len();
+
+ // execute cleanup code on Drop to make sure that it happens even if `f` panics
+ let on_drop = OnDrop::new(|| {
+ {
+ // SAFETY:
+ // we checked for null, and wgGetConfig promises that this is a valid cstr.
+ // config_len comes from the CStr above, so it should be good.
+ let config_bytes = unsafe { slice::from_raw_parts_mut(ptr, config_len) };
+ config_bytes.zeroize();
+ }
+
+ // SAFETY: the pointer was created by wgGetConfig, and we are no longer using it.
+ unsafe { ffi::wgFreePtr(ptr.cast()) };
+ });
+
+ let t = f(config);
+ let _ = config;
+ drop(on_drop);
+
+ Some(t)
+ }
+
+ /// Set the config of the WireGuard interface.
+ pub fn set_config(&self, config: &CStr) -> Result<(), Error> {
+ // SAFETY: pointer is valid for the lifetime of this function.
+ let code = unsafe { ffi::wgSetConfig(self.handle, config.as_ptr()) };
+ result_from_code(code)
+ }
+
+ /// Activate DAITA for the specified peer.
+ ///
+ /// `machines` is a string containing LF-separated maybenot machines.
+ #[cfg(any(target_os = "windows", target_os = "linux"))]
+ pub fn activate_daita(
+ &self,
+ peer_public_key: &[u8; 32],
+ machines: &CStr,
+ events_capacity: u32,
+ actions_capacity: u32,
+ ) -> Result<(), Error> {
+ // SAFETY: pointers are valid for the lifetime of this function.
+ let code = unsafe {
+ ffi::wgActivateDaita(
+ self.handle,
+ peer_public_key.as_ptr(),
+ machines.as_ptr(),
+ events_capacity,
+ actions_capacity,
+ )
+ };
+
+ result_from_code(code)
+ }
+
+ /// Get the file descriptor of the tunnel IPv4 socket.
+ #[cfg(target_os = "android")]
+ pub fn get_socket_v4(&self) -> Fd {
+ unsafe { ffi::wgGetSocketV4(self.handle) }
+ }
+
+ /// Get the file descriptor of the tunnel IPv6 socket.
+ #[cfg(target_os = "android")]
+ pub fn get_socket_v6(&self) -> Fd {
+ unsafe { ffi::wgGetSocketV6(self.handle) }
+ }
+}
+
+impl Drop for Tunnel {
+ fn drop(&mut self) {
+ let code = unsafe { ffi::wgTurnOff(self.handle) };
+ if let Err(e) = result_from_code(code) {
+ log::error!("Failed to stop wireguard-go tunnel,oerror_code={code} ({e:?})")
+ }
+ }
+}
+
+fn result_from_code(code: i32) -> Result<(), Error> {
+ // NOTE: must be kept in sync with enum definition
+ Err(match code {
+ 0.. => return Ok(()),
+ -1 => Error::GeneralFailure,
+ -2 => Error::IntermittentFailure,
+ -3 => Error::UnknownTunnel,
+ -4 => Error::UnknownPeer,
+ -5 => Error::EnableDaita,
+ _ => Error::Other,
+ })
+}
+
+impl Error {
+ pub const fn as_raw(self) -> i32 {
+ self as i32
+ }
+}
+
+mod ffi {
+ use super::{Fd, LoggingCallback, LoggingContext};
+ use core::ffi::{c_char, c_void};
+
+ extern "C" {
+ /// Creates a new wireguard tunnel, uses the specific interface name, and file descriptors
+ /// for the tunnel device and logging. For targets other than android, this also takes an
+ /// MTU value.
+ ///
+ /// Positive return values are tunnel handles for this specific wireguard tunnel instance.
+ /// Negative return values signify errors.
+ pub fn wgTurnOn(
+ #[cfg(not(target_os = "android"))] mtu: isize,
+ settings: *const c_char,
+ fd: Fd,
+ logging_callback: Option<LoggingCallback>,
+ logging_context: LoggingContext,
+ ) -> i32;
+
+ /// Pass a handle that was created by wgTurnOn to stop a wireguard tunnel.
+ ///
+ /// Negative return values signify errors.
+ pub fn wgTurnOff(handle: i32) -> i32;
+
+ /// Get the config of the WireGuard interface. Returns null in case of error.
+ ///
+ /// # Safety:
+ /// - The function returns an owned pointer to a null-terminated UTF-8 string.
+ /// - The pointer may only be freed using [wgFreePtr].
+ pub fn wgGetConfig(handle: i32) -> *mut c_char;
+
+ /// Set the config of the WireGuard interface.
+ ///
+ /// Negative return values signify errors.
+ ///
+ /// # Safety:
+ /// - `settings` must point to a null-terminated UTF-8 string.
+ /// - The pointer will not be read from after `wgActivateDaita` has returned.
+ pub fn wgSetConfig(handle: i32, settings: *const c_char) -> i32;
+
+ /// Activate DAITA for the specified peer.
+ ///
+ /// `tunnel_handle` must come from [wgTurnOn]. `machines` is a string containing
+ /// LF-separated maybenot machines.
+ ///
+ /// Negative return values signify errors.
+ ///
+ /// # Safety:
+ /// - `peer_public_key` must point to a 32 byte array.
+ /// - `machines` must point to a null-terminated UTF-8 string.
+ /// - Neither pointer will be read from after `wgActivateDaita` has returned.
+ #[cfg(any(target_os = "windows", target_os = "linux"))]
+ pub fn wgActivateDaita(
+ tunnel_handle: i32,
+ peer_public_key: *const u8,
+ machines: *const c_char,
+ events_capacity: u32,
+ actions_capacity: u32,
+ ) -> i32;
+
+ /// Free a pointer allocated by the go runtime - useful to free return value of wgGetConfig
+ pub fn wgFreePtr(ptr: *mut c_void);
+
+ /// Get the file descriptor of the tunnel IPv4 socket.
+ #[cfg(target_os = "android")]
+ pub fn wgGetSocketV4(handle: i32) -> Fd;
+
+ /// Get the file descriptor of the tunnel IPv6 socket.
+ #[cfg(target_os = "android")]
+ pub fn wgGetSocketV6(handle: i32) -> Fd;
+ }
+}
diff --git a/wireguard-go-rs/src/util.rs b/wireguard-go-rs/src/util.rs
new file mode 100644
index 0000000000..4c2df2f9bb
--- /dev/null
+++ b/wireguard-go-rs/src/util.rs
@@ -0,0 +1,15 @@
+pub struct OnDrop<F: FnOnce()>(Option<F>);
+
+impl<F: FnOnce()> OnDrop<F> {
+ pub fn new(f: F) -> Self {
+ OnDrop(Some(f))
+ }
+}
+
+impl<F: FnOnce()> Drop for OnDrop<F> {
+ fn drop(&mut self) {
+ if let Some(f) = self.0.take() {
+ f()
+ }
+ }
+}