diff options
| author | Emīls Piņķis <emils@mullvad.net> | 2022-05-17 17:09:17 +0100 |
|---|---|---|
| committer | David Lönnhager <david.l@mullvad.net> | 2022-06-14 12:37:02 +0200 |
| commit | ec5ef50c03e6b8ea953e88b752886072dd7ebe9e (patch) | |
| tree | 62741b5e50f68bd2ccc96038b537f8a25a4f6d71 | |
| parent | 0f1106cf9b0c3aee13dc331a3e14678520464614 (diff) | |
| download | mullvadvpn-ec5ef50c03e6b8ea953e88b752886072dd7ebe9e.tar.xz mullvadvpn-ec5ef50c03e6b8ea953e88b752886072dd7ebe9e.zip | |
Implement set_config for Linux in-kernel tunnel
5 files changed, 61 insertions, 20 deletions
diff --git a/talpid-core/src/tunnel/wireguard/config.rs b/talpid-core/src/tunnel/wireguard/config.rs index 026c6c9fae..cd87a90124 100644 --- a/talpid-core/src/tunnel/wireguard/config.rs +++ b/talpid-core/src/tunnel/wireguard/config.rs @@ -6,6 +6,7 @@ use std::{ use talpid_types::net::{obfuscation::ObfuscatorConfig, wireguard, GenericTunnelOptions}; /// Config required to set up a single WireGuard tunnel +#[derive(Clone)] pub struct Config { /// Contains tunnel endpoint specific config pub tunnel: wireguard::TunnelConfig, diff --git a/talpid-core/src/tunnel/wireguard/mod.rs b/talpid-core/src/tunnel/wireguard/mod.rs index dc2300b0ab..33a6001361 100644 --- a/talpid-core/src/tunnel/wireguard/mod.rs +++ b/talpid-core/src/tunnel/wireguard/mod.rs @@ -8,6 +8,7 @@ use futures::{channel::mpsc, StreamExt}; use futures::{ channel::oneshot, future::{abortable, AbortHandle as FutureAbortHandle}, + Future, }; #[cfg(target_os = "linux")] use lazy_static::lazy_static; @@ -21,6 +22,7 @@ use std::{ convert::Infallible, net::IpAddr, path::Path, + pin::Pin, sync::{mpsc as sync_mpsc, Arc, Mutex}, }; #[cfg(windows)] @@ -308,11 +310,18 @@ impl WireguardMonitor { } } - log::trace!("Ephemeral pubkey: {}", config.tunnel.private_key.public_key()); + log::trace!( + "Ephemeral pubkey: {}", + config.tunnel.private_key.public_key() + ); - if let Some(tunnel) = &*tunnel.lock().unwrap() { - tunnel - .set_config(&config) + let set_config_future = tunnel + .lock() + .unwrap() + .as_ref() + .map(|tunnel| tunnel.set_config(config.clone())); + if let Some(f) = set_config_future { + f.await .map_err(Error::TunnelError) .map_err(CloseMsg::SetupError)?; } @@ -632,7 +641,10 @@ pub(crate) trait Tunnel: Send { fn get_interface_name(&self) -> String; fn stop(self: Box<Self>) -> std::result::Result<(), TunnelError>; fn get_tunnel_stats(&self) -> std::result::Result<stats::StatsMap, TunnelError>; - fn set_config(&self, _config: &Config) -> std::result::Result<(), TunnelError> { + fn set_config( + &self, + _config: Config, + ) -> Pin<Box<dyn Future<Output = std::result::Result<(), TunnelError>> + Send + 'static>> { unimplemented!() } } diff --git a/talpid-core/src/tunnel/wireguard/wireguard_kernel/netlink_tunnel.rs b/talpid-core/src/tunnel/wireguard/wireguard_kernel/netlink_tunnel.rs index 8ab3234fd5..109ae75367 100644 --- a/talpid-core/src/tunnel/wireguard/wireguard_kernel/netlink_tunnel.rs +++ b/talpid-core/src/tunnel/wireguard/wireguard_kernel/netlink_tunnel.rs @@ -1,3 +1,7 @@ +use std::pin::Pin; + +use futures::Future; + use super::{ super::stats::{Stats, StatsMap}, wg_message::DeviceNla, @@ -109,4 +113,20 @@ impl Tunnel for NetlinkTunnel { result } + + fn set_config( + &self, + config: Config, + ) -> Pin<Box<dyn Future<Output = std::result::Result<(), TunnelError>> + Send + 'static>> { + let mut wg = self.netlink_connections.wg_handle.clone(); + let interface_index = self.interface_index; + Box::pin(async move { + wg.set_config(interface_index, &config) + .await + .map_err(|err| { + log::error!("Failed to fetch WireGuard device config: {}", err); + TunnelError::SetConfigError + }) + }) + } } diff --git a/talpid-core/src/tunnel/wireguard/wireguard_kernel/wg_message.rs b/talpid-core/src/tunnel/wireguard/wireguard_kernel/wg_message.rs index 007acb8df7..dd17bc0cb3 100644 --- a/talpid-core/src/tunnel/wireguard/wireguard_kernel/wg_message.rs +++ b/talpid-core/src/tunnel/wireguard/wireguard_kernel/wg_message.rs @@ -81,12 +81,16 @@ impl DeviceMessage { for peer in config.peers.iter() { let peer_endpoint = InetAddr::from_std(&peer.endpoint); let allowed_ips = peer.allowed_ips.iter().map(From::from).collect(); - peers.push(PeerMessage(vec![ + let mut peer_nlas = vec![ PeerNla::PublicKey(*peer.public_key.as_bytes()), PeerNla::Endpoint(peer_endpoint), PeerNla::AllowedIps(allowed_ips), PeerNla::Flags(WGPEER_F_REPLACE_ALLOWEDIPS), - ])); + ]; + if let Some(psk) = peer.psk.as_ref() { + peer_nlas.push(PeerNla::PresharedKey(psk.as_bytes().clone())); + } + peers.push(PeerMessage(peer_nlas)); } let nlas = vec![ diff --git a/talpid-core/src/tunnel/wireguard/wireguard_nt.rs b/talpid-core/src/tunnel/wireguard/wireguard_nt.rs index 31222d2ce8..0ae469144b 100644 --- a/talpid-core/src/tunnel/wireguard/wireguard_nt.rs +++ b/talpid-core/src/tunnel/wireguard/wireguard_nt.rs @@ -11,11 +11,14 @@ use ipnetwork::IpNetwork; use lazy_static::lazy_static; use std::{ ffi::CStr, - fmt, io, mem, + fmt, + future::Future, + io, mem, mem::MaybeUninit, net::{IpAddr, Ipv4Addr, Ipv6Addr}, os::windows::io::RawHandle, path::Path, + pin::Pin, ptr, sync::{Arc, Mutex}, }; @@ -986,22 +989,22 @@ impl Tunnel for WgNtTunnel { Ok(()) } - fn set_config(&self, config: &Config) -> std::result::Result<(), super::TunnelError> { - if let Some(ref device) = &*self.device.lock().unwrap() { - if let Err(error) = device.set_config(&config) { + fn set_config( + &self, + config: Config, + ) -> Pin<Box<dyn Future<Output = std::result::Result<(), super::TunnelError>> + Send>> { + let device = self.device.clone(); + Box::pin(async move { + let guard = device.lock().unwrap(); + let device = guard.as_ref().ok_or(super::TunnelError::SetConfigError)?; + device.set_config(&config).map_err(|error| { log::error!( "{}", error.display_chain_with_msg("Failed to set wg-nt tunnel config") ); - Err(super::TunnelError::SetConfigError) - } else { - Ok(()) - } - } else { - Err(super::TunnelError::StatsError( - super::stats::Error::NoTunnelDevice, - )) - } + super::TunnelError::SetConfigError + }) + }) } } @@ -1033,6 +1036,7 @@ mod tests { public_key: WG_PUBLIC_KEY.clone(), allowed_ips: vec!["1.3.3.0/24".parse().unwrap()], endpoint: "1.2.3.4:1234".parse().unwrap(), + psk: None, }], ipv4_gateway: "0.0.0.0".parse().unwrap(), ipv6_gateway: None, |
