diff options
| author | Jonathan <jonathan@mullvad.net> | 2022-06-13 10:49:46 +0200 |
|---|---|---|
| committer | Jonathan <jonathan@mullvad.net> | 2022-06-21 14:31:40 +0200 |
| commit | d3da8745c8ff9e66d6698d8a239b8139dbe8abfe (patch) | |
| tree | 528a15026535b01bc3324892be783797aa64bbe4 | |
| parent | b6b80b9ffe6521a78ea6b2cdfd0e6965e67479fd (diff) | |
| download | mullvadvpn-d3da8745c8ff9e66d6698d8a239b8139dbe8abfe.tar.xz mullvadvpn-d3da8745c8ff9e66d6698d8a239b8139dbe8abfe.zip | |
Fix the large majority of clippy warnings
This commit fixes most of the remaining clippy warnings in the codebase.
These warnings were the more semantically difficult ones to fix.
There are some warnings that remain from the rebase that will be fixed
in the upcoming PR.
46 files changed, 450 insertions, 423 deletions
diff --git a/android/translations-converter/src/android/string_value.rs b/android/translations-converter/src/android/string_value.rs index 04ad23653e..f11453d020 100644 --- a/android/translations-converter/src/android/string_value.rs +++ b/android/translations-converter/src/android/string_value.rs @@ -84,13 +84,6 @@ impl StringValue { } } -impl StringValue { - /// Clones the internal string value. - pub fn to_string(&self) -> String { - self.0.clone() - } -} - impl Deref for StringValue { type Target = str; diff --git a/clippy.toml b/clippy.toml new file mode 100644 index 0000000000..c2ffede55a --- /dev/null +++ b/clippy.toml @@ -0,0 +1 @@ +enum-variant-size-threshold = 1000 diff --git a/mullvad-api/src/address_cache.rs b/mullvad-api/src/address_cache.rs index a8b1a91a44..d37d429e89 100644 --- a/mullvad-api/src/address_cache.rs +++ b/mullvad-api/src/address_cache.rs @@ -10,19 +10,16 @@ use tokio::{ #[error(no_from)] pub enum Error { #[error(display = "Failed to open the address cache file")] - OpenAddressCache(#[error(source)] io::Error), + Open(#[error(source)] io::Error), #[error(display = "Failed to read the address cache file")] - ReadAddressCache(#[error(source)] io::Error), + Read(#[error(source)] io::Error), #[error(display = "Failed to parse the address cache file")] - ParseAddressCache, + Parse, #[error(display = "Failed to update the address cache file")] - WriteAddressCache(#[error(source)] io::Error), - - #[error(display = "The address cache is empty")] - EmptyAddressCache, + Write(#[error(source)] io::Error), } #[derive(Clone)] @@ -68,7 +65,7 @@ impl AddressCache { self.inner.lock().await.address } - pub async fn set_address(&self, address: SocketAddr) -> io::Result<()> { + pub async fn set_address(&self, address: SocketAddr) -> Result<(), Error> { let mut inner = self.inner.lock().await; if address != inner.address { self.save_to_disk(&address).await?; @@ -77,17 +74,21 @@ impl AddressCache { Ok(()) } - async fn save_to_disk(&self, address: &SocketAddr) -> io::Result<()> { + async fn save_to_disk(&self, address: &SocketAddr) -> Result<(), Error> { let write_path = match self.write_path.as_ref() { Some(write_path) => write_path, None => return Ok(()), }; - let mut file = crate::fs::AtomicFile::new(write_path.to_path_buf()).await?; + let mut file = crate::fs::AtomicFile::new(write_path.to_path_buf()) + .await + .map_err(Error::Open)?; let mut contents = address.to_string(); contents += "\n"; - file.write_all(contents.as_bytes()).await?; - file.finalize().await + file.write_all(contents.as_bytes()) + .await + .map_err(Error::Write)?; + file.finalize().await.map_err(Error::Write) } } @@ -103,12 +104,10 @@ impl AddressCacheInner { } async fn read_address_file(path: &Path) -> Result<SocketAddr, Error> { - let mut file = fs::File::open(path) - .await - .map_err(Error::OpenAddressCache)?; + let mut file = fs::File::open(path).await.map_err(Error::Open)?; let mut address = String::new(); file.read_to_string(&mut address) .await - .map_err(Error::ReadAddressCache)?; - address.trim().parse().map_err(|_| Error::ParseAddressCache) + .map_err(Error::Read)?; + address.trim().parse().map_err(|_| Error::Parse) } diff --git a/mullvad-api/src/https_client_with_sni.rs b/mullvad-api/src/https_client_with_sni.rs index 2920cf989d..0d85ddd790 100644 --- a/mullvad-api/src/https_client_with_sni.rs +++ b/mullvad-api/src/https_client_with_sni.rs @@ -304,7 +304,7 @@ impl Service<Uri> for HttpsConnectorWithSni { ) .await?; let tls_stream = TlsStream::connect_https(socket, &hostname).await?; - Ok::<_, io::Error>(ApiConnection::Direct(tls_stream)) + Ok::<_, io::Error>(ApiConnection::Direct(Box::new(tls_stream))) } InnerConnectionMode::Proxied(proxy_config) => { let socket = Self::open_socket( @@ -320,7 +320,7 @@ impl Service<Uri> for HttpsConnectorWithSni { addr, ); let tls_stream = TlsStream::connect_https(proxy, &hostname).await?; - Ok(ApiConnection::Proxied(tls_stream)) + Ok(ApiConnection::Proxied(Box::new(tls_stream))) } } }; diff --git a/mullvad-api/src/lib.rs b/mullvad-api/src/lib.rs index b1019ef155..d967791192 100644 --- a/mullvad-api/src/lib.rs +++ b/mullvad-api/src/lib.rs @@ -236,7 +236,7 @@ impl Runtime { new_address_callback: impl ApiEndpointUpdateCallback + Send + Sync + 'static, #[cfg(target_os = "android")] socket_bypass_tx: Option<mpsc::Sender<SocketBypassRequest>>, ) -> rest::RequestServiceHandle { - let service_handle = rest::RequestService::new( + let service_handle = rest::RequestService::spawn( sni_hostname, self.api_availability.handle(), self.address_cache.clone(), diff --git a/mullvad-api/src/proxy.rs b/mullvad-api/src/proxy.rs index 21fa39c9c6..cc7ee3aa6c 100644 --- a/mullvad-api/src/proxy.rs +++ b/mullvad-api/src/proxy.rs @@ -132,8 +132,8 @@ impl ApiConnectionMode { /// Stream that is either a regular TLS stream or TLS via shadowsocks pub enum ApiConnection { - Direct(TlsStream<TcpStream>), - Proxied(TlsStream<ProxyClientStream<TcpStream>>), + Direct(Box<TlsStream<TcpStream>>), + Proxied(Box<TlsStream<ProxyClientStream<TcpStream>>>), } impl AsyncRead for ApiConnection { diff --git a/mullvad-api/src/relay_list.rs b/mullvad-api/src/relay_list.rs index 6bd4523652..a2d699e248 100644 --- a/mullvad-api/src/relay_list.rs +++ b/mullvad-api/src/relay_list.rs @@ -130,7 +130,7 @@ impl ServerRelayList { ) { let openvpn_endpoint_data = openvpn.ports; for mut openvpn_relay in openvpn.relays.into_iter() { - openvpn_relay.to_lower(); + openvpn_relay.convert_to_lowercase(); if let Some((country_code, city_code)) = split_location_code(&openvpn_relay.location) { if let Some(country) = countries.get_mut(country_code) { if let Some(city) = country @@ -184,7 +184,7 @@ impl ServerRelayList { }; for mut wireguard_relay in relays { - wireguard_relay.relay.to_lower(); + wireguard_relay.relay.convert_to_lowercase(); if let Some((country_code, city_code)) = split_location_code(&wireguard_relay.relay.location) { @@ -235,7 +235,7 @@ impl ServerRelayList { } = bridges; for mut bridge_relay in relays { - bridge_relay.to_lower(); + bridge_relay.convert_to_lowercase(); if let Some((country_code, city_code)) = split_location_code(&bridge_relay.location) { if let Some(country) = countries.get_mut(country_code) { if let Some(city) = country @@ -345,7 +345,7 @@ struct Relay { } impl Relay { - fn to_lower(&mut self) { + fn convert_to_lowercase(&mut self) { self.hostname = self.hostname.to_lowercase(); self.location = self.location.to_lowercase(); } diff --git a/mullvad-api/src/rest.rs b/mullvad-api/src/rest.rs index 8a04c62e39..c80f01049a 100644 --- a/mullvad-api/src/rest.rs +++ b/mullvad-api/src/rest.rs @@ -130,7 +130,7 @@ impl< > RequestService<T, F> { /// Constructs a new request service. - pub async fn new( + pub async fn spawn( sni_hostname: Option<String>, api_availability: ApiAvailabilityHandle, address_cache: AddressCache, diff --git a/mullvad-daemon/src/cleanup.rs b/mullvad-daemon/src/cleanup.rs index b4765761e0..1f0cbdf309 100644 --- a/mullvad-daemon/src/cleanup.rs +++ b/mullvad-daemon/src/cleanup.rs @@ -7,26 +7,26 @@ use tokio::{fs, io}; #[error(no_from)] pub enum Error { #[error(display = "Failed to get path")] - PathError(#[error(source)] mullvad_paths::Error), + Path(#[error(source)] mullvad_paths::Error), #[error(display = "Failed to remove directory {}", _0)] - RemoveDirError(String, #[error(source)] io::Error), + RemoveDir(String, #[error(source)] io::Error), #[cfg(not(target_os = "windows"))] #[error(display = "Failed to create directory {}", _0)] - CreateDirError(String, #[error(source)] io::Error), + CreateDir(String, #[error(source)] io::Error), #[cfg(target_os = "windows")] #[error(display = "Failed to get file type info")] - FileTypeError(#[error(source)] io::Error), + FileType(#[error(source)] io::Error), #[cfg(target_os = "windows")] #[error(display = "Failed to get dir entry")] - FileEntryError(#[error(source)] io::Error), + FileEntry(#[error(source)] io::Error), #[cfg(target_os = "windows")] #[error(display = "Failed to read dir entries")] - ReadDirError(#[error(source)] io::Error), + ReadDir(#[error(source)] io::Error), } pub async fn clear_directories() -> Result<(), Error> { @@ -35,12 +35,12 @@ pub async fn clear_directories() -> Result<(), Error> { } async fn clear_log_directory() -> Result<(), Error> { - let log_dir = mullvad_paths::get_log_dir().map_err(Error::PathError)?; + let log_dir = mullvad_paths::get_log_dir().map_err(Error::Path)?; clear_directory(&log_dir).await } async fn clear_cache_directory() -> Result<(), Error> { - let cache_dir = mullvad_paths::cache_dir().map_err(Error::PathError)?; + let cache_dir = mullvad_paths::cache_dir().map_err(Error::Path)?; clear_directory(&cache_dir).await } @@ -49,22 +49,22 @@ async fn clear_directory(path: &Path) -> Result<(), Error> { { fs::remove_dir_all(path) .await - .map_err(|e| Error::RemoveDirError(path.display().to_string(), e))?; + .map_err(|e| Error::RemoveDir(path.display().to_string(), e))?; fs::create_dir_all(path) .await - .map_err(|e| Error::CreateDirError(path.display().to_string(), e)) + .map_err(|e| Error::CreateDir(path.display().to_string(), e)) } #[cfg(target_os = "windows")] { - let mut dir = fs::read_dir(&path).await.map_err(Error::ReadDirError)?; + let mut dir = fs::read_dir(&path).await.map_err(Error::ReadDir)?; let mut result = Ok(()); - while let Some(entry) = dir.next_entry().await.map_err(Error::FileEntryError)? { + while let Some(entry) = dir.next_entry().await.map_err(Error::FileEntry)? { let entry_type = match entry.file_type().await { Ok(entry_type) => entry_type, Err(error) => { - result = result.and(Err(Error::FileTypeError(error))); + result = result.and(Err(Error::FileType(error))); continue; } }; @@ -74,9 +74,8 @@ async fn clear_directory(path: &Path) -> Result<(), Error> { } else { fs::remove_dir_all(entry.path()).await }; - result = result.and( - removal.map_err(|e| Error::RemoveDirError(entry.path().display().to_string(), e)), - ); + result = result + .and(removal.map_err(|e| Error::RemoveDir(entry.path().display().to_string(), e))); } result } diff --git a/mullvad-daemon/src/device/api.rs b/mullvad-daemon/src/device/api.rs index cee0186987..6e00b669ca 100644 --- a/mullvad-daemon/src/device/api.rs +++ b/mullvad-daemon/src/device/api.rs @@ -35,10 +35,10 @@ impl CurrentApiCall { } pub fn is_validating(&self) -> bool { - match &self.current_call { - Some(Call::Validation(_)) | Some(Call::OneshotKeyRotation(_)) => true, - _ => false, - } + matches!( + &self.current_call, + Some(Call::Validation(_)) | Some(Call::OneshotKeyRotation(_)) + ) } pub fn is_running_timed_totation(&self) -> bool { @@ -51,10 +51,7 @@ impl CurrentApiCall { pub fn is_logging_in(&self) -> bool { use Call::*; - match &self.current_call { - Some(Login(..)) => true, - _ => false, - } + matches!(&self.current_call, Some(Login(..))) } } diff --git a/mullvad-daemon/src/exception_logging/unix.rs b/mullvad-daemon/src/exception_logging/unix.rs index 8d87be2da6..430fedb74f 100644 --- a/mullvad-daemon/src/exception_logging/unix.rs +++ b/mullvad-daemon/src/exception_logging/unix.rs @@ -5,7 +5,7 @@ use nix::sys::signal::{sigaction, SaFlags, SigAction, SigHandler, SigSet, Signal use std::{convert::TryFrom, sync::Once}; -const INIT_ONCE: Once = Once::new(); +static INIT_ONCE: Once = Once::new(); const FAULT_SIGNALS: [Signal; 5] = [ // Access to invalid memory address diff --git a/mullvad-daemon/src/lib.rs b/mullvad-daemon/src/lib.rs index 95039cd662..c6f024eb49 100644 --- a/mullvad-daemon/src/lib.rs +++ b/mullvad-daemon/src/lib.rs @@ -31,7 +31,7 @@ use crate::target_state::PersistentTargetState; use device::{PrivateAccountAndDevice, PrivateDeviceEvent}; use futures::{ channel::{mpsc, oneshot}, - future::{abortable, AbortHandle, Future}, + future::{abortable, AbortHandle, Future, LocalBoxFuture}, StreamExt, }; use mullvad_relay_selector::{ @@ -385,6 +385,12 @@ pub struct DaemonCommandChannel { receiver: mpsc::UnboundedReceiver<InternalDaemonEvent>, } +impl Default for DaemonCommandChannel { + fn default() -> Self { + Self::new() + } +} + impl DaemonCommandChannel { pub fn new() -> Self { let (untracked_sender, receiver) = mpsc::unbounded(); @@ -472,13 +478,13 @@ impl<E> Sender<E> for DaemonEventSender<E> where InternalDaemonEvent: From<E>, { - fn send(&self, event: E) -> Result<(), ()> { + fn send(&self, event: E) -> Result<(), talpid_core::mpsc::Error> { if let Some(sender) = self.sender.upgrade() { sender .unbounded_send(InternalDaemonEvent::from(event)) - .map_err(|_| ()) + .map_err(|_| talpid_core::mpsc::Error::ChannelClosed) } else { - Err(()) + Err(talpid_core::mpsc::Error::ChannelClosed) } } } @@ -684,7 +690,7 @@ where relay_list_listener.notify_relay_list(relay_list.clone()); }; - let mut relay_list_updater = RelayListUpdater::new( + let mut relay_list_updater = RelayListUpdater::spawn( relay_selector.clone(), api_handle.clone(), &cache_dir, @@ -785,11 +791,11 @@ where /// Shuts down the daemon without shutting down the underlying event listener and the shutdown /// callbacks - fn shutdown( + fn shutdown<'a>( self, ) -> ( L, - Vec<Pin<Box<dyn Future<Output = ()>>>>, + Vec<LocalBoxFuture<'a, ()>>, mullvad_api::Runtime, TunnelStateMachineHandle, ) { @@ -845,11 +851,11 @@ where TunnelStateTransition::Disconnected => TunnelState::Disconnected, TunnelStateTransition::Connecting(endpoint) => TunnelState::Connecting { endpoint, - location: self.parameters_generator.get_last_location(), + location: self.parameters_generator.get_last_location().await, }, TunnelStateTransition::Connected(endpoint) => TunnelState::Connected { endpoint, - location: self.parameters_generator.get_last_location(), + location: self.parameters_generator.get_last_location().await, }, TunnelStateTransition::Disconnecting(after_disconnect) => { TunnelState::Disconnecting(after_disconnect) @@ -1184,7 +1190,7 @@ where } Disconnecting(..) => Self::oneshot_send( tx, - self.parameters_generator.get_last_location(), + self.parameters_generator.get_last_location().await, "current location", ), Connected { location, .. } => { @@ -1703,7 +1709,8 @@ where Self::oneshot_send(tx, Ok(()), "use_wireguard_nt response"); if settings_changed { self.parameters_generator - .set_tunnel_options(&self.settings.tunnel_options); + .set_tunnel_options(&self.settings.tunnel_options) + .await; self.event_listener .notify_settings(self.settings.to_settings()); if let Some(TunnelType::Wireguard) = self.get_target_tunnel_type() { @@ -1854,7 +1861,8 @@ where Self::oneshot_send(tx, Ok(()), "set_openvpn_mssfix response"); if settings_changed { self.parameters_generator - .set_tunnel_options(&self.settings.tunnel_options); + .set_tunnel_options(&self.settings.tunnel_options) + .await; self.event_listener .notify_settings(self.settings.to_settings()); if self.get_target_tunnel_type() == Some(TunnelType::OpenVpn) { @@ -1963,7 +1971,8 @@ where Self::oneshot_send(tx, Ok(()), "set_enable_ipv6 response"); if settings_changed { self.parameters_generator - .set_tunnel_options(&self.settings.tunnel_options); + .set_tunnel_options(&self.settings.tunnel_options) + .await; self.event_listener .notify_settings(self.settings.to_settings()); log::info!("Initiating tunnel restart because the enable IPv6 setting changed"); @@ -1991,7 +2000,8 @@ where Self::oneshot_send(tx, Ok(()), "set_quantum_resistant_tunnel response"); if settings_changed { self.parameters_generator - .set_tunnel_options(&self.settings.tunnel_options); + .set_tunnel_options(&self.settings.tunnel_options) + .await; self.event_listener .notify_settings(self.settings.to_settings()); if self.get_target_tunnel_type() == Some(TunnelType::Wireguard) { @@ -2021,7 +2031,8 @@ where let resolvers = dns::addresses_from_options(&settings.tunnel_options.dns_options); self.parameters_generator - .set_tunnel_options(&settings.tunnel_options); + .set_tunnel_options(&settings.tunnel_options) + .await; self.event_listener.notify_settings(settings); self.send_tunnel_command(TunnelCommand::Dns(resolvers)); } @@ -2044,7 +2055,8 @@ where Self::oneshot_send(tx, Ok(()), "set_wireguard_mtu response"); if settings_changed { self.parameters_generator - .set_tunnel_options(&self.settings.tunnel_options); + .set_tunnel_options(&self.settings.tunnel_options) + .await; self.event_listener .notify_settings(self.settings.to_settings()); if let Some(TunnelType::Wireguard) = self.get_connected_tunnel_type() { @@ -2086,7 +2098,8 @@ where ); } self.parameters_generator - .set_tunnel_options(&self.settings.tunnel_options); + .set_tunnel_options(&self.settings.tunnel_options) + .await; self.event_listener .notify_settings(self.settings.to_settings()); } diff --git a/mullvad-daemon/src/management_interface.rs b/mullvad-daemon/src/management_interface.rs index 7999fd9c55..26ad5e39c0 100644 --- a/mullvad-daemon/src/management_interface.rs +++ b/mullvad-daemon/src/management_interface.rs @@ -151,7 +151,7 @@ impl ManagementService for ManagementServiceImpl { self.send_command_to_daemon(DaemonCommand::GetVersionInfo(tx))?; self.wait_for_result(rx) .await? - .ok_or(Status::not_found("no version cache")) + .ok_or_else(|| Status::not_found("no version cache")) .map(types::AppVersionInfo::from) .map(Response::new) } diff --git a/mullvad-daemon/src/migrations/account_history.rs b/mullvad-daemon/src/migrations/account_history.rs index 22534c7f88..520717e405 100644 --- a/mullvad-daemon/src/migrations/account_history.rs +++ b/mullvad-daemon/src/migrations/account_history.rs @@ -53,12 +53,12 @@ pub async fn migrate_formats(settings_dir: &Path, settings: &mut serde_json::Val .read(true) .open(path) .await - .map_err(Error::ReadHistoryError)?; + .map_err(Error::ReadHistory)?; let mut bytes = vec![]; file.read_to_end(&mut bytes) .await - .map_err(Error::ReadHistoryError)?; + .map_err(Error::ReadHistory)?; if is_format_v3(&bytes) { return Ok(()); @@ -92,16 +92,16 @@ fn is_format_v3(bytes: &[u8]) -> bool { } async fn write_format_v3(mut file: File, token: Option<AccountToken>) -> Result<()> { - file.set_len(0).await.map_err(Error::WriteHistoryError)?; + file.set_len(0).await.map_err(Error::WriteHistory)?; file.seek(io::SeekFrom::Start(0)) .await - .map_err(Error::WriteHistoryError)?; + .map_err(Error::WriteHistory)?; if let Some(token) = token { file.write_all(token.as_bytes()) .await - .map_err(Error::WriteHistoryError)?; + .map_err(Error::WriteHistory)?; } - file.sync_all().await.map_err(Error::WriteHistoryError) + file.sync_all().await.map_err(Error::WriteHistory) } fn try_format_v2(bytes: &[u8]) -> Result<Option<(AccountToken, serde_json::Value)>> { diff --git a/mullvad-daemon/src/migrations/mod.rs b/mullvad-daemon/src/migrations/mod.rs index bb1e9d1ba0..b4a1d18979 100644 --- a/mullvad-daemon/src/migrations/mod.rs +++ b/mullvad-daemon/src/migrations/mod.rs @@ -57,31 +57,31 @@ const SETTINGS_FILE: &str = "settings.json"; #[error(no_from)] pub enum Error { #[error(display = "Failed to read the settings")] - ReadError(#[error(source)] io::Error), + Read(#[error(source)] io::Error), #[error(display = "Malformed settings")] - ParseError(#[error(source)] serde_json::Error), + Parse(#[error(source)] serde_json::Error), #[error(display = "Unable to read any version of the settings")] NoMatchingVersion, #[error(display = "Unable to serialize settings to JSON")] - SerializeError(#[error(source)] serde_json::Error), + Serialize(#[error(source)] serde_json::Error), #[error(display = "Unable to open settings for writing")] - OpenError(#[error(source)] io::Error), + Open(#[error(source)] io::Error), #[error(display = "Unable to write new settings")] - WriteError(#[error(source)] io::Error), + Write(#[error(source)] io::Error), #[error(display = "Unable to sync settings to disk")] - SyncError(#[error(source)] io::Error), + SyncSettings(#[error(source)] io::Error), #[error(display = "Failed to read the account history")] - ReadHistoryError(#[error(source)] io::Error), + ReadHistory(#[error(source)] io::Error), #[error(display = "Failed to write new account history")] - WriteHistoryError(#[error(source)] io::Error), + WriteHistory(#[error(source)] io::Error), #[error(display = "Failed to parse account history")] ParseHistoryError, @@ -129,10 +129,10 @@ pub(crate) async fn migrate_all( return Ok(None); } - let settings_bytes = fs::read(&path).await.map_err(Error::ReadError)?; + let settings_bytes = fs::read(&path).await.map_err(Error::Read)?; let mut settings: serde_json::Value = - serde_json::from_reader(&settings_bytes[..]).map_err(Error::ParseError)?; + serde_json::from_reader(&settings_bytes[..]).map_err(Error::Parse)?; if !settings.is_object() { return Err(Error::NoMatchingVersion); @@ -155,7 +155,7 @@ pub(crate) async fn migrate_all( return Ok(migration_data); } - let buffer = serde_json::to_string_pretty(&settings).map_err(Error::SerializeError)?; + let buffer = serde_json::to_string_pretty(&settings).map_err(Error::Serialize)?; let mut options = fs::OpenOptions::new(); #[cfg(unix)] @@ -168,11 +168,11 @@ pub(crate) async fn migrate_all( .truncate(true) .open(&path) .await - .map_err(Error::OpenError)?; + .map_err(Error::Open)?; file.write_all(&buffer.into_bytes()) .await - .map_err(Error::WriteError)?; - file.sync_data().await.map_err(Error::SyncError)?; + .map_err(Error::Write)?; + file.sync_data().await.map_err(Error::SyncSettings)?; log::debug!("Migrated settings. Wrote settings to {}", path.display()); diff --git a/mullvad-daemon/src/migrations/v2.rs b/mullvad-daemon/src/migrations/v2.rs index e91a0d08e8..d4acbf77a7 100644 --- a/mullvad-daemon/src/migrations/v2.rs +++ b/mullvad-daemon/src/migrations/v2.rs @@ -1,3 +1,4 @@ +#![allow(clippy::identity_op)] use super::{Error, Result}; use mullvad_types::settings::SettingsVersion; use std::time::Duration; diff --git a/mullvad-daemon/src/migrations/v3.rs b/mullvad-daemon/src/migrations/v3.rs index cf8631e121..7d2800e32a 100644 --- a/mullvad-daemon/src/migrations/v3.rs +++ b/mullvad-daemon/src/migrations/v3.rs @@ -66,7 +66,7 @@ pub fn migrate(settings: &mut serde_json::Value) -> Result<()> { DnsState::Default }; let addresses = if let Some(addrs) = options.get("addresses") { - serde_json::from_value(addrs.clone()).map_err(Error::ParseError)? + serde_json::from_value(addrs.clone()).map_err(Error::Parse)? } else { vec![] }; diff --git a/mullvad-daemon/src/migrations/v4.rs b/mullvad-daemon/src/migrations/v4.rs index 6908b13010..77b72bcb3e 100644 --- a/mullvad-daemon/src/migrations/v4.rs +++ b/mullvad-daemon/src/migrations/v4.rs @@ -43,8 +43,7 @@ pub fn migrate(settings: &mut serde_json::Value) -> Result<()> { if let Some(constraints) = wireguard_constraints { let (port, protocol): (Constraint<u16>, TransportProtocol) = if let Some(port) = constraints.get("port") { - let port_constraint = - serde_json::from_value(port.clone()).map_err(Error::ParseError)?; + let port_constraint = serde_json::from_value(port.clone()).map_err(Error::Parse)?; match port_constraint { Constraint::Any => (Constraint::Any, TransportProtocol::Udp), Constraint::Only(port) => (Constraint::Only(port), wg_protocol_from_port(port)), @@ -77,13 +76,13 @@ pub fn migrate(settings: &mut serde_json::Value) -> Result<()> { if let Some(constraints) = openvpn_constraints { let port: Constraint<u16> = if let Some(port) = constraints.get("port") { - serde_json::from_value(port.clone()).map_err(Error::ParseError)? + serde_json::from_value(port.clone()).map_err(Error::Parse)? } else { Constraint::Any }; let transport_constraint: Constraint<TransportProtocol> = if let Some(protocol) = constraints.get("protocol") { - serde_json::from_value(protocol.clone()).map_err(Error::ParseError)? + serde_json::from_value(protocol.clone()).map_err(Error::Parse)? } else { Constraint::Any }; diff --git a/mullvad-daemon/src/migrations/v5.rs b/mullvad-daemon/src/migrations/v5.rs index 9f0fdc4b94..cfc74d8438 100644 --- a/mullvad-daemon/src/migrations/v5.rs +++ b/mullvad-daemon/src/migrations/v5.rs @@ -95,7 +95,7 @@ pub(crate) async fn migrate(settings: &mut serde_json::Value) -> Result<Option<M // if let Some(port) = wireguard_constraints.get("port") { let port_constraint: Constraint<TransportPort> = - serde_json::from_value(port.clone()).map_err(Error::ParseError)?; + serde_json::from_value(port.clone()).map_err(Error::Parse)?; if let Some(transport_port) = port_constraint.option() { let (port, obfuscation_settings) = match transport_port.protocol { TransportProtocol::Udp => (serde_json::json!(transport_port.port), None), @@ -116,8 +116,7 @@ pub(crate) async fn migrate(settings: &mut serde_json::Value) -> Result<Option<M let migration_data = if let Some(token) = settings.get("account_token").filter(|t| !t.is_null()) { - let token: AccountToken = - serde_json::from_value(token.clone()).map_err(Error::ParseError)?; + let token: AccountToken = serde_json::from_value(token.clone()).map_err(Error::Parse)?; let migration_data = if let Some(wg_data) = settings.get("wireguard").filter(|wg| !wg.is_null()) { Some(MigrationData { diff --git a/mullvad-daemon/src/tunnel.rs b/mullvad-daemon/src/tunnel.rs index e95850c4f4..18b8bf2c20 100644 --- a/mullvad-daemon/src/tunnel.rs +++ b/mullvad-daemon/src/tunnel.rs @@ -1,8 +1,6 @@ -use std::{ - future::Future, - pin::Pin, - sync::{Arc, Mutex}, -}; +use std::{future::Future, pin::Pin, sync::Arc}; + +use tokio::sync::Mutex; use mullvad_relay_selector::{RelaySelector, SelectedBridge, SelectedObfuscator, SelectedRelay}; use mullvad_types::{ @@ -32,7 +30,7 @@ pub enum Error { NoBridgeAvailable, #[error(display = "Failed to resolve hostname for custom relay")] - ResolveCustomHostnameError, + ResolveCustomHostname, } #[derive(Clone)] @@ -65,13 +63,13 @@ impl ParametersGenerator { } /// Sets the tunnel options to use when generating new tunnel parameters. - pub fn set_tunnel_options(&self, tunnel_options: &TunnelOptions) { - self.0.lock().unwrap().tunnel_options = tunnel_options.clone(); + pub async fn set_tunnel_options(&self, tunnel_options: &TunnelOptions) { + self.0.lock().await.tunnel_options = tunnel_options.clone(); } /// Gets the location associated with the last generated tunnel parameters. - pub fn get_last_location(&self) -> Option<GeoIpLocation> { - let inner = self.0.lock().unwrap(); + pub async fn get_last_location(&self) -> Option<GeoIpLocation> { + let inner = self.0.lock().await; let relays = inner.last_generated_relays.as_ref()?; @@ -131,7 +129,7 @@ impl InnerParametersGenerator { .to_tunnel_parameters(self.tunnel_options.clone(), None) .map_err(|e| { log::error!("Failed to resolve hostname for custom tunnel config: {}", e); - Error::ResolveCustomHostnameError + Error::ResolveCustomHostname }) } Ok((SelectedRelay::Normal(constraints), bridge, obfuscator)) => { @@ -246,13 +244,13 @@ impl TunnelParametersGenerator for ParametersGenerator { ) -> Pin<Box<dyn Future<Output = Result<TunnelParameters, ParameterGenerationError>>>> { let generator = self.0.clone(); Box::pin(async move { - let mut inner = generator.lock().unwrap(); + let mut inner = generator.lock().await; inner .generate(retry_attempt) .await .map_err(|error| match error { Error::NoBridgeAvailable => ParameterGenerationError::NoMatchingBridgeRelay, - Error::ResolveCustomHostnameError => { + Error::ResolveCustomHostname => { ParameterGenerationError::CustomTunnelHostResultionError } error => { diff --git a/mullvad-daemon/src/version_check.rs b/mullvad-daemon/src/version_check.rs index ffa5abc979..10dbf7bfb4 100644 --- a/mullvad-daemon/src/version_check.rs +++ b/mullvad-daemon/src/version_check.rs @@ -14,6 +14,7 @@ use std::{ future::Future, io, path::{Path, PathBuf}, + str::FromStr, time::Duration, }; use talpid_core::mpsc::Sender; @@ -297,10 +298,10 @@ impl VersionUpdater { if !*IS_DEV_BUILD { let stable_version = latest_stable .as_ref() - .and_then(|stable| ParsedAppVersion::from_str(stable)); + .and_then(|stable| ParsedAppVersion::from_str(stable).ok()); let beta_version = if show_beta { - ParsedAppVersion::from_str(latest_beta) + ParsedAppVersion::from_str(latest_beta).ok() } else { None }; @@ -339,10 +340,10 @@ impl VersionUpdater { let mut check_delay = next_delay(); let mut version_check = futures::future::Fuse::terminated(); - // If this is a dev build ,there's no need to pester the API for version checks. + // If this is a dev build, there's no need to pester the API for version checks. if *IS_DEV_BUILD { log::warn!("Not checking for updates because this is a development build"); - while let Some(_) = rx.next().await {} + while rx.next().await.is_some() {} return; } diff --git a/mullvad-relay-selector/src/lib.rs b/mullvad-relay-selector/src/lib.rs index 158086b1d4..e7b85fbc4c 100644 --- a/mullvad-relay-selector/src/lib.rs +++ b/mullvad-relay-selector/src/lib.rs @@ -384,7 +384,7 @@ impl RelaySelector { let mut relay_matcher = RelayMatcher { location: location.clone(), providers: providers.clone(), - ownership: ownership.clone(), + ownership: *ownership, tunnel: openvpn_constraints, }; @@ -492,7 +492,7 @@ impl RelaySelector { let mut entry_relay_matcher = RelayMatcher { location: location.clone(), providers: providers.clone(), - ownership: ownership.clone(), + ownership: *ownership, tunnel: wireguard_constraints.clone().into(), }; @@ -532,7 +532,7 @@ impl RelaySelector { .clone(), ..matcher.clone() } - .to_wireguard_matcher(); + .into_wireguard_matcher(); // Pick the entry relay first if its location constraint is a subset of the exit location. if relay_constraints.wireguard_constraints.use_multihop { @@ -746,7 +746,7 @@ impl RelaySelector { let bridge_constraints = InternalBridgeConstraints { location: settings.location.clone(), providers: settings.providers.clone(), - ownership: settings.ownership.clone(), + ownership: settings.ownership, // FIXME: This is temporary while talpid-core only supports TCP proxies transport_protocol: Constraint::Only(TransportProtocol::Tcp), }; @@ -791,7 +791,7 @@ impl RelaySelector { BridgeSettings::Normal(settings) => InternalBridgeConstraints { location: settings.location.clone(), providers: settings.providers.clone(), - ownership: settings.ownership.clone(), + ownership: settings.ownership, transport_protocol: Constraint::Only(TransportProtocol::Tcp), }, BridgeSettings::Custom(_bridge_settings) => InternalBridgeConstraints { @@ -1064,7 +1064,7 @@ impl RelaySelector { let addr_in = endpoint .as_ref() .map(|endpoint| endpoint.to_endpoint().address.ip()) - .unwrap_or(IpAddr::from(selected_relay.ipv4_addr_in)); + .unwrap_or_else(|| IpAddr::from(selected_relay.ipv4_addr_in)); log::info!("Selected relay {} at {}", selected_relay.hostname, addr_in); endpoint.map(|endpoint| NormalSelectedRelay::new(endpoint, selected_relay.clone())) }) diff --git a/mullvad-relay-selector/src/matcher.rs b/mullvad-relay-selector/src/matcher.rs index 13e16646ab..350a106745 100644 --- a/mullvad-relay-selector/src/matcher.rs +++ b/mullvad-relay-selector/src/matcher.rs @@ -34,7 +34,7 @@ impl From<RelayConstraints> for RelayMatcher<AnyTunnelMatcher> { } impl RelayMatcher<AnyTunnelMatcher> { - pub fn to_wireguard_matcher(self) -> RelayMatcher<WireguardMatcher> { + pub fn into_wireguard_matcher(self) -> RelayMatcher<WireguardMatcher> { RelayMatcher { tunnel: self.tunnel.wireguard, location: self.location, diff --git a/mullvad-relay-selector/src/updater.rs b/mullvad-relay-selector/src/updater.rs index 480be795fb..31299eea14 100644 --- a/mullvad-relay-selector/src/updater.rs +++ b/mullvad-relay-selector/src/updater.rs @@ -57,7 +57,7 @@ pub struct RelayListUpdater { } impl RelayListUpdater { - pub fn new( + pub fn spawn( selector: super::RelaySelector, api_handle: MullvadRestHandle, cache_dir: &Path, diff --git a/mullvad-setup/src/main.rs b/mullvad-setup/src/main.rs index b257d92080..9d353b78a6 100644 --- a/mullvad-setup/src/main.rs +++ b/mullvad-setup/src/main.rs @@ -2,7 +2,7 @@ use clap::{crate_authors, crate_description, crate_name, App}; use mullvad_api::{self, proxy::ApiConnectionMode}; use mullvad_management_interface::new_rpc_client; use mullvad_types::version::ParsedAppVersion; -use std::{path::PathBuf, process, time::Duration}; +use std::{path::PathBuf, process, str::FromStr, time::Duration}; use talpid_core::{ firewall::{self, Firewall}, future_retry::{constant_interval, retry_future_n}, @@ -133,7 +133,7 @@ async fn main() { async fn is_older_version(old_version: &str) -> Result<ExitStatus, Error> { let parsed_version = - ParsedAppVersion::from_str(old_version).ok_or(Error::ParseVersionStringError)?; + ParsedAppVersion::from_str(old_version).map_err(|_| Error::ParseVersionStringError)?; Ok(if parsed_version < *APP_VERSION { ExitStatus::Ok @@ -152,7 +152,7 @@ async fn prepare_restart() -> Result<(), Error> { async fn reset_firewall() -> Result<(), Error> { // Ensure that the daemon isn't running - if let Ok(_) = new_rpc_client().await { + if new_rpc_client().await.is_ok() { return Err(Error::DaemonIsRunning); } diff --git a/mullvad-types/src/relay_constraints.rs b/mullvad-types/src/relay_constraints.rs index 4ea0a8dfc5..ffbb317bd3 100644 --- a/mullvad-types/src/relay_constraints.rs +++ b/mullvad-types/src/relay_constraints.rs @@ -625,19 +625,15 @@ impl RelaySettingsUpdate { RelaySettingsUpdate::CustomTunnelEndpoint(endpoint) => { endpoint.endpoint().protocol == TransportProtocol::Tcp } - RelaySettingsUpdate::Normal(update) => { - if let Some(constraints) = &update.openvpn_constraints { - !matches!( - &constraints.port, - Constraint::Only(TransportPort { - protocol: TransportProtocol::Udp, - .. - }) - ) - } else { - true - } - } + RelaySettingsUpdate::Normal(update) => !matches!( + &update.openvpn_constraints, + Some(OpenVpnConstraints { + port: Constraint::Only(TransportPort { + protocol: TransportProtocol::Udp, + .. + }) + }) + ), } } } diff --git a/mullvad-types/src/version.rs b/mullvad-types/src/version.rs index 7daf50e74b..8040b181b0 100644 --- a/mullvad-types/src/version.rs +++ b/mullvad-types/src/version.rs @@ -2,7 +2,10 @@ use jnix::IntoJava; use regex::Regex; use serde::{Deserialize, Serialize}; -use std::cmp::{Ord, Ordering, PartialOrd}; +use std::{ + cmp::{Ord, Ordering, PartialOrd}, + str::FromStr, +}; lazy_static::lazy_static! { static ref STABLE_REGEX: Regex = Regex::new(r"^(\d{4})\.(\d+)$").unwrap(); @@ -44,30 +47,33 @@ pub enum ParsedAppVersion { Dev(u32, u32, Option<u32>, String), } -impl ParsedAppVersion { - pub fn from_str(version: &str) -> Option<Self> { +impl FromStr for ParsedAppVersion { + type Err = (); + fn from_str(version: &str) -> Result<Self, Self::Err> { let get_int = |cap: ®ex::Captures<'_>, idx| cap.get(idx)?.as_str().parse().ok(); if let Some(caps) = STABLE_REGEX.captures(version) { - let year = get_int(&caps, 1)?; - let version = get_int(&caps, 2)?; - Some(Self::Stable(year, version)) + let year = get_int(&caps, 1).ok_or(())?; + let version = get_int(&caps, 2).ok_or(())?; + Ok(Self::Stable(year, version)) } else if let Some(caps) = BETA_REGEX.captures(version) { - let year = get_int(&caps, 1)?; - let version = get_int(&caps, 2)?; - let beta_version = get_int(&caps, 3)?; - Some(Self::Beta(year, version, beta_version)) + let year = get_int(&caps, 1).ok_or(())?; + let version = get_int(&caps, 2).ok_or(())?; + let beta_version = get_int(&caps, 3).ok_or(())?; + Ok(Self::Beta(year, version, beta_version)) } else if let Some(caps) = DEV_REGEX.captures(version) { - let year = get_int(&caps, 1)?; - let version = get_int(&caps, 2)?; + let year = get_int(&caps, 1).ok_or(())?; + let version = get_int(&caps, 2).ok_or(())?; let beta_version = caps.get(4).map(|_| get_int(&caps, 5).unwrap()); - let dev_hash = caps.get(6)?.as_str().to_string(); - Some(Self::Dev(year, version, beta_version, dev_hash)) + let dev_hash = caps.get(6).ok_or(())?.as_str().to_string(); + Ok(Self::Dev(year, version, beta_version, dev_hash)) } else { - None + Err(()) } } +} +impl ParsedAppVersion { pub fn is_dev(&self) -> bool { matches!(self, ParsedAppVersion::Dev(..)) } @@ -191,7 +197,7 @@ mod test { ]; for (input, expected_output) in tests { - assert_eq!(ParsedAppVersion::from_str(input), expected_output,); + assert_eq!(ParsedAppVersion::from_str(input).ok(), expected_output,); } } } diff --git a/talpid-core/src/dns/linux/resolvconf.rs b/talpid-core/src/dns/linux/resolvconf.rs index ea2f3f3704..97db14b622 100644 --- a/talpid-core/src/dns/linux/resolvconf.rs +++ b/talpid-core/src/dns/linux/resolvconf.rs @@ -22,16 +22,16 @@ pub enum Error { RunResolvconf(#[error(source)] io::Error), #[error(display = "Using 'resolvconf' to add a record failed: {}", stderr)] - AddRecordError { stderr: String }, + AddRecord { stderr: String }, #[error(display = "Using 'resolvconf' to delete a record failed")] - DeleteRecordError, + DeleteRecord, #[error(display = "Detected dnsmasq is runing and misconfigured")] - DnsmasqMisconfigurationError, + DnsmasqMisconfiguration, #[error(display = "Current /etc/resolv.conf is not generated by resolvconf")] - ResolvconfNotInUseError, + ResolvconfNotInUse, } pub struct Resolvconf { @@ -50,15 +50,15 @@ impl Resolvconf { // Check if resolvconf is managing DNS by /etc/resolv.conf if !is_dnsmasq_running - && !(Self::check_if_resolvconf_is_symlinked_correctly() - || Self::check_if_resolvconf_was_generated()) + && !Self::check_if_resolvconf_is_symlinked_correctly() + && !Self::check_if_resolvconf_was_generated() { - return Err(Error::ResolvconfNotInUseError); + return Err(Error::ResolvconfNotInUse); } // Check if resolvconf can manage DNS via dnsmasq if is_dnsmasq_running && Self::is_dnsmasq_configured_wrong() { - return Err(Error::DnsmasqMisconfigurationError); + return Err(Error::DnsmasqMisconfiguration); } Ok(Resolvconf { @@ -94,7 +94,7 @@ impl Resolvconf { if !output.status.success() { let stderr = String::from_utf8_lossy(&output.stderr).to_string(); - return Err(Error::AddRecordError { stderr }); + return Err(Error::AddRecord { stderr }); } self.record_names.insert(record_name); @@ -118,7 +118,7 @@ impl Resolvconf { record_name, String::from_utf8_lossy(&output.stderr) ); - result = Err(Error::DeleteRecordError); + result = Err(Error::DeleteRecord); } } diff --git a/talpid-core/src/dns/linux/static_resolv_conf.rs b/talpid-core/src/dns/linux/static_resolv_conf.rs index 196fb31003..691d7b468b 100644 --- a/talpid-core/src/dns/linux/static_resolv_conf.rs +++ b/talpid-core/src/dns/linux/static_resolv_conf.rs @@ -28,7 +28,7 @@ pub enum Error { ReadResolvConf(&'static str, #[error(source)] io::Error), #[error(display = "resolv.conf at {} could not be parsed", _0)] - ParseError(&'static str, #[error(source)] resolv_conf::ParseError), + Parse(&'static str, #[error(source)] resolv_conf::ParseError), #[error(display = "Failed to remove stale resolv.conf backup at {}", _0)] RemoveBackup(&'static str, #[error(source)] io::Error), @@ -179,7 +179,7 @@ fn read_config() -> Result<Config> { let contents = fs::read_to_string(RESOLV_CONF_PATH) .map_err(|e| Error::ReadResolvConf(RESOLV_CONF_PATH, e))?; - let config = Config::parse(&contents).map_err(|e| Error::ParseError(RESOLV_CONF_PATH, e))?; + let config = Config::parse(&contents).map_err(|e| Error::Parse(RESOLV_CONF_PATH, e))?; Ok(config) } @@ -198,8 +198,8 @@ fn restore_from_backup() -> Result<()> { match fs::read_to_string(RESOLV_CONF_BACKUP_PATH) { Ok(backup) => { log::info!("Restoring DNS state from backup"); - let config = Config::parse(&backup) - .map_err(|e| Error::ParseError(RESOLV_CONF_BACKUP_PATH, e))?; + let config = + Config::parse(&backup).map_err(|e| Error::Parse(RESOLV_CONF_BACKUP_PATH, e))?; write_config(&config)?; diff --git a/talpid-core/src/mpsc.rs b/talpid-core/src/mpsc.rs index 8c6424bc01..6492796cfc 100644 --- a/talpid-core/src/mpsc.rs +++ b/talpid-core/src/mpsc.rs @@ -1,11 +1,20 @@ +/// Error type for `Sender` trait. +#[derive(err_derive::Error, Debug)] +pub enum Error { + /// The underlying channel is closed. + #[error(display = "Channel is closed")] + ChannelClosed, +} + /// Abstraction over any type that can be used similarly to an `std::mpsc::Sender`. pub trait Sender<T> { /// Sends an item over the underlying channel, failing only if the channel is closed. - fn send(&self, item: T) -> Result<(), ()>; + fn send(&self, item: T) -> Result<(), Error>; } impl<E> Sender<E> for futures::channel::mpsc::UnboundedSender<E> { - fn send(&self, content: E) -> Result<(), ()> { - self.unbounded_send(content).map_err(|_| ()) + fn send(&self, content: E) -> Result<(), Error> { + self.unbounded_send(content) + .map_err(|_| Error::ChannelClosed) } } diff --git a/talpid-core/src/ping_monitor/icmp.rs b/talpid-core/src/ping_monitor/icmp.rs index 67f5b70cb5..0bcd9da72f 100644 --- a/talpid-core/src/ping_monitor/icmp.rs +++ b/talpid-core/src/ping_monitor/icmp.rs @@ -183,7 +183,7 @@ fn construct_icmpv4_packet_inner( let checksum = internet_checksum::checksum(buffer); (&mut buffer[ICMP_CHECKSUM_OFFSET..]) - .write(&checksum) + .write_all(&checksum) .unwrap(); true diff --git a/talpid-core/src/routing/linux.rs b/talpid-core/src/routing/linux.rs index 092ad6f52a..4b039fe9eb 100644 --- a/talpid-core/src/routing/linux.rs +++ b/talpid-core/src/routing/linux.rs @@ -87,13 +87,13 @@ pub type Result<T> = std::result::Result<T, Error>; #[error(no_from)] pub enum Error { #[error(display = "Failed to open a netlink connection")] - ConnectError(#[error(source)] io::Error), + Connect(#[error(source)] io::Error), #[error(display = "Failed to bind netlink socket")] - BindError(#[error(source)] io::Error), + Bind(#[error(source)] io::Error), #[error(display = "Netlink error")] - NetlinkError(#[error(source)] rtnetlink::Error), + Netlink(#[error(source)] rtnetlink::Error), #[error(display = "Route without a valid node")] InvalidRoute, @@ -108,16 +108,16 @@ pub enum Error { UnknownDeviceIndex(u32), #[error(display = "Failed to get a route for the given IP address")] - GetRouteError(#[error(source)] rtnetlink::Error), + GetRoute(#[error(source)] rtnetlink::Error), #[error(display = "No netlink response for route query")] - NoRouteError, + NoRoute, #[error(display = "Route node was malformed")] InvalidRouteNode, #[error(display = "No link found")] - LinkNotFoundError, + LinkNotFound, /// Unable to create routing table for tagged connections and packets. #[error(display = "Cannot find a free routing table ID")] @@ -140,14 +140,11 @@ pub struct RouteManagerImpl { impl RouteManagerImpl { pub async fn new(required_routes: HashSet<RequiredRoute>) -> Result<Self> { let (mut connection, handle, messages) = - rtnetlink::new_connection().map_err(Error::ConnectError)?; + rtnetlink::new_connection().map_err(Error::Connect)?; let mgroup_flags = RTMGRP_IPV4_ROUTE | RTMGRP_IPV6_ROUTE | RTMGRP_LINK | RTMGRP_NOTIFY; let addr = SocketAddr::new(0, mgroup_flags); - connection - .socket_mut() - .bind(&addr) - .map_err(Error::BindError)?; + connection.socket_mut().bind(&addr).map_err(Error::Bind)?; tokio::spawn(connection); @@ -179,11 +176,11 @@ impl RouteManagerImpl { let mut req = NetlinkMessage::from(RtnlMessage::NewRule((*rule).clone())); req.header.flags = NLM_F_REQUEST | NLM_F_ACK | NLM_F_CREATE | NLM_F_REPLACE; - let mut response = self.handle.request(req).map_err(Error::NetlinkError)?; + let mut response = self.handle.request(req).map_err(Error::Netlink)?; while let Some(message) = response.next().await { if let NetlinkPayload::Error(error) = message.payload { - return Err(Error::NetlinkError(rtnetlink::Error::NetlinkError(error))); + return Err(Error::Netlink(rtnetlink::Error::NetlinkError(error))); } } } @@ -236,7 +233,7 @@ impl RouteManagerImpl { let mut req = NetlinkMessage::from(RtnlMessage::GetRule(RuleMessage::default())); req.header.flags = NLM_F_REQUEST | NLM_F_ACK | NLM_F_DUMP; - let mut response = self.handle.request(req).map_err(Error::NetlinkError)?; + let mut response = self.handle.request(req).map_err(Error::Netlink)?; let mut rules = vec![]; @@ -246,7 +243,7 @@ impl RouteManagerImpl { rules.push(rule); } NetlinkPayload::Error(error) => { - return Err(Error::NetlinkError(rtnetlink::Error::NetlinkError(error))); + return Err(Error::Netlink(rtnetlink::Error::NetlinkError(error))); } _ => (), } @@ -260,12 +257,12 @@ impl RouteManagerImpl { let mut req = NetlinkMessage::from(RtnlMessage::DelRule(rule)); req.header.flags = NLM_F_REQUEST | NLM_F_ACK; - let mut response = self.handle.request(req).map_err(Error::NetlinkError)?; + let mut response = self.handle.request(req).map_err(Error::Netlink)?; while let Some(message) = response.next().await { if let NetlinkPayload::Error(error) = message.payload { if error.to_io().kind() != io::ErrorKind::NotFound { - return Err(Error::NetlinkError(rtnetlink::Error::NetlinkError(error))); + return Err(Error::Netlink(rtnetlink::Error::NetlinkError(error))); } } } @@ -296,7 +293,7 @@ impl RouteManagerImpl { ) -> Result<BTreeMap<u32, NetworkInterface>> { let mut link_map = BTreeMap::new(); let mut link_request = handle.link().get().execute(); - while let Some(link) = link_request.try_next().await.map_err(Error::NetlinkError)? { + while let Some(link) = link_request.try_next().await.map_err(Error::Netlink)? { if let Some((idx, device)) = Self::map_interface(link) { link_map.insert(idx, device); } @@ -543,7 +540,7 @@ impl RouteManagerImpl { async fn delete_route_if_exists(&self, route: &Route) -> Result<()> { if let Err(error) = self.delete_route(route).await { - if let Error::NetlinkError(rtnetlink::Error::NetlinkError(msg)) = &error { + if let Error::Netlink(rtnetlink::Error::NetlinkError(msg)) = &error { if msg.code == -libc::ESRCH { return Ok(()); } @@ -619,7 +616,7 @@ impl RouteManagerImpl { .del(route_message) .execute() .await - .map_err(Error::NetlinkError) + .map_err(Error::Netlink) } async fn add_route_direct(&mut self, route: Route) -> Result<()> { @@ -693,11 +690,11 @@ impl RouteManagerImpl { let mut req = NetlinkMessage::from(RtnlMessage::NewRoute(add_message)); req.header.flags = NLM_F_REQUEST | NLM_F_ACK | NLM_F_CREATE | NLM_F_REPLACE; - let mut response = self.handle.request(req).map_err(Error::NetlinkError)?; + let mut response = self.handle.request(req).map_err(Error::Netlink)?; while let Some(message) = response.next().await { if let NetlinkPayload::Error(err) = message.payload { - return Err(Error::NetlinkError(rtnetlink::Error::NetlinkError(err))); + return Err(Error::Netlink(rtnetlink::Error::NetlinkError(err))); } } Ok(()) @@ -759,7 +756,7 @@ impl RouteManagerImpl { } None => { log::error!("No route detected when assigning the mtu to the Wireguard tunnel"); - return Err(Error::NoRouteError); + return Err(Error::NoRoute); } } } @@ -767,17 +764,13 @@ impl RouteManagerImpl { "Retried {} times looking for the correct device and could not find it", RECURSION_LIMIT ); - Err(Error::NoRouteError) + Err(Error::NoRoute) } async fn get_device_mtu(&self, device: String) -> Result<u16> { let mut links = self.handle.link().get().execute(); let target_device = LinkNla::IfName(device); - while let Some(msg) = links - .try_next() - .await - .map_err(|_| Error::LinkNotFoundError)? - { + while let Some(msg) = links.try_next().await.map_err(|_| Error::LinkNotFound)? { let found = msg.nlas.iter().any(|e| *e == target_device); if found { if let Some(LinkNla::Mtu(mtu)) = @@ -788,7 +781,7 @@ impl RouteManagerImpl { } } } - Err(Error::LinkNotFoundError) + Err(Error::LinkNotFound) } async fn get_destination_route( @@ -813,11 +806,11 @@ impl RouteManagerImpl { let mut stream = execute_route_get_request(self.handle.clone(), message.clone()); match stream.try_next().await { Ok(Some(route_msg)) => self.parse_route_message(route_msg), - Ok(None) => Err(Error::NoRouteError), + Ok(None) => Err(Error::NoRoute), Err(rtnetlink::Error::NetlinkError(nl_err)) if nl_err.code == -libc::ENETUNREACH => { Ok(None) } - Err(err) => Err(Error::GetRouteError(err)), + Err(err) => Err(Error::GetRoute(err)), } } } diff --git a/talpid-core/src/routing/unix.rs b/talpid-core/src/routing/unix.rs index edfbdd2b85..326fb1fad1 100644 --- a/talpid-core/src/routing/unix.rs +++ b/talpid-core/src/routing/unix.rs @@ -19,16 +19,19 @@ use futures::stream::Stream; #[cfg(target_os = "linux")] use std::net::IpAddr; +#[allow(clippy::module_inception)] #[cfg(target_os = "macos")] #[path = "macos.rs"] mod imp; #[cfg(target_os = "macos")] pub(crate) use imp::listen_for_default_route_changes; +#[allow(clippy::module_inception)] #[cfg(target_os = "linux")] #[path = "linux.rs"] mod imp; +#[allow(clippy::module_inception)] #[cfg(target_os = "android")] #[path = "android.rs"] mod imp; diff --git a/talpid-core/src/tunnel/mod.rs b/talpid-core/src/tunnel/mod.rs index 5da3c092a4..f6ada1c2cf 100644 --- a/talpid-core/src/tunnel/mod.rs +++ b/talpid-core/src/tunnel/mod.rs @@ -1,6 +1,6 @@ use self::tun_provider::TunProvider; use crate::{logging, routing::RouteManagerHandle}; -use futures::channel::oneshot; +use futures::{channel::oneshot, future::BoxFuture}; use std::{ net::{IpAddr, Ipv4Addr, Ipv6Addr}, path::{Path, PathBuf}, @@ -98,6 +98,20 @@ pub struct TunnelMonitor { monitor: InternalTunnelMonitor, } +/// Arguments for creating a tunnel. +pub struct TunnelArgs<'a, L> +where + // L: (Fn(TunnelEvent) -> std::pin::Pin<Box<dyn std::future::Future<Output = ()> + Send>>) + L: (Fn(TunnelEvent) -> BoxFuture<'static, ()>) + Send + Clone + Sync + 'static, +{ + /// Resource directory. + pub resource_dir: &'a Path, + /// Callback function called when an event happens. + pub on_event: L, + /// Receiver oneshot channel for closing the tunnel. + pub tunnel_close_rx: oneshot::Receiver<()>, +} + // TODO(emilsp) move most of the openvpn tunnel details to OpenVpnTunnelMonitor impl TunnelMonitor { /// Creates a new `TunnelMonitor` that connects to the given remote and notifies `on_event` @@ -107,12 +121,10 @@ impl TunnelMonitor { runtime: tokio::runtime::Handle, tunnel_parameters: &mut TunnelParameters, log_dir: &Option<PathBuf>, - resource_dir: &Path, - on_event: L, tun_provider: Arc<Mutex<TunProvider>>, - route_manager: RouteManagerHandle, retry_attempt: u32, - tunnel_close_rx: oneshot::Receiver<()>, + route_manager: RouteManagerHandle, + init_args: TunnelArgs<'_, L>, ) -> Result<Self> where L: (Fn(TunnelEvent) -> std::pin::Pin<Box<dyn std::future::Future<Output = ()> + Send>>) @@ -129,9 +141,9 @@ impl TunnelMonitor { TunnelParameters::OpenVpn(config) => runtime.block_on(Self::start_openvpn_tunnel( config, log_file, - resource_dir, - on_event, - tunnel_close_rx, + init_args.resource_dir, + init_args.on_event, + init_args.tunnel_close_rx, #[cfg(target_os = "linux")] route_manager, )), @@ -142,12 +154,10 @@ impl TunnelMonitor { runtime, config, log_file, - resource_dir, - on_event, tun_provider, - route_manager, retry_attempt, - tunnel_close_rx, + route_manager, + init_args, ), } } @@ -178,12 +188,10 @@ impl TunnelMonitor { runtime: tokio::runtime::Handle, params: &mut wireguard_types::TunnelParameters, log: Option<PathBuf>, - resource_dir: &Path, - on_event: L, tun_provider: Arc<Mutex<TunProvider>>, - route_manager: RouteManagerHandle, retry_attempt: u32, - tunnel_close_rx: oneshot::Receiver<()>, + route_manager: RouteManagerHandle, + init_args: TunnelArgs<'_, L>, ) -> Result<Self> where L: (Fn(TunnelEvent) -> std::pin::Pin<Box<dyn std::future::Future<Output = ()> + Send>>) @@ -211,12 +219,10 @@ impl TunnelMonitor { None }, log.as_deref(), - resource_dir, - on_event, tun_provider, - route_manager, retry_attempt, - tunnel_close_rx, + route_manager, + init_args, )?; Ok(TunnelMonitor { monitor: InternalTunnelMonitor::Wireguard(monitor), diff --git a/talpid-core/src/tunnel/openvpn/mod.rs b/talpid-core/src/tunnel/openvpn/mod.rs index 910d5bb49e..9fdfb3e80b 100644 --- a/talpid-core/src/tunnel/openvpn/mod.rs +++ b/talpid-core/src/tunnel/openvpn/mod.rs @@ -310,10 +310,19 @@ impl OpenVpnMonitor<OpenVpnCommand> { let (event_server_abort_tx, event_server_abort_rx) = triggered::trigger(); + let openvpn_init_args = OpenVpnTunnelInitArgs { + event_server_abort_tx: event_server_abort_tx.clone(), + event_server_abort_rx, + plugin_path, + log_path, + user_pass_file, + proxy_auth_file, + proxy_monitor, + tunnel_close_rx, + }; Self::new_internal( cmd, - event_server_abort_tx.clone(), - event_server_abort_rx, + openvpn_init_args, event_server::OpenvpnEventProxyImpl { on_event, user_pass_file_path: user_pass_file_path.clone(), @@ -324,12 +333,6 @@ impl OpenVpnMonitor<OpenVpnCommand> { #[cfg(target_os = "linux")] ipv6_enabled, }, - plugin_path, - log_path, - user_pass_file, - proxy_auth_file, - proxy_monitor, - tunnel_close_rx, #[cfg(windows)] Box::new(wintun), ) @@ -371,23 +374,36 @@ fn extract_routes(env: &HashMap<String, String>) -> Result<HashSet<RequiredRoute Ok(routes) } +struct OpenVpnTunnelInitArgs { + event_server_abort_tx: triggered::Trigger, + event_server_abort_rx: triggered::Listener, + plugin_path: PathBuf, + log_path: Option<PathBuf>, + user_pass_file: mktemp::TempFile, + proxy_auth_file: Option<mktemp::TempFile>, + proxy_monitor: Option<Box<dyn ProxyMonitor>>, + tunnel_close_rx: oneshot::Receiver<()>, +} + impl<C: OpenVpnBuilder + Send + 'static> OpenVpnMonitor<C> { async fn new_internal<L>( mut cmd: C, - event_server_abort_tx: triggered::Trigger, - event_server_abort_rx: triggered::Listener, + init_args: OpenVpnTunnelInitArgs, on_event: L, - plugin_path: PathBuf, - log_path: Option<PathBuf>, - user_pass_file: mktemp::TempFile, - proxy_auth_file: Option<mktemp::TempFile>, - proxy_monitor: Option<Box<dyn ProxyMonitor>>, - tunnel_close_rx: oneshot::Receiver<()>, #[cfg(windows)] wintun: Box<dyn WintunContext>, ) -> Result<OpenVpnMonitor<C>> where L: event_server::OpenvpnEventProxy + Send + Sync + 'static, { + let event_server_abort_tx = init_args.event_server_abort_tx; + let event_server_abort_rx = init_args.event_server_abort_rx; + let plugin_path = init_args.plugin_path; + let log_path = init_args.log_path; + let user_pass_file = init_args.user_pass_file; + let proxy_auth_file = init_args.proxy_auth_file; + let proxy_monitor = init_args.proxy_monitor; + let tunnel_close_rx = init_args.tunnel_close_rx; + let (server_join_handle, ipc_path) = event_server::start(on_event, event_server_abort_rx) .await .map_err(Error::EventDispatcherError)?; @@ -1220,23 +1236,37 @@ mod tests { .map_err(Error::RuntimeError) } + fn create_init_args_plugin_log( + plugin_path: PathBuf, + log_path: Option<PathBuf>, + ) -> OpenVpnTunnelInitArgs { + let (_close_tx, close_rx) = oneshot::channel(); + let (event_server_abort_tx, event_server_abort_rx) = triggered::trigger(); + OpenVpnTunnelInitArgs { + event_server_abort_tx, + event_server_abort_rx, + plugin_path, + log_path, + user_pass_file: TempFile::new(), + proxy_auth_file: None, + proxy_monitor: None, + tunnel_close_rx: close_rx, + } + } + + fn create_init_args() -> OpenVpnTunnelInitArgs { + create_init_args_plugin_log("".into(), None) + } + #[test] fn sets_plugin() { let builder = TestOpenVpnBuilder::default(); - let (event_server_abort_tx, event_server_abort_rx) = triggered::trigger(); - let (_close_tx, close_rx) = oneshot::channel(); let runtime = new_runtime().unwrap(); + let openvpn_init_args = create_init_args_plugin_log("./my_test_plugin".into(), None); let _ = runtime.block_on(OpenVpnMonitor::new_internal( builder.clone(), - event_server_abort_tx, - event_server_abort_rx, + openvpn_init_args, TestOpenvpnEventProxy {}, - "./my_test_plugin".into(), - None, - TempFile::new(), - None, - None, - close_rx, #[cfg(windows)] Box::new(TestWintunContext {}), )); @@ -1249,20 +1279,13 @@ mod tests { #[test] fn sets_log() { let builder = TestOpenVpnBuilder::default(); - let (event_server_abort_tx, event_server_abort_rx) = triggered::trigger(); - let (_close_tx, close_rx) = oneshot::channel(); let runtime = new_runtime().unwrap(); + let openvpn_init_args = + create_init_args_plugin_log("".into(), Some(PathBuf::from("./my_test_log_file"))); let _ = runtime.block_on(OpenVpnMonitor::new_internal( builder.clone(), - event_server_abort_tx, - event_server_abort_rx, + openvpn_init_args, TestOpenvpnEventProxy {}, - "".into(), - Some(PathBuf::from("./my_test_log_file")), - TempFile::new(), - None, - None, - close_rx, #[cfg(windows)] Box::new(TestWintunContext {}), )); @@ -1276,21 +1299,13 @@ mod tests { fn exit_successfully() { let mut builder = TestOpenVpnBuilder::default(); builder.process_handle = Some(TestProcessHandle(0)); - let (event_server_abort_tx, event_server_abort_rx) = triggered::trigger(); - let (_close_tx, close_rx) = oneshot::channel(); let runtime = new_runtime().unwrap(); + let openvpn_init_args = create_init_args(); let testee = runtime .block_on(OpenVpnMonitor::new_internal( builder, - event_server_abort_tx, - event_server_abort_rx, + openvpn_init_args, TestOpenvpnEventProxy {}, - "".into(), - None, - TempFile::new(), - None, - None, - close_rx, #[cfg(windows)] Box::new(TestWintunContext {}), )) @@ -1302,21 +1317,13 @@ mod tests { fn exit_error() { let mut builder = TestOpenVpnBuilder::default(); builder.process_handle = Some(TestProcessHandle(1)); - let (event_server_abort_tx, event_server_abort_rx) = triggered::trigger(); - let (_close_tx, close_rx) = oneshot::channel(); let runtime = new_runtime().unwrap(); + let openvpn_init_args = create_init_args(); let testee = runtime .block_on(OpenVpnMonitor::new_internal( builder, - event_server_abort_tx, - event_server_abort_rx, + openvpn_init_args, TestOpenvpnEventProxy {}, - "".into(), - None, - TempFile::new(), - None, - None, - close_rx, #[cfg(windows)] Box::new(TestWintunContext {}), )) @@ -1328,21 +1335,13 @@ mod tests { fn wait_closed() { let mut builder = TestOpenVpnBuilder::default(); builder.process_handle = Some(TestProcessHandle(1)); - let (event_server_abort_tx, event_server_abort_rx) = triggered::trigger(); - let (_close_tx, close_rx) = oneshot::channel(); let runtime = new_runtime().unwrap(); + let openvpn_init_args = create_init_args(); let testee = runtime .block_on(OpenVpnMonitor::new_internal( builder, - event_server_abort_tx, - event_server_abort_rx, + openvpn_init_args, TestOpenvpnEventProxy {}, - "".into(), - None, - TempFile::new(), - None, - None, - close_rx, #[cfg(windows)] Box::new(TestWintunContext {}), )) @@ -1354,21 +1353,13 @@ mod tests { #[test] fn failed_process_start() { let builder = TestOpenVpnBuilder::default(); - let (event_server_abort_tx, event_server_abort_rx) = triggered::trigger(); - let (_close_tx, close_rx) = oneshot::channel(); let runtime = new_runtime().unwrap(); + let openvpn_init_args = create_init_args(); let result = runtime .block_on(OpenVpnMonitor::new_internal( builder, - event_server_abort_tx, - event_server_abort_rx, + openvpn_init_args, TestOpenvpnEventProxy {}, - "".into(), - None, - TempFile::new(), - None, - None, - close_rx, #[cfg(windows)] Box::new(TestWintunContext {}), )) diff --git a/talpid-core/src/tunnel/tun_provider/unix.rs b/talpid-core/src/tunnel/tun_provider/unix.rs index d8d3b7ce01..5c48a3c663 100644 --- a/talpid-core/src/tunnel/tun_provider/unix.rs +++ b/talpid-core/src/tunnel/tun_provider/unix.rs @@ -22,6 +22,12 @@ pub enum Error { /// Factory of tunnel devices on Unix systems. pub struct UnixTunProvider; +impl Default for UnixTunProvider { + fn default() -> Self { + Self::new() + } +} + impl UnixTunProvider { pub fn new() -> Self { UnixTunProvider diff --git a/talpid-core/src/tunnel/wireguard/logging.rs b/talpid-core/src/tunnel/wireguard/logging.rs index 1a7a52ed8b..35ec10fc2f 100644 --- a/talpid-core/src/tunnel/wireguard/logging.rs +++ b/talpid-core/src/tunnel/wireguard/logging.rs @@ -112,7 +112,7 @@ pub unsafe extern "system" fn wg_go_logging_callback( let level = match level { WG_GO_LOG_VERBOSE => LogLevel::Verbose, - WG_GO_LOG_ERROR | _ => LogLevel::Error, + _ => LogLevel::Error, }; log_inner(logfile, level, "wireguard-go", &managed_msg); } @@ -121,5 +121,5 @@ pub unsafe extern "system" fn wg_go_logging_callback( pub type WgLogLevel = u32; // wireguard-go supports log levels 0 through 3 with 3 being the most verbose // const WG_GO_LOG_SILENT: WgLogLevel = 0; -const WG_GO_LOG_ERROR: WgLogLevel = 1; +// const WG_GO_LOG_ERROR: WgLogLevel = 1; const WG_GO_LOG_VERBOSE: WgLogLevel = 2; diff --git a/talpid-core/src/tunnel/wireguard/mod.rs b/talpid-core/src/tunnel/wireguard/mod.rs index 7f17726c33..e49286cb30 100644 --- a/talpid-core/src/tunnel/wireguard/mod.rs +++ b/talpid-core/src/tunnel/wireguard/mod.rs @@ -1,15 +1,11 @@ use self::config::Config; #[cfg(not(windows))] use super::tun_provider; -use super::{tun_provider::TunProvider, TunnelEvent, TunnelMetadata}; +use super::{tun_provider::TunProvider, TunnelArgs, TunnelEvent, TunnelMetadata}; use crate::routing::{self, RequiredRoute, RouteManagerHandle}; +use futures::future::{abortable, AbortHandle as FutureAbortHandle, BoxFuture, Future}; #[cfg(windows)] use futures::{channel::mpsc, StreamExt}; -use futures::{ - channel::oneshot, - future::{abortable, AbortHandle as FutureAbortHandle}, - Future, -}; #[cfg(target_os = "linux")] use lazy_static::lazy_static; #[cfg(target_os = "linux")] @@ -54,6 +50,7 @@ mod wireguard_nt; use self::wireguard_go::WgGoTunnel; type Result<T> = std::result::Result<T, Error>; +type EventCallback = Box<dyn (Fn(TunnelEvent) -> BoxFuture<'static, ()>) + Send + Sync + 'static>; /// Errors that can happen in the Wireguard tunnel monitor. #[derive(err_derive::Error, Debug)] @@ -104,12 +101,7 @@ pub struct WireguardMonitor { /// Tunnel implementation tunnel: Arc<Mutex<Option<Box<dyn Tunnel>>>>, /// Callback to signal tunnel events - event_callback: Box< - dyn (Fn(TunnelEvent) -> std::pin::Pin<Box<dyn std::future::Future<Output = ()> + Send>>) - + Send - + Sync - + 'static, - >, + event_callback: EventCallback, close_msg_receiver: sync_mpsc::Receiver<CloseMsg>, pinger_stop_sender: sync_mpsc::Sender<()>, _obfuscator: Option<ObfuscatorHandle>, @@ -208,13 +200,13 @@ impl WireguardMonitor { mut config: Config, psk_negotiation: Option<PublicKey>, log_path: Option<&Path>, - resource_dir: &Path, - on_event: F, tun_provider: Arc<Mutex<TunProvider>>, - route_manager: RouteManagerHandle, retry_attempt: u32, - tunnel_close_rx: oneshot::Receiver<()>, + route_manager: RouteManagerHandle, + init_args: TunnelArgs<'_, F>, ) -> Result<WireguardMonitor> { + let on_event = init_args.on_event; + let endpoint_addrs: Vec<IpAddr> = config.peers.iter().map(|peer| peer.endpoint.ip()).collect(); let (close_msg_sender, close_msg_receiver) = sync_mpsc::channel(); @@ -228,7 +220,7 @@ impl WireguardMonitor { runtime.clone(), &Self::patch_allowed_ips(&config, psk_negotiation.is_some()), log_path, - resource_dir, + init_args.resource_dir, tun_provider, #[cfg(target_os = "windows")] setup_done_tx, @@ -351,7 +343,7 @@ impl WireguardMonitor { }); tokio::spawn(async move { - if tunnel_close_rx.await.is_ok() { + if init_args.tunnel_close_rx.await.is_ok() { monitor_handle.abort(); let _ = close_msg_sender.send(CloseMsg::Stop); } diff --git a/talpid-core/src/tunnel/wireguard/stats.rs b/talpid-core/src/tunnel/wireguard/stats.rs index bda8af2e1f..cec033f611 100644 --- a/talpid-core/src/tunnel/wireguard/stats.rs +++ b/talpid-core/src/tunnel/wireguard/stats.rs @@ -4,10 +4,10 @@ use super::wireguard_kernel::wg_message::{DeviceMessage, DeviceNla, PeerNla}; #[derive(err_derive::Error, Debug, PartialEq)] pub enum Error { #[error(display = "Failed to parse peer pubkey from string \"_0\"")] - PubKeyParseError(String, #[error(source)] hex::FromHexError), + PubKeyParse(String, #[error(source)] hex::FromHexError), #[error(display = "Failed to parse integer from string \"_0\"")] - IntParseError(String, #[error(source)] std::num::ParseIntError), + IntParse(String, #[error(source)] std::num::ParseIntError), #[error(display = "Device no longer exists")] NoTunnelDevice, @@ -47,7 +47,7 @@ impl Stats { "public_key" => { let mut buffer = [0u8; 32]; hex::decode_to_slice(value, &mut buffer) - .map_err(|err| Error::PubKeyParseError(value.to_string(), err))?; + .map_err(|err| Error::PubKeyParse(value.to_string(), err))?; peer = Some(buffer); tx_bytes = None; rx_bytes = None; @@ -57,7 +57,7 @@ impl Stats { value .trim() .parse() - .map_err(|err| Error::IntParseError(value.to_string(), err))?, + .map_err(|err| Error::IntParse(value.to_string(), err))?, ); } "tx_bytes" => { @@ -65,7 +65,7 @@ impl Stats { value .trim() .parse() - .map_err(|err| Error::IntParseError(value.to_string(), err))?, + .map_err(|err| Error::IntParse(value.to_string(), err))?, ); } @@ -145,7 +145,7 @@ mod test { assert_eq!( Stats::parse_config_str(invalid_input), - Err(Error::IntParseError(invalid_str, int_err)) + Err(Error::IntParse(invalid_str, int_err)) ); } } diff --git a/talpid-core/src/tunnel/wireguard/wireguard_kernel/mod.rs b/talpid-core/src/tunnel/wireguard/wireguard_kernel/mod.rs index 0f3866500e..5b7b6a1e12 100644 --- a/talpid-core/src/tunnel/wireguard/wireguard_kernel/mod.rs +++ b/talpid-core/src/tunnel/wireguard/wireguard_kernel/mod.rs @@ -33,16 +33,16 @@ pub use nm_tunnel::NetworkManagerTunnel; #[error(no_from)] pub enum Error { #[error(display = "Failed to decode netlink message")] - DecodeError(#[error(source)] DecodeError), + Decode(#[error(source)] DecodeError), #[error(display = "Failed to execute netlink control request")] - NetlinkControlMessageError(#[error(source)] nl_message::Error), + NetlinkControlMessage(#[error(source)] nl_message::Error), #[error(display = "Failed to open netlink socket")] - NetlinkSocketError(#[error(source)] std::io::Error), + NetlinkSocket(#[error(source)] std::io::Error), #[error(display = "Failed to send netlink control request")] - NetlinkRequestError(#[error(source)] netlink_proto::Error<NetlinkControlMessage>), + NetlinkRequest(#[error(source)] netlink_proto::Error<NetlinkControlMessage>), #[error(display = "WireGuard netlink interface unavailable. Is the kernel module loaded?")] WireguardNetlinkInterfaceUnavailable, @@ -60,25 +60,25 @@ pub enum Error { NoDevice, #[error(display = "Failed to get config: _0")] - WgGetConfError(netlink_packet_core::error::ErrorMessage), + WgGetConf(netlink_packet_core::error::ErrorMessage), #[error(display = "Failed to apply config: _0")] - WgSetConfError(netlink_packet_core::error::ErrorMessage), + WgSetConf(netlink_packet_core::error::ErrorMessage), #[error(display = "Interface name too long")] - InterfaceNameError, + InterfaceName, #[error(display = "Send request error")] - SendRequestError(#[error(source)] NetlinkError<DeviceMessage>), + SendRequest(#[error(source)] NetlinkError<DeviceMessage>), #[error(display = "Create device error")] - NetlinkCreateDeviceError(#[error(source)] rtnetlink::Error), + NetlinkCreateDevice(#[error(source)] rtnetlink::Error), #[error(display = "Add IP to device error")] - NetlinkSetIpError(rtnetlink::Error), + NetlinkSetIp(rtnetlink::Error), #[error(display = "Failed to delete device")] - DeleteDeviceError(#[error(source)] rtnetlink::Error), + DeleteDevice(#[error(source)] rtnetlink::Error), #[error(display = "NetworkManager error")] NetworkManager(#[error(source)] nm_tunnel::Error), @@ -98,7 +98,7 @@ impl Handle { pub async fn connect() -> Result<Self, Error> { let message_type = Self::get_wireguard_message_type().await?; let (conn, wireguard_connection, _messages) = - netlink_proto::new_connection(NETLINK_GENERIC).map_err(Error::NetlinkSocketError)?; + netlink_proto::new_connection(NETLINK_GENERIC).map_err(Error::NetlinkSocket)?; let wg_handle = WireguardConnection { message_type, connection: wireguard_connection, @@ -106,7 +106,7 @@ impl Handle { let (abortable_connection, wg_abort_handle) = abortable(conn); tokio::spawn(abortable_connection); let (conn, route_handle, _messages) = - rtnetlink::new_connection().map_err(Error::NetlinkSocketError)?; + rtnetlink::new_connection().map_err(Error::NetlinkSocket)?; let (abortable_connection, route_abort_handle) = abortable(conn); tokio::spawn(abortable_connection); @@ -120,21 +120,21 @@ impl Handle { async fn get_wireguard_message_type() -> Result<u16, Error> { let (conn, mut handle, _messages) = - netlink_proto::new_connection(NETLINK_GENERIC).map_err(Error::NetlinkSocketError)?; + netlink_proto::new_connection(NETLINK_GENERIC).map_err(Error::NetlinkSocket)?; let (conn, abort_handle) = abortable(conn); tokio::spawn(conn); let result = async move { let mut message: NetlinkMessage<NetlinkControlMessage> = NetlinkControlMessage::get_netlink_family_id(CString::new("wireguard").unwrap()) - .map_err(Error::NetlinkControlMessageError)? + .map_err(Error::NetlinkControlMessage)? .into(); message.header.flags = NLM_F_REQUEST | NLM_F_ACK; let mut req = handle .request(message, SocketAddr::new(0, 0)) - .map_err(Error::NetlinkRequestError)?; + .map_err(Error::NetlinkRequest)?; let response = req.next().await; if let Some(response) = response { if let NetlinkPayload::InnerMessage(msg) = response.payload { @@ -177,14 +177,14 @@ impl Handle { let mut response = self .route_handle .request(add_request) - .map_err(Error::NetlinkCreateDeviceError)?; + .map_err(Error::NetlinkCreateDevice)?; while let Some(response_message) = response.next().await { if let NetlinkPayload::Error(err) = response_message.payload { // if the device exists, verify that it's a wireguard device if -err.code != libc::EEXIST { - return Err(Error::NetlinkCreateDeviceError( - rtnetlink::Error::NetlinkError(err), - )); + return Err(Error::NetlinkCreateDevice(rtnetlink::Error::NetlinkError( + err, + ))); } } } @@ -208,9 +208,9 @@ impl Handle { let mut response = self .route_handle .request(request) - .map_err(Error::NetlinkSetIpError)?; + .map_err(Error::NetlinkSetIp)?; while let Some(response_message) = response.next().await { - consume_netlink_error(response_message, Error::NetlinkSetIpError)?; + consume_netlink_error(response_message, Error::NetlinkSetIp)?; } Ok(()) @@ -226,9 +226,9 @@ impl Handle { let mut response = self .route_handle .request(request) - .map_err(Error::DeleteDeviceError)?; + .map_err(Error::DeleteDevice)?; while let Some(message) = response.next().await { - consume_netlink_error(message, Error::DeleteDeviceError)?; + consume_netlink_error(message, Error::DeleteDevice)?; } Ok(()) @@ -269,7 +269,7 @@ impl WireguardConnection { let mut response = self .connection .request(netlink_message, SocketAddr::new(0, 0)) - .map_err(Error::SendRequestError)?; + .map_err(Error::SendRequest)?; match response.next().await { Some(received_message) => match received_message.payload { NetlinkPayload::InnerMessage(inner) => Ok(inner), @@ -277,7 +277,7 @@ impl WireguardConnection { if err.code == -libc::ENODEV { Err(Error::NoDevice) } else { - Err(Error::WgGetConfError(err)) + Err(Error::WgGetConf(err)) } } anything_else => { @@ -297,11 +297,11 @@ impl WireguardConnection { let mut request = self .connection .request(netlink_message, SocketAddr::new(0, 0)) - .map_err(Error::SendRequestError)?; + .map_err(Error::SendRequest)?; while let Some(response) = request.next().await { if let NetlinkPayload::Error(err) = response.payload { - return Err(Error::WgSetConfError(err)); + return Err(Error::WgSetConf(err)); } } Ok(()) diff --git a/talpid-core/src/tunnel/wireguard/wireguard_kernel/wg_message.rs b/talpid-core/src/tunnel/wireguard/wireguard_kernel/wg_message.rs index be2231f771..f2de334762 100644 --- a/talpid-core/src/tunnel/wireguard/wireguard_kernel/wg_message.rs +++ b/talpid-core/src/tunnel/wireguard/wireguard_kernel/wg_message.rs @@ -110,9 +110,9 @@ impl DeviceMessage { } pub fn get_by_name(message_type: u16, name: String) -> Result<Self, Error> { - let c_name = CString::new(name).map_err(|_| Error::InterfaceNameError)?; + let c_name = CString::new(name).map_err(|_| Error::InterfaceName)?; if c_name.as_bytes_with_nul().len() > libc::IFNAMSIZ { - return Err(Error::InterfaceNameError); + return Err(Error::InterfaceName); } Ok(Self { @@ -178,9 +178,7 @@ impl NetlinkDeserializable<DeviceMessage> for DeviceMessage { let new_payload = &payload[mem::size_of::<libc::genlmsghdr>()..]; let mut nlas = vec![]; for buf in NlasIterator::new(new_payload) { - nlas.push( - DeviceNla::parse(&buf.map_err(Error::DecodeError)?).map_err(Error::DecodeError)?, - ); + nlas.push(DeviceNla::parse(&buf.map_err(Error::Decode)?).map_err(Error::Decode)?); } Ok(DeviceMessage { @@ -391,13 +389,13 @@ impl Nla for PeerNla { InetAddr::V4(sockaddr_in) => { // SAFETY: `sockaddr_in` has no padding bytes buffer - .write(unsafe { struct_as_slice(sockaddr_in) }) + .write_all(unsafe { struct_as_slice(sockaddr_in) }) .expect("Buffer too small for sockaddr_in"); } InetAddr::V6(sockaddr_in6) => { // SAFETY: `sockaddr_in` has no padding bytes buffer - .write(unsafe { struct_as_slice(sockaddr_in6) }) + .write_all(unsafe { struct_as_slice(sockaddr_in6) }) .expect("Buffer too small for sockaddr_in6"); } }, @@ -408,7 +406,7 @@ impl Nla for PeerNla { let timespec: &libc::timespec = last_handshake.as_ref(); // SAFETY: `timespec` has no padding bytes buffer - .write(unsafe { struct_as_slice(timespec) }) + .write_all(unsafe { struct_as_slice(timespec) }) .expect("Buffer too small for timespec"); } RxBytes(num_bytes) | TxBytes(num_bytes) => NativeEndian::write_u64(buffer, *num_bytes), @@ -535,7 +533,7 @@ impl Nla for AllowedIpNla { } IpAddr(ip_addr) => { buffer - .write(&ip_addr_to_bytes(ip_addr)) + .write_all(&ip_addr_to_bytes(ip_addr)) .expect("Buffer too small for AllowedIpNla::IpAddr"); } CidrMask(cidr_mask) => buffer[0] = *cidr_mask, diff --git a/talpid-core/src/tunnel_state_machine/connecting_state.rs b/talpid-core/src/tunnel_state_machine/connecting_state.rs index 7536b26b09..e787729c04 100644 --- a/talpid-core/src/tunnel_state_machine/connecting_state.rs +++ b/talpid-core/src/tunnel_state_machine/connecting_state.rs @@ -6,7 +6,9 @@ use super::{ use crate::{ firewall::FirewallPolicy, routing::RouteManager, - tunnel::{self, tun_provider::TunProvider, TunnelEvent, TunnelMetadata, TunnelMonitor}, + tunnel::{ + self, tun_provider::TunProvider, TunnelArgs, TunnelEvent, TunnelMetadata, TunnelMonitor, + }, }; use cfg_if::cfg_if; use futures::{ @@ -142,16 +144,20 @@ impl ConnectingState { } }; + let init_args = TunnelArgs { + resource_dir: &resource_dir, + on_event: on_tunnel_event, + tunnel_close_rx, + }; + let block_reason = match TunnelMonitor::start( runtime, &mut tunnel_parameters, &log_dir, - &resource_dir, - on_tunnel_event, tun_provider, - route_manager_handle, retry_attempt, - tunnel_close_rx, + route_manager_handle, + init_args, ) { Ok(monitor) => { let reason = Self::wait_for_tunnel_monitor(monitor, retry_attempt); diff --git a/talpid-core/src/tunnel_state_machine/mod.rs b/talpid-core/src/tunnel_state_machine/mod.rs index 3552eeab61..4c3eda0ecb 100644 --- a/talpid-core/src/tunnel_state_machine/mod.rs +++ b/talpid-core/src/tunnel_state_machine/mod.rs @@ -132,23 +132,25 @@ pub async fn spawn( let (shutdown_tx, shutdown_rx) = oneshot::channel(); let weak_command_tx = Arc::downgrade(&command_tx); - let state_machine = TunnelStateMachine::new( - initial_settings, - weak_command_tx, - offline_state_listener, + + let init_args = TunnelStateMachineInitArgs { + settings: initial_settings, + command_tx: weak_command_tx, + offline_state_tx: offline_state_listener, tunnel_parameters_generator, tun_provider, log_dir, resource_dir, - command_rx, + commands_rx: command_rx, #[cfg(target_os = "windows")] volume_update_rx, #[cfg(target_os = "macos")] exclusion_gid, #[cfg(target_os = "android")] android_context, - ) - .await?; + }; + + let state_machine = TunnelStateMachine::new(init_args).await?; #[cfg(windows)] let split_tunnel = state_machine.shared_values.split_tunnel.handle(); @@ -219,20 +221,35 @@ struct TunnelStateMachine { shared_values: SharedTunnelStateValues, } +/// Tunnel state machine initialization arguments arguments +struct TunnelStateMachineInitArgs<G: TunnelParametersGenerator> { + settings: InitialTunnelState, + command_tx: std::sync::Weak<mpsc::UnboundedSender<TunnelCommand>>, + offline_state_tx: mpsc::UnboundedSender<bool>, + tunnel_parameters_generator: G, + tun_provider: TunProvider, + log_dir: Option<PathBuf>, + resource_dir: PathBuf, + commands_rx: mpsc::UnboundedReceiver<TunnelCommand>, + #[cfg(target_os = "windows")] + volume_update_rx: mpsc::UnboundedReceiver<()>, + #[cfg(target_os = "macos")] + exclusion_gid: u32, + #[cfg(target_os = "android")] + android_context: AndroidContext, +} + impl TunnelStateMachine { async fn new( - settings: InitialTunnelState, - command_tx: std::sync::Weak<mpsc::UnboundedSender<TunnelCommand>>, - offline_state_tx: mpsc::UnboundedSender<bool>, - tunnel_parameters_generator: impl TunnelParametersGenerator, - tun_provider: TunProvider, - log_dir: Option<PathBuf>, - resource_dir: PathBuf, - commands_rx: mpsc::UnboundedReceiver<TunnelCommand>, - #[cfg(target_os = "windows")] volume_update_rx: mpsc::UnboundedReceiver<()>, - #[cfg(target_os = "macos")] exclusion_gid: u32, - #[cfg(target_os = "android")] android_context: AndroidContext, + args: TunnelStateMachineInitArgs<impl TunnelParametersGenerator>, ) -> Result<Self, Error> { + #[cfg(target_os = "windows")] + let volume_update_rx = args.volume_update_rx; + #[cfg(target_os = "macos")] + let exclusion_gid = args.exclusion_gid; + #[cfg(target_os = "android")] + let android_context = args.android_context; + let runtime = tokio::runtime::Handle::current(); #[cfg(target_os = "macos")] @@ -242,20 +259,24 @@ impl TunnelStateMachine { let power_mgmt_rx = crate::windows::window::PowerManagementListener::new(); #[cfg(windows)] - let split_tunnel = - split_tunnel::SplitTunnel::new(runtime.clone(), command_tx.clone(), volume_update_rx) - .map_err(Error::InitSplitTunneling)?; + let split_tunnel = split_tunnel::SplitTunnel::new( + runtime.clone(), + args.command_tx.clone(), + volume_update_rx, + ) + .map_err(Error::InitSplitTunneling)?; - let args = FirewallArguments { - initial_state: if settings.block_when_disconnected || !settings.reset_firewall { - InitialFirewallState::Blocked(settings.allowed_endpoint.clone()) + let fw_args = FirewallArguments { + initial_state: if args.settings.block_when_disconnected || !args.settings.reset_firewall + { + InitialFirewallState::Blocked(args.settings.allowed_endpoint.clone()) } else { InitialFirewallState::None }, - allow_lan: settings.allow_lan, + allow_lan: args.settings.allow_lan, }; - let firewall = Firewall::from_args(args).map_err(Error::InitFirewallError)?; + let firewall = Firewall::from_args(fw_args).map_err(Error::InitFirewallError)?; let route_manager = RouteManager::new(HashSet::new()) .await .map_err(Error::InitRouteManagerError)?; @@ -267,20 +288,20 @@ impl TunnelStateMachine { .handle() .map_err(Error::InitRouteManagerError)?, #[cfg(target_os = "macos")] - command_tx.clone(), + args.command_tx.clone(), ) .map_err(Error::InitDnsMonitorError)?; let (offline_tx, mut offline_rx) = mpsc::unbounded(); - let initial_offline_state_tx = offline_state_tx.clone(); + let initial_offline_state_tx = args.offline_state_tx.clone(); tokio::spawn(async move { while let Some(offline) = offline_rx.next().await { - if let Some(tx) = command_tx.upgrade() { + if let Some(tx) = args.command_tx.upgrade() { let _ = tx.unbounded_send(TunnelCommand::IsOffline(offline)); } else { break; } - let _ = offline_state_tx.unbounded_send(offline); + let _ = args.offline_state_tx.unbounded_send(offline); } }); let mut offline_monitor = offline::spawn_monitor( @@ -301,7 +322,7 @@ impl TunnelStateMachine { #[cfg(windows)] split_tunnel - .set_paths_sync(&settings.exclude_paths) + .set_paths_sync(&args.settings.exclude_paths) .map_err(Error::InitSplitTunneling)?; let mut shared_values = SharedTunnelStateValues { @@ -312,15 +333,15 @@ impl TunnelStateMachine { dns_monitor, route_manager, _offline_monitor: offline_monitor, - allow_lan: settings.allow_lan, - block_when_disconnected: settings.block_when_disconnected, + allow_lan: args.settings.allow_lan, + block_when_disconnected: args.settings.block_when_disconnected, is_offline, - dns_servers: settings.dns_servers, - allowed_endpoint: settings.allowed_endpoint, - tunnel_parameters_generator: Box::new(tunnel_parameters_generator), - tun_provider: Arc::new(Mutex::new(tun_provider)), - log_dir, - resource_dir, + dns_servers: args.settings.dns_servers, + allowed_endpoint: args.settings.allowed_endpoint, + tunnel_parameters_generator: Box::new(args.tunnel_parameters_generator), + tun_provider: Arc::new(Mutex::new(args.tun_provider)), + log_dir: args.log_dir, + resource_dir: args.resource_dir, #[cfg(target_os = "linux")] connectivity_check_was_enabled: None, #[cfg(target_os = "macos")] @@ -331,11 +352,11 @@ impl TunnelStateMachine { tokio::task::spawn_blocking(move || { let (initial_state, _) = - DisconnectedState::enter(&mut shared_values, settings.reset_firewall); + DisconnectedState::enter(&mut shared_values, args.settings.reset_firewall); Ok(TunnelStateMachine { current_state: Some(initial_state), - commands: commands_rx.fuse(), + commands: args.commands_rx.fuse(), shared_values, }) }) diff --git a/talpid-dbus/src/network_manager.rs b/talpid-dbus/src/network_manager.rs index 87566891d1..5d89a69035 100644 --- a/talpid-dbus/src/network_manager.rs +++ b/talpid-dbus/src/network_manager.rs @@ -59,6 +59,7 @@ const MAXIMUM_SUPPORTED_MINOR_VERSION: u32 = 26; const NM_DEVICE_STATE_CHANGED: &str = "StateChanged"; pub type Result<T> = std::result::Result<T, Error>; +type NetworkSettings<'a> = HashMap<String, HashMap<String, Variant<Box<dyn RefArg + 'a>>>>; #[derive(err_derive::Error, Debug)] pub enum Error { @@ -447,10 +448,8 @@ impl NetworkManager { let device = self.as_path(&device_path); // Get the last applied connection - let (mut settings, version_id): ( - HashMap<String, HashMap<String, Variant<Box<dyn RefArg>>>>, - u64, - ) = device.method_call(NM_DEVICE, "GetAppliedConnection", (0u32,))?; + let (mut settings, version_id): (NetworkSettings, u64) = + device.method_call(NM_DEVICE, "GetAppliedConnection", (0u32,))?; // Keep changed routes. // These routes were modified outside NM, likely by RouteManager. @@ -576,7 +575,7 @@ impl NetworkManager { } fn update_dns_config<'a, T>( - settings: &mut HashMap<String, HashMap<String, Variant<Box<dyn RefArg + 'a>>>>, + settings: &mut NetworkSettings<'a>, ip_protocol: &'static str, servers: T, ) where diff --git a/talpid-dbus/src/systemd_resolved.rs b/talpid-dbus/src/systemd_resolved.rs index d664872557..6ba8603508 100644 --- a/talpid-dbus/src/systemd_resolved.rs +++ b/talpid-dbus/src/systemd_resolved.rs @@ -349,7 +349,7 @@ impl SystemdResolved { .map_err(Error::DBusRpcError) } - fn link_disable_dns_over_tls<'a, 'b: 'a>(&'a self, interface_index: u32) -> Result<()> { + fn link_disable_dns_over_tls(&self, interface_index: u32) -> Result<()> { let link_object_path = self .fetch_link(interface_index) .map_err(|e| Error::GetLinkError(Box::new(e)))?; diff --git a/talpid-types/src/net/wireguard.rs b/talpid-types/src/net/wireguard.rs index ff7ebef090..6a8bbfd521 100644 --- a/talpid-types/src/net/wireguard.rs +++ b/talpid-types/src/net/wireguard.rs @@ -95,6 +95,7 @@ fn default_wgnt_setting() -> bool { true } +#[allow(clippy::derivable_impls)] impl Default for TunnelOptions { fn default() -> Self { Self { |
