summaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorMarkus Pettersson <markus.pettersson@mullvad.net>2024-11-22 17:54:43 +0100
committerMarkus Pettersson <markus.pettersson@mullvad.net>2024-11-22 17:54:43 +0100
commitb1737f5543ed8896c45f28a4e37c022991f22adf (patch)
tree6ab95a5148c2c1aff752e3ebcc2801085f064b34
parentb2d3287552a6530901e7e954daa5bb446307f672 (diff)
parent3516c3c5f987a47b922670aa6d6f34c8c864af8a (diff)
downloadmullvadvpn-b1737f5543ed8896c45f28a4e37c022991f22adf.tar.xz
mullvadvpn-b1737f5543ed8896c45f28a4e37c022991f22adf.zip
Merge branch 'implement-wgturnonmultihop-for-android-droid-1365'
-rw-r--r--mullvad-relay-selector/src/relay_selector/mod.rs6
-rw-r--r--talpid-types/src/net/mod.rs11
-rw-r--r--talpid-wireguard/src/config.rs91
-rw-r--r--talpid-wireguard/src/connectivity/check.rs (renamed from talpid-wireguard/src/connectivity_check.rs)593
-rw-r--r--talpid-wireguard/src/connectivity/constants.rs22
-rw-r--r--talpid-wireguard/src/connectivity/error.rs14
-rw-r--r--talpid-wireguard/src/connectivity/mock.rs133
-rw-r--r--talpid-wireguard/src/connectivity/mod.rs13
-rw-r--r--talpid-wireguard/src/connectivity/monitor.rs174
-rw-r--r--talpid-wireguard/src/connectivity/pinger/android.rs (renamed from talpid-wireguard/src/ping_monitor/android.rs)0
-rw-r--r--talpid-wireguard/src/connectivity/pinger/icmp.rs (renamed from talpid-wireguard/src/ping_monitor/icmp.rs)0
-rw-r--r--talpid-wireguard/src/connectivity/pinger/mod.rs (renamed from talpid-wireguard/src/ping_monitor/mod.rs)0
-rw-r--r--talpid-wireguard/src/ephemeral.rs72
-rw-r--r--talpid-wireguard/src/lib.rs213
-rw-r--r--talpid-wireguard/src/wireguard_go/mod.rs425
-rw-r--r--wireguard-go-rs/libwg/go.mod5
-rw-r--r--wireguard-go-rs/libwg/go.sum6
-rw-r--r--wireguard-go-rs/libwg/libwg.go3
-rw-r--r--wireguard-go-rs/libwg/libwg_android.go170
-rw-r--r--wireguard-go-rs/libwg/libwg_daita.go10
-rw-r--r--wireguard-go-rs/libwg/tunnelcontainer/tunnelcontainer.go1
-rw-r--r--wireguard-go-rs/src/lib.rs52
22 files changed, 1378 insertions, 636 deletions
diff --git a/mullvad-relay-selector/src/relay_selector/mod.rs b/mullvad-relay-selector/src/relay_selector/mod.rs
index 550d1955a0..86f0b300e5 100644
--- a/mullvad-relay-selector/src/relay_selector/mod.rs
+++ b/mullvad-relay-selector/src/relay_selector/mod.rs
@@ -722,12 +722,6 @@ impl RelaySelector {
custom_lists: &CustomListsSettings,
parsed_relays: &ParsedRelays,
) -> Result<WireguardConfig, Error> {
- // TODO: Remove when Android gets support for multihop.
- if cfg!(target_os = "android") {
- let relay = Self::get_wireguard_singlehop_config(query, custom_lists, parsed_relays)
- .ok_or(Error::NoRelay)?;
- return Ok(WireguardConfig::from(relay));
- }
let inner = if query.singlehop() {
match Self::get_wireguard_singlehop_config(query, custom_lists, parsed_relays) {
Some(exit) => WireguardConfig::from(exit),
diff --git a/talpid-types/src/net/mod.rs b/talpid-types/src/net/mod.rs
index 1ec8ba46c5..e53b3fa54a 100644
--- a/talpid-types/src/net/mod.rs
+++ b/talpid-types/src/net/mod.rs
@@ -435,15 +435,26 @@ impl AllowedClients {
}
}
+/// What [`Endpoint`]s to allow the client to send traffic to and receive from.
+///
+/// In some cases we want to restrict what IP addresses the client may communicate with even
+/// inside of the tunnel, for example while negotiating a PQ-safe PSK with an ephemeral peer.
#[derive(Debug, Clone, Eq, PartialEq, Hash)]
pub enum AllowedTunnelTraffic {
+ /// Block all traffic inside the tunnel.
None,
+ /// Allow all traffic inside the tunnel. This is the normal mode of operation.
All,
+ /// Only allow communication with this specific endpoint. This will usually be a relay during a
+ /// short amount of time.
One(Endpoint),
+ /// Only allow communication with these two specific endpoints. The intended use case for this
+ /// is while negotiating for example a PSK with both the entry & exit relays in a multihop setup.
Two(Endpoint, Endpoint),
}
impl AllowedTunnelTraffic {
+ /// Do we currently allow traffic to all endpoints?
pub fn all(&self) -> bool {
matches!(self, AllowedTunnelTraffic::All)
}
diff --git a/talpid-wireguard/src/config.rs b/talpid-wireguard/src/config.rs
index 5326427d13..0469273545 100644
--- a/talpid-wireguard/src/config.rs
+++ b/talpid-wireguard/src/config.rs
@@ -3,6 +3,7 @@ use std::{
ffi::CString,
net::{Ipv4Addr, Ipv6Addr},
};
+use talpid_types::net::wireguard::{PeerConfig, PrivateKey};
use talpid_types::net::{obfuscation::ObfuscatorConfig, wireguard, GenericTunnelOptions};
/// Name to use for the tunnel device
@@ -121,38 +122,12 @@ impl Config {
/// Returns a CString with the appropriate config for WireGuard-go
// TODO: Consider outputting both overriding and additive configs
pub fn to_userspace_format(&self) -> CString {
- // the order of insertion matters, public key entry denotes a new peer entry
- let mut wg_conf = WgConfigBuffer::new();
- wg_conf
- .add::<&[u8]>("private_key", self.tunnel.private_key.to_bytes().as_ref())
- .add("listen_port", "0");
-
- #[cfg(target_os = "linux")]
- if let Some(fwmark) = &self.fwmark {
- wg_conf.add("fwmark", fwmark.to_string().as_str());
- }
-
- wg_conf.add("replace_peers", "true");
-
- for peer in self.peers() {
- wg_conf
- .add::<&[u8]>("public_key", peer.public_key.as_bytes().as_ref())
- .add("endpoint", peer.endpoint.to_string().as_str())
- .add("replace_allowed_ips", "true");
- if let Some(ref psk) = peer.psk {
- wg_conf.add::<&[u8]>("preshared_key", psk.as_bytes().as_ref());
- }
- for addr in &peer.allowed_ips {
- wg_conf.add("allowed_ip", addr.to_string().as_str());
- }
- #[cfg(daita)]
- if peer.constant_packet_size {
- wg_conf.add("constant_packet_size", "true");
- }
- }
-
- let bytes = wg_conf.into_config();
- CString::new(bytes).expect("null bytes inside config")
+ userspace_format(
+ &self.tunnel.private_key,
+ self.peers(),
+ #[cfg(target_os = "linux")]
+ self.fwmark,
+ )
}
/// Return whether the config connects to an exit peer from another remote peer.
@@ -185,6 +160,13 @@ impl Config {
.into_iter()
.chain(std::iter::once(&mut self.entry_peer))
}
+
+ /// Return routes for all allowed IPs.
+ pub fn get_tunnel_destinations(&self) -> impl Iterator<Item = ipnetwork::IpNetwork> + '_ {
+ self.peers()
+ .flat_map(|peer| peer.allowed_ips.iter())
+ .cloned()
+ }
}
enum ConfValue<'a> {
@@ -235,3 +217,48 @@ impl WgConfigBuffer {
self.buf
}
}
+
+/// Returns a CString with the appropriate config for WireGuard-go
+#[allow(single_use_lifetimes)]
+pub fn userspace_format<'a>(
+ private_key: &PrivateKey,
+ peers: impl Iterator<Item = &'a PeerConfig>,
+ #[cfg(target_os = "linux")] fwmark: Option<u32>,
+) -> CString {
+ // the order of insertion matters, public key entry denotes a new peer entry
+ let mut wg_conf = WgConfigBuffer::new();
+ wg_conf
+ .add::<&[u8]>("private_key", private_key.to_bytes().as_ref())
+ .add("listen_port", "0");
+
+ #[cfg(target_os = "linux")]
+ if let Some(fwmark) = fwmark {
+ wg_conf.add("fwmark", fwmark.to_string().as_str());
+ }
+
+ wg_conf.add("replace_peers", "true");
+
+ for peer in peers {
+ write_peer_to_config(&mut wg_conf, peer)
+ }
+
+ let bytes = wg_conf.into_config();
+ CString::new(bytes).expect("null bytes inside config")
+}
+
+fn write_peer_to_config(wg_conf: &mut WgConfigBuffer, peer: &PeerConfig) {
+ wg_conf
+ .add::<&[u8]>("public_key", peer.public_key.as_bytes().as_ref())
+ .add("endpoint", peer.endpoint.to_string().as_str())
+ .add("replace_allowed_ips", "true");
+ if let Some(ref psk) = peer.psk {
+ wg_conf.add::<&[u8]>("preshared_key", psk.as_bytes().as_ref());
+ }
+ for addr in &peer.allowed_ips {
+ wg_conf.add("allowed_ip", addr.to_string().as_str());
+ }
+ #[cfg(daita)]
+ if peer.constant_packet_size {
+ wg_conf.add("constant_packet_size", "true");
+ }
+}
diff --git a/talpid-wireguard/src/connectivity_check.rs b/talpid-wireguard/src/connectivity/check.rs
index 608002d1a6..527931563b 100644
--- a/talpid-wireguard/src/connectivity_check.rs
+++ b/talpid-wireguard/src/connectivity/check.rs
@@ -1,52 +1,17 @@
-use crate::{
- ping_monitor::{new_pinger, Pinger},
- stats::StatsMap,
-};
-use std::{
- cmp,
- net::Ipv4Addr,
- sync::{mpsc, Weak},
- time::{Duration, Instant},
-};
-use tokio::sync::Mutex;
+use std::cmp;
+use std::net::Ipv4Addr;
+use std::sync::mpsc;
+use std::time::{Duration, Instant};
-use super::{Tunnel, TunnelError};
+use super::constants::*;
+use super::error::Error;
+use super::pinger;
-/// Sleep time used when initially establishing connectivity
-const DELAY_ON_INITIAL_SETUP: Duration = Duration::from_millis(50);
-/// Sleep time used when checking if an established connection is still working.
-const REGULAR_LOOP_SLEEP: Duration = Duration::from_secs(1);
-
-/// Timeout for waiting on receiving traffic after sending outgoing traffic. Once this timeout is
-/// hit, a ping will be sent every `SECONDS_PER_PING` until `PING_TIMEOUT` is reached, or traffic
-/// is received.
-const BYTES_RX_TIMEOUT: Duration = Duration::from_secs(5);
-/// Timeout for waiting on receiving or sending any traffic. Once this timeout is hit, a ping will
-/// be sent every `SECONDS_PER_PING` until `PING_TIMEOUT` is reached or traffic is received.
-const TRAFFIC_TIMEOUT: Duration = Duration::from_secs(120);
-/// Timeout for waiting on receiving traffic after sending the first ICMP packet. Once this
-/// timeout is reached, it is assumed that the connection is lost.
-const PING_TIMEOUT: Duration = Duration::from_secs(15);
-/// Timeout for receiving traffic when establishing a connection.
-const ESTABLISH_TIMEOUT: Duration = Duration::from_secs(4);
-/// `ESTABLISH_TIMEOUT` is multiplied by this after each failed connection attempt.
-const ESTABLISH_TIMEOUT_MULTIPLIER: u32 = 2;
-/// Maximum timeout for establishing a connection.
-const MAX_ESTABLISH_TIMEOUT: Duration = PING_TIMEOUT;
-/// Number of seconds to wait between sending ICMP packets
-const SECONDS_PER_PING: Duration = Duration::from_secs(3);
-
-/// Connectivity monitor errors
-#[derive(thiserror::Error, Debug)]
-pub enum Error {
- /// Failed to read tunnel's configuration
- #[error("Failed to read tunnel's configuration")]
- ConfigReadError(TunnelError),
-
- /// Failed to send ping
- #[error("Ping monitor failed")]
- PingError(#[from] crate::ping_monitor::Error),
-}
+use crate::stats::StatsMap;
+#[cfg(target_os = "android")]
+use crate::Tunnel;
+use crate::{TunnelError, TunnelType};
+use pinger::Pinger;
/// Verifies if a connection to a tunnel is working.
/// The connectivity monitor is biased to receiving traffic - it is expected that all outgoing
@@ -70,60 +35,126 @@ pub enum Error {
///
/// Once a connection established, a connection is only considered broken once the connectivity
/// monitor has started pinging and no traffic has been received for a duration of `PING_TIMEOUT`.
-pub struct ConnectivityMonitor {
- tunnel_handle: Weak<Mutex<Option<Box<dyn Tunnel>>>>,
+pub struct Check<Strategy = Timeout> {
conn_state: ConnState,
- initial_ping_timestamp: Option<Instant>,
- num_pings_sent: u32,
- pinger: Box<dyn Pinger>,
+ ping_state: PingState,
+ strategy: Strategy,
+ retry_attempt: u32,
+}
+
+// Define the type state of [Check]
+pub(crate) trait Strategy {
+ fn should_shut_down(&mut self, timeout: Duration) -> bool;
+}
+
+/// An uncancellable [Check] that will run [Check::establish_connectivity] until
+/// completion or until it times out.
+pub struct Timeout;
+
+impl Strategy for Timeout {
+ /// The Timeout strategy cannot receive shut down signals so this function always returns false.
+ fn should_shut_down(&mut self, _timeout: Duration) -> bool {
+ false
+ }
+}
+
+/// A cancellable [Check] may be cancelled before it will time out by sending
+/// a signal on the channel returned by [Check::with_cancellation]. Otherwise,
+/// it behaves as [Timeout].
+pub struct Cancellable {
close_receiver: mpsc::Receiver<()>,
}
-impl ConnectivityMonitor {
- pub(super) fn new(
+impl Strategy for Cancellable {
+ /// Returns true if monitor should be shut down
+ fn should_shut_down(&mut self, timeout: Duration) -> bool {
+ match self.close_receiver.recv_timeout(timeout) {
+ Ok(()) | Err(mpsc::RecvTimeoutError::Disconnected) => true,
+ Err(mpsc::RecvTimeoutError::Timeout) => false,
+ }
+ }
+}
+
+impl Check<Timeout> {
+ pub fn new(
addr: Ipv4Addr,
#[cfg(any(target_os = "macos", target_os = "linux"))] interface: String,
- tunnel_handle: Weak<Mutex<Option<Box<dyn Tunnel>>>>,
- close_receiver: mpsc::Receiver<()>,
- ) -> Result<Self, Error> {
- let pinger = new_pinger(
- addr,
- #[cfg(any(target_os = "macos", target_os = "linux"))]
- interface,
- )
- .map_err(Error::PingError)?;
+ retry_attempt: u32,
+ ) -> Result<Check<Timeout>, Error> {
+ Ok(Check {
+ conn_state: ConnState::new(Instant::now(), Default::default()),
+ ping_state: PingState::new(
+ addr,
+ #[cfg(any(target_os = "macos", target_os = "linux"))]
+ interface,
+ )?,
+ strategy: Timeout,
+ retry_attempt,
+ })
+ }
- let now = Instant::now();
+ /// Cancel a [Check] preemptively by sennding a message on the channel or by dropping
+ /// the returned channel.
+ pub fn with_cancellation(self) -> (Check<Cancellable>, mpsc::Sender<()>) {
+ let (cancellation_tx, cancellation_rx) = mpsc::channel();
+ let check = Check {
+ conn_state: self.conn_state,
+ ping_state: self.ping_state,
+ strategy: Cancellable {
+ close_receiver: cancellation_rx,
+ },
+ retry_attempt: self.retry_attempt,
+ };
+ (check, cancellation_tx)
+ }
- Ok(Self {
- tunnel_handle,
- conn_state: ConnState::new(now, Default::default()),
- initial_ping_timestamp: None,
- num_pings_sent: 0,
- pinger,
- close_receiver,
- })
+ #[cfg(test)]
+ /// Create a new [Check] with a custom initial state. To use the [Cancellable] strategy,
+ /// see [Check::with_cancellation].
+ pub(super) fn mock(conn_state: ConnState, ping_state: PingState) -> Self {
+ Check {
+ conn_state,
+ ping_state,
+ strategy: Timeout,
+ retry_attempt: 0,
+ }
}
+}
+impl<S: Strategy> Check<S> {
// checks if the tunnel has ever worked. Intended to check if a connection to a tunnel is
// successful at the start of a connection.
- pub(super) fn establish_connectivity(&mut self, retry_attempt: u32) -> Result<bool, Error> {
+ pub fn establish_connectivity(&mut self, tunnel_handle: &TunnelType) -> Result<bool, Error> {
// Send initial ping to prod WireGuard into connecting.
- self.pinger.send_icmp().map_err(Error::PingError)?;
+ self.ping_state
+ .pinger
+ .send_icmp()
+ .map_err(Error::PingError)?;
self.establish_connectivity_inner(
- retry_attempt,
+ self.retry_attempt,
ESTABLISH_TIMEOUT,
ESTABLISH_TIMEOUT_MULTIPLIER,
MAX_ESTABLISH_TIMEOUT,
+ tunnel_handle,
)
}
+ pub(crate) fn reset(&mut self, current_iteration: Instant) {
+ self.ping_state.reset();
+ self.conn_state.reset_after_suspension(current_iteration);
+ }
+
+ pub(crate) fn should_shut_down(&mut self, timeout: Duration) -> bool {
+ self.strategy.should_shut_down(timeout)
+ }
+
fn establish_connectivity_inner(
&mut self,
retry_attempt: u32,
timeout_initial: Duration,
timeout_multiplier: u32,
max_timeout: Duration,
+ tunnel_handle: &TunnelType,
) -> Result<bool, Error> {
if self.conn_state.connected() {
return Ok(true);
@@ -136,7 +167,7 @@ impl ConnectivityMonitor {
let start = Instant::now();
while start.elapsed() < check_timeout {
- if self.check_connectivity_interval(Instant::now(), check_timeout)? {
+ if self.check_connectivity_interval(Instant::now(), check_timeout, tunnel_handle)? {
return Ok(true);
}
if self.should_shut_down(DELAY_ON_INITIAL_SETUP) {
@@ -146,46 +177,13 @@ impl ConnectivityMonitor {
Ok(false)
}
- pub(super) fn run(&mut self) -> Result<(), Error> {
- self.wait_loop(REGULAR_LOOP_SLEEP)
- }
-
- /// Returns true if monitor should be shut down
- fn should_shut_down(&mut self, timeout: Duration) -> bool {
- match self.close_receiver.recv_timeout(timeout) {
- Ok(()) | Err(mpsc::RecvTimeoutError::Disconnected) => true,
- Err(mpsc::RecvTimeoutError::Timeout) => false,
- }
- }
-
- fn wait_loop(&mut self, iter_delay: Duration) -> Result<(), Error> {
- let mut last_iteration = Instant::now();
- while !self.should_shut_down(iter_delay) {
- let mut current_iteration = Instant::now();
- let time_slept = current_iteration - last_iteration;
- if time_slept < (iter_delay * 2) {
- if !self.check_connectivity(Instant::now())? {
- return Ok(());
- }
-
- let end = Instant::now();
- if end - current_iteration > Duration::from_secs(1) {
- current_iteration = end;
- }
- } else {
- // Loop was suspended for too long, so it's safer to assume that the host still has
- // connectivity.
- self.reset_pinger();
- self.conn_state.reset_after_suspension(current_iteration);
- }
- last_iteration = current_iteration;
- }
- Ok(())
- }
-
/// Returns true if connection is established
- fn check_connectivity(&mut self, now: Instant) -> Result<bool, Error> {
- self.check_connectivity_interval(now, PING_TIMEOUT)
+ pub(crate) fn check_connectivity(
+ &mut self,
+ now: Instant,
+ tunnel_handle: &TunnelType,
+ ) -> Result<bool, Error> {
+ self.check_connectivity_interval(now, PING_TIMEOUT, tunnel_handle)
}
/// Returns true if connection is established
@@ -193,19 +191,18 @@ impl ConnectivityMonitor {
&mut self,
now: Instant,
timeout: Duration,
+ tunnel_handle: &TunnelType,
) -> Result<bool, Error> {
- match self.get_stats() {
+ match Self::get_stats(tunnel_handle).map_err(Error::ConfigReadError)? {
None => Ok(false),
Some(new_stats) => {
- let new_stats = new_stats?;
-
if self.conn_state.update(now, new_stats) {
- self.reset_pinger();
+ self.ping_state.reset();
return Ok(true);
}
self.maybe_send_ping(now)?;
- Ok(!self.ping_timed_out(timeout) && self.conn_state.connected())
+ Ok(!self.ping_state.ping_timed_out(timeout) && self.conn_state.connected())
}
}
}
@@ -214,19 +211,14 @@ impl ConnectivityMonitor {
/// calls will also return None.
///
/// NOTE: will panic if called from within a tokio runtime.
- fn get_stats(&self) -> Option<Result<StatsMap, Error>> {
- self.tunnel_handle
- .upgrade()?
- .blocking_lock()
- .as_ref()
- .and_then(|tunnel| match tunnel.get_tunnel_stats() {
- Ok(stats) if stats.is_empty() => {
- log::error!("Tunnel unexpectedly shut down");
- None
- }
- Ok(stats) => Some(Ok(stats)),
- Err(error) => Some(Err(Error::ConfigReadError(error))),
- })
+ fn get_stats(tunnel_handle: &TunnelType) -> Result<Option<StatsMap>, TunnelError> {
+ let stats = tunnel_handle.get_tunnel_stats()?;
+ if stats.is_empty() {
+ log::error!("Tunnel unexpectedly shut down");
+ Ok(None)
+ } else {
+ Ok(Some(stats))
+ }
}
fn maybe_send_ping(&mut self, now: Instant) -> Result<(), Error> {
@@ -235,20 +227,55 @@ impl ConnectivityMonitor {
// 3 seconds.
if (self.conn_state.rx_timed_out() || self.conn_state.traffic_timed_out())
&& self
+ .ping_state
.initial_ping_timestamp
.map(|initial_ping_timestamp| {
- initial_ping_timestamp.elapsed() / self.num_pings_sent < SECONDS_PER_PING
+ initial_ping_timestamp.elapsed() / self.ping_state.num_pings_sent
+ < SECONDS_PER_PING
})
.unwrap_or(true)
{
- self.pinger.send_icmp().map_err(Error::PingError)?;
- if self.initial_ping_timestamp.is_none() {
- self.initial_ping_timestamp = Some(now);
+ self.ping_state
+ .pinger
+ .send_icmp()
+ .map_err(Error::PingError)?;
+ if self.ping_state.initial_ping_timestamp.is_none() {
+ self.ping_state.initial_ping_timestamp = Some(now);
}
- self.num_pings_sent += 1;
+ self.ping_state.num_pings_sent += 1;
}
Ok(())
}
+}
+
+pub(super) struct PingState {
+ initial_ping_timestamp: Option<Instant>,
+ num_pings_sent: u32,
+ pinger: Box<dyn Pinger>,
+}
+
+impl PingState {
+ pub(super) fn new(
+ addr: Ipv4Addr,
+ #[cfg(any(target_os = "macos", target_os = "linux"))] interface: String,
+ ) -> Result<Self, Error> {
+ let pinger = pinger::new_pinger(
+ addr,
+ #[cfg(any(target_os = "macos", target_os = "linux"))]
+ interface,
+ )
+ .map_err(Error::PingError)?;
+
+ Ok(Self::new_with(pinger))
+ }
+
+ pub(super) fn new_with(pinger: Box<dyn Pinger>) -> Self {
+ Self {
+ initial_ping_timestamp: None,
+ num_pings_sent: 0,
+ pinger,
+ }
+ }
fn ping_timed_out(&self, timeout: Duration) -> bool {
self.initial_ping_timestamp
@@ -257,14 +284,14 @@ impl ConnectivityMonitor {
}
/// Reset timeouts - assume that the last time bytes were received is now.
- fn reset_pinger(&mut self) {
+ fn reset(&mut self) {
self.initial_ping_timestamp = None;
self.num_pings_sent = 0;
self.pinger.reset();
}
}
-enum ConnState {
+pub(super) enum ConnState {
Connecting {
start: Instant,
stats: StatsMap,
@@ -397,21 +424,8 @@ impl ConnState {
#[cfg(test)]
mod test {
- use futures::Future;
-
use super::*;
- use crate::{
- config::Config,
- stats::{self, Stats},
- Tunnel,
- };
- use std::{
- pin::Pin,
- sync::{
- atomic::{AtomicBool, Ordering},
- Arc,
- },
- };
+ use crate::connectivity::mock::*;
/// Test if a newly created ConnState won't have timed out or consider itself connected
#[test]
@@ -517,300 +531,76 @@ mod test {
assert!(!conn_state.traffic_timed_out());
}
- #[derive(Default)]
- struct MockPinger {
- on_send_ping: Option<Box<dyn FnMut() + Send>>,
- }
-
- impl Pinger for MockPinger {
- fn send_icmp(&mut self) -> Result<(), crate::ping_monitor::Error> {
- if let Some(callback) = self.on_send_ping.as_mut() {
- (callback)();
- }
- Ok(())
- }
- }
-
- struct MockTunnel {
- on_get_stats: Box<dyn Fn() -> Result<stats::StatsMap, TunnelError> + Send>,
- }
-
- impl MockTunnel {
- const PEER: [u8; 32] = [0u8; 32];
-
- fn new<F: Fn() -> Result<stats::StatsMap, TunnelError> + Send + 'static>(f: F) -> Self {
- Self {
- on_get_stats: Box::new(f),
- }
- }
-
- fn always_incrementing() -> Self {
- let mut map = stats::StatsMap::new();
- map.insert(
- Self::PEER,
- stats::Stats {
- tx_bytes: 0,
- rx_bytes: 0,
- },
- );
- let peers = std::sync::Mutex::new(map);
- Self {
- on_get_stats: Box::new(move || {
- let mut peers = peers.lock().unwrap();
- for traffic in peers.values_mut() {
- traffic.tx_bytes += 1;
- traffic.rx_bytes += 1;
- }
- Ok(peers.clone())
- }),
- }
- }
-
- fn never_incrementing() -> Self {
- Self {
- on_get_stats: Box::new(|| {
- let mut map = stats::StatsMap::new();
- map.insert(
- Self::PEER,
- stats::Stats {
- tx_bytes: 0,
- rx_bytes: 0,
- },
- );
- Ok(map)
- }),
- }
- }
-
- #[allow(clippy::type_complexity)]
- fn into_locked(
- self,
- ) -> (
- Arc<Mutex<Option<Box<dyn Tunnel>>>>,
- Weak<Mutex<Option<Box<dyn Tunnel>>>>,
- ) {
- let dyn_tunnel: Box<dyn Tunnel> = Box::new(self);
- let arc = Arc::new(Mutex::new(Some(dyn_tunnel)));
- let weak_ref = Arc::downgrade(&arc);
- (arc, weak_ref)
- }
- }
-
- impl Tunnel for MockTunnel {
- fn get_interface_name(&self) -> String {
- "mock-tunnel".to_string()
- }
-
- fn stop(self: Box<Self>) -> Result<(), TunnelError> {
- Ok(())
- }
-
- fn get_tunnel_stats(&self) -> Result<stats::StatsMap, TunnelError> {
- (self.on_get_stats)()
- }
-
- fn set_config(
- &mut self,
- _config: Config,
- ) -> Pin<Box<dyn Future<Output = std::result::Result<(), TunnelError>> + Send>> {
- Box::pin(async { Ok(()) })
- }
-
- #[cfg(daita)]
- fn start_daita(&mut self) -> std::result::Result<(), TunnelError> {
- Ok(())
- }
- }
-
- fn mock_monitor(
- now: Instant,
- pinger: Box<dyn Pinger>,
- tunnel_handle: Weak<Mutex<Option<Box<dyn Tunnel>>>>,
- close_receiver: mpsc::Receiver<()>,
- ) -> ConnectivityMonitor {
- ConnectivityMonitor {
- conn_state: ConnState::new(now, Default::default()),
- initial_ping_timestamp: None,
- num_pings_sent: 0,
- pinger,
- close_receiver,
- tunnel_handle,
- }
- }
-
- fn connected_state(timestamp: Instant) -> ConnState {
- const PEER: [u8; 32] = [0u8; 32];
- let mut stats = stats::StatsMap::new();
- stats.insert(
- PEER,
- stats::Stats {
- tx_bytes: 0,
- rx_bytes: 0,
- },
- );
- ConnState::Connected {
- rx_timestamp: timestamp,
- tx_timestamp: timestamp,
- stats,
- }
- }
-
#[test]
/// Verify that `check_connectivity()` returns `false` if the tunnel is connected and traffic is
/// not flowing after `BYTES_RX_TIMEOUT` and `PING_TIMEOUT`.
fn test_ping_times_out() {
- let (_tunnel_anchor, tunnel) = MockTunnel::never_incrementing().into_locked();
- let (_tx, rx) = mpsc::channel();
+ let tunnel = MockTunnel::never_incrementing().boxed();
let pinger = MockPinger::default();
let now = Instant::now();
let start = now
.checked_sub(BYTES_RX_TIMEOUT + PING_TIMEOUT + Duration::from_secs(10))
.unwrap();
- let mut monitor = mock_monitor(start, Box::new(pinger), tunnel, rx);
+ let mut checker = mock_checker(start, Box::new(pinger));
// Mock the state - connectivity has been established
- monitor.conn_state = connected_state(start);
+ checker.conn_state = connected_state(start);
// A ping was sent to verify connectivity
- monitor.maybe_send_ping(start).unwrap();
- assert!(!monitor.check_connectivity(now).unwrap())
+ checker.maybe_send_ping(start).unwrap();
+ assert!(!checker.check_connectivity(now, &tunnel).unwrap())
}
#[test]
/// Verify that `check_connectivity()` returns `true` if the tunnel is connected and traffic is
/// flowing constantly.
fn test_no_connection_on_start() {
- let (_tunnel_anchor, tunnel) = MockTunnel::never_incrementing().into_locked();
- let (_tx, rx) = mpsc::channel();
+ let tunnel = MockTunnel::never_incrementing().boxed();
let pinger = MockPinger::default();
let now = Instant::now();
let start = now.checked_sub(Duration::from_secs(1)).unwrap();
- let mut monitor = mock_monitor(start, Box::new(pinger), tunnel, rx);
+ let mut monitor = mock_checker(start, Box::new(pinger));
- assert!(!monitor.check_connectivity(now).unwrap())
+ assert!(!monitor.check_connectivity(now, &tunnel).unwrap())
}
#[test]
/// Verify that `check_connectivity()` returns `true` if the tunnel is connected and traffic is
/// flowing constantly.
fn test_connection_works() {
- let (_tunnel_anchor, tunnel) = MockTunnel::always_incrementing().into_locked();
- let (_tx, rx) = mpsc::channel();
+ let tunnel = MockTunnel::always_incrementing().boxed();
let pinger = MockPinger::default();
let now = Instant::now();
let start = now.checked_sub(Duration::from_secs(1)).unwrap();
- let mut monitor = mock_monitor(start, Box::new(pinger), tunnel, rx);
+ let mut monitor = mock_checker(start, Box::new(pinger));
// Mock the state - connectivity has been established
monitor.conn_state = connected_state(start);
- assert!(monitor.check_connectivity(now).unwrap())
- }
-
- #[test]
- /// Verify that the connectivity monitor doesn't fail if the tunnel constantly sends traffic,
- /// and it shuts down properly.
- fn test_wait_loop() {
- let (result_tx, result_rx) = mpsc::channel();
- let (_tunnel_anchor, tunnel) = MockTunnel::always_incrementing().into_locked();
- let pinger = MockPinger::default();
- let (stop_tx, stop_rx) = mpsc::channel();
- std::thread::spawn(move || {
- let now = Instant::now();
- let start = now.checked_sub(Duration::from_secs(1)).unwrap();
- let mut monitor = mock_monitor(start, Box::new(pinger), tunnel, stop_rx);
-
- let start_result = monitor.establish_connectivity(0);
- result_tx.send(start_result).unwrap();
-
- let result = monitor.run().map(|_| true);
- result_tx.send(result).unwrap();
- });
-
- std::thread::sleep(Duration::from_secs(1));
- assert!(result_rx.try_recv().unwrap().unwrap());
- stop_tx.send(()).unwrap();
- std::thread::sleep(Duration::from_secs(1));
- assert!(result_rx.try_recv().unwrap().is_ok());
- }
-
- #[test]
- /// Verify that the connectivity monitor detects the tunnel timing out after no longer than
- /// `BYTES_RX_TIMEOUT` and `PING_TIMEOUT` combined.
- fn test_wait_loop_timeout() {
- let should_stop = Arc::new(AtomicBool::new(false));
- let should_stop_inner = should_stop.clone();
-
- let mut map = stats::StatsMap::new();
- map.insert(
- [0u8; 32],
- stats::Stats {
- tx_bytes: 0,
- rx_bytes: 0,
- },
- );
- let tunnel_stats = std::sync::Mutex::new(map);
-
- let pinger = MockPinger::default();
- let (_tunnel_anchor, tunnel) = MockTunnel::new(move || {
- let mut tunnel_stats = tunnel_stats.lock().unwrap();
- if !should_stop_inner.load(Ordering::SeqCst) {
- for traffic in tunnel_stats.values_mut() {
- traffic.rx_bytes += 1;
- }
- }
- for traffic in tunnel_stats.values_mut() {
- traffic.tx_bytes += 1;
- }
- Ok(tunnel_stats.clone())
- })
- .into_locked();
-
- let (result_tx, result_rx) = mpsc::channel();
-
- let (_stop_tx, stop_rx) = mpsc::channel();
- std::thread::spawn(move || {
- let now = Instant::now();
- let start = now.checked_sub(Duration::from_secs(1)).unwrap();
- let mut monitor = mock_monitor(start, Box::new(pinger), tunnel, stop_rx);
- let start_result = monitor.establish_connectivity(0);
- result_tx.send(start_result).unwrap();
- let end_result = monitor.run().map(|_| true);
- result_tx.send(end_result).expect("Failed to send result");
- });
- assert!(result_rx
- .recv_timeout(Duration::from_secs(1))
- .unwrap()
- .unwrap());
- should_stop.store(true, Ordering::SeqCst);
- assert!(result_rx
- .recv_timeout(BYTES_RX_TIMEOUT + PING_TIMEOUT + Duration::from_secs(2))
- .unwrap()
- .is_ok());
+ assert!(monitor.check_connectivity(now, &tunnel).unwrap())
}
#[test]
/// Verify that the timeout for setting up a tunnel works as expected.
fn test_establish_timeout() {
- let mut tunnel_stats = stats::StatsMap::new();
- tunnel_stats.insert(
- [0u8; 32],
- stats::Stats {
- tx_bytes: 0,
- rx_bytes: 0,
- },
- );
-
let pinger = MockPinger::default();
- let (_tunnel_anchor, tunnel) =
- MockTunnel::new(move || Ok(tunnel_stats.clone())).into_locked();
+ let tunnel = {
+ let mut tunnel_stats = StatsMap::new();
+ tunnel_stats.insert(
+ [0u8; 32],
+ Stats {
+ tx_bytes: 0,
+ rx_bytes: 0,
+ },
+ );
+ MockTunnel::new(move || Ok(tunnel_stats.clone())).boxed()
+ };
let (result_tx, result_rx) = mpsc::channel();
- let (_stop_tx, stop_rx) = mpsc::channel();
std::thread::spawn(move || {
let now = Instant::now();
let start = now.checked_sub(Duration::from_secs(1)).unwrap();
- let mut monitor = mock_monitor(start, Box::new(pinger), tunnel, stop_rx);
+ let mut monitor = mock_checker(start, Box::new(pinger));
const ESTABLISH_TIMEOUT_MULTIPLIER: u32 = 2;
const ESTABLISH_TIMEOUT: Duration = Duration::from_millis(500);
@@ -823,6 +613,7 @@ mod test {
ESTABLISH_TIMEOUT,
ESTABLISH_TIMEOUT_MULTIPLIER,
MAX_ESTABLISH_TIMEOUT,
+ &tunnel,
))
.unwrap();
}
diff --git a/talpid-wireguard/src/connectivity/constants.rs b/talpid-wireguard/src/connectivity/constants.rs
new file mode 100644
index 0000000000..a8d6752ddd
--- /dev/null
+++ b/talpid-wireguard/src/connectivity/constants.rs
@@ -0,0 +1,22 @@
+use std::time::Duration;
+
+/// Sleep time used when initially establishing connectivity
+pub(crate) const DELAY_ON_INITIAL_SETUP: Duration = Duration::from_millis(50);
+/// Timeout for waiting on receiving traffic after sending outgoing traffic. Once this timeout is
+/// hit, a ping will be sent every `SECONDS_PER_PING` until `PING_TIMEOUT` is reached, or traffic
+/// is received.
+pub(crate) const BYTES_RX_TIMEOUT: Duration = Duration::from_secs(5);
+/// Timeout for waiting on receiving or sending any traffic. Once this timeout is hit, a ping will
+/// be sent every `SECONDS_PER_PING` until `PING_TIMEOUT` is reached or traffic is received.
+pub(crate) const TRAFFIC_TIMEOUT: Duration = Duration::from_secs(120);
+/// Timeout for waiting on receiving traffic after sending the first ICMP packet. Once this
+/// timeout is reached, it is assumed that the connection is lost.
+pub(crate) const PING_TIMEOUT: Duration = Duration::from_secs(15);
+/// Timeout for receiving traffic when establishing a connection.
+pub(crate) const ESTABLISH_TIMEOUT: Duration = Duration::from_secs(4);
+/// `ESTABLISH_TIMEOUT` is multiplied by this after each failed connection attempt.
+pub(crate) const ESTABLISH_TIMEOUT_MULTIPLIER: u32 = 2;
+/// Maximum timeout for establishing a connection.
+pub(crate) const MAX_ESTABLISH_TIMEOUT: Duration = PING_TIMEOUT;
+/// Number of seconds to wait between sending ICMP packets
+pub(crate) const SECONDS_PER_PING: Duration = Duration::from_secs(3);
diff --git a/talpid-wireguard/src/connectivity/error.rs b/talpid-wireguard/src/connectivity/error.rs
new file mode 100644
index 0000000000..9e8c98a751
--- /dev/null
+++ b/talpid-wireguard/src/connectivity/error.rs
@@ -0,0 +1,14 @@
+use super::pinger;
+use crate::TunnelError;
+
+/// Connectivity monitor errors
+#[derive(thiserror::Error, Debug)]
+pub enum Error {
+ /// Failed to read tunnel's configuration
+ #[error("Failed to read tunnel's configuration")]
+ ConfigReadError(TunnelError),
+
+ /// Failed to send ping
+ #[error("Ping failed")]
+ PingError(#[from] pinger::Error),
+}
diff --git a/talpid-wireguard/src/connectivity/mock.rs b/talpid-wireguard/src/connectivity/mock.rs
new file mode 100644
index 0000000000..892f3966ea
--- /dev/null
+++ b/talpid-wireguard/src/connectivity/mock.rs
@@ -0,0 +1,133 @@
+use std::future::Future;
+use std::pin::Pin;
+use std::time::Instant;
+
+use super::check::{ConnState, PingState, Timeout};
+use super::pinger;
+use super::Check;
+
+use crate::{Config, Tunnel, TunnelError};
+use pinger::Pinger;
+
+// Convenient re-exports
+pub use crate::stats::{Stats, StatsMap};
+
+#[derive(Default)]
+pub(crate) struct MockPinger {
+ on_send_ping: Option<Box<dyn FnMut() + Send>>,
+}
+
+pub(crate) struct MockTunnel {
+ on_get_stats: Box<dyn Fn() -> Result<StatsMap, TunnelError> + Send>,
+}
+
+pub fn mock_checker(now: Instant, pinger: Box<dyn Pinger>) -> Check<Timeout> {
+ let conn_state = ConnState::new(now, Default::default());
+ let ping_state = PingState::new_with(pinger);
+ Check::mock(conn_state, ping_state)
+}
+
+pub fn connected_state(timestamp: Instant) -> ConnState {
+ const PEER: [u8; 32] = [0u8; 32];
+ let mut stats = StatsMap::new();
+ stats.insert(
+ PEER,
+ Stats {
+ tx_bytes: 0,
+ rx_bytes: 0,
+ },
+ );
+ ConnState::Connected {
+ rx_timestamp: timestamp,
+ tx_timestamp: timestamp,
+ stats,
+ }
+}
+
+impl MockTunnel {
+ const PEER: [u8; 32] = [0u8; 32];
+
+ pub fn new<F: Fn() -> Result<StatsMap, TunnelError> + Send + 'static>(f: F) -> Self {
+ Self {
+ on_get_stats: Box::new(f),
+ }
+ }
+
+ /// Convert self to the more general [TunnelType].
+ pub fn boxed(self) -> Box<dyn Tunnel> {
+ Box::new(self)
+ }
+
+ pub fn always_incrementing() -> Self {
+ let mut map = StatsMap::new();
+ map.insert(
+ Self::PEER,
+ Stats {
+ tx_bytes: 0,
+ rx_bytes: 0,
+ },
+ );
+ let peers = std::sync::Mutex::new(map);
+ Self {
+ on_get_stats: Box::new(move || {
+ let mut peers = peers.lock().unwrap();
+ for traffic in peers.values_mut() {
+ traffic.tx_bytes += 1;
+ traffic.rx_bytes += 1;
+ }
+ Ok(peers.clone())
+ }),
+ }
+ }
+
+ pub fn never_incrementing() -> Self {
+ Self {
+ on_get_stats: Box::new(|| {
+ let mut map = StatsMap::new();
+ map.insert(
+ Self::PEER,
+ Stats {
+ tx_bytes: 0,
+ rx_bytes: 0,
+ },
+ );
+ Ok(map)
+ }),
+ }
+ }
+}
+
+impl Tunnel for MockTunnel {
+ fn get_interface_name(&self) -> String {
+ "mock-tunnel".to_string()
+ }
+
+ fn stop(self: Box<Self>) -> Result<(), TunnelError> {
+ Ok(())
+ }
+
+ fn get_tunnel_stats(&self) -> Result<StatsMap, TunnelError> {
+ (self.on_get_stats)()
+ }
+
+ fn set_config(
+ &mut self,
+ _config: Config,
+ ) -> Pin<Box<dyn Future<Output = std::result::Result<(), TunnelError>> + Send>> {
+ Box::pin(async { Ok(()) })
+ }
+
+ #[cfg(daita)]
+ fn start_daita(&mut self) -> std::result::Result<(), TunnelError> {
+ Ok(())
+ }
+}
+
+impl Pinger for MockPinger {
+ fn send_icmp(&mut self) -> Result<(), pinger::Error> {
+ if let Some(callback) = self.on_send_ping.as_mut() {
+ (callback)();
+ }
+ Ok(())
+ }
+}
diff --git a/talpid-wireguard/src/connectivity/mod.rs b/talpid-wireguard/src/connectivity/mod.rs
new file mode 100644
index 0000000000..512d8715f1
--- /dev/null
+++ b/talpid-wireguard/src/connectivity/mod.rs
@@ -0,0 +1,13 @@
+mod check;
+mod constants;
+mod error;
+#[cfg(test)]
+mod mock;
+mod monitor;
+mod pinger;
+
+#[cfg(target_os = "android")]
+pub use check::Cancellable;
+pub use check::Check;
+pub use error::Error;
+pub use monitor::Monitor;
diff --git a/talpid-wireguard/src/connectivity/monitor.rs b/talpid-wireguard/src/connectivity/monitor.rs
new file mode 100644
index 0000000000..583b8d9589
--- /dev/null
+++ b/talpid-wireguard/src/connectivity/monitor.rs
@@ -0,0 +1,174 @@
+use std::{
+ sync::Weak,
+ time::{Duration, Instant},
+};
+
+use tokio::sync::Mutex;
+
+use crate::TunnelType;
+
+use super::check::{Cancellable, Check};
+use super::error::Error;
+
+/// Sleep time used when checking if an established connection is still working.
+const REGULAR_LOOP_SLEEP: Duration = Duration::from_secs(1);
+
+pub struct Monitor {
+ connectivity_check: Check<Cancellable>,
+}
+
+impl Monitor {
+ pub fn init(connectivity_check: Check<Cancellable>) -> Self {
+ Self { connectivity_check }
+ }
+
+ pub fn run(self, tunnel_handle: Weak<Mutex<Option<TunnelType>>>) -> Result<(), Error> {
+ self.wait_loop(REGULAR_LOOP_SLEEP, tunnel_handle)
+ }
+
+ fn wait_loop(
+ mut self,
+ iter_delay: Duration,
+ tunnel_handle: Weak<Mutex<Option<TunnelType>>>,
+ ) -> Result<(), Error> {
+ let mut last_iteration = Instant::now();
+ while !self.connectivity_check.should_shut_down(iter_delay) {
+ let mut current_iteration = Instant::now();
+ let time_slept = current_iteration - last_iteration;
+ if time_slept < (iter_delay * 2) {
+ let Some(tunnel) = tunnel_handle.upgrade() else {
+ return Ok(());
+ };
+ let lock = tunnel.blocking_lock();
+ let Some(tunnel) = lock.as_ref() else {
+ return Ok(());
+ };
+
+ if !self
+ .connectivity_check
+ .check_connectivity(Instant::now(), tunnel)?
+ {
+ return Ok(());
+ }
+ drop(lock);
+
+ let end = Instant::now();
+ if end - current_iteration > Duration::from_secs(1) {
+ current_iteration = end;
+ }
+ } else {
+ // Loop was suspended for too long, so it's safer to assume that the host still has
+ // connectivity.
+ self.connectivity_check.reset(current_iteration);
+ }
+ last_iteration = current_iteration;
+ }
+ Ok(())
+ }
+}
+
+#[cfg(test)]
+mod test {
+ use super::*;
+
+ // TODO: Port to async + tokio to reduce cost of testing?
+ use std::sync::atomic::{AtomicBool, Ordering};
+ use std::sync::mpsc;
+ use std::sync::Arc;
+ use std::time::Duration;
+ use std::time::Instant;
+
+ use tokio::sync::Mutex;
+
+ use crate::connectivity::constants::*;
+ use crate::connectivity::mock::*;
+
+ #[test]
+ /// Verify that the connectivity monitor doesn't fail if the tunnel constantly sends traffic,
+ /// and it shuts down properly.
+ fn test_wait_loop() {
+ use std::sync::mpsc;
+ let (result_tx, result_rx) = mpsc::channel();
+ let tunnel = MockTunnel::always_incrementing().boxed();
+ let pinger = MockPinger::default();
+ let (mut checker, stop_tx) = {
+ let now = Instant::now();
+ let start = now.checked_sub(Duration::from_secs(1)).unwrap();
+ mock_checker(start, Box::new(pinger)).with_cancellation()
+ };
+ std::thread::spawn(move || {
+ let start_result = checker.establish_connectivity(&tunnel);
+ result_tx.send(start_result).unwrap();
+ // Pointer dance
+ let tunnel = Arc::new(Mutex::new(Some(tunnel)));
+ let _tunnel = Arc::downgrade(&tunnel);
+ let result = Monitor::init(checker).run(_tunnel).map(|_| true);
+ result_tx.send(result).unwrap();
+ });
+
+ std::thread::sleep(Duration::from_secs(1));
+ assert!(result_rx.try_recv().unwrap().unwrap());
+ stop_tx.send(()).unwrap();
+ std::thread::sleep(Duration::from_secs(1));
+ assert!(result_rx.try_recv().unwrap().is_ok());
+ }
+
+ #[test]
+ /// Verify that the connectivity monitor detects the tunnel timing out after no longer than
+ /// `BYTES_RX_TIMEOUT` and `PING_TIMEOUT` combined.
+ fn test_wait_loop_timeout() {
+ let should_stop = Arc::new(AtomicBool::new(false));
+ let should_stop_inner = should_stop.clone();
+
+ let mut map = StatsMap::new();
+ map.insert(
+ [0u8; 32],
+ Stats {
+ tx_bytes: 0,
+ rx_bytes: 0,
+ },
+ );
+ let tunnel_stats = std::sync::Mutex::new(map);
+
+ let pinger = MockPinger::default();
+ let tunnel = MockTunnel::new(move || {
+ let mut tunnel_stats = tunnel_stats.lock().unwrap();
+ if !should_stop_inner.load(Ordering::SeqCst) {
+ for traffic in tunnel_stats.values_mut() {
+ traffic.rx_bytes += 1;
+ }
+ }
+ for traffic in tunnel_stats.values_mut() {
+ traffic.tx_bytes += 1;
+ }
+ Ok(tunnel_stats.clone())
+ })
+ .boxed();
+
+ let (result_tx, result_rx) = mpsc::channel();
+
+ std::thread::spawn(move || {
+ let (mut checker, _cancellation_token) = {
+ let now = Instant::now();
+ let start = now.checked_sub(Duration::from_secs(1)).unwrap();
+ mock_checker(start, Box::new(pinger)).with_cancellation()
+ };
+ let start_result = checker.establish_connectivity(&tunnel);
+ result_tx.send(start_result).unwrap();
+ // Pointer dance
+ let _tunnel = Arc::new(Mutex::new(Some(tunnel)));
+ let tunnel = Arc::downgrade(&_tunnel);
+ let end_result = Monitor::init(checker).run(tunnel).map(|_| true);
+ result_tx.send(end_result).expect("Failed to send result");
+ });
+ assert!(result_rx
+ .recv_timeout(Duration::from_secs(1))
+ .unwrap()
+ .unwrap());
+ should_stop.store(true, Ordering::SeqCst);
+ assert!(result_rx
+ .recv_timeout(BYTES_RX_TIMEOUT + PING_TIMEOUT + Duration::from_secs(2))
+ .unwrap()
+ .is_ok());
+ }
+}
diff --git a/talpid-wireguard/src/ping_monitor/android.rs b/talpid-wireguard/src/connectivity/pinger/android.rs
index 00ad4d8fd3..00ad4d8fd3 100644
--- a/talpid-wireguard/src/ping_monitor/android.rs
+++ b/talpid-wireguard/src/connectivity/pinger/android.rs
diff --git a/talpid-wireguard/src/ping_monitor/icmp.rs b/talpid-wireguard/src/connectivity/pinger/icmp.rs
index 0e5d739425..0e5d739425 100644
--- a/talpid-wireguard/src/ping_monitor/icmp.rs
+++ b/talpid-wireguard/src/connectivity/pinger/icmp.rs
diff --git a/talpid-wireguard/src/ping_monitor/mod.rs b/talpid-wireguard/src/connectivity/pinger/mod.rs
index ef2394f1b7..ef2394f1b7 100644
--- a/talpid-wireguard/src/ping_monitor/mod.rs
+++ b/talpid-wireguard/src/connectivity/pinger/mod.rs
diff --git a/talpid-wireguard/src/ephemeral.rs b/talpid-wireguard/src/ephemeral.rs
index 5440a142f6..a9283fcb2e 100644
--- a/talpid-wireguard/src/ephemeral.rs
+++ b/talpid-wireguard/src/ephemeral.rs
@@ -1,7 +1,10 @@
//! This module takes care of obtaining ephemeral peers, updating the WireGuard configuration and
//! restarting obfuscation and WG tunnels when necessary.
-use super::{config::Config, obfuscation::ObfuscatorHandle, CloseMsg, Error, Tunnel};
+#[cfg(target_os = "android")] // On Android, the Tunnel trait is not imported by default.
+use super::Tunnel;
+use super::{config::Config, obfuscation::ObfuscatorHandle, CloseMsg, Error, TunnelType};
+
#[cfg(target_os = "android")]
use std::sync::Mutex;
use std::{
@@ -22,7 +25,7 @@ const PSK_EXCHANGE_TIMEOUT_MULTIPLIER: u32 = 2;
#[cfg(windows)]
pub async fn config_ephemeral_peers(
- tunnel: &Arc<AsyncMutex<Option<Box<dyn Tunnel>>>>,
+ tunnel: &Arc<AsyncMutex<Option<TunnelType>>>,
config: &mut Config,
retry_attempt: u32,
obfuscator: Arc<AsyncMutex<Option<ObfuscatorHandle>>>,
@@ -66,13 +69,13 @@ fn try_set_ipv4_mtu(alias: &str, mtu: u16) {
#[cfg(not(windows))]
pub async fn config_ephemeral_peers(
- tunnel: &Arc<AsyncMutex<Option<Box<dyn Tunnel>>>>,
+ tunnel: &Arc<AsyncMutex<Option<TunnelType>>>,
config: &mut Config,
retry_attempt: u32,
obfuscator: Arc<AsyncMutex<Option<ObfuscatorHandle>>>,
close_obfs_sender: sync_mpsc::Sender<CloseMsg>,
#[cfg(target_os = "android")] tun_provider: Arc<Mutex<TunProvider>>,
-) -> std::result::Result<(), CloseMsg> {
+) -> Result<(), CloseMsg> {
config_ephemeral_peers_inner(
tunnel,
config,
@@ -86,13 +89,13 @@ pub async fn config_ephemeral_peers(
}
async fn config_ephemeral_peers_inner(
- tunnel: &Arc<AsyncMutex<Option<Box<dyn Tunnel>>>>,
+ tunnel: &Arc<AsyncMutex<Option<TunnelType>>>,
config: &mut Config,
retry_attempt: u32,
obfuscator: Arc<AsyncMutex<Option<ObfuscatorHandle>>>,
close_obfs_sender: sync_mpsc::Sender<CloseMsg>,
#[cfg(target_os = "android")] tun_provider: Arc<Mutex<TunProvider>>,
-) -> std::result::Result<(), CloseMsg> {
+) -> Result<(), CloseMsg> {
let ephemeral_private_key = PrivateKey::new_from_random();
let close_obfs_sender = close_obfs_sender.clone();
@@ -111,6 +114,7 @@ async fn config_ephemeral_peers_inner(
if config.is_multihop() {
// Set up tunnel to lead to entry
let mut entry_tun_config = config.clone();
+ entry_tun_config.exit_peer = None;
entry_tun_config
.entry_peer
.allowed_ips
@@ -126,6 +130,7 @@ async fn config_ephemeral_peers_inner(
&tun_provider,
)
.await?;
+
let entry_psk = request_ephemeral_peer(
retry_attempt,
&entry_config,
@@ -173,15 +178,16 @@ async fn config_ephemeral_peers_inner(
Ok(())
}
+#[cfg(target_os = "android")]
/// Reconfigures the tunnel to use the provided config while potentially modifying the config
/// and restarting the obfuscation provider. Returns the new config used by the new tunnel.
async fn reconfigure_tunnel(
- tunnel: &Arc<AsyncMutex<Option<Box<dyn Tunnel>>>>,
+ tunnel: &Arc<AsyncMutex<Option<TunnelType>>>,
mut config: Config,
obfuscator: Arc<AsyncMutex<Option<ObfuscatorHandle>>>,
close_obfs_sender: sync_mpsc::Sender<CloseMsg>,
- #[cfg(target_os = "android")] tun_provider: &Arc<Mutex<TunProvider>>,
-) -> std::result::Result<Config, CloseMsg> {
+ tun_provider: &Arc<Mutex<TunProvider>>,
+) -> Result<Config, CloseMsg> {
let mut obfs_guard = obfuscator.lock().await;
if let Some(obfuscator_handle) = obfs_guard.take() {
obfuscator_handle.abort();
@@ -194,17 +200,49 @@ async fn reconfigure_tunnel(
.await
.map_err(CloseMsg::ObfuscatorFailed)?;
}
+ {
+ let mut shared_tunnel = tunnel.lock().await;
+ let tunnel = shared_tunnel.take().expect("tunnel was None");
- let mut tunnel = tunnel.lock().await;
-
- let set_config_future = tunnel
- .as_mut()
- .map(|tunnel| tunnel.set_config(config.clone()));
-
- if let Some(f) = set_config_future {
- f.await
+ let updated_tunnel = tunnel
+ .set_config(&config)
.map_err(Error::TunnelError)
.map_err(CloseMsg::SetupError)?;
+
+ *shared_tunnel = Some(updated_tunnel);
+ }
+ Ok(config)
+}
+
+#[cfg(not(target_os = "android"))]
+/// Reconfigures the tunnel to use the provided config while potentially modifying the config
+/// and restarting the obfuscation provider. Returns the new config used by the new tunnel.
+async fn reconfigure_tunnel(
+ tunnel: &Arc<AsyncMutex<Option<TunnelType>>>,
+ mut config: Config,
+ obfuscator: Arc<AsyncMutex<Option<ObfuscatorHandle>>>,
+ close_obfs_sender: sync_mpsc::Sender<CloseMsg>,
+) -> Result<Config, CloseMsg> {
+ let mut obfs_guard = obfuscator.lock().await;
+ if let Some(obfuscator_handle) = obfs_guard.take() {
+ obfuscator_handle.abort();
+ *obfs_guard = super::obfuscation::apply_obfuscation_config(&mut config, close_obfs_sender)
+ .await
+ .map_err(CloseMsg::ObfuscatorFailed)?;
+ }
+
+ {
+ let mut tunnel = tunnel.lock().await;
+
+ let set_config_future = tunnel
+ .as_mut()
+ .map(|tunnel| tunnel.set_config(config.clone()));
+
+ if let Some(f) = set_config_future {
+ f.await
+ .map_err(Error::TunnelError)
+ .map_err(CloseMsg::SetupError)?;
+ }
}
Ok(config)
diff --git a/talpid-wireguard/src/lib.rs b/talpid-wireguard/src/lib.rs
index d1e09ff570..7c93b39f18 100644
--- a/talpid-wireguard/src/lib.rs
+++ b/talpid-wireguard/src/lib.rs
@@ -35,11 +35,10 @@ use tokio::sync::Mutex as AsyncMutex;
/// WireGuard config data-types
pub mod config;
-mod connectivity_check;
+mod connectivity;
mod ephemeral;
mod logging;
mod obfuscation;
-mod ping_monitor;
mod stats;
#[cfg(wireguard_go)]
mod wireguard_go;
@@ -54,6 +53,12 @@ mod mtu_detection;
#[cfg(wireguard_go)]
use self::wireguard_go::WgGoTunnel;
+// On android we only have Wireguard Go tunnel
+#[cfg(not(target_os = "android"))]
+type TunnelType = Box<dyn Tunnel>;
+#[cfg(target_os = "android")]
+type TunnelType = WgGoTunnel;
+
type Result<T> = std::result::Result<T, Error>;
type EventCallback = Box<dyn (Fn(TunnelEvent) -> BoxFuture<'static, ()>) + Send + Sync + 'static>;
@@ -82,7 +87,7 @@ pub enum Error {
/// Failed to set up connectivity monitor
#[error("Connectivity monitor failed")]
- ConnectivityMonitorError(#[source] connectivity_check::Error),
+ ConnectivityMonitorError(#[source] connectivity::Error),
/// Failed while negotiating ephemeral peer
#[error("Failed while negotiating ephemeral peer")]
@@ -134,7 +139,7 @@ impl Error {
pub struct WireguardMonitor {
runtime: tokio::runtime::Handle,
/// Tunnel implementation
- tunnel: Arc<AsyncMutex<Option<Box<dyn Tunnel>>>>,
+ tunnel: Arc<AsyncMutex<Option<TunnelType>>>,
/// Callback to signal tunnel events
event_callback: EventCallback,
close_msg_receiver: sync_mpsc::Receiver<CloseMsg>,
@@ -210,8 +215,17 @@ impl WireguardMonitor {
let obfuscator = Arc::new(AsyncMutex::new(obfuscator));
+ let gateway = config.ipv4_gateway;
+ let (mut connectivity_monitor, pinger_tx) = connectivity::Check::new(
+ gateway,
+ #[cfg(any(target_os = "macos", target_os = "linux"))]
+ iface_name.clone(),
+ args.retry_attempt,
+ )
+ .map_err(Error::ConnectivityMonitorError)?
+ .with_cancellation();
+
let event_callback = Box::new(on_event.clone());
- let (pinger_tx, pinger_rx) = sync_mpsc::channel();
let monitor = WireguardMonitor {
runtime: args.runtime.clone(),
tunnel: Arc::new(AsyncMutex::new(Some(tunnel))),
@@ -221,16 +235,6 @@ impl WireguardMonitor {
obfuscator,
};
- let gateway = config.ipv4_gateway;
- let mut connectivity_monitor = connectivity_check::ConnectivityMonitor::new(
- gateway,
- #[cfg(any(target_os = "macos", target_os = "linux"))]
- iface_name.clone(),
- Arc::downgrade(&monitor.tunnel),
- pinger_rx,
- )
- .map_err(Error::ConnectivityMonitorError)?;
-
let moved_tunnel = monitor.tunnel.clone();
let moved_close_obfs_sender = close_obfs_sender.clone();
let moved_obfuscator = monitor.obfuscator.clone();
@@ -315,8 +319,12 @@ impl WireguardMonitor {
});
}
- let mut connectivity_monitor = tokio::task::spawn_blocking(move || {
- match connectivity_monitor.establish_connectivity(args.retry_attempt) {
+ let cloned_tunnel = Arc::clone(&tunnel);
+
+ let connectivity_check = tokio::task::spawn_blocking(move || {
+ let lock = cloned_tunnel.blocking_lock();
+ let tunnel = lock.as_ref().expect("The tunnel was dropped unexpectedly");
+ match connectivity_monitor.establish_connectivity(tunnel) {
Ok(true) => Ok(connectivity_monitor),
Ok(false) => {
log::warn!("Timeout while checking tunnel connection");
@@ -344,8 +352,11 @@ impl WireguardMonitor {
let metadata = Self::tunnel_metadata(&iface_name, &config);
(on_event)(TunnelEvent::Up(metadata)).await;
+ let monitored_tunnel = Arc::downgrade(&tunnel);
tokio::task::spawn_blocking(move || {
- if let Err(error) = connectivity_monitor.run() {
+ if let Err(error) =
+ connectivity::Monitor::init(connectivity_check).run(monitored_tunnel)
+ {
log::error!(
"{}",
error.display_chain_with_msg("Connectivity monitor failed")
@@ -396,8 +407,8 @@ impl WireguardMonitor {
args: TunnelArgs<'_, F>,
) -> Result<WireguardMonitor> {
let desired_mtu = get_desired_mtu(params);
- let mut config = crate::config::Config::from_parameters(params, desired_mtu)
- .map_err(Error::WireguardConfigError)?;
+ let mut config =
+ Config::from_parameters(params, desired_mtu).map_err(Error::WireguardConfigError)?;
let (close_obfs_sender, close_obfs_listener) = sync_mpsc::channel();
// Start obfuscation server and patch the WireGuard config to point the endpoint to it.
@@ -417,8 +428,13 @@ impl WireguardMonitor {
}
let should_negotiate_ephemeral_peer = config.quantum_resistant || config.daita;
- let tunnel = Self::open_tunnel(
- args.runtime.clone(),
+
+ let (connectivity_check, pinger_tx) =
+ connectivity::Check::new(config.ipv4_gateway, args.retry_attempt)
+ .map_err(Error::ConnectivityMonitorError)?
+ .with_cancellation();
+
+ let tunnel = Self::open_wireguard_go_tunnel(
&config,
log_path,
args.resource_dir,
@@ -427,77 +443,34 @@ impl WireguardMonitor {
// that we only allows traffic to/from the gateway. This is only needed on Android
// since we lack a firewall there.
should_negotiate_ephemeral_peer,
+ connectivity_check,
)?;
let iface_name = tunnel.get_interface_name();
-
- let (pinger_tx, pinger_rx) = sync_mpsc::channel();
+ let tunnel = Arc::new(AsyncMutex::new(Some(tunnel)));
let monitor = WireguardMonitor {
runtime: args.runtime.clone(),
- tunnel: Arc::new(AsyncMutex::new(Some(tunnel))),
+ tunnel: Arc::clone(&tunnel),
event_callback: Box::new(args.on_event.clone()),
close_msg_receiver: close_obfs_listener,
pinger_stop_sender: pinger_tx,
obfuscator: Arc::new(AsyncMutex::new(obfuscator)),
};
- let gateway = config.ipv4_gateway;
- let connectivity_monitor = connectivity_check::ConnectivityMonitor::new(
- gateway,
- Arc::downgrade(&monitor.tunnel),
- pinger_rx,
- )
- .map_err(Error::ConnectivityMonitorError)?;
-
- let moved_tunnel = monitor.tunnel.clone();
let moved_close_obfs_sender = close_obfs_sender.clone();
let moved_obfuscator = monitor.obfuscator.clone();
let tunnel_fut = async move {
- let tunnel = moved_tunnel;
let close_obfs_sender: sync_mpsc::Sender<CloseMsg> = moved_close_obfs_sender;
let obfuscator = moved_obfuscator;
- let connectivity_monitor = Arc::new(Mutex::new(connectivity_monitor));
let metadata = Self::tunnel_metadata(&iface_name, &config);
let allowed_traffic = Self::allowed_traffic_during_tunnel_config(&config);
- (args.on_event.clone())(TunnelEvent::InterfaceUp(metadata.clone(), allowed_traffic))
+ args.on_event.clone()(TunnelEvent::InterfaceUp(metadata.clone(), allowed_traffic))
.await;
- let handle_ping = |ping_result: std::result::Result<
- bool,
- connectivity_check::Error,
- >| match ping_result {
- Ok(true) => Ok(()),
- Ok(false) => {
- log::warn!("Timeout while checking tunnel connection");
- Err(CloseMsg::PingErr)
- }
- Err(error) => {
- log::error!(
- "{}",
- error.display_chain_with_msg("Failed to check tunnel connection")
- );
- Err(CloseMsg::PingErr)
- }
- };
-
- // Prepare a closure which pings inside the tunnel when executed.
- let ping = || {
- let connectivity_monitor_arc = connectivity_monitor.clone();
- let retry_attempt = args.retry_attempt;
- move || {
- let ping_result = connectivity_monitor_arc
- .lock()
- .unwrap()
- .establish_connectivity(retry_attempt);
- handle_ping(ping_result)
- }
- };
-
if should_negotiate_ephemeral_peer {
- // Ping before negotiating the ephemeral peer to make sure that the tunnel works.
- tokio::task::spawn_blocking(ping()).await.unwrap()?;
let ephemeral_obfs_sender = close_obfs_sender.clone();
+
ephemeral::config_ephemeral_peers(
&tunnel,
&mut config,
@@ -509,21 +482,31 @@ impl WireguardMonitor {
.await?;
let metadata = Self::tunnel_metadata(&iface_name, &config);
- (args.on_event.clone())(TunnelEvent::InterfaceUp(
+ args.on_event.clone()(TunnelEvent::InterfaceUp(
metadata,
Self::allowed_traffic_after_tunnel_config(),
))
.await;
}
- // Make sure the tunnel works (after potentially having negotiated an ephemeral peer).
- tokio::task::spawn_blocking(ping()).await.unwrap()?;
-
let metadata = Self::tunnel_metadata(&iface_name, &config);
- (args.on_event.clone())(TunnelEvent::Up(metadata)).await;
+ args.on_event.clone()(TunnelEvent::Up(metadata)).await;
+
+ // HACK: The tunnel does not need the connectivity::Check anymore, so lets take it
+ let connectivity_check = {
+ let mut tunnel_lock = tunnel.lock().await;
+ let Some(tunnel) = tunnel_lock.as_mut() else {
+ log::debug!("Tunnel is no longer running");
+ return Err::<Infallible, CloseMsg>(CloseMsg::PingErr);
+ };
+ tunnel
+ .take_checker()
+ .expect("connectivity checker unexpectedly dropped")
+ };
tokio::task::spawn_blocking(move || {
- if let Err(error) = connectivity_monitor.lock().unwrap().run() {
+ let tunnel = Arc::downgrade(&tunnel);
+ if let Err(error) = connectivity::Monitor::init(connectivity_check).run(tunnel) {
log::error!(
"{}",
error.display_chain_with_msg("Connectivity monitor failed")
@@ -585,6 +568,7 @@ impl WireguardMonitor {
/// Replace `0.0.0.0/0`/`::/0` with the gateway IPs when `gateway_only` is true.
/// Used to block traffic to other destinations while connecting on Android.
+ ///
#[cfg(target_os = "android")]
fn patch_allowed_ips(config: &Config, gateway_only: bool) -> Cow<'_, Config> {
if gateway_only {
@@ -654,16 +638,16 @@ impl WireguardMonitor {
}
#[allow(unused_variables)]
+ #[cfg(not(target_os = "android"))]
fn open_tunnel(
runtime: tokio::runtime::Handle,
config: &Config,
log_path: Option<&Path>,
resource_dir: &Path,
tun_provider: Arc<Mutex<TunProvider>>,
- #[cfg(target_os = "android")] gateway_only: bool,
#[cfg(windows)] route_manager: talpid_routing::RouteManagerHandle,
#[cfg(windows)] setup_done_tx: mpsc::Sender<std::result::Result<(), BoxedError>>,
- ) -> Result<Box<dyn Tunnel>> {
+ ) -> Result<TunnelType> {
log::debug!("Tunnel MTU: {}", config.mtu);
#[cfg(target_os = "linux")]
@@ -743,12 +727,15 @@ impl WireguardMonitor {
#[cfg(daita)] resource_dir: &Path,
tun_provider: Arc<Mutex<TunProvider>>,
#[cfg(target_os = "android")] gateway_only: bool,
+ #[cfg(target_os = "android")] connectivity_check: connectivity::Check<
+ connectivity::Cancellable,
+ >,
) -> Result<WgGoTunnel> {
- let routes = Self::get_tunnel_destinations(config).flat_map(Self::replace_default_prefixes);
-
- #[cfg(target_os = "android")]
- let config = Self::patch_allowed_ips(config, gateway_only);
+ let routes = config
+ .get_tunnel_destinations()
+ .flat_map(Self::replace_default_prefixes);
+ #[cfg(not(target_os = "android"))]
let tunnel = WgGoTunnel::start_tunnel(
#[allow(clippy::needless_borrow)]
&config,
@@ -760,6 +747,44 @@ impl WireguardMonitor {
)
.map_err(Error::TunnelError)?;
+ // Android uses multihop implemented in Mullvad's wireguard-go fork. When negotiating
+ // with an ephemeral peer, this multihop strategy require us to restart the tunnel
+ // every time we want to reconfigure it. As such, we will actually start a multihop
+ // tunnel at a later stage, after we have negotiated with the first ephemeral peer.
+ // At this point, when the tunnel *is first started*, we establish a regular, singlehop
+ // tunnel to where the ephemeral peer resides.
+ //
+ // Refer to `docs/architecture.md` for details on how to use multihop + PQ.
+ #[cfg(target_os = "android")]
+ let config = Self::patch_allowed_ips(config, gateway_only);
+
+ #[cfg(target_os = "android")]
+ let tunnel = if let Some(exit_peer) = &config.exit_peer {
+ WgGoTunnel::start_multihop_tunnel(
+ &config,
+ exit_peer,
+ log_path,
+ tun_provider,
+ routes,
+ #[cfg(daita)]
+ resource_dir,
+ connectivity_check,
+ )
+ .map_err(Error::TunnelError)?
+ } else {
+ WgGoTunnel::start_tunnel(
+ #[allow(clippy::needless_borrow)]
+ &config,
+ log_path,
+ tun_provider,
+ routes,
+ #[cfg(daita)]
+ resource_dir,
+ connectivity_check,
+ )
+ .map_err(Error::TunnelError)?
+ };
+
Ok(tunnel)
}
@@ -865,7 +890,8 @@ impl WireguardMonitor {
gateway_routes.map(|route| Self::apply_route_mtu_for_multihop(route, config));
let routes = gateway_routes.chain(
- Self::get_tunnel_destinations(config)
+ config
+ .get_tunnel_destinations()
.filter(|allowed_ip| allowed_ip.prefix() != 0)
.map(move |allowed_ip| {
if allowed_ip.is_ipv4() {
@@ -886,7 +912,8 @@ impl WireguardMonitor {
config: &'a Config,
) -> impl Iterator<Item = RequiredRoute> + 'a {
let (node_v4, node_v6) = Self::get_tunnel_nodes(iface_name, config);
- let iter = Self::get_tunnel_destinations(config)
+ let iter = config
+ .get_tunnel_destinations()
.filter(|allowed_ip| allowed_ip.prefix() == 0)
.flat_map(Self::replace_default_prefixes)
.map(move |allowed_ip| {
@@ -928,14 +955,6 @@ impl WireguardMonitor {
}
}
- /// Return routes for all allowed IPs.
- fn get_tunnel_destinations(config: &Config) -> impl Iterator<Item = ipnetwork::IpNetwork> + '_ {
- config
- .peers()
- .flat_map(|peer| peer.allowed_ips.iter())
- .cloned()
- }
-
/// Replace default (0-prefix) routes with more specific routes.
fn replace_default_prefixes(network: ipnetwork::IpNetwork) -> Vec<ipnetwork::IpNetwork> {
#[cfg(windows)]
@@ -973,6 +992,7 @@ enum CloseMsg {
ObfuscatorFailed(Error),
}
+#[allow(unused)]
pub(crate) trait Tunnel: Send {
fn get_interface_name(&self) -> String;
fn stop(self: Box<Self>) -> std::result::Result<(), TunnelError>;
@@ -1067,6 +1087,15 @@ pub enum TunnelError {
#[cfg(daita)]
#[error("Failed to start DAITA - tunnel implemenation does not support DAITA")]
DaitaNotSupported,
+
+ /// [connectivity] error.
+ #[error(transparent)]
+ Connectivity(#[from] Box<connectivity::Error>),
+
+ /// Tunnel seemingly does not serve any traffic
+ #[cfg(target_os = "android")]
+ #[error("Tunnel seemingly does not serve any traffic")]
+ TunnelUp,
}
#[cfg(target_os = "linux")]
diff --git a/talpid-wireguard/src/wireguard_go/mod.rs b/talpid-wireguard/src/wireguard_go/mod.rs
index 25ebb45a38..d283758f3b 100644
--- a/talpid-wireguard/src/wireguard_go/mod.rs
+++ b/talpid-wireguard/src/wireguard_go/mod.rs
@@ -1,3 +1,14 @@
+#[cfg(target_os = "android")]
+use super::config;
+use super::{
+ stats::{Stats, StatsMap},
+ Config, Tunnel, TunnelError,
+};
+#[cfg(target_os = "linux")]
+use crate::config::MULLVAD_INTERFACE_NAME;
+#[cfg(target_os = "android")]
+use crate::connectivity;
+use crate::logging::{clean_up_logging, initialize_logging};
use ipnetwork::IpNetwork;
#[cfg(daita)]
use once_cell::sync::OnceCell;
@@ -13,16 +24,10 @@ use std::{
#[cfg(target_os = "android")]
use talpid_tunnel::tun_provider::Error as TunProviderError;
use talpid_tunnel::tun_provider::{Tun, TunProvider};
+#[cfg(target_os = "android")]
+use talpid_types::net::wireguard::PeerConfig;
use talpid_types::BoxedError;
-use super::{
- stats::{Stats, StatsMap},
- Config, Tunnel, TunnelError,
-};
-#[cfg(target_os = "linux")]
-use crate::config::MULLVAD_INTERFACE_NAME;
-use crate::logging::{clean_up_logging, initialize_logging};
-
const MAX_PREPARE_TUN_ATTEMPTS: usize = 4;
/// Maximum number of events that can be stored in the underlying buffer
@@ -35,21 +40,129 @@ const DAITA_ACTIONS_CAPACITY: u32 = 1000;
type Result<T> = std::result::Result<T, TunnelError>;
-struct LoggingContext(u64);
+struct LoggingContext {
+ ordinal: u64,
+ #[allow(dead_code)]
+ path: Option<PathBuf>,
+}
+
+impl LoggingContext {
+ fn new(ordinal: u64, path: Option<PathBuf>) -> Self {
+ LoggingContext { ordinal, path }
+ }
+}
impl Drop for LoggingContext {
fn drop(&mut self) {
- clean_up_logging(self.0);
+ clean_up_logging(self.ordinal);
+ }
+}
+
+#[cfg(not(target_os = "android"))]
+pub struct WgGoTunnel(WgGoTunnelState);
+
+#[cfg(target_os = "android")]
+pub enum WgGoTunnel {
+ Multihop(WgGoTunnelState),
+ Singlehop(WgGoTunnelState),
+}
+
+#[cfg(not(target_os = "android"))]
+impl WgGoTunnel {
+ fn into_state(self) -> WgGoTunnelState {
+ self.0
+ }
+
+ fn as_state(&self) -> &WgGoTunnelState {
+ &self.0
+ }
+
+ fn as_state_mut(&mut self) -> &mut WgGoTunnelState {
+ &mut self.0
+ }
+}
+
+#[cfg(target_os = "android")]
+impl WgGoTunnel {
+ fn into_state(self) -> WgGoTunnelState {
+ match self {
+ WgGoTunnel::Multihop(state) => state,
+ WgGoTunnel::Singlehop(state) => state,
+ }
+ }
+
+ fn as_state(&self) -> &WgGoTunnelState {
+ match self {
+ WgGoTunnel::Multihop(state) => state,
+ WgGoTunnel::Singlehop(state) => state,
+ }
+ }
+
+ fn as_state_mut(&mut self) -> &mut WgGoTunnelState {
+ match self {
+ WgGoTunnel::Multihop(state) => state,
+ WgGoTunnel::Singlehop(state) => state,
+ }
+ }
+
+ pub fn set_config(mut self, config: &Config) -> Result<Self> {
+ let connectivity_checker = self
+ .take_checker()
+ .expect("connectivity checker unexpectedly dropped");
+ let state = self.as_state();
+ let log_path = state._logging_context.path.clone();
+ let tun_provider = Arc::clone(&state.tun_provider);
+ let routes = config.get_tunnel_destinations();
+ #[cfg(daita)]
+ let resource_dir = state.resource_dir.clone();
+
+ match self {
+ WgGoTunnel::Multihop(state) if !config.is_multihop() => {
+ state.stop()?;
+ Self::start_tunnel(
+ config,
+ log_path.as_deref(),
+ tun_provider,
+ routes,
+ &resource_dir,
+ connectivity_checker,
+ )
+ }
+ WgGoTunnel::Singlehop(state) if config.is_multihop() => {
+ state.stop()?;
+ Self::start_multihop_tunnel(
+ config,
+ &config.exit_peer.clone().unwrap().clone(),
+ log_path.as_deref(),
+ tun_provider,
+ routes,
+ &resource_dir,
+ connectivity_checker,
+ )
+ }
+ WgGoTunnel::Singlehop(mut state) => {
+ state.set_config(config.clone())?;
+ Ok(WgGoTunnel::Singlehop(state))
+ }
+ WgGoTunnel::Multihop(mut state) => {
+ state.set_config(config.clone())?;
+ Ok(WgGoTunnel::Multihop(state))
+ }
+ }
+ }
+
+ pub fn stop(self) -> Result<()> {
+ self.into_state().stop()
}
}
-pub struct WgGoTunnel {
+pub(crate) struct WgGoTunnelState {
interface_name: String,
tunnel_handle: wireguard_go_rs::Tunnel,
// holding on to the tunnel device and the log file ensures that the associated file handles
// live long enough and get closed when the tunnel is stopped
_tunnel_device: Tun,
- // context that maps to fs::File instance, used with logging callback
+ // context that maps to fs::File instance and stores the file path, used with logging callback
_logging_context: LoggingContext,
#[cfg(target_os = "android")]
tun_provider: Arc<Mutex<TunProvider>>,
@@ -57,9 +170,53 @@ pub struct WgGoTunnel {
resource_dir: PathBuf,
#[cfg(daita)]
config: Config,
+ // HACK: Check is not Clone, so we have to pass this around ..
+ // This is conceptually the connection between this Tunnel and the currently running
+ // WireguardMonitor, and it is used to allow WireguardMonitor to cancel the setup of
+ // a new Tunnel during the "ensure_connectivity" phase. This field should be removed
+ // as soon as we implement a better way to cancel Check asynchronously.
+ #[cfg(target_os = "android")]
+ connectivity_checker: Option<connectivity::Check<connectivity::Cancellable>>,
+}
+
+impl WgGoTunnelState {
+ fn stop(self) -> Result<()> {
+ self.tunnel_handle
+ .turn_off()
+ .map_err(|e| TunnelError::StopWireguardError(Box::new(e)))
+ }
+
+ fn set_config(&mut self, config: Config) -> Result<()> {
+ let wg_config_str = config.to_userspace_format();
+
+ self.tunnel_handle
+ .set_config(&wg_config_str)
+ .map_err(|_| TunnelError::SetConfigError)?;
+
+ #[cfg(target_os = "android")]
+ let tun_provider = self.tun_provider.clone();
+
+ // When reapplying the config, the endpoint socket may be discarded
+ // and needs to be excluded again
+ #[cfg(target_os = "android")]
+ {
+ let socket_v4 = self.tunnel_handle.get_socket_v4();
+ let socket_v6 = self.tunnel_handle.get_socket_v6();
+ let mut provider = tun_provider.lock().unwrap();
+ provider
+ .bypass(socket_v4)
+ .map_err(super::TunnelError::BypassError)?;
+ provider
+ .bypass(socket_v6)
+ .map_err(super::TunnelError::BypassError)?;
+ }
+
+ Ok(())
+ }
}
impl WgGoTunnel {
+ #[cfg(not(target_os = "android"))]
pub fn start_tunnel(
config: &Config,
log_path: Option<&Path>,
@@ -67,60 +224,35 @@ impl WgGoTunnel {
routes: impl Iterator<Item = IpNetwork>,
#[cfg(daita)] resource_dir: &Path,
) -> Result<Self> {
- #[cfg(target_os = "android")]
- let tun_provider_clone = tun_provider.clone();
-
- #[cfg_attr(not(target_os = "android"), allow(unused_mut))]
- let (mut tunnel_device, tunnel_fd) = Self::get_tunnel(tun_provider, config, routes)?;
+ let (tunnel_device, tunnel_fd) = Self::get_tunnel(tun_provider, config, routes)?;
let interface_name: String = tunnel_device.interface_name().to_string();
let wg_config_str = config.to_userspace_format();
let logging_context = initialize_logging(log_path)
- .map(LoggingContext)
+ .map(|ordinal| LoggingContext::new(ordinal, log_path.map(Path::to_owned)))
.map_err(TunnelError::LoggingError)?;
- #[cfg(not(target_os = "android"))]
let mtu = config.mtu as isize;
+
let handle = wireguard_go_rs::Tunnel::turn_on(
- #[cfg(not(target_os = "android"))]
mtu,
&wg_config_str,
tunnel_fd,
Some(logging::wg_go_logging_callback),
- logging_context.0,
+ logging_context.ordinal,
)
.map_err(|e| TunnelError::FatalStartWireguardError(Box::new(e)))?;
- #[cfg(target_os = "android")]
- Self::bypass_tunnel_sockets(&handle, &mut tunnel_device)
- .map_err(TunnelError::BypassError)?;
-
- Ok(WgGoTunnel {
+ Ok(WgGoTunnel(WgGoTunnelState {
interface_name,
tunnel_handle: handle,
_tunnel_device: tunnel_device,
_logging_context: logging_context,
- #[cfg(target_os = "android")]
- tun_provider: tun_provider_clone,
#[cfg(daita)]
resource_dir: resource_dir.to_owned(),
#[cfg(daita)]
config: config.clone(),
- })
- }
-
- #[cfg(target_os = "android")]
- fn bypass_tunnel_sockets(
- handle: &wireguard_go_rs::Tunnel,
- tunnel_device: &mut Tun,
- ) -> std::result::Result<(), TunProviderError> {
- let socket_v4 = handle.get_socket_v4();
- let socket_v6 = handle.get_socket_v6();
-
- tunnel_device.bypass(socket_v4)?;
- tunnel_device.bypass(socket_v6)?;
-
- Ok(())
+ }))
}
fn get_tunnel(
@@ -162,13 +294,171 @@ impl WgGoTunnel {
}
}
+#[cfg(target_os = "android")]
+impl WgGoTunnel {
+ pub fn start_tunnel(
+ config: &Config,
+ log_path: Option<&Path>,
+ tun_provider: Arc<Mutex<TunProvider>>,
+ routes: impl Iterator<Item = IpNetwork>,
+ #[cfg(daita)] resource_dir: &Path,
+ mut connectivity_check: connectivity::Check<connectivity::Cancellable>,
+ ) -> Result<Self> {
+ let (mut tunnel_device, tunnel_fd) =
+ Self::get_tunnel(Arc::clone(&tun_provider), config, routes)?;
+
+ let interface_name: String = tunnel_device.interface_name().to_string();
+ let logging_context = initialize_logging(log_path)
+ .map(|ordinal| LoggingContext::new(ordinal, log_path.map(Path::to_owned)))
+ .map_err(TunnelError::LoggingError)?;
+
+ let wg_config_str = config.to_userspace_format();
+
+ let handle = wireguard_go_rs::Tunnel::turn_on(
+ &wg_config_str,
+ tunnel_fd,
+ Some(logging::wg_go_logging_callback),
+ logging_context.ordinal,
+ )
+ .map_err(|e| TunnelError::FatalStartWireguardError(Box::new(e)))?;
+
+ Self::bypass_tunnel_sockets(&handle, &mut tunnel_device)
+ .map_err(TunnelError::BypassError)?;
+
+ let mut tunnel = WgGoTunnel::Singlehop(WgGoTunnelState {
+ interface_name,
+ tunnel_handle: handle,
+ _tunnel_device: tunnel_device,
+ _logging_context: logging_context,
+ tun_provider,
+ #[cfg(daita)]
+ resource_dir: resource_dir.to_owned(),
+ #[cfg(daita)]
+ config: config.clone(),
+ connectivity_checker: None,
+ });
+
+ // HACK: Check if the tunnel is working by sending a ping in the tunnel.
+ tunnel.ensure_tunnel_is_running(&mut connectivity_check)?;
+ tunnel.as_state_mut().connectivity_checker = Some(connectivity_check);
+
+ Ok(tunnel)
+ }
+
+ pub fn start_multihop_tunnel(
+ config: &Config,
+ exit_peer: &PeerConfig,
+ log_path: Option<&Path>,
+ tun_provider: Arc<Mutex<TunProvider>>,
+ routes: impl Iterator<Item = IpNetwork>,
+ #[cfg(daita)] resource_dir: &Path,
+ mut connectivity_check: connectivity::Check<connectivity::Cancellable>,
+ ) -> Result<Self> {
+ let (mut tunnel_device, tunnel_fd) =
+ Self::get_tunnel(Arc::clone(&tun_provider), config, routes)?;
+
+ let interface_name: String = tunnel_device.interface_name().to_string();
+ let logging_context = initialize_logging(log_path)
+ .map(|ordinal| LoggingContext::new(ordinal, log_path.map(Path::to_owned)))
+ .map_err(TunnelError::LoggingError)?;
+
+ let entry_config_str = config::userspace_format(
+ &config.tunnel.private_key,
+ std::iter::once(&config.entry_peer),
+ );
+
+ let exit_config_str =
+ config::userspace_format(&config.tunnel.private_key, std::iter::once(exit_peer));
+
+ let private_ip = config
+ .tunnel
+ .addresses
+ .iter()
+ .find(|addr| addr.is_ipv4())
+ .map(|addr| CString::new(addr.to_string()).unwrap())
+ .ok_or(TunnelError::SetConfigError)?;
+
+ let handle = wireguard_go_rs::Tunnel::turn_on_multihop(
+ &exit_config_str,
+ &entry_config_str,
+ &private_ip,
+ tunnel_fd,
+ Some(logging::wg_go_logging_callback),
+ logging_context.ordinal,
+ )
+ .map_err(|e| TunnelError::FatalStartWireguardError(Box::new(e)))?;
+
+ Self::bypass_tunnel_sockets(&handle, &mut tunnel_device)
+ .map_err(TunnelError::BypassError)?;
+
+ let mut tunnel = WgGoTunnel::Multihop(WgGoTunnelState {
+ interface_name,
+ tunnel_handle: handle,
+ _tunnel_device: tunnel_device,
+ _logging_context: logging_context,
+ tun_provider,
+ #[cfg(daita)]
+ resource_dir: resource_dir.to_owned(),
+ #[cfg(daita)]
+ config: config.clone(),
+ connectivity_checker: None,
+ });
+
+ // HACK: Check if the tunnel is working by sending a ping in the tunnel.
+ tunnel.ensure_tunnel_is_running(&mut connectivity_check)?;
+ tunnel.as_state_mut().connectivity_checker = Some(connectivity_check);
+
+ Ok(tunnel)
+ }
+
+ fn bypass_tunnel_sockets(
+ handle: &wireguard_go_rs::Tunnel,
+ tunnel_device: &mut Tun,
+ ) -> std::result::Result<(), TunProviderError> {
+ let socket_v4 = handle.get_socket_v4();
+ let socket_v6 = handle.get_socket_v6();
+
+ tunnel_device.bypass(socket_v4)?;
+ tunnel_device.bypass(socket_v6)?;
+
+ Ok(())
+ }
+
+ pub fn take_checker(&mut self) -> Option<connectivity::Check<connectivity::Cancellable>> {
+ self.as_state_mut().connectivity_checker.take()
+ }
+
+ /// There is a brief period of time between setting up a Wireguard-go tunnel and the tunnel being ready to serve
+ /// traffic. This function blocks until the tunnel starts to serve traffic or until [connectivity::Check] times out.
+ fn ensure_tunnel_is_running(
+ &self,
+ checker: &mut connectivity::Check<connectivity::Cancellable>,
+ ) -> Result<()> {
+ let connectivity_err = |e| TunnelError::Connectivity(Box::new(e));
+ let connection_established = checker
+ .establish_connectivity(self)
+ .map_err(connectivity_err)?;
+
+ // Timed out
+ if !connection_established {
+ return Err(TunnelError::TunnelUp);
+ }
+ Ok(())
+ }
+}
+
impl Tunnel for WgGoTunnel {
fn get_interface_name(&self) -> String {
- self.interface_name.clone()
+ self.as_state().interface_name.clone()
+ }
+
+ fn stop(self: Box<Self>) -> Result<()> {
+ self.into_state().stop()
}
fn get_tunnel_stats(&self) -> Result<StatsMap> {
- self.tunnel_handle
+ self.as_state()
+ .tunnel_handle
.get_config(|cstr| {
Stats::parse_config_str(cstr.to_str().expect("Go strings are always UTF-8"))
})
@@ -176,54 +466,25 @@ impl Tunnel for WgGoTunnel {
.map_err(|error| TunnelError::StatsError(BoxedError::new(error)))
}
- fn stop(self: Box<Self>) -> Result<()> {
- self.tunnel_handle
- .turn_off()
- .map_err(|e| TunnelError::StopWireguardError(Box::new(e)))
- }
-
fn set_config(
&mut self,
config: Config,
) -> Pin<Box<dyn Future<Output = Result<()>> + Send + '_>> {
- Box::pin(async move {
- let wg_config_str = config.to_userspace_format();
-
- self.tunnel_handle
- .set_config(&wg_config_str)
- .map_err(|_| TunnelError::SetConfigError)?;
-
- #[cfg(target_os = "android")]
- let tun_provider = self.tun_provider.clone();
-
- // When reapplying the config, the endpoint socket may be discarded
- // and needs to be excluded again
- #[cfg(target_os = "android")]
- {
- let socket_v4 = self.tunnel_handle.get_socket_v4();
- let socket_v6 = self.tunnel_handle.get_socket_v6();
- let mut provider = tun_provider.lock().unwrap();
- provider
- .bypass(socket_v4)
- .map_err(super::TunnelError::BypassError)?;
- provider
- .bypass(socket_v6)
- .map_err(super::TunnelError::BypassError)?;
- }
-
- Ok(())
- })
+ Box::pin(async move { self.as_state_mut().set_config(config) })
}
#[cfg(daita)]
fn start_daita(&mut self) -> Result<()> {
static MAYBENOT_MACHINES: OnceCell<CString> = OnceCell::new();
- let machines =
- MAYBENOT_MACHINES.get_or_try_init(|| load_maybenot_machines(&self.resource_dir))?;
+ let machines = MAYBENOT_MACHINES
+ .get_or_try_init(|| load_maybenot_machines(&self.as_state().resource_dir))?;
log::info!("Initializing DAITA for wireguard device");
- let peer_public_key = &self.config.entry_peer.public_key;
- self.tunnel_handle
+ let config = &self.as_state().config;
+ let peer_public_key = &config.entry_peer.public_key;
+
+ self.as_state()
+ .tunnel_handle
.activate_daita(
peer_public_key.as_bytes(),
machines,
diff --git a/wireguard-go-rs/libwg/go.mod b/wireguard-go-rs/libwg/go.mod
index 76627dcb7f..8f463a64ad 100644
--- a/wireguard-go-rs/libwg/go.mod
+++ b/wireguard-go-rs/libwg/go.mod
@@ -4,13 +4,16 @@ go 1.21
require (
golang.org/x/sys v0.19.0
- golang.zx2c4.com/wireguard v0.0.0-20230223181233-21636207a675
+ golang.zx2c4.com/wireguard v0.0.0-20230223181233-21636207a67
)
require (
+ github.com/google/btree v1.0.1 // indirect
golang.org/x/crypto v0.22.0 // indirect
golang.org/x/net v0.24.0 // indirect
+ golang.org/x/time v0.0.0-20220210224613-90d013bbcef8 // indirect
golang.zx2c4.com/wintun v0.0.0-20230126152724-0fa3db229ce2 // indirect
+ gvisor.dev/gvisor v0.0.0-20230927004350-cbd86285d259 // indirect
)
replace golang.zx2c4.com/wireguard => ./wireguard-go
diff --git a/wireguard-go-rs/libwg/go.sum b/wireguard-go-rs/libwg/go.sum
index b41c5842d1..d04296cf67 100644
--- a/wireguard-go-rs/libwg/go.sum
+++ b/wireguard-go-rs/libwg/go.sum
@@ -1,8 +1,14 @@
+github.com/google/btree v1.0.1 h1:gK4Kx5IaGY9CD5sPJ36FHiBJ6ZXl0kilRiiCj+jdYp4=
+github.com/google/btree v1.0.1/go.mod h1:xXMiIv4Fb/0kKde4SpL7qlzvu5cMJDRkFDxJfI9uaxA=
golang.org/x/crypto v0.22.0 h1:g1v0xeRhjcugydODzvb3mEM9SQ0HGp9s/nh3COQ/C30=
golang.org/x/crypto v0.22.0/go.mod h1:vr6Su+7cTlO45qkww3VDJlzDn0ctJvRgYbC2NvXHt+M=
golang.org/x/net v0.24.0 h1:1PcaxkF854Fu3+lvBIx5SYn9wRlBzzcnHZSiaFFAb0w=
golang.org/x/net v0.24.0/go.mod h1:2Q7sJY5mzlzWjKtYUEXSlBWCdyaioyXzRB2RtU8KVE8=
golang.org/x/sys v0.19.0 h1:q5f1RH2jigJ1MoAWp2KTp3gm5zAGFUTarQZ5U386+4o=
golang.org/x/sys v0.19.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA=
+golang.org/x/time v0.0.0-20220210224613-90d013bbcef8 h1:vVKdlvoWBphwdxWKrFZEuM0kGgGLxUOYcY4U/2Vjg44=
+golang.org/x/time v0.0.0-20220210224613-90d013bbcef8/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ=
golang.zx2c4.com/wintun v0.0.0-20230126152724-0fa3db229ce2 h1:B82qJJgjvYKsXS9jeunTOisW56dUokqW/FOteYJJ/yg=
golang.zx2c4.com/wintun v0.0.0-20230126152724-0fa3db229ce2/go.mod h1:deeaetjYA+DHMHg+sMSMI58GrEteJUUzzw7en6TJQcI=
+gvisor.dev/gvisor v0.0.0-20230927004350-cbd86285d259 h1:TbRPT0HtzFP3Cno1zZo7yPzEEnfu8EjLfl6IU9VfqkQ=
+gvisor.dev/gvisor v0.0.0-20230927004350-cbd86285d259/go.mod h1:AVgIgHMwK63XvmAzWG9vLQ41YnVHN0du0tEC46fI7yY=
diff --git a/wireguard-go-rs/libwg/libwg.go b/wireguard-go-rs/libwg/libwg.go
index 6cfbd0ba55..5dcc9141b2 100644
--- a/wireguard-go-rs/libwg/libwg.go
+++ b/wireguard-go-rs/libwg/libwg.go
@@ -65,6 +65,9 @@ func wgTurnOff(tunnelHandle int32) {
return
}
tunnel.Device.Close()
+ if tunnel.EntryDevice != nil {
+ tunnel.EntryDevice.Close()
+ }
}
// Calling twice convinces the GC to release NOW.
runtime.GC()
diff --git a/wireguard-go-rs/libwg/libwg_android.go b/wireguard-go-rs/libwg/libwg_android.go
index d623b7711d..caca9b04d0 100644
--- a/wireguard-go-rs/libwg/libwg_android.go
+++ b/wireguard-go-rs/libwg/libwg_android.go
@@ -11,6 +11,8 @@ import "C"
import (
"bufio"
+ "errors"
+ "net/netip"
"strings"
"unsafe"
@@ -19,6 +21,7 @@ import (
"golang.zx2c4.com/wireguard/conn"
"golang.zx2c4.com/wireguard/device"
"golang.zx2c4.com/wireguard/tun"
+ "golang.zx2c4.com/wireguard/tun/multihoptun"
"github.com/mullvad/mullvadvpn-app/wireguard/libwg/logging"
"github.com/mullvad/mullvadvpn-app/wireguard/libwg/tunnelcontainer"
@@ -29,6 +32,12 @@ import (
type LogSink = unsafe.Pointer
type LogContext = C.uint64_t
+type tunnelHandle struct {
+ exit *device.Device
+ entry *device.Device
+ logger *device.Logger
+}
+
//export wgTurnOn
func wgTurnOn(cSettings *C.char, fd int, logSink LogSink, logContext LogContext) C.int32_t {
logger := logging.NewLogger(logSink, logging.LogContext(logContext))
@@ -77,13 +86,166 @@ func wgTurnOn(cSettings *C.char, fd int, logSink LogSink, logContext LogContext)
return C.int32_t(handle)
}
+//export wgTurnOnMultihop
+func wgTurnOnMultihop(cExitSettings *C.char, cEntrySettings *C.char, privateIp *C.char, fd int, logSink LogSink, logContext LogContext) C.int32_t {
+ logger := logging.NewLogger(logSink, logging.LogContext(logContext))
+ if cExitSettings == nil {
+ logger.Errorf("cExitSettings is null\n")
+ return ERROR_INVALID_ARGUMENT
+ }
+ exitSettings := goStringFixed(cExitSettings)
+
+ if cEntrySettings == nil {
+ logger.Errorf("cEntrySettings is null\n")
+ return ERROR_INVALID_ARGUMENT
+ }
+ entrySettings := goStringFixed(cEntrySettings)
+
+ exitEndpoint := parseEndpointFromConfig(exitSettings)
+
+ if exitEndpoint == nil {
+ logger.Errorf("exitEndpoint is null\n")
+ return ERROR_INVALID_ARGUMENT
+ }
+
+ // Set up a two tunnel devices: One 'fake' device for the exit relay and one 'real' device for the entry relay
+
+ tunDevice, _, err := tun.CreateUnmonitoredTUNFromFD(fd)
+ if err != nil {
+ logger.Errorf("%s\n", err)
+ unix.Close(fd)
+ if err.Error() == "bad file descriptor" {
+ return ERROR_INTERMITTENT_FAILURE
+ }
+ return ERROR_GENERAL_FAILURE
+ }
+
+ ip, err := netip.ParseAddr(goStringFixed(privateIp))
+ if err != nil {
+ logger.Errorf("%s\n", err)
+ tunDevice.Close()
+ return ERROR_INVALID_ARGUMENT
+ }
+
+ mtu, err := tunDevice.MTU()
+ if err != nil {
+ logger.Errorf("%s\n", err)
+ tunDevice.Close()
+ return ERROR_GENERAL_FAILURE
+ }
+
+ singleTunMtu := mtu - 80 //Internet mtu - Wireguard header size - ipv4 UDP header
+ singletun := multihoptun.NewMultihopTun(ip, exitEndpoint.Addr(), exitEndpoint.Port(), singleTunMtu)
+
+ entryDevice := device.NewDevice(&singletun, conn.NewStdNetBind(), logger)
+ exitDevice := device.NewDevice(tunDevice, singletun.Binder(), logger)
+
+ setErr := entryDevice.IpcSetOperation(bufio.NewReader(strings.NewReader(entrySettings)))
+ if setErr != nil {
+ logger.Errorf("%s\n", setErr)
+ exitDevice.Close()
+ entryDevice.Close()
+ return ERROR_INTERMITTENT_FAILURE
+ }
+
+ entryDevice.DisableSomeRoamingForBrokenMobileSemantics()
+
+ setErr = exitDevice.IpcSetOperation(bufio.NewReader(strings.NewReader(exitSettings)))
+ if setErr != nil {
+ logger.Errorf("%s\n", setErr)
+ exitDevice.Close()
+ entryDevice.Close()
+ return ERROR_INTERMITTENT_FAILURE
+ }
+
+ exitDevice.DisableSomeRoamingForBrokenMobileSemantics()
+
+ exitDevice.Up()
+ entryDevice.Up()
+
+ // Create the stuff that needs
+
+ context := tunnelcontainer.Context{
+ Device: exitDevice,
+ EntryDevice: entryDevice,
+ Logger: logger,
+ }
+
+ handle, err := tunnels.Insert(context)
+ if err != nil {
+ logger.Errorf("%s\n", err)
+ entryDevice.Close()
+ exitDevice.Close()
+ return ERROR_GENERAL_FAILURE
+ }
+
+ return C.int32_t(handle)
+
+}
+
+func addTunnelFromDevice(exitDev *device.Device, entryDev *device.Device, exitSettings string, entrySettings string, logger *device.Logger) (*tunnelHandle, error) {
+ err := bringUpDevice(exitDev, exitSettings, logger)
+ if err != nil {
+ return nil, errors.New("Could not bring up exit device") // errBadWgConfig
+ }
+
+ if entryDev != nil {
+ err = bringUpDevice(entryDev, entrySettings, logger)
+ if err != nil {
+ exitDev.Close()
+ return nil, errors.New("Could not bring up entry device")
+ }
+ }
+
+ return &tunnelHandle{exitDev, entryDev, logger}, nil
+}
+
+func bringUpDevice(dev *device.Device, settings string, logger *device.Logger) error {
+ err := dev.IpcSet(settings)
+ if err != nil {
+ logger.Errorf("Unable to set IPC settings: %v", err)
+ dev.Close()
+ return err
+ }
+
+ dev.Up()
+ logger.Verbosef("Device started")
+ return nil
+}
+
+// Parse a wireguard config and return the first endpoint address it finds and
+// parses successfully.gi b
+func parseEndpointFromConfig(config string) *netip.AddrPort {
+ scanner := bufio.NewScanner(strings.NewReader(config))
+ for scanner.Scan() {
+ line := scanner.Text()
+ key, value, ok := strings.Cut(line, "=")
+ if !ok {
+ continue
+ }
+
+ if key == "endpoint" {
+ endpoint, err := netip.ParseAddrPort(value)
+ if err == nil {
+ return &endpoint
+ }
+ }
+
+ }
+ return nil
+}
+
//export wgGetSocketV4
func wgGetSocketV4(tunnelHandle int32) C.int32_t {
tunnel, err := tunnels.Get(tunnelHandle)
if err != nil {
return ERROR_UNKNOWN_TUNNEL
}
- peek := tunnel.Device.Bind().(conn.PeekLookAtSocketFd)
+ device := tunnel.EntryDevice
+ if device == nil {
+ device = tunnel.Device
+ }
+ peek := device.Bind().(conn.PeekLookAtSocketFd)
fd, err := peek.PeekLookAtSocketFd4()
if err != nil {
return ERROR_GENERAL_FAILURE
@@ -97,7 +259,11 @@ func wgGetSocketV6(tunnelHandle int32) C.int32_t {
if err != nil {
return ERROR_UNKNOWN_TUNNEL
}
- peek := tunnel.Device.Bind().(conn.PeekLookAtSocketFd)
+ device := tunnel.EntryDevice
+ if device == nil {
+ device = tunnel.Device
+ }
+ peek := device.Bind().(conn.PeekLookAtSocketFd)
fd, err := peek.PeekLookAtSocketFd6()
if err != nil {
return ERROR_GENERAL_FAILURE
diff --git a/wireguard-go-rs/libwg/libwg_daita.go b/wireguard-go-rs/libwg/libwg_daita.go
index 3b1fedda4c..fbfceec8f0 100644
--- a/wireguard-go-rs/libwg/libwg_daita.go
+++ b/wireguard-go-rs/libwg/libwg_daita.go
@@ -32,7 +32,15 @@ func wgActivateDaita(tunnelHandle C.int32_t, peerPubkey *C.uint8_t, machines *C.
var publicKey device.NoisePublicKey
copy(publicKey[:], C.GoBytes(unsafe.Pointer(peerPubkey), device.NoisePublicKeySize))
- peer := tunnel.Device.LookupPeer(publicKey)
+
+ var peer *device.Peer
+ if tunnel.EntryDevice != nil {
+ // TODO: Document me
+ peer = tunnel.EntryDevice.LookupPeer(publicKey)
+ } else {
+ // TODO: Document me
+ peer = tunnel.Device.LookupPeer(publicKey)
+ }
if peer == nil {
return ERROR_UNKNOWN_PEER
diff --git a/wireguard-go-rs/libwg/tunnelcontainer/tunnelcontainer.go b/wireguard-go-rs/libwg/tunnelcontainer/tunnelcontainer.go
index 91291dcf4b..79eacc2a17 100644
--- a/wireguard-go-rs/libwg/tunnelcontainer/tunnelcontainer.go
+++ b/wireguard-go-rs/libwg/tunnelcontainer/tunnelcontainer.go
@@ -16,6 +16,7 @@ import (
type Context struct {
Device *device.Device
+ EntryDevice *device.Device
Uapi net.Listener
Logger *device.Logger
}
diff --git a/wireguard-go-rs/src/lib.rs b/wireguard-go-rs/src/lib.rs
index a77b48c0bd..851fd47b9f 100644
--- a/wireguard-go-rs/src/lib.rs
+++ b/wireguard-go-rs/src/lib.rs
@@ -8,10 +8,10 @@
#![cfg(unix)]
-use core::slice;
-use std::{
+use core::{
ffi::{c_char, CStr},
mem::{ManuallyDrop, MaybeUninit},
+ slice,
};
use util::OnDrop;
use zeroize::Zeroize;
@@ -105,6 +105,37 @@ impl Tunnel {
result_from_code(code)
}
+ /// Special function for android multihop since that behavior is different from desktop
+ /// and android non-multihop.
+ ///
+ /// The `logging_callback` let's you provide a Rust function that receives any logging output
+ /// from wireguard-go. `logging_context` is a value that will be passed to each invocation of
+ /// `logging_callback`.
+ #[cfg(target_os = "android")]
+ pub fn turn_on_multihop(
+ exit_settings: &CStr,
+ entry_settings: &CStr,
+ private_ip: &CStr,
+ device: Fd,
+ logging_callback: Option<LoggingCallback>,
+ logging_context: LoggingContext,
+ ) -> Result<Self, Error> {
+ // SAFETY: pointer is valid for the the lifetime of this function
+ let code = unsafe {
+ ffi::wgTurnOnMultihop(
+ exit_settings.as_ptr(),
+ entry_settings.as_ptr(),
+ private_ip.as_ptr(),
+ device,
+ logging_callback,
+ logging_context,
+ )
+ };
+
+ result_from_code(code)?;
+ Ok(Tunnel { handle: code })
+ }
+
/// Get the config of the WireGuard interface and make it available in the provided function.
///
/// This takes a function to make sure the cstr get's zeroed and freed afterwards.
@@ -180,12 +211,14 @@ impl Tunnel {
/// Get the file descriptor of the tunnel IPv4 socket.
#[cfg(target_os = "android")]
pub fn get_socket_v4(&self) -> Fd {
+ // SAFETY: self.handle is a valid pointer to an active wireguard-go tunnel.
unsafe { ffi::wgGetSocketV4(self.handle) }
}
/// Get the file descriptor of the tunnel IPv6 socket.
#[cfg(target_os = "android")]
pub fn get_socket_v6(&self) -> Fd {
+ // SAFETY: self.handle is a valid pointer to an active wireguard-go tunnel.
unsafe { ffi::wgGetSocketV6(self.handle) }
}
}
@@ -257,6 +290,21 @@ mod ffi {
logging_context: LoggingContext,
) -> i32;
+ /// Creates a new wireguard tunnel, uses the specific interface name, and file descriptors
+ /// for the tunnel device and logging.
+ ///
+ /// Positive return values are tunnel handles for this specific wireguard tunnel instance.
+ /// Negative return values signify errors.
+ #[cfg(target_os = "android")]
+ pub fn wgTurnOnMultihop(
+ exit_settings: *const c_char,
+ entry_settings: *const c_char,
+ private_ip: *const c_char,
+ fd: Fd,
+ logging_callback: Option<LoggingCallback>,
+ logging_context: LoggingContext,
+ ) -> i32;
+
/// Pass a handle that was created by wgTurnOn to stop a wireguard tunnel.
///
/// Negative return values signify errors.