summaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorOdd Stranne <odd@mullvad.net>2020-02-10 00:21:27 +0100
committerOdd Stranne <odd@mullvad.net>2020-02-13 11:29:24 +0100
commit57f3160281a6acea74c8bbc44cd96722c7ddcaa4 (patch)
tree41a15a83ae819e393e209ac4ff54baccdd11b867
parent2781b614d5ffd6949385de3c5dd1ea2a6704ff3a (diff)
downloadmullvadvpn-57f3160281a6acea74c8bbc44cd96722c7ddcaa4.tar.xz
mullvadvpn-57f3160281a6acea74c8bbc44cd96722c7ddcaa4.zip
Update WireGuard FFI in talpid-core
-rw-r--r--talpid-core/src/tunnel/wireguard/logging.rs86
-rw-r--r--talpid-core/src/tunnel/wireguard/mod.rs2
-rw-r--r--talpid-core/src/tunnel/wireguard/wireguard_go.rs168
3 files changed, 136 insertions, 120 deletions
diff --git a/talpid-core/src/tunnel/wireguard/logging.rs b/talpid-core/src/tunnel/wireguard/logging.rs
new file mode 100644
index 0000000000..5d177441e3
--- /dev/null
+++ b/talpid-core/src/tunnel/wireguard/logging.rs
@@ -0,0 +1,86 @@
+use super::{Error, Result};
+use chrono;
+use parking_lot::Mutex;
+use std::{collections::HashMap, fs, io::Write, path::Path};
+
+lazy_static::lazy_static! {
+ static ref LOG_MUTEX: Mutex<HashMap<u32, fs::File>> = Mutex::new(HashMap::new());
+}
+
+static mut LOG_CONTEXT_NEXT_ORDINAL: u32 = 0;
+
+pub fn initialize_logging(log_path: Option<&Path>) -> Result<u32> {
+ let log_file = create_log_file(log_path)?;
+
+ let log_context_ordinal = unsafe {
+ let mut map = LOG_MUTEX.lock();
+ let ordinal = LOG_CONTEXT_NEXT_ORDINAL;
+ LOG_CONTEXT_NEXT_ORDINAL += 1;
+ map.insert(ordinal, log_file);
+ ordinal
+ };
+
+ Ok(log_context_ordinal)
+}
+
+#[cfg(target_os = "windows")]
+static NULL_DEVICE: &str = "NUL";
+
+#[cfg(not(target_os = "windows"))]
+static NULL_DEVICE: &str = "/dev/null";
+
+fn create_log_file(log_path: Option<&Path>) -> Result<fs::File> {
+ fs::File::create(log_path.unwrap_or(NULL_DEVICE.as_ref())).map_err(Error::PrepareLogFileError)
+}
+
+pub fn clean_up_logging(ordinal: u32) {
+ let mut map = LOG_MUTEX.lock();
+ map.remove(&ordinal);
+}
+
+// Callback that receives messages from WireGuard
+pub unsafe extern "system" fn logging_callback(
+ level: WgLogLevel,
+ msg: *const libc::c_char,
+ context: *mut libc::c_void,
+) {
+ let map = LOG_MUTEX.lock();
+ if let Some(mut logfile) = map.get(&(context as u32)) {
+ let managed_msg = if !msg.is_null() {
+ #[cfg(not(target_os = "windows"))]
+ let m = std::ffi::CStr::from_ptr(msg).to_string_lossy().to_string();
+ #[cfg(target_os = "windows")]
+ let m = std::ffi::CStr::from_ptr(msg)
+ .to_string_lossy()
+ .to_string()
+ .replace("\n", "\r\n");
+ m
+ } else {
+ "Logging message from WireGuard is NULL".to_string()
+ };
+
+ let level_str = match level {
+ WG_GO_LOG_DEBUG => "DEBUG",
+ WG_GO_LOG_INFO => "INFO",
+ WG_GO_LOG_ERROR | _ => "ERROR",
+ };
+
+ let _ = write!(
+ logfile,
+ "{}[{}][{}] {}",
+ chrono::Local::now().format("[%Y-%m-%d %H:%M:%S%.3f]"),
+ "wireguard-go",
+ level_str,
+ managed_msg
+ );
+ }
+}
+
+// unsafe fn
+
+pub type WgLogLevel = u32;
+// wireguard-go supports log levels 0 through 3 with 3 being the most verbose
+// const WG_GO_LOG_SILENT: WgLogLevel = 0;
+const WG_GO_LOG_ERROR: WgLogLevel = 1;
+const WG_GO_LOG_INFO: WgLogLevel = 2;
+const WG_GO_LOG_DEBUG: WgLogLevel = 3;
diff --git a/talpid-core/src/tunnel/wireguard/mod.rs b/talpid-core/src/tunnel/wireguard/mod.rs
index 6308b1c514..efb44ebbc6 100644
--- a/talpid-core/src/tunnel/wireguard/mod.rs
+++ b/talpid-core/src/tunnel/wireguard/mod.rs
@@ -15,6 +15,7 @@ use talpid_types::ErrorExt;
pub mod config;
mod connectivity_check;
+mod logging;
mod stats;
pub mod wireguard_go;
@@ -134,7 +135,6 @@ impl WireguardMonitor {
route_handle
.add_default_route_callback(Some(WgGoTunnel::default_route_changed_callback), ());
-
let event_callback = Box::new(on_event.clone());
let (close_msg_sender, close_msg_receiver) = mpsc::channel();
let (pinger_tx, pinger_rx) = mpsc::channel();
diff --git a/talpid-core/src/tunnel/wireguard/wireguard_go.rs b/talpid-core/src/tunnel/wireguard/wireguard_go.rs
index f0c79b595d..6adfa934d8 100644
--- a/talpid-core/src/tunnel/wireguard/wireguard_go.rs
+++ b/talpid-core/src/tunnel/wireguard/wireguard_go.rs
@@ -1,5 +1,8 @@
use super::{stats::Stats, Config, Error, Result, Tunnel};
-use crate::tunnel::tun_provider::TunProvider;
+use crate::tunnel::{
+ tun_provider::TunProvider,
+ wireguard::logging::{clean_up_logging, initialize_logging, logging_callback, WgLogLevel},
+};
use ipnetwork::IpNetwork;
use std::{
ffi::{c_void, CStr, CString},
@@ -22,26 +25,11 @@ use {
};
#[cfg(target_os = "windows")]
-use {
- crate::winnet::{self, add_device_ip_addresses},
- chrono,
- parking_lot::Mutex,
- std::{collections::HashMap, fs, io::Write},
-};
-
-
-#[cfg(target_os = "windows")]
-lazy_static::lazy_static! {
- static ref LOG_MUTEX: Mutex<HashMap<u32, fs::File>> = Mutex::new(HashMap::new());
-}
-
-#[cfg(target_os = "windows")]
-static mut LOG_CONTEXT_NEXT_ORDINAL: u32 = 0;
+use crate::winnet::{self, add_device_ip_addresses};
#[cfg(not(target_os = "windows"))]
const MAX_PREPARE_TUN_ATTEMPTS: usize = 4;
-
pub struct WgGoTunnel {
interface_name: String,
handle: Option<i32>,
@@ -49,9 +37,8 @@ pub struct WgGoTunnel {
// live long enough and get closed when the tunnel is stopped
#[cfg(not(target_os = "windows"))]
_tunnel_device: Tun,
- // ordinal that maps to fs::File instance, used with logging callback
- #[cfg(target_os = "windows")]
- log_context_ordinal: u32,
+ // context that maps to fs::File instance, used with logging callback
+ logging_context: u32,
}
impl WgGoTunnel {
@@ -65,29 +52,32 @@ impl WgGoTunnel {
#[cfg_attr(not(target_os = "android"), allow(unused_mut))]
let (mut tunnel_device, tunnel_fd) = Self::get_tunnel(tun_provider, config, routes)?;
let interface_name: String = tunnel_device.interface_name().to_string();
-
let wg_config_str = config.to_userspace_format();
- let iface_name =
- CString::new(interface_name.as_bytes()).map_err(Error::InterfaceNameError)?;
-
- let log_path = log_path.and_then(|path| CString::new(path.to_string_lossy().as_ref()).ok());
- let log_path_ptr = log_path
- .as_ref()
- .map(|path| path.as_ptr())
- .unwrap_or_else(|| ptr::null());
+ let logging_context = initialize_logging(log_path)?;
+ #[cfg(not(target_os = "android"))]
let handle = unsafe {
- wgTurnOnWithFd(
- iface_name.as_ptr() as *const i8,
+ wgTurnOn(
config.mtu as isize,
wg_config_str.as_ptr() as *const i8,
tunnel_fd,
- log_path_ptr as *const i8,
- WG_GO_LOG_DEBUG,
+ Some(logging_callback),
+ logging_context as *mut libc::c_void,
+ )
+ };
+
+ #[cfg(target_os = "android")]
+ let handle = unsafe {
+ wgTurnOn(
+ wg_config_str.as_ptr() as *const i8,
+ tunnel_fd,
+ Some(logging_callback),
+ logging_context as *mut libc::c_void,
)
};
if handle < 0 {
+ clean_up_logging(logging_context);
// Error values returned from the wireguard-go library
return match handle {
-1 => Err(Error::FatalStartWireguardError),
@@ -97,12 +87,16 @@ impl WgGoTunnel {
}
#[cfg(target_os = "android")]
- Self::bypass_tunnel_sockets(&mut tunnel_device, handle).map_err(Error::BypassError)?;
+ Self::bypass_tunnel_sockets(&mut tunnel_device, handle).or_else(|_| {
+ clean_up_logging(logging_context);
+ Err(Error::BypassError)
+ })?;
Ok(WgGoTunnel {
interface_name,
handle: Some(handle),
_tunnel_device: tunnel_device,
+ logging_context,
})
}
@@ -113,46 +107,37 @@ impl WgGoTunnel {
_tun_provider: &mut TunProvider,
_routes: impl Iterator<Item = IpNetwork>,
) -> Result<Self> {
- let log_file = prepare_log_file(log_path)?;
-
- let log_context_ordinal = unsafe {
- let mut map = LOG_MUTEX.lock();
- let ordinal = LOG_CONTEXT_NEXT_ORDINAL;
- LOG_CONTEXT_NEXT_ORDINAL += 1;
- map.insert(ordinal, log_file);
- ordinal
- };
-
let wg_config_str = config.to_userspace_format();
let iface_name: String = "wg-mullvad".to_string();
let cstr_iface_name =
CString::new(iface_name.as_bytes()).map_err(Error::InterfaceNameError)?;
+ let logging_context = initialize_logging(log_path)?;
let handle = unsafe {
wgTurnOn(
cstr_iface_name.as_ptr(),
config.mtu as i64,
wg_config_str.as_ptr(),
- Some(Self::logging_callback),
- log_context_ordinal as *mut libc::c_void,
+ Some(logging_callback),
+ logging_context as *mut libc::c_void,
)
};
if handle < 0 {
- clean_up_log_file(log_context_ordinal);
+ clean_up_logging(logging_context);
return Err(Error::FatalStartWireguardError);
}
if !add_device_ip_addresses(&iface_name, &config.tunnel.addresses) {
// Todo: what kind of clean-up is required?
- clean_up_log_file(log_context_ordinal);
+ clean_up_logging(logging_context);
return Err(Error::SetIpAddressesError);
}
Ok(WgGoTunnel {
interface_name: iface_name.clone(),
handle: Some(handle),
- log_context_ordinal,
+ logging_context,
})
}
@@ -190,41 +175,6 @@ impl WgGoTunnel {
wgRebindTunnelSocket(address_family.to_windows_proto_enum(), iface_idx);
}
- // Callback that receives messages from WireGuard
- #[cfg(target_os = "windows")]
- pub unsafe extern "system" fn logging_callback(
- level: WgLogLevel,
- msg: *const libc::c_char,
- context: *mut libc::c_void,
- ) {
- let map = LOG_MUTEX.lock();
- if let Some(mut logfile) = map.get(&(context as u32)) {
- let managed_msg = if !msg.is_null() {
- std::ffi::CStr::from_ptr(msg)
- .to_string_lossy()
- .to_string()
- .replace("\n", "\r\n")
- } else {
- "Logging message from WireGuard is NULL".to_string()
- };
-
- let level_str = match level {
- WG_GO_LOG_DEBUG => "DEBUG",
- WG_GO_LOG_INFO => "INFO",
- WG_GO_LOG_ERROR | _ => "ERROR",
- };
-
- let _ = write!(
- logfile,
- "{}[{}][{}] {}",
- chrono::Local::now().format("[%Y-%m-%d %H:%M:%S%.3f]"),
- "wireguard-go",
- level_str,
- managed_msg
- );
- }
- }
-
#[cfg(not(target_os = "windows"))]
fn create_tunnel_config(config: &Config, routes: impl Iterator<Item = IpNetwork>) -> TunConfig {
let mut dns_servers = vec![IpAddr::V4(config.ipv4_gateway)];
@@ -305,30 +255,15 @@ impl WgGoTunnel {
}
}
-#[cfg(target_os = "windows")]
-fn clean_up_log_file(ordinal: u32) {
- let mut map = LOG_MUTEX.lock();
- map.remove(&ordinal);
-}
-
impl Drop for WgGoTunnel {
fn drop(&mut self) {
if let Err(e) = self.stop_tunnel() {
log::error!("Failed to stop tunnel - {}", e);
}
- #[cfg(target_os = "windows")]
- clean_up_log_file(self.log_context_ordinal);
+ clean_up_logging(self.logging_context);
}
}
-#[cfg(target_os = "windows")]
-static NULL_DEVICE: &str = "NUL";
-
-#[cfg(target_os = "windows")]
-fn prepare_log_file(log_path: Option<&Path>) -> Result<fs::File> {
- fs::File::create(log_path.unwrap_or(NULL_DEVICE.as_ref())).map_err(Error::PrepareLogFileError)
-}
-
impl Tunnel for WgGoTunnel {
fn get_interface_name(&self) -> &str {
&self.interface_name
@@ -370,19 +305,6 @@ impl Tunnel for WgGoTunnel {
#[cfg(unix)]
pub type Fd = std::os::unix::io::RawFd;
-#[cfg(windows)]
-pub type Fd = std::os::windows::io::RawHandle;
-
-type WgLogLevel = u32;
-// wireguard-go supports log levels 0 through 3 with 3 being the most verbose
-// const WG_GO_LOG_SILENT: WgLogLevel = 0;
-#[cfg(target_os = "windows")]
-const WG_GO_LOG_ERROR: WgLogLevel = 1;
-#[cfg(target_os = "windows")]
-const WG_GO_LOG_INFO: WgLogLevel = 2;
-const WG_GO_LOG_DEBUG: WgLogLevel = 3;
-
-#[cfg(target_os = "windows")]
pub type LoggingCallback = unsafe extern "system" fn(
level: WgLogLevel,
msg: *const libc::c_char,
@@ -395,15 +317,23 @@ extern "C" {
//
// Positive return values are tunnel handles for this specific wireguard tunnel instance.
// Negative return values signify errors. All error codes are opaque.
- #[cfg_attr(target_os = "android", link_name = "wgTurnOnWithFdAndroid")]
+ #[cfg(not(target_os = "android"))]
#[cfg(not(target_os = "windows"))]
- fn wgTurnOnWithFd(
- iface_name: *const i8,
+ fn wgTurnOn(
mtu: isize,
settings: *const i8,
fd: Fd,
- log_path: *const i8,
- logLevel: WgLogLevel,
+ logging_callback: Option<LoggingCallback>,
+ logging_context: *mut libc::c_void,
+ ) -> i32;
+
+ // Android
+ #[cfg(target_os = "android")]
+ fn wgTurnOn(
+ settings: *const i8,
+ fd: Fd,
+ logging_callback: Option<LoggingCallback>,
+ logging_context: *mut libc::c_void,
) -> i32;
// Windows
@@ -416,7 +346,7 @@ extern "C" {
logging_context: *mut libc::c_void,
) -> i32;
- // Pass a handle that was created by wgTurnOnWithFd to stop a wireguard tunnel.
+ // Pass a handle that was created by wgTurnOn to stop a wireguard tunnel.
fn wgTurnOff(handle: i32) -> i32;
// Returns the file descriptor of the tunnel IPv4 socket.