diff options
46 files changed, 450 insertions, 423 deletions
diff --git a/android/translations-converter/src/android/string_value.rs b/android/translations-converter/src/android/string_value.rs index 04ad23653e..f11453d020 100644 --- a/android/translations-converter/src/android/string_value.rs +++ b/android/translations-converter/src/android/string_value.rs @@ -84,13 +84,6 @@ impl StringValue { } } -impl StringValue { - /// Clones the internal string value. - pub fn to_string(&self) -> String { - self.0.clone() - } -} - impl Deref for StringValue { type Target = str; diff --git a/clippy.toml b/clippy.toml new file mode 100644 index 0000000000..c2ffede55a --- /dev/null +++ b/clippy.toml @@ -0,0 +1 @@ +enum-variant-size-threshold = 1000 diff --git a/mullvad-api/src/address_cache.rs b/mullvad-api/src/address_cache.rs index a8b1a91a44..d37d429e89 100644 --- a/mullvad-api/src/address_cache.rs +++ b/mullvad-api/src/address_cache.rs @@ -10,19 +10,16 @@ use tokio::{ #[error(no_from)] pub enum Error { #[error(display = "Failed to open the address cache file")] - OpenAddressCache(#[error(source)] io::Error), + Open(#[error(source)] io::Error), #[error(display = "Failed to read the address cache file")] - ReadAddressCache(#[error(source)] io::Error), + Read(#[error(source)] io::Error), #[error(display = "Failed to parse the address cache file")] - ParseAddressCache, + Parse, #[error(display = "Failed to update the address cache file")] - WriteAddressCache(#[error(source)] io::Error), - - #[error(display = "The address cache is empty")] - EmptyAddressCache, + Write(#[error(source)] io::Error), } #[derive(Clone)] @@ -68,7 +65,7 @@ impl AddressCache { self.inner.lock().await.address } - pub async fn set_address(&self, address: SocketAddr) -> io::Result<()> { + pub async fn set_address(&self, address: SocketAddr) -> Result<(), Error> { let mut inner = self.inner.lock().await; if address != inner.address { self.save_to_disk(&address).await?; @@ -77,17 +74,21 @@ impl AddressCache { Ok(()) } - async fn save_to_disk(&self, address: &SocketAddr) -> io::Result<()> { + async fn save_to_disk(&self, address: &SocketAddr) -> Result<(), Error> { let write_path = match self.write_path.as_ref() { Some(write_path) => write_path, None => return Ok(()), }; - let mut file = crate::fs::AtomicFile::new(write_path.to_path_buf()).await?; + let mut file = crate::fs::AtomicFile::new(write_path.to_path_buf()) + .await + .map_err(Error::Open)?; let mut contents = address.to_string(); contents += "\n"; - file.write_all(contents.as_bytes()).await?; - file.finalize().await + file.write_all(contents.as_bytes()) + .await + .map_err(Error::Write)?; + file.finalize().await.map_err(Error::Write) } } @@ -103,12 +104,10 @@ impl AddressCacheInner { } async fn read_address_file(path: &Path) -> Result<SocketAddr, Error> { - let mut file = fs::File::open(path) - .await - .map_err(Error::OpenAddressCache)?; + let mut file = fs::File::open(path).await.map_err(Error::Open)?; let mut address = String::new(); file.read_to_string(&mut address) .await - .map_err(Error::ReadAddressCache)?; - address.trim().parse().map_err(|_| Error::ParseAddressCache) + .map_err(Error::Read)?; + address.trim().parse().map_err(|_| Error::Parse) } diff --git a/mullvad-api/src/https_client_with_sni.rs b/mullvad-api/src/https_client_with_sni.rs index 2920cf989d..0d85ddd790 100644 --- a/mullvad-api/src/https_client_with_sni.rs +++ b/mullvad-api/src/https_client_with_sni.rs @@ -304,7 +304,7 @@ impl Service<Uri> for HttpsConnectorWithSni { ) .await?; let tls_stream = TlsStream::connect_https(socket, &hostname).await?; - Ok::<_, io::Error>(ApiConnection::Direct(tls_stream)) + Ok::<_, io::Error>(ApiConnection::Direct(Box::new(tls_stream))) } InnerConnectionMode::Proxied(proxy_config) => { let socket = Self::open_socket( @@ -320,7 +320,7 @@ impl Service<Uri> for HttpsConnectorWithSni { addr, ); let tls_stream = TlsStream::connect_https(proxy, &hostname).await?; - Ok(ApiConnection::Proxied(tls_stream)) + Ok(ApiConnection::Proxied(Box::new(tls_stream))) } } }; diff --git a/mullvad-api/src/lib.rs b/mullvad-api/src/lib.rs index b1019ef155..d967791192 100644 --- a/mullvad-api/src/lib.rs +++ b/mullvad-api/src/lib.rs @@ -236,7 +236,7 @@ impl Runtime { new_address_callback: impl ApiEndpointUpdateCallback + Send + Sync + 'static, #[cfg(target_os = "android")] socket_bypass_tx: Option<mpsc::Sender<SocketBypassRequest>>, ) -> rest::RequestServiceHandle { - let service_handle = rest::RequestService::new( + let service_handle = rest::RequestService::spawn( sni_hostname, self.api_availability.handle(), self.address_cache.clone(), diff --git a/mullvad-api/src/proxy.rs b/mullvad-api/src/proxy.rs index 21fa39c9c6..cc7ee3aa6c 100644 --- a/mullvad-api/src/proxy.rs +++ b/mullvad-api/src/proxy.rs @@ -132,8 +132,8 @@ impl ApiConnectionMode { /// Stream that is either a regular TLS stream or TLS via shadowsocks pub enum ApiConnection { - Direct(TlsStream<TcpStream>), - Proxied(TlsStream<ProxyClientStream<TcpStream>>), + Direct(Box<TlsStream<TcpStream>>), + Proxied(Box<TlsStream<ProxyClientStream<TcpStream>>>), } impl AsyncRead for ApiConnection { diff --git a/mullvad-api/src/relay_list.rs b/mullvad-api/src/relay_list.rs index 6bd4523652..a2d699e248 100644 --- a/mullvad-api/src/relay_list.rs +++ b/mullvad-api/src/relay_list.rs @@ -130,7 +130,7 @@ impl ServerRelayList { ) { let openvpn_endpoint_data = openvpn.ports; for mut openvpn_relay in openvpn.relays.into_iter() { - openvpn_relay.to_lower(); + openvpn_relay.convert_to_lowercase(); if let Some((country_code, city_code)) = split_location_code(&openvpn_relay.location) { if let Some(country) = countries.get_mut(country_code) { if let Some(city) = country @@ -184,7 +184,7 @@ impl ServerRelayList { }; for mut wireguard_relay in relays { - wireguard_relay.relay.to_lower(); + wireguard_relay.relay.convert_to_lowercase(); if let Some((country_code, city_code)) = split_location_code(&wireguard_relay.relay.location) { @@ -235,7 +235,7 @@ impl ServerRelayList { } = bridges; for mut bridge_relay in relays { - bridge_relay.to_lower(); + bridge_relay.convert_to_lowercase(); if let Some((country_code, city_code)) = split_location_code(&bridge_relay.location) { if let Some(country) = countries.get_mut(country_code) { if let Some(city) = country @@ -345,7 +345,7 @@ struct Relay { } impl Relay { - fn to_lower(&mut self) { + fn convert_to_lowercase(&mut self) { self.hostname = self.hostname.to_lowercase(); self.location = self.location.to_lowercase(); } diff --git a/mullvad-api/src/rest.rs b/mullvad-api/src/rest.rs index 8a04c62e39..c80f01049a 100644 --- a/mullvad-api/src/rest.rs +++ b/mullvad-api/src/rest.rs @@ -130,7 +130,7 @@ impl< > RequestService<T, F> { /// Constructs a new request service. - pub async fn new( + pub async fn spawn( sni_hostname: Option<String>, api_availability: ApiAvailabilityHandle, address_cache: AddressCache, diff --git a/mullvad-daemon/src/cleanup.rs b/mullvad-daemon/src/cleanup.rs index b4765761e0..1f0cbdf309 100644 --- a/mullvad-daemon/src/cleanup.rs +++ b/mullvad-daemon/src/cleanup.rs @@ -7,26 +7,26 @@ use tokio::{fs, io}; #[error(no_from)] pub enum Error { #[error(display = "Failed to get path")] - PathError(#[error(source)] mullvad_paths::Error), + Path(#[error(source)] mullvad_paths::Error), #[error(display = "Failed to remove directory {}", _0)] - RemoveDirError(String, #[error(source)] io::Error), + RemoveDir(String, #[error(source)] io::Error), #[cfg(not(target_os = "windows"))] #[error(display = "Failed to create directory {}", _0)] - CreateDirError(String, #[error(source)] io::Error), + CreateDir(String, #[error(source)] io::Error), #[cfg(target_os = "windows")] #[error(display = "Failed to get file type info")] - FileTypeError(#[error(source)] io::Error), + FileType(#[error(source)] io::Error), #[cfg(target_os = "windows")] #[error(display = "Failed to get dir entry")] - FileEntryError(#[error(source)] io::Error), + FileEntry(#[error(source)] io::Error), #[cfg(target_os = "windows")] #[error(display = "Failed to read dir entries")] - ReadDirError(#[error(source)] io::Error), + ReadDir(#[error(source)] io::Error), } pub async fn clear_directories() -> Result<(), Error> { @@ -35,12 +35,12 @@ pub async fn clear_directories() -> Result<(), Error> { } async fn clear_log_directory() -> Result<(), Error> { - let log_dir = mullvad_paths::get_log_dir().map_err(Error::PathError)?; + let log_dir = mullvad_paths::get_log_dir().map_err(Error::Path)?; clear_directory(&log_dir).await } async fn clear_cache_directory() -> Result<(), Error> { - let cache_dir = mullvad_paths::cache_dir().map_err(Error::PathError)?; + let cache_dir = mullvad_paths::cache_dir().map_err(Error::Path)?; clear_directory(&cache_dir).await } @@ -49,22 +49,22 @@ async fn clear_directory(path: &Path) -> Result<(), Error> { { fs::remove_dir_all(path) .await - .map_err(|e| Error::RemoveDirError(path.display().to_string(), e))?; + .map_err(|e| Error::RemoveDir(path.display().to_string(), e))?; fs::create_dir_all(path) .await - .map_err(|e| Error::CreateDirError(path.display().to_string(), e)) + .map_err(|e| Error::CreateDir(path.display().to_string(), e)) } #[cfg(target_os = "windows")] { - let mut dir = fs::read_dir(&path).await.map_err(Error::ReadDirError)?; + let mut dir = fs::read_dir(&path).await.map_err(Error::ReadDir)?; let mut result = Ok(()); - while let Some(entry) = dir.next_entry().await.map_err(Error::FileEntryError)? { + while let Some(entry) = dir.next_entry().await.map_err(Error::FileEntry)? { let entry_type = match entry.file_type().await { Ok(entry_type) => entry_type, Err(error) => { - result = result.and(Err(Error::FileTypeError(error))); + result = result.and(Err(Error::FileType(error))); continue; } }; @@ -74,9 +74,8 @@ async fn clear_directory(path: &Path) -> Result<(), Error> { } else { fs::remove_dir_all(entry.path()).await }; - result = result.and( - removal.map_err(|e| Error::RemoveDirError(entry.path().display().to_string(), e)), - ); + result = result + .and(removal.map_err(|e| Error::RemoveDir(entry.path().display().to_string(), e))); } result } diff --git a/mullvad-daemon/src/device/api.rs b/mullvad-daemon/src/device/api.rs index cee0186987..6e00b669ca 100644 --- a/mullvad-daemon/src/device/api.rs +++ b/mullvad-daemon/src/device/api.rs @@ -35,10 +35,10 @@ impl CurrentApiCall { } pub fn is_validating(&self) -> bool { - match &self.current_call { - Some(Call::Validation(_)) | Some(Call::OneshotKeyRotation(_)) => true, - _ => false, - } + matches!( + &self.current_call, + Some(Call::Validation(_)) | Some(Call::OneshotKeyRotation(_)) + ) } pub fn is_running_timed_totation(&self) -> bool { @@ -51,10 +51,7 @@ impl CurrentApiCall { pub fn is_logging_in(&self) -> bool { use Call::*; - match &self.current_call { - Some(Login(..)) => true, - _ => false, - } + matches!(&self.current_call, Some(Login(..))) } } diff --git a/mullvad-daemon/src/exception_logging/unix.rs b/mullvad-daemon/src/exception_logging/unix.rs index 8d87be2da6..430fedb74f 100644 --- a/mullvad-daemon/src/exception_logging/unix.rs +++ b/mullvad-daemon/src/exception_logging/unix.rs @@ -5,7 +5,7 @@ use nix::sys::signal::{sigaction, SaFlags, SigAction, SigHandler, SigSet, Signal use std::{convert::TryFrom, sync::Once}; -const INIT_ONCE: Once = Once::new(); +static INIT_ONCE: Once = Once::new(); const FAULT_SIGNALS: [Signal; 5] = [ // Access to invalid memory address diff --git a/mullvad-daemon/src/lib.rs b/mullvad-daemon/src/lib.rs index 95039cd662..c6f024eb49 100644 --- a/mullvad-daemon/src/lib.rs +++ b/mullvad-daemon/src/lib.rs @@ -31,7 +31,7 @@ use crate::target_state::PersistentTargetState; use device::{PrivateAccountAndDevice, PrivateDeviceEvent}; use futures::{ channel::{mpsc, oneshot}, - future::{abortable, AbortHandle, Future}, + future::{abortable, AbortHandle, Future, LocalBoxFuture}, StreamExt, }; use mullvad_relay_selector::{ @@ -385,6 +385,12 @@ pub struct DaemonCommandChannel { receiver: mpsc::UnboundedReceiver<InternalDaemonEvent>, } +impl Default for DaemonCommandChannel { + fn default() -> Self { + Self::new() + } +} + impl DaemonCommandChannel { pub fn new() -> Self { let (untracked_sender, receiver) = mpsc::unbounded(); @@ -472,13 +478,13 @@ impl<E> Sender<E> for DaemonEventSender<E> where InternalDaemonEvent: From<E>, { - fn send(&self, event: E) -> Result<(), ()> { + fn send(&self, event: E) -> Result<(), talpid_core::mpsc::Error> { if let Some(sender) = self.sender.upgrade() { sender .unbounded_send(InternalDaemonEvent::from(event)) - .map_err(|_| ()) + .map_err(|_| talpid_core::mpsc::Error::ChannelClosed) } else { - Err(()) + Err(talpid_core::mpsc::Error::ChannelClosed) } } } @@ -684,7 +690,7 @@ where relay_list_listener.notify_relay_list(relay_list.clone()); }; - let mut relay_list_updater = RelayListUpdater::new( + let mut relay_list_updater = RelayListUpdater::spawn( relay_selector.clone(), api_handle.clone(), &cache_dir, @@ -785,11 +791,11 @@ where /// Shuts down the daemon without shutting down the underlying event listener and the shutdown /// callbacks - fn shutdown( + fn shutdown<'a>( self, ) -> ( L, - Vec<Pin<Box<dyn Future<Output = ()>>>>, + Vec<LocalBoxFuture<'a, ()>>, mullvad_api::Runtime, TunnelStateMachineHandle, ) { @@ -845,11 +851,11 @@ where TunnelStateTransition::Disconnected => TunnelState::Disconnected, TunnelStateTransition::Connecting(endpoint) => TunnelState::Connecting { endpoint, - location: self.parameters_generator.get_last_location(), + location: self.parameters_generator.get_last_location().await, }, TunnelStateTransition::Connected(endpoint) => TunnelState::Connected { endpoint, - location: self.parameters_generator.get_last_location(), + location: self.parameters_generator.get_last_location().await, }, TunnelStateTransition::Disconnecting(after_disconnect) => { TunnelState::Disconnecting(after_disconnect) @@ -1184,7 +1190,7 @@ where } Disconnecting(..) => Self::oneshot_send( tx, - self.parameters_generator.get_last_location(), + self.parameters_generator.get_last_location().await, "current location", ), Connected { location, .. } => { @@ -1703,7 +1709,8 @@ where Self::oneshot_send(tx, Ok(()), "use_wireguard_nt response"); if settings_changed { self.parameters_generator - .set_tunnel_options(&self.settings.tunnel_options); + .set_tunnel_options(&self.settings.tunnel_options) + .await; self.event_listener .notify_settings(self.settings.to_settings()); if let Some(TunnelType::Wireguard) = self.get_target_tunnel_type() { @@ -1854,7 +1861,8 @@ where Self::oneshot_send(tx, Ok(()), "set_openvpn_mssfix response"); if settings_changed { self.parameters_generator - .set_tunnel_options(&self.settings.tunnel_options); + .set_tunnel_options(&self.settings.tunnel_options) + .await; self.event_listener .notify_settings(self.settings.to_settings()); if self.get_target_tunnel_type() == Some(TunnelType::OpenVpn) { @@ -1963,7 +1971,8 @@ where Self::oneshot_send(tx, Ok(()), "set_enable_ipv6 response"); if settings_changed { self.parameters_generator - .set_tunnel_options(&self.settings.tunnel_options); + .set_tunnel_options(&self.settings.tunnel_options) + .await; self.event_listener .notify_settings(self.settings.to_settings()); log::info!("Initiating tunnel restart because the enable IPv6 setting changed"); @@ -1991,7 +2000,8 @@ where Self::oneshot_send(tx, Ok(()), "set_quantum_resistant_tunnel response"); if settings_changed { self.parameters_generator - .set_tunnel_options(&self.settings.tunnel_options); + .set_tunnel_options(&self.settings.tunnel_options) + .await; self.event_listener .notify_settings(self.settings.to_settings()); if self.get_target_tunnel_type() == Some(TunnelType::Wireguard) { @@ -2021,7 +2031,8 @@ where let resolvers = dns::addresses_from_options(&settings.tunnel_options.dns_options); self.parameters_generator - .set_tunnel_options(&settings.tunnel_options); + .set_tunnel_options(&settings.tunnel_options) + .await; self.event_listener.notify_settings(settings); self.send_tunnel_command(TunnelCommand::Dns(resolvers)); } @@ -2044,7 +2055,8 @@ where Self::oneshot_send(tx, Ok(()), "set_wireguard_mtu response"); if settings_changed { self.parameters_generator - .set_tunnel_options(&self.settings.tunnel_options); + .set_tunnel_options(&self.settings.tunnel_options) + .await; self.event_listener .notify_settings(self.settings.to_settings()); if let Some(TunnelType::Wireguard) = self.get_connected_tunnel_type() { @@ -2086,7 +2098,8 @@ where ); } self.parameters_generator - .set_tunnel_options(&self.settings.tunnel_options); + .set_tunnel_options(&self.settings.tunnel_options) + .await; self.event_listener .notify_settings(self.settings.to_settings()); } diff --git a/mullvad-daemon/src/management_interface.rs b/mullvad-daemon/src/management_interface.rs index 7999fd9c55..26ad5e39c0 100644 --- a/mullvad-daemon/src/management_interface.rs +++ b/mullvad-daemon/src/management_interface.rs @@ -151,7 +151,7 @@ impl ManagementService for ManagementServiceImpl { self.send_command_to_daemon(DaemonCommand::GetVersionInfo(tx))?; self.wait_for_result(rx) .await? - .ok_or(Status::not_found("no version cache")) + .ok_or_else(|| Status::not_found("no version cache")) .map(types::AppVersionInfo::from) .map(Response::new) } diff --git a/mullvad-daemon/src/migrations/account_history.rs b/mullvad-daemon/src/migrations/account_history.rs index 22534c7f88..520717e405 100644 --- a/mullvad-daemon/src/migrations/account_history.rs +++ b/mullvad-daemon/src/migrations/account_history.rs @@ -53,12 +53,12 @@ pub async fn migrate_formats(settings_dir: &Path, settings: &mut serde_json::Val .read(true) .open(path) .await - .map_err(Error::ReadHistoryError)?; + .map_err(Error::ReadHistory)?; let mut bytes = vec![]; file.read_to_end(&mut bytes) .await - .map_err(Error::ReadHistoryError)?; + .map_err(Error::ReadHistory)?; if is_format_v3(&bytes) { return Ok(()); @@ -92,16 +92,16 @@ fn is_format_v3(bytes: &[u8]) -> bool { } async fn write_format_v3(mut file: File, token: Option<AccountToken>) -> Result<()> { - file.set_len(0).await.map_err(Error::WriteHistoryError)?; + file.set_len(0).await.map_err(Error::WriteHistory)?; file.seek(io::SeekFrom::Start(0)) .await - .map_err(Error::WriteHistoryError)?; + .map_err(Error::WriteHistory)?; if let Some(token) = token { file.write_all(token.as_bytes()) .await - .map_err(Error::WriteHistoryError)?; + .map_err(Error::WriteHistory)?; } - file.sync_all().await.map_err(Error::WriteHistoryError) + file.sync_all().await.map_err(Error::WriteHistory) } fn try_format_v2(bytes: &[u8]) -> Result<Option<(AccountToken, serde_json::Value)>> { diff --git a/mullvad-daemon/src/migrations/mod.rs b/mullvad-daemon/src/migrations/mod.rs index bb1e9d1ba0..b4a1d18979 100644 --- a/mullvad-daemon/src/migrations/mod.rs +++ b/mullvad-daemon/src/migrations/mod.rs @@ -57,31 +57,31 @@ const SETTINGS_FILE: &str = "settings.json"; #[error(no_from)] pub enum Error { #[error(display = "Failed to read the settings")] - ReadError(#[error(source)] io::Error), + Read(#[error(source)] io::Error), #[error(display = "Malformed settings")] - ParseError(#[error(source)] serde_json::Error), + Parse(#[error(source)] serde_json::Error), #[error(display = "Unable to read any version of the settings")] NoMatchingVersion, #[error(display = "Unable to serialize settings to JSON")] - SerializeError(#[error(source)] serde_json::Error), + Serialize(#[error(source)] serde_json::Error), #[error(display = "Unable to open settings for writing")] - OpenError(#[error(source)] io::Error), + Open(#[error(source)] io::Error), #[error(display = "Unable to write new settings")] - WriteError(#[error(source)] io::Error), + Write(#[error(source)] io::Error), #[error(display = "Unable to sync settings to disk")] - SyncError(#[error(source)] io::Error), + SyncSettings(#[error(source)] io::Error), #[error(display = "Failed to read the account history")] - ReadHistoryError(#[error(source)] io::Error), + ReadHistory(#[error(source)] io::Error), #[error(display = "Failed to write new account history")] - WriteHistoryError(#[error(source)] io::Error), + WriteHistory(#[error(source)] io::Error), #[error(display = "Failed to parse account history")] ParseHistoryError, @@ -129,10 +129,10 @@ pub(crate) async fn migrate_all( return Ok(None); } - let settings_bytes = fs::read(&path).await.map_err(Error::ReadError)?; + let settings_bytes = fs::read(&path).await.map_err(Error::Read)?; let mut settings: serde_json::Value = - serde_json::from_reader(&settings_bytes[..]).map_err(Error::ParseError)?; + serde_json::from_reader(&settings_bytes[..]).map_err(Error::Parse)?; if !settings.is_object() { return Err(Error::NoMatchingVersion); @@ -155,7 +155,7 @@ pub(crate) async fn migrate_all( return Ok(migration_data); } - let buffer = serde_json::to_string_pretty(&settings).map_err(Error::SerializeError)?; + let buffer = serde_json::to_string_pretty(&settings).map_err(Error::Serialize)?; let mut options = fs::OpenOptions::new(); #[cfg(unix)] @@ -168,11 +168,11 @@ pub(crate) async fn migrate_all( .truncate(true) .open(&path) .await - .map_err(Error::OpenError)?; + .map_err(Error::Open)?; file.write_all(&buffer.into_bytes()) .await - .map_err(Error::WriteError)?; - file.sync_data().await.map_err(Error::SyncError)?; + .map_err(Error::Write)?; + file.sync_data().await.map_err(Error::SyncSettings)?; log::debug!("Migrated settings. Wrote settings to {}", path.display()); diff --git a/mullvad-daemon/src/migrations/v2.rs b/mullvad-daemon/src/migrations/v2.rs index e91a0d08e8..d4acbf77a7 100644 --- a/mullvad-daemon/src/migrations/v2.rs +++ b/mullvad-daemon/src/migrations/v2.rs @@ -1,3 +1,4 @@ +#![allow(clippy::identity_op)] use super::{Error, Result}; use mullvad_types::settings::SettingsVersion; use std::time::Duration; diff --git a/mullvad-daemon/src/migrations/v3.rs b/mullvad-daemon/src/migrations/v3.rs index cf8631e121..7d2800e32a 100644 --- a/mullvad-daemon/src/migrations/v3.rs +++ b/mullvad-daemon/src/migrations/v3.rs @@ -66,7 +66,7 @@ pub fn migrate(settings: &mut serde_json::Value) -> Result<()> { DnsState::Default }; let addresses = if let Some(addrs) = options.get("addresses") { - serde_json::from_value(addrs.clone()).map_err(Error::ParseError)? + serde_json::from_value(addrs.clone()).map_err(Error::Parse)? } else { vec![] }; diff --git a/mullvad-daemon/src/migrations/v4.rs b/mullvad-daemon/src/migrations/v4.rs index 6908b13010..77b72bcb3e 100644 --- a/mullvad-daemon/src/migrations/v4.rs +++ b/mullvad-daemon/src/migrations/v4.rs @@ -43,8 +43,7 @@ pub fn migrate(settings: &mut serde_json::Value) -> Result<()> { if let Some(constraints) = wireguard_constraints { let (port, protocol): (Constraint<u16>, TransportProtocol) = if let Some(port) = constraints.get("port") { - let port_constraint = - serde_json::from_value(port.clone()).map_err(Error::ParseError)?; + let port_constraint = serde_json::from_value(port.clone()).map_err(Error::Parse)?; match port_constraint { Constraint::Any => (Constraint::Any, TransportProtocol::Udp), Constraint::Only(port) => (Constraint::Only(port), wg_protocol_from_port(port)), @@ -77,13 +76,13 @@ pub fn migrate(settings: &mut serde_json::Value) -> Result<()> { if let Some(constraints) = openvpn_constraints { let port: Constraint<u16> = if let Some(port) = constraints.get("port") { - serde_json::from_value(port.clone()).map_err(Error::ParseError)? + serde_json::from_value(port.clone()).map_err(Error::Parse)? } else { Constraint::Any }; let transport_constraint: Constraint<TransportProtocol> = if let Some(protocol) = constraints.get("protocol") { - serde_json::from_value(protocol.clone()).map_err(Error::ParseError)? + serde_json::from_value(protocol.clone()).map_err(Error::Parse)? } else { Constraint::Any }; diff --git a/mullvad-daemon/src/migrations/v5.rs b/mullvad-daemon/src/migrations/v5.rs index 9f0fdc4b94..cfc74d8438 100644 --- a/mullvad-daemon/src/migrations/v5.rs +++ b/mullvad-daemon/src/migrations/v5.rs @@ -95,7 +95,7 @@ pub(crate) async fn migrate(settings: &mut serde_json::Value) -> Result<Option<M // if let Some(port) = wireguard_constraints.get("port") { let port_constraint: Constraint<TransportPort> = - serde_json::from_value(port.clone()).map_err(Error::ParseError)?; + serde_json::from_value(port.clone()).map_err(Error::Parse)?; if let Some(transport_port) = port_constraint.option() { let (port, obfuscation_settings) = match transport_port.protocol { TransportProtocol::Udp => (serde_json::json!(transport_port.port), None), @@ -116,8 +116,7 @@ pub(crate) async fn migrate(settings: &mut serde_json::Value) -> Result<Option<M let migration_data = if let Some(token) = settings.get("account_token").filter(|t| !t.is_null()) { - let token: AccountToken = - serde_json::from_value(token.clone()).map_err(Error::ParseError)?; + let token: AccountToken = serde_json::from_value(token.clone()).map_err(Error::Parse)?; let migration_data = if let Some(wg_data) = settings.get("wireguard").filter(|wg| !wg.is_null()) { Some(MigrationData { diff --git a/mullvad-daemon/src/tunnel.rs b/mullvad-daemon/src/tunnel.rs index e95850c4f4..18b8bf2c20 100644 --- a/mullvad-daemon/src/tunnel.rs +++ b/mullvad-daemon/src/tunnel.rs @@ -1,8 +1,6 @@ -use std::{ - future::Future, - pin::Pin, - sync::{Arc, Mutex}, -}; +use std::{future::Future, pin::Pin, sync::Arc}; + +use tokio::sync::Mutex; use mullvad_relay_selector::{RelaySelector, SelectedBridge, SelectedObfuscator, SelectedRelay}; use mullvad_types::{ @@ -32,7 +30,7 @@ pub enum Error { NoBridgeAvailable, #[error(display = "Failed to resolve hostname for custom relay")] - ResolveCustomHostnameError, + ResolveCustomHostname, } #[derive(Clone)] @@ -65,13 +63,13 @@ impl ParametersGenerator { } /// Sets the tunnel options to use when generating new tunnel parameters. - pub fn set_tunnel_options(&self, tunnel_options: &TunnelOptions) { - self.0.lock().unwrap().tunnel_options = tunnel_options.clone(); + pub async fn set_tunnel_options(&self, tunnel_options: &TunnelOptions) { + self.0.lock().await.tunnel_options = tunnel_options.clone(); } /// Gets the location associated with the last generated tunnel parameters. - pub fn get_last_location(&self) -> Option<GeoIpLocation> { - let inner = self.0.lock().unwrap(); + pub async fn get_last_location(&self) -> Option<GeoIpLocation> { + let inner = self.0.lock().await; let relays = inner.last_generated_relays.as_ref()?; @@ -131,7 +129,7 @@ impl InnerParametersGenerator { .to_tunnel_parameters(self.tunnel_options.clone(), None) .map_err(|e| { log::error!("Failed to resolve hostname for custom tunnel config: {}", e); - Error::ResolveCustomHostnameError + Error::ResolveCustomHostname }) } Ok((SelectedRelay::Normal(constraints), bridge, obfuscator)) => { @@ -246,13 +244,13 @@ impl TunnelParametersGenerator for ParametersGenerator { ) -> Pin<Box<dyn Future<Output = Result<TunnelParameters, ParameterGenerationError>>>> { let generator = self.0.clone(); Box::pin(async move { - let mut inner = generator.lock().unwrap(); + let mut inner = generator.lock().await; inner .generate(retry_attempt) .await .map_err(|error| match error { Error::NoBridgeAvailable => ParameterGenerationError::NoMatchingBridgeRelay, - Error::ResolveCustomHostnameError => { + Error::ResolveCustomHostname => { ParameterGenerationError::CustomTunnelHostResultionError } error => { diff --git a/mullvad-daemon/src/version_check.rs b/mullvad-daemon/src/version_check.rs index ffa5abc979..10dbf7bfb4 100644 --- a/mullvad-daemon/src/version_check.rs +++ b/mullvad-daemon/src/version_check.rs @@ -14,6 +14,7 @@ use std::{ future::Future, io, path::{Path, PathBuf}, + str::FromStr, time::Duration, }; use talpid_core::mpsc::Sender; @@ -297,10 +298,10 @@ impl VersionUpdater { if !*IS_DEV_BUILD { let stable_version = latest_stable .as_ref() - .and_then(|stable| ParsedAppVersion::from_str(stable)); + .and_then(|stable| ParsedAppVersion::from_str(stable).ok()); let beta_version = if show_beta { - ParsedAppVersion::from_str(latest_beta) + ParsedAppVersion::from_str(latest_beta).ok() } else { None }; @@ -339,10 +340,10 @@ impl VersionUpdater { let mut check_delay = next_delay(); let mut version_check = futures::future::Fuse::terminated(); - // If this is a dev build ,there's no need to pester the API for version checks. + // If this is a dev build, there's no need to pester the API for version checks. if *IS_DEV_BUILD { log::warn!("Not checking for updates because this is a development build"); - while let Some(_) = rx.next().await {} + while rx.next().await.is_some() {} return; } diff --git a/mullvad-relay-selector/src/lib.rs b/mullvad-relay-selector/src/lib.rs index 158086b1d4..e7b85fbc4c 100644 --- a/mullvad-relay-selector/src/lib.rs +++ b/mullvad-relay-selector/src/lib.rs @@ -384,7 +384,7 @@ impl RelaySelector { let mut relay_matcher = RelayMatcher { location: location.clone(), providers: providers.clone(), - ownership: ownership.clone(), + ownership: *ownership, tunnel: openvpn_constraints, }; @@ -492,7 +492,7 @@ impl RelaySelector { let mut entry_relay_matcher = RelayMatcher { location: location.clone(), providers: providers.clone(), - ownership: ownership.clone(), + ownership: *ownership, tunnel: wireguard_constraints.clone().into(), }; @@ -532,7 +532,7 @@ impl RelaySelector { .clone(), ..matcher.clone() } - .to_wireguard_matcher(); + .into_wireguard_matcher(); // Pick the entry relay first if its location constraint is a subset of the exit location. if relay_constraints.wireguard_constraints.use_multihop { @@ -746,7 +746,7 @@ impl RelaySelector { let bridge_constraints = InternalBridgeConstraints { location: settings.location.clone(), providers: settings.providers.clone(), - ownership: settings.ownership.clone(), + ownership: settings.ownership, // FIXME: This is temporary while talpid-core only supports TCP proxies transport_protocol: Constraint::Only(TransportProtocol::Tcp), }; @@ -791,7 +791,7 @@ impl RelaySelector { BridgeSettings::Normal(settings) => InternalBridgeConstraints { location: settings.location.clone(), providers: settings.providers.clone(), - ownership: settings.ownership.clone(), + ownership: settings.ownership, transport_protocol: Constraint::Only(TransportProtocol::Tcp), }, BridgeSettings::Custom(_bridge_settings) => InternalBridgeConstraints { @@ -1064,7 +1064,7 @@ impl RelaySelector { let addr_in = endpoint .as_ref() .map(|endpoint| endpoint.to_endpoint().address.ip()) - .unwrap_or(IpAddr::from(selected_relay.ipv4_addr_in)); + .unwrap_or_else(|| IpAddr::from(selected_relay.ipv4_addr_in)); log::info!("Selected relay {} at {}", selected_relay.hostname, addr_in); endpoint.map(|endpoint| NormalSelectedRelay::new(endpoint, selected_relay.clone())) }) diff --git a/mullvad-relay-selector/src/matcher.rs b/mullvad-relay-selector/src/matcher.rs index 13e16646ab..350a106745 100644 --- a/mullvad-relay-selector/src/matcher.rs +++ b/mullvad-relay-selector/src/matcher.rs @@ -34,7 +34,7 @@ impl From<RelayConstraints> for RelayMatcher<AnyTunnelMatcher> { } impl RelayMatcher<AnyTunnelMatcher> { - pub fn to_wireguard_matcher(self) -> RelayMatcher<WireguardMatcher> { + pub fn into_wireguard_matcher(self) -> RelayMatcher<WireguardMatcher> { RelayMatcher { tunnel: self.tunnel.wireguard, location: self.location, diff --git a/mullvad-relay-selector/src/updater.rs b/mullvad-relay-selector/src/updater.rs index 480be795fb..31299eea14 100644 --- a/mullvad-relay-selector/src/updater.rs +++ b/mullvad-relay-selector/src/updater.rs @@ -57,7 +57,7 @@ pub struct RelayListUpdater { } impl RelayListUpdater { - pub fn new( + pub fn spawn( selector: super::RelaySelector, api_handle: MullvadRestHandle, cache_dir: &Path, diff --git a/mullvad-setup/src/main.rs b/mullvad-setup/src/main.rs index b257d92080..9d353b78a6 100644 --- a/mullvad-setup/src/main.rs +++ b/mullvad-setup/src/main.rs @@ -2,7 +2,7 @@ use clap::{crate_authors, crate_description, crate_name, App}; use mullvad_api::{self, proxy::ApiConnectionMode}; use mullvad_management_interface::new_rpc_client; use mullvad_types::version::ParsedAppVersion; -use std::{path::PathBuf, process, time::Duration}; +use std::{path::PathBuf, process, str::FromStr, time::Duration}; use talpid_core::{ firewall::{self, Firewall}, future_retry::{constant_interval, retry_future_n}, @@ -133,7 +133,7 @@ async fn main() { async fn is_older_version(old_version: &str) -> Result<ExitStatus, Error> { let parsed_version = - ParsedAppVersion::from_str(old_version).ok_or(Error::ParseVersionStringError)?; + ParsedAppVersion::from_str(old_version).map_err(|_| Error::ParseVersionStringError)?; Ok(if parsed_version < *APP_VERSION { ExitStatus::Ok @@ -152,7 +152,7 @@ async fn prepare_restart() -> Result<(), Error> { async fn reset_firewall() -> Result<(), Error> { // Ensure that the daemon isn't running - if let Ok(_) = new_rpc_client().await { + if new_rpc_client().await.is_ok() { return Err(Error::DaemonIsRunning); } diff --git a/mullvad-types/src/relay_constraints.rs b/mullvad-types/src/relay_constraints.rs index 4ea0a8dfc5..ffbb317bd3 100644 --- a/mullvad-types/src/relay_constraints.rs +++ b/mullvad-types/src/relay_constraints.rs @@ -625,19 +625,15 @@ impl RelaySettingsUpdate { RelaySettingsUpdate::CustomTunnelEndpoint(endpoint) => { endpoint.endpoint().protocol == TransportProtocol::Tcp } - RelaySettingsUpdate::Normal(update) => { - if let Some(constraints) = &update.openvpn_constraints { - !matches!( - &constraints.port, - Constraint::Only(TransportPort { - protocol: TransportProtocol::Udp, - .. - }) - ) - } else { - true - } - } + RelaySettingsUpdate::Normal(update) => !matches!( + &update.openvpn_constraints, + Some(OpenVpnConstraints { + port: Constraint::Only(TransportPort { + protocol: TransportProtocol::Udp, + .. + }) + }) + ), } } } diff --git a/mullvad-types/src/version.rs b/mullvad-types/src/version.rs index 7daf50e74b..8040b181b0 100644 --- a/mullvad-types/src/version.rs +++ b/mullvad-types/src/version.rs @@ -2,7 +2,10 @@ use jnix::IntoJava; use regex::Regex; use serde::{Deserialize, Serialize}; -use std::cmp::{Ord, Ordering, PartialOrd}; +use std::{ + cmp::{Ord, Ordering, PartialOrd}, + str::FromStr, +}; lazy_static::lazy_static! { static ref STABLE_REGEX: Regex = Regex::new(r"^(\d{4})\.(\d+)$").unwrap(); @@ -44,30 +47,33 @@ pub enum ParsedAppVersion { Dev(u32, u32, Option<u32>, String), } -impl ParsedAppVersion { - pub fn from_str(version: &str) -> Option<Self> { +impl FromStr for ParsedAppVersion { + type Err = (); + fn from_str(version: &str) -> Result<Self, Self::Err> { let get_int = |cap: ®ex::Captures<'_>, idx| cap.get(idx)?.as_str().parse().ok(); if let Some(caps) = STABLE_REGEX.captures(version) { - let year = get_int(&caps, 1)?; - let version = get_int(&caps, 2)?; - Some(Self::Stable(year, version)) + let year = get_int(&caps, 1).ok_or(())?; + let version = get_int(&caps, 2).ok_or(())?; + Ok(Self::Stable(year, version)) } else if let Some(caps) = BETA_REGEX.captures(version) { - let year = get_int(&caps, 1)?; - let version = get_int(&caps, 2)?; - let beta_version = get_int(&caps, 3)?; - Some(Self::Beta(year, version, beta_version)) + let year = get_int(&caps, 1).ok_or(())?; + let version = get_int(&caps, 2).ok_or(())?; + let beta_version = get_int(&caps, 3).ok_or(())?; + Ok(Self::Beta(year, version, beta_version)) } else if let Some(caps) = DEV_REGEX.captures(version) { - let year = get_int(&caps, 1)?; - let version = get_int(&caps, 2)?; + let year = get_int(&caps, 1).ok_or(())?; + let version = get_int(&caps, 2).ok_or(())?; let beta_version = caps.get(4).map(|_| get_int(&caps, 5).unwrap()); - let dev_hash = caps.get(6)?.as_str().to_string(); - Some(Self::Dev(year, version, beta_version, dev_hash)) + let dev_hash = caps.get(6).ok_or(())?.as_str().to_string(); + Ok(Self::Dev(year, version, beta_version, dev_hash)) } else { - None + Err(()) } } +} +impl ParsedAppVersion { pub fn is_dev(&self) -> bool { matches!(self, ParsedAppVersion::Dev(..)) } @@ -191,7 +197,7 @@ mod test { ]; for (input, expected_output) in tests { - assert_eq!(ParsedAppVersion::from_str(input), expected_output,); + assert_eq!(ParsedAppVersion::from_str(input).ok(), expected_output,); } } } diff --git a/talpid-core/src/dns/linux/resolvconf.rs b/talpid-core/src/dns/linux/resolvconf.rs index ea2f3f3704..97db14b622 100644 --- a/talpid-core/src/dns/linux/resolvconf.rs +++ b/talpid-core/src/dns/linux/resolvconf.rs @@ -22,16 +22,16 @@ pub enum Error { RunResolvconf(#[error(source)] io::Error), #[error(display = "Using 'resolvconf' to add a record failed: {}", stderr)] - AddRecordError { stderr: String }, + AddRecord { stderr: String }, #[error(display = "Using 'resolvconf' to delete a record failed")] - DeleteRecordError, + DeleteRecord, #[error(display = "Detected dnsmasq is runing and misconfigured")] - DnsmasqMisconfigurationError, + DnsmasqMisconfiguration, #[error(display = "Current /etc/resolv.conf is not generated by resolvconf")] - ResolvconfNotInUseError, + ResolvconfNotInUse, } pub struct Resolvconf { @@ -50,15 +50,15 @@ impl Resolvconf { // Check if resolvconf is managing DNS by /etc/resolv.conf if !is_dnsmasq_running - && !(Self::check_if_resolvconf_is_symlinked_correctly() - || Self::check_if_resolvconf_was_generated()) + && !Self::check_if_resolvconf_is_symlinked_correctly() + && !Self::check_if_resolvconf_was_generated() { - return Err(Error::ResolvconfNotInUseError); + return Err(Error::ResolvconfNotInUse); } // Check if resolvconf can manage DNS via dnsmasq if is_dnsmasq_running && Self::is_dnsmasq_configured_wrong() { - return Err(Error::DnsmasqMisconfigurationError); + return Err(Error::DnsmasqMisconfiguration); } Ok(Resolvconf { @@ -94,7 +94,7 @@ impl Resolvconf { if !output.status.success() { let stderr = String::from_utf8_lossy(&output.stderr).to_string(); - return Err(Error::AddRecordError { stderr }); + return Err(Error::AddRecord { stderr }); } self.record_names.insert(record_name); @@ -118,7 +118,7 @@ impl Resolvconf { record_name, String::from_utf8_lossy(&output.stderr) ); - result = Err(Error::DeleteRecordError); + result = Err(Error::DeleteRecord); } } diff --git a/talpid-core/src/dns/linux/static_resolv_conf.rs b/talpid-core/src/dns/linux/static_resolv_conf.rs index 196fb31003..691d7b468b 100644 --- a/talpid-core/src/dns/linux/static_resolv_conf.rs +++ b/talpid-core/src/dns/linux/static_resolv_conf.rs @@ -28,7 +28,7 @@ pub enum Error { ReadResolvConf(&'static str, #[error(source)] io::Error), #[error(display = "resolv.conf at {} could not be parsed", _0)] - ParseError(&'static str, #[error(source)] resolv_conf::ParseError), + Parse(&'static str, #[error(source)] resolv_conf::ParseError), #[error(display = "Failed to remove stale resolv.conf backup at {}", _0)] RemoveBackup(&'static str, #[error(source)] io::Error), @@ -179,7 +179,7 @@ fn read_config() -> Result<Config> { let contents = fs::read_to_string(RESOLV_CONF_PATH) .map_err(|e| Error::ReadResolvConf(RESOLV_CONF_PATH, e))?; - let config = Config::parse(&contents).map_err(|e| Error::ParseError(RESOLV_CONF_PATH, e))?; + let config = Config::parse(&contents).map_err(|e| Error::Parse(RESOLV_CONF_PATH, e))?; Ok(config) } @@ -198,8 +198,8 @@ fn restore_from_backup() -> Result<()> { match fs::read_to_string(RESOLV_CONF_BACKUP_PATH) { Ok(backup) => { log::info!("Restoring DNS state from backup"); - let config = Config::parse(&backup) - .map_err(|e| Error::ParseError(RESOLV_CONF_BACKUP_PATH, e))?; + let config = + Config::parse(&backup).map_err(|e| Error::Parse(RESOLV_CONF_BACKUP_PATH, e))?; write_config(&config)?; diff --git a/talpid-core/src/mpsc.rs b/talpid-core/src/mpsc.rs index 8c6424bc01..6492796cfc 100644 --- a/talpid-core/src/mpsc.rs +++ b/talpid-core/src/mpsc.rs @@ -1,11 +1,20 @@ +/// Error type for `Sender` trait. +#[derive(err_derive::Error, Debug)] +pub enum Error { + /// The underlying channel is closed. + #[error(display = "Channel is closed")] + ChannelClosed, +} + /// Abstraction over any type that can be used similarly to an `std::mpsc::Sender`. pub trait Sender<T> { /// Sends an item over the underlying channel, failing only if the channel is closed. - fn send(&self, item: T) -> Result<(), ()>; + fn send(&self, item: T) -> Result<(), Error>; } impl<E> Sender<E> for futures::channel::mpsc::UnboundedSender<E> { - fn send(&self, content: E) -> Result<(), ()> { - self.unbounded_send(content).map_err(|_| ()) + fn send(&self, content: E) -> Result<(), Error> { + self.unbounded_send(content) + .map_err(|_| Error::ChannelClosed) } } diff --git a/talpid-core/src/ping_monitor/icmp.rs b/talpid-core/src/ping_monitor/icmp.rs index 67f5b70cb5..0bcd9da72f 100644 --- a/talpid-core/src/ping_monitor/icmp.rs +++ b/talpid-core/src/ping_monitor/icmp.rs @@ -183,7 +183,7 @@ fn construct_icmpv4_packet_inner( let checksum = internet_checksum::checksum(buffer); (&mut buffer[ICMP_CHECKSUM_OFFSET..]) - .write(&checksum) + .write_all(&checksum) .unwrap(); true diff --git a/talpid-core/src/routing/linux.rs b/talpid-core/src/routing/linux.rs index 092ad6f52a..4b039fe9eb 100644 --- a/talpid-core/src/routing/linux.rs +++ b/talpid-core/src/routing/linux.rs @@ -87,13 +87,13 @@ pub type Result<T> = std::result::Result<T, Error>; #[error(no_from)] pub enum Error { #[error(display = "Failed to open a netlink connection")] - ConnectError(#[error(source)] io::Error), + Connect(#[error(source)] io::Error), #[error(display = "Failed to bind netlink socket")] - BindError(#[error(source)] io::Error), + Bind(#[error(source)] io::Error), #[error(display = "Netlink error")] - NetlinkError(#[error(source)] rtnetlink::Error), + Netlink(#[error(source)] rtnetlink::Error), #[error(display = "Route without a valid node")] InvalidRoute, @@ -108,16 +108,16 @@ pub enum Error { UnknownDeviceIndex(u32), #[error(display = "Failed to get a route for the given IP address")] - GetRouteError(#[error(source)] rtnetlink::Error), + GetRoute(#[error(source)] rtnetlink::Error), #[error(display = "No netlink response for route query")] - NoRouteError, + NoRoute, #[error(display = "Route node was malformed")] InvalidRouteNode, #[error(display = "No link found")] - LinkNotFoundError, + LinkNotFound, /// Unable to create routing table for tagged connections and packets. #[error(display = "Cannot find a free routing table ID")] @@ -140,14 +140,11 @@ pub struct RouteManagerImpl { impl RouteManagerImpl { pub async fn new(required_routes: HashSet<RequiredRoute>) -> Result<Self> { let (mut connection, handle, messages) = - rtnetlink::new_connection().map_err(Error::ConnectError)?; + rtnetlink::new_connection().map_err(Error::Connect)?; let mgroup_flags = RTMGRP_IPV4_ROUTE | RTMGRP_IPV6_ROUTE | RTMGRP_LINK | RTMGRP_NOTIFY; let addr = SocketAddr::new(0, mgroup_flags); - connection - .socket_mut() - .bind(&addr) - .map_err(Error::BindError)?; + connection.socket_mut().bind(&addr).map_err(Error::Bind)?; tokio::spawn(connection); @@ -179,11 +176,11 @@ impl RouteManagerImpl { let mut req = NetlinkMessage::from(RtnlMessage::NewRule((*rule).clone())); req.header.flags = NLM_F_REQUEST | NLM_F_ACK | NLM_F_CREATE | NLM_F_REPLACE; - let mut response = self.handle.request(req).map_err(Error::NetlinkError)?; + let mut response = self.handle.request(req).map_err(Error::Netlink)?; while let Some(message) = response.next().await { if let NetlinkPayload::Error(error) = message.payload { - return Err(Error::NetlinkError(rtnetlink::Error::NetlinkError(error))); + return Err(Error::Netlink(rtnetlink::Error::NetlinkError(error))); } } } @@ -236,7 +233,7 @@ impl RouteManagerImpl { let mut req = NetlinkMessage::from(RtnlMessage::GetRule(RuleMessage::default())); req.header.flags = NLM_F_REQUEST | NLM_F_ACK | NLM_F_DUMP; - let mut response = self.handle.request(req).map_err(Error::NetlinkError)?; + let mut response = self.handle.request(req).map_err(Error::Netlink)?; let mut rules = vec![]; @@ -246,7 +243,7 @@ impl RouteManagerImpl { rules.push(rule); } NetlinkPayload::Error(error) => { - return Err(Error::NetlinkError(rtnetlink::Error::NetlinkError(error))); + return Err(Error::Netlink(rtnetlink::Error::NetlinkError(error))); } _ => (), } @@ -260,12 +257,12 @@ impl RouteManagerImpl { let mut req = NetlinkMessage::from(RtnlMessage::DelRule(rule)); req.header.flags = NLM_F_REQUEST | NLM_F_ACK; - let mut response = self.handle.request(req).map_err(Error::NetlinkError)?; + let mut response = self.handle.request(req).map_err(Error::Netlink)?; while let Some(message) = response.next().await { if let NetlinkPayload::Error(error) = message.payload { if error.to_io().kind() != io::ErrorKind::NotFound { - return Err(Error::NetlinkError(rtnetlink::Error::NetlinkError(error))); + return Err(Error::Netlink(rtnetlink::Error::NetlinkError(error))); } } } @@ -296,7 +293,7 @@ impl RouteManagerImpl { ) -> Result<BTreeMap<u32, NetworkInterface>> { let mut link_map = BTreeMap::new(); let mut link_request = handle.link().get().execute(); - while let Some(link) = link_request.try_next().await.map_err(Error::NetlinkError)? { + while let Some(link) = link_request.try_next().await.map_err(Error::Netlink)? { if let Some((idx, device)) = Self::map_interface(link) { link_map.insert(idx, device); } @@ -543,7 +540,7 @@ impl RouteManagerImpl { async fn delete_route_if_exists(&self, route: &Route) -> Result<()> { if let Err(error) = self.delete_route(route).await { - if let Error::NetlinkError(rtnetlink::Error::NetlinkError(msg)) = &error { + if let Error::Netlink(rtnetlink::Error::NetlinkError(msg)) = &error { if msg.code == -libc::ESRCH { return Ok(()); } @@ -619,7 +616,7 @@ impl RouteManagerImpl { .del(route_message) .execute() .await - .map_err(Error::NetlinkError) + .map_err(Error::Netlink) } async fn add_route_direct(&mut self, route: Route) -> Result<()> { @@ -693,11 +690,11 @@ impl RouteManagerImpl { let mut req = NetlinkMessage::from(RtnlMessage::NewRoute(add_message)); req.header.flags = NLM_F_REQUEST | NLM_F_ACK | NLM_F_CREATE | NLM_F_REPLACE; - let mut response = self.handle.request(req).map_err(Error::NetlinkError)?; + let mut response = self.handle.request(req).map_err(Error::Netlink)?; while let Some(message) = response.next().await { if let NetlinkPayload::Error(err) = message.payload { - return Err(Error::NetlinkError(rtnetlink::Error::NetlinkError(err))); + return Err(Error::Netlink(rtnetlink::Error::NetlinkError(err))); } } Ok(()) @@ -759,7 +756,7 @@ impl RouteManagerImpl { } None => { log::error!("No route detected when assigning the mtu to the Wireguard tunnel"); - return Err(Error::NoRouteError); + return Err(Error::NoRoute); } } } @@ -767,17 +764,13 @@ impl RouteManagerImpl { "Retried {} times looking for the correct device and could not find it", RECURSION_LIMIT ); - Err(Error::NoRouteError) + Err(Error::NoRoute) } async fn get_device_mtu(&self, device: String) -> Result<u16> { let mut links = self.handle.link().get().execute(); let target_device = LinkNla::IfName(device); - while let Some(msg) = links - .try_next() - .await - .map_err(|_| Error::LinkNotFoundError)? - { + while let Some(msg) = links.try_next().await.map_err(|_| Error::LinkNotFound)? { let found = msg.nlas.iter().any(|e| *e == target_device); if found { if let Some(LinkNla::Mtu(mtu)) = @@ -788,7 +781,7 @@ impl RouteManagerImpl { } } } - Err(Error::LinkNotFoundError) + Err(Error::LinkNotFound) } async fn get_destination_route( @@ -813,11 +806,11 @@ impl RouteManagerImpl { let mut stream = execute_route_get_request(self.handle.clone(), message.clone()); match stream.try_next().await { Ok(Some(route_msg)) => self.parse_route_message(route_msg), - Ok(None) => Err(Error::NoRouteError), + Ok(None) => Err(Error::NoRoute), Err(rtnetlink::Error::NetlinkError(nl_err)) if nl_err.code == -libc::ENETUNREACH => { Ok(None) } - Err(err) => Err(Error::GetRouteError(err)), + Err(err) => Err(Error::GetRoute(err)), } } } diff --git a/talpid-core/src/routing/unix.rs b/talpid-core/src/routing/unix.rs index edfbdd2b85..326fb1fad1 100644 --- a/talpid-core/src/routing/unix.rs +++ b/talpid-core/src/routing/unix.rs @@ -19,16 +19,19 @@ use futures::stream::Stream; #[cfg(target_os = "linux")] use std::net::IpAddr; +#[allow(clippy::module_inception)] #[cfg(target_os = "macos")] #[path = "macos.rs"] mod imp; #[cfg(target_os = "macos")] pub(crate) use imp::listen_for_default_route_changes; +#[allow(clippy::module_inception)] #[cfg(target_os = "linux")] #[path = "linux.rs"] mod imp; +#[allow(clippy::module_inception)] #[cfg(target_os = "android")] #[path = "android.rs"] mod imp; diff --git a/talpid-core/src/tunnel/mod.rs b/talpid-core/src/tunnel/mod.rs index 5da3c092a4..f6ada1c2cf 100644 --- a/talpid-core/src/tunnel/mod.rs +++ b/talpid-core/src/tunnel/mod.rs @@ -1,6 +1,6 @@ use self::tun_provider::TunProvider; use crate::{logging, routing::RouteManagerHandle}; -use futures::channel::oneshot; +use futures::{channel::oneshot, future::BoxFuture}; use std::{ net::{IpAddr, Ipv4Addr, Ipv6Addr}, path::{Path, PathBuf}, @@ -98,6 +98,20 @@ pub struct TunnelMonitor { monitor: InternalTunnelMonitor, } +/// Arguments for creating a tunnel. +pub struct TunnelArgs<'a, L> +where + // L: (Fn(TunnelEvent) -> std::pin::Pin<Box<dyn std::future::Future<Output = ()> + Send>>) + L: (Fn(TunnelEvent) -> BoxFuture<'static, ()>) + Send + Clone + Sync + 'static, +{ + /// Resource directory. + pub resource_dir: &'a Path, + /// Callback function called when an event happens. + pub on_event: L, + /// Receiver oneshot channel for closing the tunnel. + pub tunnel_close_rx: oneshot::Receiver<()>, +} + // TODO(emilsp) move most of the openvpn tunnel details to OpenVpnTunnelMonitor impl TunnelMonitor { /// Creates a new `TunnelMonitor` that connects to the given remote and notifies `on_event` @@ -107,12 +121,10 @@ impl TunnelMonitor { runtime: tokio::runtime::Handle, tunnel_parameters: &mut TunnelParameters, log_dir: &Option<PathBuf>, - resource_dir: &Path, - on_event: L, tun_provider: Arc<Mutex<TunProvider>>, - route_manager: RouteManagerHandle, retry_attempt: u32, - tunnel_close_rx: oneshot::Receiver<()>, + route_manager: RouteManagerHandle, + init_args: TunnelArgs<'_, L>, ) -> Result<Self> where L: (Fn(TunnelEvent) -> std::pin::Pin<Box<dyn std::future::Future<Output = ()> + Send>>) @@ -129,9 +141,9 @@ impl TunnelMonitor { TunnelParameters::OpenVpn(config) => runtime.block_on(Self::start_openvpn_tunnel( config, log_file, - resource_dir, - on_event, - tunnel_close_rx, + init_args.resource_dir, + init_args.on_event, + init_args.tunnel_close_rx, #[cfg(target_os = "linux")] route_manager, )), @@ -142,12 +154,10 @@ impl TunnelMonitor { runtime, config, log_file, - resource_dir, - on_event, tun_provider, - route_manager, retry_attempt, - tunnel_close_rx, + route_manager, + init_args, ), } } @@ -178,12 +188,10 @@ impl TunnelMonitor { runtime: tokio::runtime::Handle, params: &mut wireguard_types::TunnelParameters, log: Option<PathBuf>, - resource_dir: &Path, - on_event: L, tun_provider: Arc<Mutex<TunProvider>>, - route_manager: RouteManagerHandle, retry_attempt: u32, - tunnel_close_rx: oneshot::Receiver<()>, + route_manager: RouteManagerHandle, + init_args: TunnelArgs<'_, L>, ) -> Result<Self> where L: (Fn(TunnelEvent) -> std::pin::Pin<Box<dyn std::future::Future<Output = ()> + Send>>) @@ -211,12 +219,10 @@ impl TunnelMonitor { None }, log.as_deref(), - resource_dir, - on_event, tun_provider, - route_manager, retry_attempt, - tunnel_close_rx, + route_manager, + init_args, )?; Ok(TunnelMonitor { monitor: InternalTunnelMonitor::Wireguard(monitor), diff --git a/talpid-core/src/tunnel/openvpn/mod.rs b/talpid-core/src/tunnel/openvpn/mod.rs index 910d5bb49e..9fdfb3e80b 100644 --- a/talpid-core/src/tunnel/openvpn/mod.rs +++ b/talpid-core/src/tunnel/openvpn/mod.rs @@ -310,10 +310,19 @@ impl OpenVpnMonitor<OpenVpnCommand> { let (event_server_abort_tx, event_server_abort_rx) = triggered::trigger(); + let openvpn_init_args = OpenVpnTunnelInitArgs { + event_server_abort_tx: event_server_abort_tx.clone(), + event_server_abort_rx, + plugin_path, + log_path, + user_pass_file, + proxy_auth_file, + proxy_monitor, + tunnel_close_rx, + }; Self::new_internal( cmd, - event_server_abort_tx.clone(), - event_server_abort_rx, + openvpn_init_args, event_server::OpenvpnEventProxyImpl { on_event, user_pass_file_path: user_pass_file_path.clone(), @@ -324,12 +333,6 @@ impl OpenVpnMonitor<OpenVpnCommand> { #[cfg(target_os = "linux")] ipv6_enabled, }, - plugin_path, - log_path, - user_pass_file, - proxy_auth_file, - proxy_monitor, - tunnel_close_rx, #[cfg(windows)] Box::new(wintun), ) @@ -371,23 +374,36 @@ fn extract_routes(env: &HashMap<String, String>) -> Result<HashSet<RequiredRoute Ok(routes) } +struct OpenVpnTunnelInitArgs { + event_server_abort_tx: triggered::Trigger, + event_server_abort_rx: triggered::Listener, + plugin_path: PathBuf, + log_path: Option<PathBuf>, + user_pass_file: mktemp::TempFile, + proxy_auth_file: Option<mktemp::TempFile>, + proxy_monitor: Option<Box<dyn ProxyMonitor>>, + tunnel_close_rx: oneshot::Receiver<()>, +} + impl<C: OpenVpnBuilder + Send + 'static> OpenVpnMonitor<C> { async fn new_internal<L>( mut cmd: C, - event_server_abort_tx: triggered::Trigger, - event_server_abort_rx: triggered::Listener, + init_args: OpenVpnTunnelInitArgs, on_event: L, - plugin_path: PathBuf, - log_path: Option<PathBuf>, - user_pass_file: mktemp::TempFile, - proxy_auth_file: Option<mktemp::TempFile>, - proxy_monitor: Option<Box<dyn ProxyMonitor>>, - tunnel_close_rx: oneshot::Receiver<()>, #[cfg(windows)] wintun: Box<dyn WintunContext>, ) -> Result<OpenVpnMonitor<C>> where L: event_server::OpenvpnEventProxy + Send + Sync + 'static, { + let event_server_abort_tx = init_args.event_server_abort_tx; + let event_server_abort_rx = init_args.event_server_abort_rx; + let plugin_path = init_args.plugin_path; + let log_path = init_args.log_path; + let user_pass_file = init_args.user_pass_file; + let proxy_auth_file = init_args.proxy_auth_file; + let proxy_monitor = init_args.proxy_monitor; + let tunnel_close_rx = init_args.tunnel_close_rx; + let (server_join_handle, ipc_path) = event_server::start(on_event, event_server_abort_rx) .await .map_err(Error::EventDispatcherError)?; @@ -1220,23 +1236,37 @@ mod tests { .map_err(Error::RuntimeError) } + fn create_init_args_plugin_log( + plugin_path: PathBuf, + log_path: Option<PathBuf>, + ) -> OpenVpnTunnelInitArgs { + let (_close_tx, close_rx) = oneshot::channel(); + let (event_server_abort_tx, event_server_abort_rx) = triggered::trigger(); + OpenVpnTunnelInitArgs { + event_server_abort_tx, + event_server_abort_rx, + plugin_path, + log_path, + user_pass_file: TempFile::new(), + proxy_auth_file: None, + proxy_monitor: None, + tunnel_close_rx: close_rx, + } + } + + fn create_init_args() -> OpenVpnTunnelInitArgs { + create_init_args_plugin_log("".into(), None) + } + #[test] fn sets_plugin() { let builder = TestOpenVpnBuilder::default(); - let (event_server_abort_tx, event_server_abort_rx) = triggered::trigger(); - let (_close_tx, close_rx) = oneshot::channel(); let runtime = new_runtime().unwrap(); + let openvpn_init_args = create_init_args_plugin_log("./my_test_plugin".into(), None); let _ = runtime.block_on(OpenVpnMonitor::new_internal( builder.clone(), - event_server_abort_tx, - event_server_abort_rx, + openvpn_init_args, TestOpenvpnEventProxy {}, - "./my_test_plugin".into(), - None, - TempFile::new(), - None, - None, - close_rx, #[cfg(windows)] Box::new(TestWintunContext {}), )); @@ -1249,20 +1279,13 @@ mod tests { #[test] fn sets_log() { let builder = TestOpenVpnBuilder::default(); - let (event_server_abort_tx, event_server_abort_rx) = triggered::trigger(); - let (_close_tx, close_rx) = oneshot::channel(); let runtime = new_runtime().unwrap(); + let openvpn_init_args = + create_init_args_plugin_log("".into(), Some(PathBuf::from("./my_test_log_file"))); let _ = runtime.block_on(OpenVpnMonitor::new_internal( builder.clone(), - event_server_abort_tx, - event_server_abort_rx, + openvpn_init_args, TestOpenvpnEventProxy {}, - "".into(), - Some(PathBuf::from("./my_test_log_file")), - TempFile::new(), - None, - None, - close_rx, #[cfg(windows)] Box::new(TestWintunContext {}), )); @@ -1276,21 +1299,13 @@ mod tests { fn exit_successfully() { let mut builder = TestOpenVpnBuilder::default(); builder.process_handle = Some(TestProcessHandle(0)); - let (event_server_abort_tx, event_server_abort_rx) = triggered::trigger(); - let (_close_tx, close_rx) = oneshot::channel(); let runtime = new_runtime().unwrap(); + let openvpn_init_args = create_init_args(); let testee = runtime .block_on(OpenVpnMonitor::new_internal( builder, - event_server_abort_tx, - event_server_abort_rx, + openvpn_init_args, TestOpenvpnEventProxy {}, - "".into(), - None, - TempFile::new(), - None, - None, - close_rx, #[cfg(windows)] Box::new(TestWintunContext {}), )) @@ -1302,21 +1317,13 @@ mod tests { fn exit_error() { let mut builder = TestOpenVpnBuilder::default(); builder.process_handle = Some(TestProcessHandle(1)); - let (event_server_abort_tx, event_server_abort_rx) = triggered::trigger(); - let (_close_tx, close_rx) = oneshot::channel(); let runtime = new_runtime().unwrap(); + let openvpn_init_args = create_init_args(); let testee = runtime .block_on(OpenVpnMonitor::new_internal( builder, - event_server_abort_tx, - event_server_abort_rx, + openvpn_init_args, TestOpenvpnEventProxy {}, - "".into(), - None, - TempFile::new(), - None, - None, - close_rx, #[cfg(windows)] Box::new(TestWintunContext {}), )) @@ -1328,21 +1335,13 @@ mod tests { fn wait_closed() { let mut builder = TestOpenVpnBuilder::default(); builder.process_handle = Some(TestProcessHandle(1)); - let (event_server_abort_tx, event_server_abort_rx) = triggered::trigger(); - let (_close_tx, close_rx) = oneshot::channel(); let runtime = new_runtime().unwrap(); + let openvpn_init_args = create_init_args(); let testee = runtime .block_on(OpenVpnMonitor::new_internal( builder, - event_server_abort_tx, - event_server_abort_rx, + openvpn_init_args, TestOpenvpnEventProxy {}, - "".into(), - None, - TempFile::new(), - None, - None, - close_rx, #[cfg(windows)] Box::new(TestWintunContext {}), )) @@ -1354,21 +1353,13 @@ mod tests { #[test] fn failed_process_start() { let builder = TestOpenVpnBuilder::default(); - let (event_server_abort_tx, event_server_abort_rx) = triggered::trigger(); - let (_close_tx, close_rx) = oneshot::channel(); let runtime = new_runtime().unwrap(); + let openvpn_init_args = create_init_args(); let result = runtime .block_on(OpenVpnMonitor::new_internal( builder, - event_server_abort_tx, - event_server_abort_rx, + openvpn_init_args, TestOpenvpnEventProxy {}, - "".into(), - None, - TempFile::new(), - None, - None, - close_rx, #[cfg(windows)] Box::new(TestWintunContext {}), )) diff --git a/talpid-core/src/tunnel/tun_provider/unix.rs b/talpid-core/src/tunnel/tun_provider/unix.rs index d8d3b7ce01..5c48a3c663 100644 --- a/talpid-core/src/tunnel/tun_provider/unix.rs +++ b/talpid-core/src/tunnel/tun_provider/unix.rs @@ -22,6 +22,12 @@ pub enum Error { /// Factory of tunnel devices on Unix systems. pub struct UnixTunProvider; +impl Default for UnixTunProvider { + fn default() -> Self { + Self::new() + } +} + impl UnixTunProvider { pub fn new() -> Self { UnixTunProvider diff --git a/talpid-core/src/tunnel/wireguard/logging.rs b/talpid-core/src/tunnel/wireguard/logging.rs index 1a7a52ed8b..35ec10fc2f 100644 --- a/talpid-core/src/tunnel/wireguard/logging.rs +++ b/talpid-core/src/tunnel/wireguard/logging.rs @@ -112,7 +112,7 @@ pub unsafe extern "system" fn wg_go_logging_callback( let level = match level { WG_GO_LOG_VERBOSE => LogLevel::Verbose, - WG_GO_LOG_ERROR | _ => LogLevel::Error, + _ => LogLevel::Error, }; log_inner(logfile, level, "wireguard-go", &managed_msg); } @@ -121,5 +121,5 @@ pub unsafe extern "system" fn wg_go_logging_callback( pub type WgLogLevel = u32; // wireguard-go supports log levels 0 through 3 with 3 being the most verbose // const WG_GO_LOG_SILENT: WgLogLevel = 0; -const WG_GO_LOG_ERROR: WgLogLevel = 1; +// const WG_GO_LOG_ERROR: WgLogLevel = 1; const WG_GO_LOG_VERBOSE: WgLogLevel = 2; diff --git a/talpid-core/src/tunnel/wireguard/mod.rs b/talpid-core/src/tunnel/wireguard/mod.rs index 7f17726c33..e49286cb30 100644 --- a/talpid-core/src/tunnel/wireguard/mod.rs +++ b/talpid-core/src/tunnel/wireguard/mod.rs @@ -1,15 +1,11 @@ use self::config::Config; #[cfg(not(windows))] use super::tun_provider; -use super::{tun_provider::TunProvider, TunnelEvent, TunnelMetadata}; +use super::{tun_provider::TunProvider, TunnelArgs, TunnelEvent, TunnelMetadata}; use crate::routing::{self, RequiredRoute, RouteManagerHandle}; +use futures::future::{abortable, AbortHandle as FutureAbortHandle, BoxFuture, Future}; #[cfg(windows)] use futures::{channel::mpsc, StreamExt}; -use futures::{ - channel::oneshot, - future::{abortable, AbortHandle as FutureAbortHandle}, - Future, -}; #[cfg(target_os = "linux")] use lazy_static::lazy_static; #[cfg(target_os = "linux")] @@ -54,6 +50,7 @@ mod wireguard_nt; use self::wireguard_go::WgGoTunnel; type Result<T> = std::result::Result<T, Error>; +type EventCallback = Box<dyn (Fn(TunnelEvent) -> BoxFuture<'static, ()>) + Send + Sync + 'static>; /// Errors that can happen in the Wireguard tunnel monitor. #[derive(err_derive::Error, Debug)] @@ -104,12 +101,7 @@ pub struct WireguardMonitor { /// Tunnel implementation tunnel: Arc<Mutex<Option<Box<dyn Tunnel>>>>, /// Callback to signal tunnel events - event_callback: Box< - dyn (Fn(TunnelEvent) -> std::pin::Pin<Box<dyn std::future::Future<Output = ()> + Send>>) - + Send - + Sync - + 'static, - >, + event_callback: EventCallback, close_msg_receiver: sync_mpsc::Receiver<CloseMsg>, pinger_stop_sender: sync_mpsc::Sender<()>, _obfuscator: Option<ObfuscatorHandle>, @@ -208,13 +200,13 @@ impl WireguardMonitor { mut config: Config, psk_negotiation: Option<PublicKey>, log_path: Option<&Path>, - resource_dir: &Path, - on_event: F, tun_provider: Arc<Mutex<TunProvider>>, - route_manager: RouteManagerHandle, retry_attempt: u32, - tunnel_close_rx: oneshot::Receiver<()>, + route_manager: RouteManagerHandle, + init_args: TunnelArgs<'_, F>, ) -> Result<WireguardMonitor> { + let on_event = init_args.on_event; + let endpoint_addrs: Vec<IpAddr> = config.peers.iter().map(|peer| peer.endpoint.ip()).collect(); let (close_msg_sender, close_msg_receiver) = sync_mpsc::channel(); @@ -228,7 +220,7 @@ impl WireguardMonitor { runtime.clone(), &Self::patch_allowed_ips(&config, psk_negotiation.is_some()), log_path, - resource_dir, + init_args.resource_dir, tun_provider, #[cfg(target_os = "windows")] setup_done_tx, @@ -351,7 +343,7 @@ impl WireguardMonitor { }); tokio::spawn(async move { - if tunnel_close_rx.await.is_ok() { + if init_args.tunnel_close_rx.await.is_ok() { monitor_handle.abort(); let _ = close_msg_sender.send(CloseMsg::Stop); } diff --git a/talpid-core/src/tunnel/wireguard/stats.rs b/talpid-core/src/tunnel/wireguard/stats.rs index bda8af2e1f..cec033f611 100644 --- a/talpid-core/src/tunnel/wireguard/stats.rs +++ b/talpid-core/src/tunnel/wireguard/stats.rs @@ -4,10 +4,10 @@ use super::wireguard_kernel::wg_message::{DeviceMessage, DeviceNla, PeerNla}; #[derive(err_derive::Error, Debug, PartialEq)] pub enum Error { #[error(display = "Failed to parse peer pubkey from string \"_0\"")] - PubKeyParseError(String, #[error(source)] hex::FromHexError), + PubKeyParse(String, #[error(source)] hex::FromHexError), #[error(display = "Failed to parse integer from string \"_0\"")] - IntParseError(String, #[error(source)] std::num::ParseIntError), + IntParse(String, #[error(source)] std::num::ParseIntError), #[error(display = "Device no longer exists")] NoTunnelDevice, @@ -47,7 +47,7 @@ impl Stats { "public_key" => { let mut buffer = [0u8; 32]; hex::decode_to_slice(value, &mut buffer) - .map_err(|err| Error::PubKeyParseError(value.to_string(), err))?; + .map_err(|err| Error::PubKeyParse(value.to_string(), err))?; peer = Some(buffer); tx_bytes = None; rx_bytes = None; @@ -57,7 +57,7 @@ impl Stats { value .trim() .parse() - .map_err(|err| Error::IntParseError(value.to_string(), err))?, + .map_err(|err| Error::IntParse(value.to_string(), err))?, ); } "tx_bytes" => { @@ -65,7 +65,7 @@ impl Stats { value .trim() .parse() - .map_err(|err| Error::IntParseError(value.to_string(), err))?, + .map_err(|err| Error::IntParse(value.to_string(), err))?, ); } @@ -145,7 +145,7 @@ mod test { assert_eq!( Stats::parse_config_str(invalid_input), - Err(Error::IntParseError(invalid_str, int_err)) + Err(Error::IntParse(invalid_str, int_err)) ); } } diff --git a/talpid-core/src/tunnel/wireguard/wireguard_kernel/mod.rs b/talpid-core/src/tunnel/wireguard/wireguard_kernel/mod.rs index 0f3866500e..5b7b6a1e12 100644 --- a/talpid-core/src/tunnel/wireguard/wireguard_kernel/mod.rs +++ b/talpid-core/src/tunnel/wireguard/wireguard_kernel/mod.rs @@ -33,16 +33,16 @@ pub use nm_tunnel::NetworkManagerTunnel; #[error(no_from)] pub enum Error { #[error(display = "Failed to decode netlink message")] - DecodeError(#[error(source)] DecodeError), + Decode(#[error(source)] DecodeError), #[error(display = "Failed to execute netlink control request")] - NetlinkControlMessageError(#[error(source)] nl_message::Error), + NetlinkControlMessage(#[error(source)] nl_message::Error), #[error(display = "Failed to open netlink socket")] - NetlinkSocketError(#[error(source)] std::io::Error), + NetlinkSocket(#[error(source)] std::io::Error), #[error(display = "Failed to send netlink control request")] - NetlinkRequestError(#[error(source)] netlink_proto::Error<NetlinkControlMessage>), + NetlinkRequest(#[error(source)] netlink_proto::Error<NetlinkControlMessage>), #[error(display = "WireGuard netlink interface unavailable. Is the kernel module loaded?")] WireguardNetlinkInterfaceUnavailable, @@ -60,25 +60,25 @@ pub enum Error { NoDevice, #[error(display = "Failed to get config: _0")] - WgGetConfError(netlink_packet_core::error::ErrorMessage), + WgGetConf(netlink_packet_core::error::ErrorMessage), #[error(display = "Failed to apply config: _0")] - WgSetConfError(netlink_packet_core::error::ErrorMessage), + WgSetConf(netlink_packet_core::error::ErrorMessage), #[error(display = "Interface name too long")] - InterfaceNameError, + InterfaceName, #[error(display = "Send request error")] - SendRequestError(#[error(source)] NetlinkError<DeviceMessage>), + SendRequest(#[error(source)] NetlinkError<DeviceMessage>), #[error(display = "Create device error")] - NetlinkCreateDeviceError(#[error(source)] rtnetlink::Error), + NetlinkCreateDevice(#[error(source)] rtnetlink::Error), #[error(display = "Add IP to device error")] - NetlinkSetIpError(rtnetlink::Error), + NetlinkSetIp(rtnetlink::Error), #[error(display = "Failed to delete device")] - DeleteDeviceError(#[error(source)] rtnetlink::Error), + DeleteDevice(#[error(source)] rtnetlink::Error), #[error(display = "NetworkManager error")] NetworkManager(#[error(source)] nm_tunnel::Error), @@ -98,7 +98,7 @@ impl Handle { pub async fn connect() -> Result<Self, Error> { let message_type = Self::get_wireguard_message_type().await?; let (conn, wireguard_connection, _messages) = - netlink_proto::new_connection(NETLINK_GENERIC).map_err(Error::NetlinkSocketError)?; + netlink_proto::new_connection(NETLINK_GENERIC).map_err(Error::NetlinkSocket)?; let wg_handle = WireguardConnection { message_type, connection: wireguard_connection, @@ -106,7 +106,7 @@ impl Handle { let (abortable_connection, wg_abort_handle) = abortable(conn); tokio::spawn(abortable_connection); let (conn, route_handle, _messages) = - rtnetlink::new_connection().map_err(Error::NetlinkSocketError)?; + rtnetlink::new_connection().map_err(Error::NetlinkSocket)?; let (abortable_connection, route_abort_handle) = abortable(conn); tokio::spawn(abortable_connection); @@ -120,21 +120,21 @@ impl Handle { async fn get_wireguard_message_type() -> Result<u16, Error> { let (conn, mut handle, _messages) = - netlink_proto::new_connection(NETLINK_GENERIC).map_err(Error::NetlinkSocketError)?; + netlink_proto::new_connection(NETLINK_GENERIC).map_err(Error::NetlinkSocket)?; let (conn, abort_handle) = abortable(conn); tokio::spawn(conn); let result = async move { let mut message: NetlinkMessage<NetlinkControlMessage> = NetlinkControlMessage::get_netlink_family_id(CString::new("wireguard").unwrap()) - .map_err(Error::NetlinkControlMessageError)? + .map_err(Error::NetlinkControlMessage)? .into(); message.header.flags = NLM_F_REQUEST | NLM_F_ACK; let mut req = handle .request(message, SocketAddr::new(0, 0)) - .map_err(Error::NetlinkRequestError)?; + .map_err(Error::NetlinkRequest)?; let response = req.next().await; if let Some(response) = response { if let NetlinkPayload::InnerMessage(msg) = response.payload { @@ -177,14 +177,14 @@ impl Handle { let mut response = self .route_handle .request(add_request) - .map_err(Error::NetlinkCreateDeviceError)?; + .map_err(Error::NetlinkCreateDevice)?; while let Some(response_message) = response.next().await { if let NetlinkPayload::Error(err) = response_message.payload { // if the device exists, verify that it's a wireguard device if -err.code != libc::EEXIST { - return Err(Error::NetlinkCreateDeviceError( - rtnetlink::Error::NetlinkError(err), - )); + return Err(Error::NetlinkCreateDevice(rtnetlink::Error::NetlinkError( + err, + ))); } } } @@ -208,9 +208,9 @@ impl Handle { let mut response = self .route_handle .request(request) - .map_err(Error::NetlinkSetIpError)?; + .map_err(Error::NetlinkSetIp)?; while let Some(response_message) = response.next().await { - consume_netlink_error(response_message, Error::NetlinkSetIpError)?; + consume_netlink_error(response_message, Error::NetlinkSetIp)?; } Ok(()) @@ -226,9 +226,9 @@ impl Handle { let mut response = self .route_handle .request(request) - .map_err(Error::DeleteDeviceError)?; + .map_err(Error::DeleteDevice)?; while let Some(message) = response.next().await { - consume_netlink_error(message, Error::DeleteDeviceError)?; + consume_netlink_error(message, Error::DeleteDevice)?; } Ok(()) @@ -269,7 +269,7 @@ impl WireguardConnection { let mut response = self .connection .request(netlink_message, SocketAddr::new(0, 0)) - .map_err(Error::SendRequestError)?; + .map_err(Error::SendRequest)?; match response.next().await { Some(received_message) => match received_message.payload { NetlinkPayload::InnerMessage(inner) => Ok(inner), @@ -277,7 +277,7 @@ impl WireguardConnection { if err.code == -libc::ENODEV { Err(Error::NoDevice) } else { - Err(Error::WgGetConfError(err)) + Err(Error::WgGetConf(err)) } } anything_else => { @@ -297,11 +297,11 @@ impl WireguardConnection { let mut request = self .connection .request(netlink_message, SocketAddr::new(0, 0)) - .map_err(Error::SendRequestError)?; + .map_err(Error::SendRequest)?; while let Some(response) = request.next().await { if let NetlinkPayload::Error(err) = response.payload { - return Err(Error::WgSetConfError(err)); + return Err(Error::WgSetConf(err)); } } Ok(()) diff --git a/talpid-core/src/tunnel/wireguard/wireguard_kernel/wg_message.rs b/talpid-core/src/tunnel/wireguard/wireguard_kernel/wg_message.rs index be2231f771..f2de334762 100644 --- a/talpid-core/src/tunnel/wireguard/wireguard_kernel/wg_message.rs +++ b/talpid-core/src/tunnel/wireguard/wireguard_kernel/wg_message.rs @@ -110,9 +110,9 @@ impl DeviceMessage { } pub fn get_by_name(message_type: u16, name: String) -> Result<Self, Error> { - let c_name = CString::new(name).map_err(|_| Error::InterfaceNameError)?; + let c_name = CString::new(name).map_err(|_| Error::InterfaceName)?; if c_name.as_bytes_with_nul().len() > libc::IFNAMSIZ { - return Err(Error::InterfaceNameError); + return Err(Error::InterfaceName); } Ok(Self { @@ -178,9 +178,7 @@ impl NetlinkDeserializable<DeviceMessage> for DeviceMessage { let new_payload = &payload[mem::size_of::<libc::genlmsghdr>()..]; let mut nlas = vec![]; for buf in NlasIterator::new(new_payload) { - nlas.push( - DeviceNla::parse(&buf.map_err(Error::DecodeError)?).map_err(Error::DecodeError)?, - ); + nlas.push(DeviceNla::parse(&buf.map_err(Error::Decode)?).map_err(Error::Decode)?); } Ok(DeviceMessage { @@ -391,13 +389,13 @@ impl Nla for PeerNla { InetAddr::V4(sockaddr_in) => { // SAFETY: `sockaddr_in` has no padding bytes buffer - .write(unsafe { struct_as_slice(sockaddr_in) }) + .write_all(unsafe { struct_as_slice(sockaddr_in) }) .expect("Buffer too small for sockaddr_in"); } InetAddr::V6(sockaddr_in6) => { // SAFETY: `sockaddr_in` has no padding bytes buffer - .write(unsafe { struct_as_slice(sockaddr_in6) }) + .write_all(unsafe { struct_as_slice(sockaddr_in6) }) .expect("Buffer too small for sockaddr_in6"); } }, @@ -408,7 +406,7 @@ impl Nla for PeerNla { let timespec: &libc::timespec = last_handshake.as_ref(); // SAFETY: `timespec` has no padding bytes buffer - .write(unsafe { struct_as_slice(timespec) }) + .write_all(unsafe { struct_as_slice(timespec) }) .expect("Buffer too small for timespec"); } RxBytes(num_bytes) | TxBytes(num_bytes) => NativeEndian::write_u64(buffer, *num_bytes), @@ -535,7 +533,7 @@ impl Nla for AllowedIpNla { } IpAddr(ip_addr) => { buffer - .write(&ip_addr_to_bytes(ip_addr)) + .write_all(&ip_addr_to_bytes(ip_addr)) .expect("Buffer too small for AllowedIpNla::IpAddr"); } CidrMask(cidr_mask) => buffer[0] = *cidr_mask, diff --git a/talpid-core/src/tunnel_state_machine/connecting_state.rs b/talpid-core/src/tunnel_state_machine/connecting_state.rs index 7536b26b09..e787729c04 100644 --- a/talpid-core/src/tunnel_state_machine/connecting_state.rs +++ b/talpid-core/src/tunnel_state_machine/connecting_state.rs @@ -6,7 +6,9 @@ use super::{ use crate::{ firewall::FirewallPolicy, routing::RouteManager, - tunnel::{self, tun_provider::TunProvider, TunnelEvent, TunnelMetadata, TunnelMonitor}, + tunnel::{ + self, tun_provider::TunProvider, TunnelArgs, TunnelEvent, TunnelMetadata, TunnelMonitor, + }, }; use cfg_if::cfg_if; use futures::{ @@ -142,16 +144,20 @@ impl ConnectingState { } }; + let init_args = TunnelArgs { + resource_dir: &resource_dir, + on_event: on_tunnel_event, + tunnel_close_rx, + }; + let block_reason = match TunnelMonitor::start( runtime, &mut tunnel_parameters, &log_dir, - &resource_dir, - on_tunnel_event, tun_provider, - route_manager_handle, retry_attempt, - tunnel_close_rx, + route_manager_handle, + init_args, ) { Ok(monitor) => { let reason = Self::wait_for_tunnel_monitor(monitor, retry_attempt); diff --git a/talpid-core/src/tunnel_state_machine/mod.rs b/talpid-core/src/tunnel_state_machine/mod.rs index 3552eeab61..4c3eda0ecb 100644 --- a/talpid-core/src/tunnel_state_machine/mod.rs +++ b/talpid-core/src/tunnel_state_machine/mod.rs @@ -132,23 +132,25 @@ pub async fn spawn( let (shutdown_tx, shutdown_rx) = oneshot::channel(); let weak_command_tx = Arc::downgrade(&command_tx); - let state_machine = TunnelStateMachine::new( - initial_settings, - weak_command_tx, - offline_state_listener, + + let init_args = TunnelStateMachineInitArgs { + settings: initial_settings, + command_tx: weak_command_tx, + offline_state_tx: offline_state_listener, tunnel_parameters_generator, tun_provider, log_dir, resource_dir, - command_rx, + commands_rx: command_rx, #[cfg(target_os = "windows")] volume_update_rx, #[cfg(target_os = "macos")] exclusion_gid, #[cfg(target_os = "android")] android_context, - ) - .await?; + }; + + let state_machine = TunnelStateMachine::new(init_args).await?; #[cfg(windows)] let split_tunnel = state_machine.shared_values.split_tunnel.handle(); @@ -219,20 +221,35 @@ struct TunnelStateMachine { shared_values: SharedTunnelStateValues, } +/// Tunnel state machine initialization arguments arguments +struct TunnelStateMachineInitArgs<G: TunnelParametersGenerator> { + settings: InitialTunnelState, + command_tx: std::sync::Weak<mpsc::UnboundedSender<TunnelCommand>>, + offline_state_tx: mpsc::UnboundedSender<bool>, + tunnel_parameters_generator: G, + tun_provider: TunProvider, + log_dir: Option<PathBuf>, + resource_dir: PathBuf, + commands_rx: mpsc::UnboundedReceiver<TunnelCommand>, + #[cfg(target_os = "windows")] + volume_update_rx: mpsc::UnboundedReceiver<()>, + #[cfg(target_os = "macos")] + exclusion_gid: u32, + #[cfg(target_os = "android")] + android_context: AndroidContext, +} + impl TunnelStateMachine { async fn new( - settings: InitialTunnelState, - command_tx: std::sync::Weak<mpsc::UnboundedSender<TunnelCommand>>, - offline_state_tx: mpsc::UnboundedSender<bool>, - tunnel_parameters_generator: impl TunnelParametersGenerator, - tun_provider: TunProvider, - log_dir: Option<PathBuf>, - resource_dir: PathBuf, - commands_rx: mpsc::UnboundedReceiver<TunnelCommand>, - #[cfg(target_os = "windows")] volume_update_rx: mpsc::UnboundedReceiver<()>, - #[cfg(target_os = "macos")] exclusion_gid: u32, - #[cfg(target_os = "android")] android_context: AndroidContext, + args: TunnelStateMachineInitArgs<impl TunnelParametersGenerator>, ) -> Result<Self, Error> { + #[cfg(target_os = "windows")] + let volume_update_rx = args.volume_update_rx; + #[cfg(target_os = "macos")] + let exclusion_gid = args.exclusion_gid; + #[cfg(target_os = "android")] + let android_context = args.android_context; + let runtime = tokio::runtime::Handle::current(); #[cfg(target_os = "macos")] @@ -242,20 +259,24 @@ impl TunnelStateMachine { let power_mgmt_rx = crate::windows::window::PowerManagementListener::new(); #[cfg(windows)] - let split_tunnel = - split_tunnel::SplitTunnel::new(runtime.clone(), command_tx.clone(), volume_update_rx) - .map_err(Error::InitSplitTunneling)?; + let split_tunnel = split_tunnel::SplitTunnel::new( + runtime.clone(), + args.command_tx.clone(), + volume_update_rx, + ) + .map_err(Error::InitSplitTunneling)?; - let args = FirewallArguments { - initial_state: if settings.block_when_disconnected || !settings.reset_firewall { - InitialFirewallState::Blocked(settings.allowed_endpoint.clone()) + let fw_args = FirewallArguments { + initial_state: if args.settings.block_when_disconnected || !args.settings.reset_firewall + { + InitialFirewallState::Blocked(args.settings.allowed_endpoint.clone()) } else { InitialFirewallState::None }, - allow_lan: settings.allow_lan, + allow_lan: args.settings.allow_lan, }; - let firewall = Firewall::from_args(args).map_err(Error::InitFirewallError)?; + let firewall = Firewall::from_args(fw_args).map_err(Error::InitFirewallError)?; let route_manager = RouteManager::new(HashSet::new()) .await .map_err(Error::InitRouteManagerError)?; @@ -267,20 +288,20 @@ impl TunnelStateMachine { .handle() .map_err(Error::InitRouteManagerError)?, #[cfg(target_os = "macos")] - command_tx.clone(), + args.command_tx.clone(), ) .map_err(Error::InitDnsMonitorError)?; let (offline_tx, mut offline_rx) = mpsc::unbounded(); - let initial_offline_state_tx = offline_state_tx.clone(); + let initial_offline_state_tx = args.offline_state_tx.clone(); tokio::spawn(async move { while let Some(offline) = offline_rx.next().await { - if let Some(tx) = command_tx.upgrade() { + if let Some(tx) = args.command_tx.upgrade() { let _ = tx.unbounded_send(TunnelCommand::IsOffline(offline)); } else { break; } - let _ = offline_state_tx.unbounded_send(offline); + let _ = args.offline_state_tx.unbounded_send(offline); } }); let mut offline_monitor = offline::spawn_monitor( @@ -301,7 +322,7 @@ impl TunnelStateMachine { #[cfg(windows)] split_tunnel - .set_paths_sync(&settings.exclude_paths) + .set_paths_sync(&args.settings.exclude_paths) .map_err(Error::InitSplitTunneling)?; let mut shared_values = SharedTunnelStateValues { @@ -312,15 +333,15 @@ impl TunnelStateMachine { dns_monitor, route_manager, _offline_monitor: offline_monitor, - allow_lan: settings.allow_lan, - block_when_disconnected: settings.block_when_disconnected, + allow_lan: args.settings.allow_lan, + block_when_disconnected: args.settings.block_when_disconnected, is_offline, - dns_servers: settings.dns_servers, - allowed_endpoint: settings.allowed_endpoint, - tunnel_parameters_generator: Box::new(tunnel_parameters_generator), - tun_provider: Arc::new(Mutex::new(tun_provider)), - log_dir, - resource_dir, + dns_servers: args.settings.dns_servers, + allowed_endpoint: args.settings.allowed_endpoint, + tunnel_parameters_generator: Box::new(args.tunnel_parameters_generator), + tun_provider: Arc::new(Mutex::new(args.tun_provider)), + log_dir: args.log_dir, + resource_dir: args.resource_dir, #[cfg(target_os = "linux")] connectivity_check_was_enabled: None, #[cfg(target_os = "macos")] @@ -331,11 +352,11 @@ impl TunnelStateMachine { tokio::task::spawn_blocking(move || { let (initial_state, _) = - DisconnectedState::enter(&mut shared_values, settings.reset_firewall); + DisconnectedState::enter(&mut shared_values, args.settings.reset_firewall); Ok(TunnelStateMachine { current_state: Some(initial_state), - commands: commands_rx.fuse(), + commands: args.commands_rx.fuse(), shared_values, }) }) diff --git a/talpid-dbus/src/network_manager.rs b/talpid-dbus/src/network_manager.rs index 87566891d1..5d89a69035 100644 --- a/talpid-dbus/src/network_manager.rs +++ b/talpid-dbus/src/network_manager.rs @@ -59,6 +59,7 @@ const MAXIMUM_SUPPORTED_MINOR_VERSION: u32 = 26; const NM_DEVICE_STATE_CHANGED: &str = "StateChanged"; pub type Result<T> = std::result::Result<T, Error>; +type NetworkSettings<'a> = HashMap<String, HashMap<String, Variant<Box<dyn RefArg + 'a>>>>; #[derive(err_derive::Error, Debug)] pub enum Error { @@ -447,10 +448,8 @@ impl NetworkManager { let device = self.as_path(&device_path); // Get the last applied connection - let (mut settings, version_id): ( - HashMap<String, HashMap<String, Variant<Box<dyn RefArg>>>>, - u64, - ) = device.method_call(NM_DEVICE, "GetAppliedConnection", (0u32,))?; + let (mut settings, version_id): (NetworkSettings, u64) = + device.method_call(NM_DEVICE, "GetAppliedConnection", (0u32,))?; // Keep changed routes. // These routes were modified outside NM, likely by RouteManager. @@ -576,7 +575,7 @@ impl NetworkManager { } fn update_dns_config<'a, T>( - settings: &mut HashMap<String, HashMap<String, Variant<Box<dyn RefArg + 'a>>>>, + settings: &mut NetworkSettings<'a>, ip_protocol: &'static str, servers: T, ) where diff --git a/talpid-dbus/src/systemd_resolved.rs b/talpid-dbus/src/systemd_resolved.rs index d664872557..6ba8603508 100644 --- a/talpid-dbus/src/systemd_resolved.rs +++ b/talpid-dbus/src/systemd_resolved.rs @@ -349,7 +349,7 @@ impl SystemdResolved { .map_err(Error::DBusRpcError) } - fn link_disable_dns_over_tls<'a, 'b: 'a>(&'a self, interface_index: u32) -> Result<()> { + fn link_disable_dns_over_tls(&self, interface_index: u32) -> Result<()> { let link_object_path = self .fetch_link(interface_index) .map_err(|e| Error::GetLinkError(Box::new(e)))?; diff --git a/talpid-types/src/net/wireguard.rs b/talpid-types/src/net/wireguard.rs index ff7ebef090..6a8bbfd521 100644 --- a/talpid-types/src/net/wireguard.rs +++ b/talpid-types/src/net/wireguard.rs @@ -95,6 +95,7 @@ fn default_wgnt_setting() -> bool { true } +#[allow(clippy::derivable_impls)] impl Default for TunnelOptions { fn default() -> Self { Self { |
