diff options
| author | Janito Vaqueiro Ferreira Filho <janito@mullvad.net> | 2020-01-07 13:45:32 -0300 |
|---|---|---|
| committer | Janito Vaqueiro Ferreira Filho <janito@mullvad.net> | 2020-01-07 13:45:32 -0300 |
| commit | df1c08b5a6ecfbc51e6b06c18481d16c614e494f (patch) | |
| tree | 4370c3d9459c0c9cacde20a4d87a73a12bd9b943 | |
| parent | c36a45f09463f604d833b440021323ba28b3f2ed (diff) | |
| parent | 58c33ca12806f7ffb6a16e45a36d7aa1d1ef4522 (diff) | |
| download | mullvadvpn-df1c08b5a6ecfbc51e6b06c18481d16c614e494f.tar.xz mullvadvpn-df1c08b5a6ecfbc51e6b06c18481d16c614e494f.zip | |
Merge branch 'stop-version-updater-on-shutdown'
| -rw-r--r-- | mullvad-daemon/src/lib.rs | 77 | ||||
| -rw-r--r-- | mullvad-daemon/src/version_check.rs | 32 | ||||
| -rw-r--r-- | mullvad-daemon/src/wireguard.rs | 19 | ||||
| -rw-r--r-- | talpid-core/src/mpsc.rs | 32 |
4 files changed, 95 insertions, 65 deletions
diff --git a/mullvad-daemon/src/lib.rs b/mullvad-daemon/src/lib.rs index 4d01e802f5..9a9e42e697 100644 --- a/mullvad-daemon/src/lib.rs +++ b/mullvad-daemon/src/lib.rs @@ -28,8 +28,12 @@ use crate::management_interface::{ }; use futures::{ future::{self, Executor}, - sync::{mpsc::UnboundedSender, oneshot}, - Future, + stream::Wait, + sync::{ + mpsc::{UnboundedReceiver, UnboundedSender}, + oneshot, + }, + Future, Stream, }; use log::{debug, error, info, warn}; use mullvad_rpc::{AccountsProxy, HttpHandle, WireguardKeyProxy}; @@ -178,6 +182,12 @@ impl From<ManagementCommand> for InternalDaemonEvent { } } +impl From<AppVersionInfo> for InternalDaemonEvent { + fn from(command: AppVersionInfo) -> Self { + InternalDaemonEvent::NewAppVersionInfo(command) + } +} + #[derive(Clone, Debug, Eq, PartialEq)] enum DaemonExecutionState { Running, @@ -224,7 +234,7 @@ impl DaemonExecutionState { pub struct DaemonCommandSender(IntoSender<ManagementCommand, InternalDaemonEvent>); impl DaemonCommandSender { - pub(crate) fn new(internal_event_sender: mpsc::Sender<InternalDaemonEvent>) -> Self { + pub(crate) fn new(internal_event_sender: UnboundedSender<InternalDaemonEvent>) -> Self { DaemonCommandSender(IntoSender::from(internal_event_sender)) } @@ -257,8 +267,8 @@ pub struct Daemon<L: EventListener = ManagementInterfaceEventBroadcaster> { tunnel_state: TunnelState, target_state: TargetState, state: DaemonExecutionState, - rx: mpsc::Receiver<InternalDaemonEvent>, - tx: mpsc::Sender<InternalDaemonEvent>, + rx: Wait<UnboundedReceiver<InternalDaemonEvent>>, + tx: UnboundedSender<InternalDaemonEvent>, reconnection_loop_tx: Option<mpsc::Sender<()>>, event_listener: L, settings: Settings, @@ -286,7 +296,7 @@ impl Daemon<ManagementInterfaceEventBroadcaster> { if rpc_uniqueness_check::is_another_instance_running() { return Err(Error::DaemonIsAlreadyRunning); } - let (tx, rx) = mpsc::channel(); + let (tx, rx) = futures::sync::mpsc::unbounded(); let management_interface_broadcaster = Self::start_management_interface(tx.clone())?; Self::start_internal( @@ -304,7 +314,7 @@ impl Daemon<ManagementInterfaceEventBroadcaster> { // Starts the management interface and spawns a thread that will process it. // Returns a handle that allows notifying all subscribers on events. fn start_management_interface( - event_tx: mpsc::Sender<InternalDaemonEvent>, + event_tx: UnboundedSender<InternalDaemonEvent>, ) -> Result<ManagementInterfaceEventBroadcaster> { let multiplex_event_tx = IntoSender::from(event_tx.clone()); let server = Self::start_management_interface_server(multiplex_event_tx)?; @@ -325,12 +335,12 @@ impl Daemon<ManagementInterfaceEventBroadcaster> { fn spawn_management_interface_wait_thread( server: ManagementInterfaceServer, - exit_tx: mpsc::Sender<InternalDaemonEvent>, + exit_tx: UnboundedSender<InternalDaemonEvent>, ) { thread::spawn(move || { server.wait(); info!("Management interface shut down"); - let _ = exit_tx.send(InternalDaemonEvent::ManagementInterfaceExited); + let _ = exit_tx.unbounded_send(InternalDaemonEvent::ManagementInterfaceExited); }); } } @@ -346,7 +356,7 @@ where cache_dir: PathBuf, #[cfg(target_os = "android")] android_context: AndroidContext, ) -> Result<Self> { - let (tx, rx) = mpsc::channel(); + let (tx, rx) = futures::sync::mpsc::unbounded(); Self::start_internal( tx, @@ -361,8 +371,8 @@ where } fn start_internal( - internal_event_tx: mpsc::Sender<InternalDaemonEvent>, - internal_event_rx: mpsc::Receiver<InternalDaemonEvent>, + internal_event_tx: UnboundedSender<InternalDaemonEvent>, + internal_event_rx: UnboundedReceiver<InternalDaemonEvent>, event_listener: L, log_dir: Option<PathBuf>, resource_dir: PathBuf, @@ -397,12 +407,6 @@ where &cache_dir, ); - let version_check_internal_event_tx = internal_event_tx.clone(); - let on_version_check_update = move |app_version_info: &AppVersionInfo| { - let _ = version_check_internal_event_tx.send(InternalDaemonEvent::NewAppVersionInfo( - app_version_info.clone(), - )); - }; let app_version_info = match version_check::load_cache(&cache_dir) { Ok(app_version_info) => app_version_info, Err(error) => { @@ -422,7 +426,7 @@ where let version_check_future = version_check::VersionUpdater::new( rpc_handle.clone(), cache_dir.clone(), - on_version_check_update, + internal_event_tx.clone(), app_version_info.clone(), ); tokio_remote.spawn(|_| version_check_future); @@ -462,7 +466,7 @@ where tunnel_state: TunnelState::Disconnected, target_state: TargetState::Unsecured, state: DaemonExecutionState::Running, - rx: internal_event_rx, + rx: internal_event_rx.wait(), tx: internal_event_tx, reconnection_loop_tx: None, event_listener, @@ -511,7 +515,7 @@ where info!("Automatically connecting since auto-connect is turned on"); self.set_target_state(TargetState::Secured); } - while let Ok(event) = self.rx.recv() { + while let Some(Ok(event)) = self.rx.next() { self.handle_event(event)?; if self.state == DaemonExecutionState::Finished { break; @@ -788,9 +792,11 @@ where if let Err(mpsc::RecvTimeoutError::Timeout) = rx.recv_timeout(delay) { debug!("Attempting to reconnect"); - let _ = tunnel_command_tx.send(InternalDaemonEvent::ManagementInterfaceEvent( - ManagementCommand::SetTargetState(result_tx, TargetState::Secured), - )); + let _ = tunnel_command_tx.unbounded_send( + InternalDaemonEvent::ManagementInterfaceEvent( + ManagementCommand::SetTargetState(result_tx, TargetState::Secured), + ), + ); } }); } @@ -1030,8 +1036,10 @@ where move |result| -> std::result::Result<(), ()> { match result { Ok(account_token) => { - let _ = - daemon_tx.send(InternalDaemonEvent::NewAccountEvent(account_token, tx)); + let _ = daemon_tx.unbounded_send(InternalDaemonEvent::NewAccountEvent( + account_token, + tx, + )); } Err(err) => { let _ = tx.send(Err(err)); @@ -1642,17 +1650,17 @@ where } pub struct DaemonShutdownHandle { - tx: mpsc::Sender<InternalDaemonEvent>, + tx: UnboundedSender<InternalDaemonEvent>, } impl DaemonShutdownHandle { pub fn shutdown(&self) { - let _ = self.tx.send(InternalDaemonEvent::TriggerShutdown); + let _ = self.tx.unbounded_send(InternalDaemonEvent::TriggerShutdown); } } struct MullvadTunnelParametersGenerator { - tx: mpsc::Sender<InternalDaemonEvent>, + tx: UnboundedSender<InternalDaemonEvent>, } impl TunnelParametersGenerator for MullvadTunnelParametersGenerator { @@ -1661,10 +1669,13 @@ impl TunnelParametersGenerator for MullvadTunnelParametersGenerator { retry_attempt: u32, ) -> std::result::Result<TunnelParameters, ParameterGenerationError> { let (response_tx, response_rx) = mpsc::channel(); - if let Err(_) = self.tx.send(InternalDaemonEvent::GenerateTunnelParameters( - response_tx, - retry_attempt, - )) { + if let Err(_) = self + .tx + .unbounded_send(InternalDaemonEvent::GenerateTunnelParameters( + response_tx, + retry_attempt, + )) + { log::error!("Failed to send daemon command to generate tunnel parameters!"); return Err(ParameterGenerationError::NoMatchingRelay); } diff --git a/mullvad-daemon/src/version_check.rs b/mullvad-daemon/src/version_check.rs index 6a89ee8060..45352d3e2d 100644 --- a/mullvad-daemon/src/version_check.rs +++ b/mullvad-daemon/src/version_check.rs @@ -1,4 +1,4 @@ -use futures::{Async, Future, Poll}; +use futures::{sync::mpsc::UnboundedSender, Async, Future, Poll}; use mullvad_rpc::{AppVersionProxy, HttpHandle}; use mullvad_types::version::AppVersionInfo; use std::{ @@ -52,17 +52,17 @@ pub enum Error { Download(#[error(source)] mullvad_rpc::Error), } -impl<F> From<TimeoutError<F>> for Error { - fn from(_: TimeoutError<F>) -> Error { +impl<T> From<TimeoutError<T>> for Error { + fn from(_: TimeoutError<T>) -> Error { Error::DownloadTimeout } } -pub struct VersionUpdater<F: Fn(&AppVersionInfo) + Send + 'static> { +pub struct VersionUpdater<T: From<AppVersionInfo>> { version_proxy: AppVersionProxy<HttpHandle>, cache_path: PathBuf, - on_version_update: F, + update_sender: UnboundedSender<T>, last_app_version_info: AppVersionInfo, next_update_time: Instant, state: VersionUpdaterState, @@ -73,11 +73,11 @@ enum VersionUpdaterState { Updating(Box<dyn Future<Item = AppVersionInfo, Error = Error> + Send + 'static>), } -impl<F: Fn(&AppVersionInfo) + Send + 'static> VersionUpdater<F> { +impl<T: From<AppVersionInfo>> VersionUpdater<T> { pub fn new( rpc_handle: HttpHandle, cache_dir: PathBuf, - on_version_update: F, + update_sender: UnboundedSender<T>, last_app_version_info: AppVersionInfo, ) -> Self { let version_proxy = AppVersionProxy::new(rpc_handle); @@ -85,7 +85,7 @@ impl<F: Fn(&AppVersionInfo) + Send + 'static> VersionUpdater<F> { Self { version_proxy, cache_path, - on_version_update, + update_sender, last_app_version_info, next_update_time: Instant::now(), state: VersionUpdaterState::Sleeping(Self::create_sleep_future()), @@ -118,12 +118,16 @@ impl<F: Fn(&AppVersionInfo) + Send + 'static> VersionUpdater<F> { } } -impl<F: Fn(&AppVersionInfo) + Send + 'static> Future for VersionUpdater<F> { +impl<T: From<AppVersionInfo>> Future for VersionUpdater<T> { type Item = (); type Error = (); fn poll(&mut self) -> Poll<Self::Item, Self::Error> { loop { + if self.update_sender.is_closed() { + log::warn!("Version update receiver is closed, stopping version updater"); + return Ok(Async::Ready(())); + } let next_state = match &mut self.state { VersionUpdaterState::Sleeping(timer) => match timer.poll() { Ok(Async::NotReady) => return Ok(Async::NotReady), @@ -150,7 +154,15 @@ impl<F: Fn(&AppVersionInfo) + Send + 'static> Future for VersionUpdater<F> { log::debug!("Got new version check: {:?}", app_version_info); self.next_update_time = Instant::now() + UPDATE_INTERVAL; if app_version_info != self.last_app_version_info { - (self.on_version_update)(&app_version_info); + if let Err(_) = self + .update_sender + .unbounded_send(app_version_info.clone().into()) + { + log::warn!( + "Version update receiver is closed, stopping version updater" + ); + return Ok(Async::Ready(())); + } self.last_app_version_info = app_version_info; if let Err(e) = self.write_cache() { log::error!( diff --git a/mullvad-daemon/src/wireguard.rs b/mullvad-daemon/src/wireguard.rs index 1e1b8f72a5..bce77b9041 100644 --- a/mullvad-daemon/src/wireguard.rs +++ b/mullvad-daemon/src/wireguard.rs @@ -2,13 +2,13 @@ use crate::{account_history::AccountHistory, InternalDaemonEvent}; use chrono::offset::Utc; use futures::{ future::{Executor, IntoFuture}, - sync::oneshot, + sync::{mpsc::UnboundedSender, oneshot}, Async, Future, Poll, }; use jsonrpc_client_core::Error as JsonRpcError; use mullvad_types::account::AccountToken; pub use mullvad_types::wireguard::*; -use std::{cmp, sync::mpsc, time::Duration}; +use std::{cmp, time::Duration}; pub use talpid_types::net::wireguard::{ ConnectionConfig, PrivateKey, TunnelConfig, TunnelParameters, }; @@ -47,7 +47,7 @@ pub enum Error { pub type Result<T> = std::result::Result<T, Error>; pub struct KeyManager { - daemon_tx: mpsc::Sender<InternalDaemonEvent>, + daemon_tx: UnboundedSender<InternalDaemonEvent>, http_handle: mullvad_rpc::HttpHandle, tokio_remote: Remote, current_job: Option<CancelHandle>, @@ -59,7 +59,7 @@ pub struct KeyManager { impl KeyManager { pub(crate) fn new( - daemon_tx: mpsc::Sender<InternalDaemonEvent>, + daemon_tx: UnboundedSender<InternalDaemonEvent>, http_handle: mullvad_rpc::HttpHandle, tokio_remote: Remote, ) -> Self { @@ -205,13 +205,14 @@ impl KeyManager { let fut = fut.then(move |result| { match result { Ok(wireguard_data) => { - let _ = daemon_tx.send(InternalDaemonEvent::WgKeyEvent(( + let _ = daemon_tx.unbounded_send(InternalDaemonEvent::WgKeyEvent(( account, Ok(wireguard_data), ))); } Err(CancelErr::Inner(e)) => { - let _ = daemon_tx.send(InternalDaemonEvent::WgKeyEvent((account, Err(e)))); + let _ = daemon_tx + .unbounded_send(InternalDaemonEvent::WgKeyEvent((account, Err(e)))); } Err(CancelErr::Cancelled) => { log::error!("Key generation cancelled"); @@ -310,7 +311,7 @@ impl KeyManager { } fn next_automatic_rotation( - daemon_tx: mpsc::Sender<InternalDaemonEvent>, + daemon_tx: UnboundedSender<InternalDaemonEvent>, http_handle: mullvad_rpc::HttpHandle, public_key: PublicKey, rotation_interval_secs: u64, @@ -340,7 +341,7 @@ impl KeyManager { }) .map(move |wireguard_data| { // Update account data - let _ = daemon_tx.send(InternalDaemonEvent::WgKeyEvent(( + let _ = daemon_tx.unbounded_send(InternalDaemonEvent::WgKeyEvent(( account_token_copy, Ok(wireguard_data.clone()), ))); @@ -350,7 +351,7 @@ impl KeyManager { } fn create_automatic_rotation( - daemon_tx: mpsc::Sender<InternalDaemonEvent>, + daemon_tx: UnboundedSender<InternalDaemonEvent>, http_handle: mullvad_rpc::HttpHandle, public_key: PublicKey, rotation_interval_secs: u64, diff --git a/talpid-core/src/mpsc.rs b/talpid-core/src/mpsc.rs index fd8814ab74..21807ca377 100644 --- a/talpid-core/src/mpsc.rs +++ b/talpid-core/src/mpsc.rs @@ -1,9 +1,10 @@ -use std::{marker::PhantomData, sync::mpsc}; +use futures::sync::mpsc::{SendError, UnboundedSender}; +use std::marker::PhantomData; /// Abstraction over an `mpsc::Sender` that first converts the value to another type before sending. #[derive(Debug, Clone)] pub struct IntoSender<T, U> { - sender: mpsc::Sender<U>, + sender: UnboundedSender<U>, _marker: PhantomData<T>, } @@ -12,16 +13,16 @@ where T: Into<U>, { /// Converts the `T` into a `U` and sends it on the channel. - pub fn send(&self, t: T) -> Result<(), mpsc::SendError<U>> { - self.sender.send(t.into()) + pub fn send(&self, t: T) -> Result<(), SendError<U>> { + self.sender.unbounded_send(t.into()) } } -impl<T, U> From<mpsc::Sender<U>> for IntoSender<T, U> +impl<T, U> From<UnboundedSender<U>> for IntoSender<T, U> where T: Into<U>, { - fn from(sender: mpsc::Sender<U>) -> Self { + fn from(sender: UnboundedSender<U>) -> Self { IntoSender { sender, _marker: PhantomData, @@ -32,7 +33,8 @@ where #[cfg(test)] mod tests { use super::*; - use std::{sync::mpsc, thread}; + use futures::{sync::mpsc, Stream}; + use std::thread; #[derive(Debug, Eq, PartialEq)] enum Inner { @@ -54,25 +56,29 @@ mod tests { #[test] fn sender() { - let (tx, rx) = mpsc::channel::<Outer>(); + let (tx, rx) = mpsc::unbounded(); let inner_tx: IntoSender<Inner, Outer> = tx.clone().into(); - tx.send(Outer::Other).unwrap(); + tx.unbounded_send(Outer::Other).unwrap(); inner_tx.send(Inner::Two).unwrap(); - assert_eq!(Outer::Other, rx.recv().unwrap()); - assert_eq!(Outer::Inner(Inner::Two), rx.recv().unwrap()); + let mut sync_rx = rx.wait(); + + assert_eq!(Outer::Other, sync_rx.next().unwrap().unwrap()); + assert_eq!(Outer::Inner(Inner::Two), sync_rx.next().unwrap().unwrap()); } #[test] fn send_between_thread() { - let (tx, rx) = mpsc::channel::<Outer>(); + let (tx, rx) = mpsc::unbounded(); let inner_tx: IntoSender<Inner, Outer> = tx.clone().into(); thread::spawn(move || { inner_tx.send(Inner::One).unwrap(); }); - assert_eq!(Outer::Inner(Inner::One), rx.recv().unwrap()); + let mut sync_rx = rx.wait(); + + assert_eq!(Outer::Inner(Inner::One), sync_rx.next().unwrap().unwrap()); } } |
