diff options
| -rw-r--r-- | Cargo.lock | 1 | ||||
| -rw-r--r-- | mullvad-management-interface/src/types.rs | 1 | ||||
| -rw-r--r-- | mullvad-relay-selector/src/matcher.rs | 1 | ||||
| -rw-r--r-- | talpid-core/Cargo.toml | 1 | ||||
| -rw-r--r-- | talpid-core/src/tunnel/mod.rs | 1 | ||||
| -rw-r--r-- | talpid-core/src/tunnel/wireguard/connectivity_check.rs | 11 | ||||
| -rw-r--r-- | talpid-core/src/tunnel/wireguard/mod.rs | 47 | ||||
| -rw-r--r-- | talpid-core/src/tunnel/wireguard/wireguard_nt.rs | 31 | ||||
| -rw-r--r-- | talpid-types/src/net/wireguard.rs | 29 |
9 files changed, 119 insertions, 4 deletions
diff --git a/Cargo.lock b/Cargo.lock index f108dba034..f5e75ffcee 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -3494,6 +3494,7 @@ dependencies = [ "system-configuration", "talpid-dbus", "talpid-platform-metadata", + "talpid-relay-config-client", "talpid-time", "talpid-types", "tempfile", diff --git a/mullvad-management-interface/src/types.rs b/mullvad-management-interface/src/types.rs index 01f16d8220..cad371aa49 100644 --- a/mullvad-management-interface/src/types.rs +++ b/mullvad-management-interface/src/types.rs @@ -1185,6 +1185,7 @@ impl TryFrom<ConnectionConfig> for mullvad_types::ConnectionConfig { public_key, allowed_ips, endpoint, + psk: None, }, exit_peer: None, ipv4_gateway, diff --git a/mullvad-relay-selector/src/matcher.rs b/mullvad-relay-selector/src/matcher.rs index 089510a6c0..13e16646ab 100644 --- a/mullvad-relay-selector/src/matcher.rs +++ b/mullvad-relay-selector/src/matcher.rs @@ -189,6 +189,7 @@ impl WireguardMatcher { public_key: data.public_key, endpoint: SocketAddr::new(host, port), allowed_ips: all_of_the_internet(), + psk: None, }; Some(MullvadEndpoint::Wireguard(MullvadWireguardEndpoint { peer: peer_config, diff --git a/talpid-core/Cargo.toml b/talpid-core/Cargo.toml index e422a537ca..dbaae6a323 100644 --- a/talpid-core/Cargo.toml +++ b/talpid-core/Cargo.toml @@ -26,6 +26,7 @@ regex = "1.1.0" shell-escape = "0.1" talpid-types = { path = "../talpid-types" } talpid-time = { path = "../talpid-time" } +talpid-relay-config-client = { path = "../talpid-relay-config-client" } uuid = { version = "0.8", features = ["v4"] } zeroize = "1" chrono = "0.4.19" diff --git a/talpid-core/src/tunnel/mod.rs b/talpid-core/src/tunnel/mod.rs index ea7ec35133..8820f26c72 100644 --- a/talpid-core/src/tunnel/mod.rs +++ b/talpid-core/src/tunnel/mod.rs @@ -198,6 +198,7 @@ impl TunnelMonitor { let monitor = wireguard::WireguardMonitor::start( runtime, config, + Some(params.connection.peer.public_key.clone()), log.as_deref(), resource_dir, on_event, diff --git a/talpid-core/src/tunnel/wireguard/connectivity_check.rs b/talpid-core/src/tunnel/wireguard/connectivity_check.rs index ec2af873a0..bcb28d7c17 100644 --- a/talpid-core/src/tunnel/wireguard/connectivity_check.rs +++ b/talpid-core/src/tunnel/wireguard/connectivity_check.rs @@ -391,12 +391,16 @@ impl ConnState { #[cfg(test)] mod test { + use futures::Future; + use super::*; use crate::tunnel::wireguard::{ + config::Config, stats::{self, Stats}, TunnelError, }; use std::{ + pin::Pin, sync::{ atomic::{AtomicBool, Ordering}, Arc, Mutex, @@ -598,6 +602,13 @@ mod test { fn get_tunnel_stats(&self) -> Result<stats::StatsMap, TunnelError> { (self.on_get_stats)() } + + fn set_config( + &self, + _config: Config, + ) -> Pin<Box<dyn Future<Output = std::result::Result<(), TunnelError>> + Send>> { + Box::pin(async { Ok(()) }) + } } fn mock_monitor( diff --git a/talpid-core/src/tunnel/wireguard/mod.rs b/talpid-core/src/tunnel/wireguard/mod.rs index 80c8e2ae8e..dc2300b0ab 100644 --- a/talpid-core/src/tunnel/wireguard/mod.rs +++ b/talpid-core/src/tunnel/wireguard/mod.rs @@ -25,7 +25,10 @@ use std::{ }; #[cfg(windows)] use talpid_types::BoxedError; -use talpid_types::{net::obfuscation::ObfuscatorConfig, ErrorExt}; +use talpid_types::{ + net::{obfuscation::ObfuscatorConfig, wireguard::PublicKey}, + ErrorExt, +}; use tunnel_obfuscation::{ create_obfuscator, Error as ObfuscationError, Settings as ObfuscationSettings, Udp2TcpSettings, }; @@ -73,6 +76,10 @@ pub enum Error { #[error(display = "Connectivity monitor failed")] ConnectivityMonitorError(#[error(source)] connectivity_check::Error), + /// Failed to negotiate PQ PSK + #[error(display = "Failed to negotiate PQ PSK")] + PskNegotiationError(talpid_relay_config_client::Error), + /// Failed to set up IP interfaces. #[cfg(windows)] #[error(display = "Failed to set up IP interfaces")] @@ -188,6 +195,7 @@ impl WireguardMonitor { >( runtime: tokio::runtime::Handle, mut config: Config, + psk_negotiation: Option<PublicKey>, log_path: Option<&Path>, resource_dir: &Path, on_event: F, @@ -237,6 +245,7 @@ impl WireguardMonitor { .map_err(Error::ConnectivityMonitorError)?; let metadata = Self::tunnel_metadata(&iface_name, &config); + let tunnel = monitor.tunnel.clone(); let tunnel_fut = async move { #[cfg(windows)] @@ -280,6 +289,35 @@ impl WireguardMonitor { .map_err(Error::SetupRoutingError) .map_err(CloseMsg::SetupError)?; + if let Some(pubkey) = psk_negotiation { + // TODO: add timeout + let (private_key, psk) = talpid_relay_config_client::push_pq_key( + IpAddr::V4(config.ipv4_gateway), + config.tunnel.private_key.public_key(), + ) + .await + .map_err(Error::PskNegotiationError) + .map_err(CloseMsg::SetupError)?; + + config.tunnel.private_key = private_key; + + for peer in &mut config.peers { + if pubkey == peer.public_key { + peer.psk = Some(psk); + break; + } + } + + log::trace!("Ephemeral pubkey: {}", config.tunnel.private_key.public_key()); + + if let Some(tunnel) = &*tunnel.lock().unwrap() { + tunnel + .set_config(&config) + .map_err(Error::TunnelError) + .map_err(CloseMsg::SetupError)?; + } + } + let mut connectivity_monitor = tokio::task::spawn_blocking(move || { match connectivity_monitor.establish_connectivity(retry_attempt) { Ok(true) => Ok(connectivity_monitor), @@ -594,6 +632,9 @@ pub(crate) trait Tunnel: Send { fn get_interface_name(&self) -> String; fn stop(self: Box<Self>) -> std::result::Result<(), TunnelError>; fn get_tunnel_stats(&self) -> std::result::Result<stats::StatsMap, TunnelError>; + fn set_config(&self, _config: &Config) -> std::result::Result<(), TunnelError> { + unimplemented!() + } } /// Errors to be returned from WireGuard implementations, namely implementers of the Tunnel trait @@ -630,6 +671,10 @@ pub enum TunnelError { #[error(display = "Failed to get config of WireGuard tunnel")] GetConfigError, + /// Failed to set WireGuard tunnel config on device + #[error(display = "Failed to set config of WireGuard tunnel")] + SetConfigError, + /// Failed to duplicate tunnel file descriptor for wireguard-go #[cfg(any(target_os = "linux", target_os = "macos", target_os = "android"))] #[error(display = "Failed to duplicate tunnel file descriptor for wireguard-go")] diff --git a/talpid-core/src/tunnel/wireguard/wireguard_nt.rs b/talpid-core/src/tunnel/wireguard/wireguard_nt.rs index 21c9b705ce..31222d2ce8 100644 --- a/talpid-core/src/tunnel/wireguard/wireguard_nt.rs +++ b/talpid-core/src/tunnel/wireguard/wireguard_nt.rs @@ -833,11 +833,20 @@ fn serialize_config(config: &Config) -> Result<Vec<MaybeUninit<u8>>> { buffer.extend(windows::as_uninit_byte_slice(&header)); for peer in &config.peers { + let flags = if peer.psk.is_some() { + WgPeerFlag::HAS_PRESHARED_KEY | WgPeerFlag::HAS_PUBLIC_KEY | WgPeerFlag::HAS_ENDPOINT + } else { + WgPeerFlag::HAS_PUBLIC_KEY | WgPeerFlag::HAS_ENDPOINT + }; let wg_peer = WgPeer { - flags: WgPeerFlag::HAS_PUBLIC_KEY | WgPeerFlag::HAS_ENDPOINT, + flags, reserved: 0, public_key: peer.public_key.as_bytes().clone(), - preshared_key: [0u8; WIREGUARD_KEY_LENGTH], + preshared_key: peer + .psk + .as_ref() + .map(|psk| psk.as_bytes().clone()) + .unwrap_or([0u8; WIREGUARD_KEY_LENGTH]), persistent_keepalive: 0, endpoint: windows::inet_sockaddr_from_socketaddr(peer.endpoint).into(), tx_bytes: 0, @@ -976,6 +985,24 @@ impl Tunnel for WgNtTunnel { self.stop_tunnel(); Ok(()) } + + fn set_config(&self, config: &Config) -> std::result::Result<(), super::TunnelError> { + if let Some(ref device) = &*self.device.lock().unwrap() { + if let Err(error) = device.set_config(&config) { + log::error!( + "{}", + error.display_chain_with_msg("Failed to set wg-nt tunnel config") + ); + Err(super::TunnelError::SetConfigError) + } else { + Ok(()) + } + } else { + Err(super::TunnelError::StatsError( + super::stats::Error::NoTunnelDevice, + )) + } + } } #[cfg(test)] diff --git a/talpid-types/src/net/wireguard.rs b/talpid-types/src/net/wireguard.rs index 56ecc6bb42..16cd7390b7 100644 --- a/talpid-types/src/net/wireguard.rs +++ b/talpid-types/src/net/wireguard.rs @@ -55,6 +55,8 @@ pub struct PeerConfig { pub allowed_ips: Vec<IpNetwork>, /// IP address of the WireGuard server. pub endpoint: SocketAddr, + /// Preshared key. + pub psk: Option<PresharedKey>, } #[derive(Clone, Eq, PartialEq, Deserialize, Serialize, Debug)] @@ -253,15 +255,40 @@ impl fmt::Display for PublicKey { } } -#[derive(Debug, Clone)] +#[derive(Debug, Clone, PartialEq, Eq, Hash)] pub struct PresharedKey([u8; 32]); +impl PresharedKey { + /// Get the PSK as bytes + pub fn as_bytes(&self) -> &[u8; 32] { + &self.0 + } +} + impl From<[u8; 32]> for PresharedKey { fn from(key: [u8; 32]) -> PresharedKey { PresharedKey(key) } } +impl Serialize for PresharedKey { + fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error> + where + S: Serializer, + { + serialize_key(&self.0, serializer) + } +} + +impl<'de> Deserialize<'de> for PresharedKey { + fn deserialize<D>(deserializer: D) -> Result<Self, D::Error> + where + D: Deserializer<'de>, + { + deserialize_key(deserializer) + } +} + fn serialize_key<S>(key: &[u8; 32], serializer: S) -> Result<S::Ok, S::Error> where S: Serializer, |
