diff options
| author | Emīls <emils@mullvad.net> | 2020-07-01 13:14:35 +0100 |
|---|---|---|
| committer | Emīls <emils@mullvad.net> | 2020-07-03 15:11:25 +0100 |
| commit | e9d8f92059650d717e8871176ead70ef49092a9d (patch) | |
| tree | 832c3b77e5f1bed4d4b8a215401669f446bfbb6f | |
| parent | cb0acc43fe52a559c9c3914ef4373e4fff3d3af4 (diff) | |
| download | mullvadvpn-e9d8f92059650d717e8871176ead70ef49092a9d.tar.xz mullvadvpn-e9d8f92059650d717e8871176ead70ef49092a9d.zip | |
Use new style future for wireguard keys
| -rw-r--r-- | mullvad-daemon/src/account_history.rs | 80 | ||||
| -rw-r--r-- | mullvad-daemon/src/lib.rs | 51 | ||||
| -rw-r--r-- | mullvad-daemon/src/wireguard.rs | 405 | ||||
| -rw-r--r-- | mullvad-rpc/src/lib.rs | 61 | ||||
| -rw-r--r-- | mullvad-rpc/src/rest.rs | 62 |
5 files changed, 287 insertions, 372 deletions
diff --git a/mullvad-daemon/src/account_history.rs b/mullvad-daemon/src/account_history.rs index 4a96428d94..05ea1c7d3b 100644 --- a/mullvad-daemon/src/account_history.rs +++ b/mullvad-daemon/src/account_history.rs @@ -1,20 +1,13 @@ -#[cfg(target_os = "android")] -use futures01::future::{Executor, Future}; -#[cfg(not(target_os = "android"))] -use futures01::{ - future::{self, Executor, Future}, - sync::oneshot, -}; use mullvad_rpc::{rest::MullvadRestHandle, WireguardKeyProxy}; use mullvad_types::{account::AccountToken, wireguard::WireguardData}; use std::{ collections::VecDeque, fs, + future::Future, io::{self, Seek, Write}, path::Path, }; use talpid_types::ErrorExt; -use tokio_core::reactor::Remote; pub type Result<T> = std::result::Result<T, Error>; @@ -39,7 +32,6 @@ pub struct AccountHistory { file: io::BufWriter<fs::File>, accounts: VecDeque<AccountEntry>, rpc_handle: MullvadRestHandle, - tokio_remote: Remote, } @@ -48,7 +40,6 @@ impl AccountHistory { cache_dir: &Path, settings_dir: &Path, rpc_handle: MullvadRestHandle, - tokio_remote: Remote, ) -> Result<AccountHistory> { Self::migrate_from_old_file_location(cache_dir, settings_dir); @@ -95,7 +86,6 @@ impl AccountHistory { file, accounts, rpc_handle, - tokio_remote, }; if let Err(e) = history.save_to_disk() { log::error!("Failed to save account cache after opening it: {}", e); @@ -170,10 +160,16 @@ impl AccountHistory { &self, account: &str, wg_data: &WireguardData, - ) -> impl Future<Item = (), Error = ()> { + ) -> impl Future<Output = ()> + 'static { let mut rpc = WireguardKeyProxy::new(self.rpc_handle.clone()); - rpc.remove_wireguard_key(String::from(account), &wg_data.private_key.public_key()) - .map_err(|e| log::error!("Failed to remove WireGuard key: {}", e)) + let pub_key = wg_data.private_key.public_key(); + let account = String::from(account); + + async move { + if let Err(err) = rpc.remove_wireguard_key(account, &pub_key).await { + log::error!("Failed to remove WireGuard key: {}", err); + } + } } /// Always inserts a new entry at the start of the list @@ -186,10 +182,9 @@ impl AccountHistory { if self.accounts.len() > ACCOUNT_HISTORY_LIMIT { let last_entry = self.accounts.pop_back().unwrap(); if let Some(wg_data) = last_entry.wireguard { - let fut = self.create_remove_wg_key_rpc(&last_entry.account, &wg_data); - if let Err(e) = self.tokio_remote.execute(fut) { - log::error!("Failed to spawn future to remove WireGuard key: {:?}", e); - } + self.rpc_handle + .service() + .spawn(self.create_remove_wg_key_rpc(&last_entry.account, &wg_data)); } } @@ -213,10 +208,9 @@ impl AccountHistory { }; if let Some(wg_data) = entry.wireguard { - let fut = self.create_remove_wg_key_rpc(account, &wg_data); - if let Err(e) = self.tokio_remote.execute(fut) { - log::error!("Failed to spawn future to remove WireGuard key: {:?}", e); - } + self.rpc_handle + .service() + .spawn(self.create_remove_wg_key_rpc(account, &wg_data)) } let _ = self.accounts.pop_front(); @@ -224,35 +218,31 @@ impl AccountHistory { } /// Remove account history - #[cfg(not(target_os = "android"))] pub fn clear(&mut self) -> Result<()> { - let mut rpc = WireguardKeyProxy::new(self.rpc_handle.clone()); - log::debug!("account_history::clear"); - let mut removal_futures = Vec::with_capacity(ACCOUNT_HISTORY_LIMIT); + let rpc = WireguardKeyProxy::new(self.rpc_handle.clone()); - for entry in self.accounts.iter() { - if let Some(wg_data) = &entry.wireguard { - let fut = rpc - .remove_wireguard_key(entry.account.clone(), &wg_data.private_key.public_key()) - .map_err(|e| log::error!("Failed to remove WireGuard key: {}", e)); - removal_futures.push(fut); - } - } + let removal: Vec<_> = self + .accounts + .drain(0..) + .filter_map(move |entry| { + let account = entry.account.clone(); + let mut rpc = rpc.clone(); + entry.wireguard.map(move |wg_data| { + let public_key = wg_data.private_key.public_key(); + async move { + if let Err(err) = rpc.remove_wireguard_key(account, &public_key).await { + log::error!("Failed to remove WireGuard key: {}", err); + } + } + }) + }) + .collect(); - let joined_futs = future::join_all(removal_futures); - let (tx, rx) = oneshot::channel(); - let execute_result = self.tokio_remote.execute(joined_futs.then(|result| { - let _ = tx.send(result); - Ok(()) - })); - if let Err(e) = execute_result { - log::error!("Failed to spawn future to remove WireGuard keys: {:?}", e); - } else { - let _ = rx.wait(); - } + let joined_futs = futures::future::join_all(removal); + self.rpc_handle.service().block_on(joined_futs); self.accounts = VecDeque::new(); self.save_to_disk() diff --git a/mullvad-daemon/src/lib.rs b/mullvad-daemon/src/lib.rs index a3a856f792..0c12acf1f5 100644 --- a/mullvad-daemon/src/lib.rs +++ b/mullvad-daemon/src/lib.rs @@ -528,13 +528,9 @@ where settings.show_beta_releases, ); rpc_runtime.runtime().spawn(version_updater.run()); - let account_history = account_history::AccountHistory::new( - &cache_dir, - &settings_dir, - rpc_handle.clone(), - core_handle.remote.clone(), - ) - .map_err(Error::LoadAccountHistory)?; + let account_history = + account_history::AccountHistory::new(&cache_dir, &settings_dir, rpc_handle.clone()) + .map_err(Error::LoadAccountHistory)?; // Restore the tunnel to a previous state let target_cache = cache_dir.join(TARGET_START_STATE_FILE); @@ -573,11 +569,8 @@ where ) .map_err(Error::TunnelError)?; - let wireguard_key_manager = wireguard::KeyManager::new( - internal_event_tx.clone(), - rpc_handle.clone(), - core_handle.remote.clone(), - ); + let wireguard_key_manager = + wireguard::KeyManager::new(internal_event_tx.clone(), rpc_handle.clone()); // Attempt to download a fresh relay list relay_selector.update(); @@ -1653,15 +1646,8 @@ where .unwrap_or(true) { log::info!("Automatically generating new wireguard key for account"); - if let Err(e) = self - .wireguard_key_manager - .generate_key_async(account, Some(FIRST_KEY_PUSH_TIMEOUT)) - { - log::error!( - "{}", - e.display_chain_with_msg("Failed to start generating wireguard key") - ); - } + self.wireguard_key_manager + .generate_key_async(account, Some(FIRST_KEY_PUSH_TIMEOUT)); } else { log::info!("Account already has wireguard key"); } @@ -1778,17 +1764,20 @@ where } }; - let fut = self + let verification_rpc = self .wireguard_key_manager - .verify_wireguard_key(account, public_key) - .and_then(|is_valid| { - Self::oneshot_send(tx, is_valid, "verify_wireguard_key response"); - Ok(()) - }) - .map_err(|e: wireguard::Error| log::error!("Failed to verify wireguard key - {}", e)); - if let Err(e) = self.core_handle.remote.execute(fut) { - log::error!("Failed to spawn a future to verify wireguard key: {:?}", e); - } + .verify_wireguard_key(account, public_key); + + self.rpc_handle.service().spawn(async move { + match verification_rpc.await { + Ok(is_valid) => { + Self::oneshot_send(tx, is_valid, "verify_wireguard_key response"); + } + Err(err) => { + log::error!("Failed to verify wireguard key - {}", err); + } + } + }); } fn on_get_settings(&self, tx: oneshot::Sender<Settings>) { diff --git a/mullvad-daemon/src/wireguard.rs b/mullvad-daemon/src/wireguard.rs index 7cf1209429..923fc4fb34 100644 --- a/mullvad-daemon/src/wireguard.rs +++ b/mullvad-daemon/src/wireguard.rs @@ -1,20 +1,23 @@ use crate::{account_history::AccountHistory, DaemonEventSender, InternalDaemonEvent}; use chrono::offset::Utc; -use futures01::{future::Executor, stream::Stream, sync::oneshot, Async, Future, Poll}; -use mullvad_rpc::rest::{Error as RestError, MullvadRestHandle}; +use mullvad_rpc::rest::{CancelHandle, Cancellable, Error as RestError, MullvadRestHandle}; use mullvad_types::account::AccountToken; pub use mullvad_types::wireguard::*; -use std::time::Duration; -use talpid_core::mpsc::Sender; +use std::{ + future::Future, + pin::Pin, + time::{Duration, Instant}, +}; + +use talpid_core::{ + future_retry::{retry_future_with_backoff, ExponentialBackoff, Jittered}, + mpsc::Sender, +}; + pub use talpid_types::net::wireguard::{ ConnectionConfig, PrivateKey, TunnelConfig, TunnelParameters, }; use talpid_types::ErrorExt; -use tokio_core::reactor::Remote; -use tokio_retry::{ - strategy::{jitter, ExponentialBackoff}, - RetryIf, -}; use tokio_timer; /// Default automatic key rotation @@ -27,8 +30,6 @@ const KEY_CHECK_INTERVAL: Duration = Duration::from_secs(60); #[derive(err_derive::Error, Debug)] pub enum Error { - #[error(display = "Failed to spawn future")] - ExectuionError, #[error(display = "Unexpected HTTP request error")] RestError(#[error(source)] mullvad_rpc::rest::Error), #[error(display = "Account already has maximum number of keys")] @@ -42,7 +43,6 @@ pub type Result<T> = std::result::Result<T, Error>; pub struct KeyManager { daemon_tx: DaemonEventSender, http_handle: MullvadRestHandle, - tokio_remote: Remote, current_job: Option<CancelHandle>, abort_scheduler_tx: Option<CancelHandle>, @@ -50,15 +50,10 @@ pub struct KeyManager { } impl KeyManager { - pub(crate) fn new( - daemon_tx: DaemonEventSender, - http_handle: MullvadRestHandle, - tokio_remote: Remote, - ) -> Self { + pub(crate) fn new(daemon_tx: DaemonEventSender, http_handle: MullvadRestHandle) -> Self { Self { daemon_tx, http_handle, - tokio_remote, current_job: None, abort_scheduler_tx: None, auto_rotation_interval: Duration::new(0, 0), @@ -112,24 +107,12 @@ impl KeyManager { self.reset(); let private_key = PrivateKey::new_from_random(); - self.run_future_sync(self.push_future_generator(account, private_key, None)()) + self.http_handle + .service() + .block_on(self.push_future_generator(account, private_key, None)()) .map_err(Self::map_rpc_error) } - /// Run a future on the given tokio remote - fn run_future_sync<T: Send + 'static, E: Send + 'static>( - &mut self, - fut: impl Future<Item = T, Error = E> + Send + 'static, - ) -> std::result::Result<T, E> { - self.reset(); - let (tx, rx) = oneshot::channel(); - - let _ = self.tokio_remote.execute(fut.then(|result| { - let _ = tx.send(result); - Ok(()) - })); - rx.wait().unwrap() - } /// Replace a key for an account synchronously pub fn replace_key( @@ -138,8 +121,9 @@ impl KeyManager { old_key: PublicKey, ) -> Result<WireguardData> { self.reset(); + let new_key = PrivateKey::new_from_random(); - self.run_future_sync(Self::replace_key_rpc( + self.http_handle.service().block_on(Self::replace_key_rpc( self.http_handle.clone(), account, old_key, @@ -152,10 +136,10 @@ impl KeyManager { &self, account: AccountToken, key: talpid_types::net::wireguard::PublicKey, - ) -> impl Future<Item = bool, Error = Error> { + ) -> impl Future<Output = Result<bool>> { let mut rpc = mullvad_rpc::WireguardKeyProxy::new(self.http_handle.clone()); - rpc.get_wireguard_key(account, &key) - .then(|response| match response { + async move { + match rpc.get_wireguard_key(account, &key).await { Ok(_) => Ok(true), Err(mullvad_rpc::rest::Error::ApiError(status, _code)) if status == mullvad_rpc::StatusCode::NOT_FOUND => @@ -163,16 +147,13 @@ impl KeyManager { Ok(false) } Err(err) => Err(Self::map_rpc_error(err)), - }) + } + } } /// Generate a new private key asynchronously. The new keys will be sent to the daemon channel. - pub fn generate_key_async( - &mut self, - account: AccountToken, - timeout: Option<Duration>, - ) -> Result<()> { + pub fn generate_key_async(&mut self, account: AccountToken, timeout: Option<Duration>) { self.reset(); let private_key = PrivateKey::new_from_random(); @@ -186,73 +167,62 @@ impl KeyManager { let fut = inner_future_generator(); let error_tx = error_tx.clone(); let error_account = error_account.clone(); - fut.map_err(move |err| { - let should_retry = match &err { - RestError::ApiError(_status, code) - if code == mullvad_rpc::KEY_LIMIT_REACHED => - { - false + async move { + let response = fut.await; + match response { + Ok(addresses) => Ok(addresses), + Err(err) => { + let should_retry = if let RestError::ApiError(_status, code) = &err { + code != mullvad_rpc::KEY_LIMIT_REACHED + } else { + true + }; + let _ = error_tx.send(InternalDaemonEvent::WgKeyEvent(( + error_account, + Err(Self::map_rpc_error(err)), + ))); + Err(should_retry) } - _ => true, - }; - - let _ = error_tx.send(InternalDaemonEvent::WgKeyEvent(( - error_account, - Err(Self::map_rpc_error(err)), - ))); - - should_retry - }) + } + } }; - let retry_strategy = ExponentialBackoff::from_millis(300) - .max_delay(Duration::from_secs(60 * 60)) - .map(jitter); + let retry_strategy = Jittered::jitter( + ExponentialBackoff::from_millis(300).max_delay(Duration::from_secs(60 * 60)), + ); + let should_retry = move |result: &std::result::Result<_, bool>| -> bool { + match result { + Ok(_) => false, + Err(should_retry) => *should_retry, + } + }; let upload_future = - RetryIf::spawn(retry_strategy, future_generator, |should_retry: &bool| { - *should_retry - }) - .map_err(move |err| { - match err { - // This should really be unreachable, since the retry strategy is infinite. - tokio_retry::Error::OperationError(_) => {} - tokio_retry::Error::TimerError(timer_error) => { - log::error!("Tokio timer error {}", timer_error); - () - } - } - }); + retry_future_with_backoff(future_generator, should_retry, retry_strategy); - let (fut, cancel_handle) = Cancellable::new(upload_future); + let (cancellable_upload, cancel_handle) = Cancellable::new(Box::pin(upload_future)); let daemon_tx = self.daemon_tx.clone(); - let fut = fut.then(move |result| { - match result { - Ok(wireguard_data) => { + let future = async move { + match cancellable_upload.await { + Ok(Ok(wireguard_data)) => { let _ = daemon_tx.send(InternalDaemonEvent::WgKeyEvent(( account, Ok(wireguard_data), ))); } - Err(CancelErr::Inner(_)) => {} - Err(CancelErr::Cancelled) => { + Ok(Err(_)) => {} + Err(_) => { log::error!("Key generation cancelled"); } - }; - Ok(()) - }); + } + }; - let result = self - .tokio_remote - .execute(fut) - .map_err(|_| Error::ExectuionError); - if result.is_ok() { - self.current_job = Some(cancel_handle); - } - result + + self.http_handle.service().spawn(Box::pin(future)); + self.current_job = Some(cancel_handle); } @@ -261,42 +231,49 @@ impl KeyManager { account: AccountToken, private_key: PrivateKey, timeout: Option<Duration>, - ) -> Box<dyn FnMut() -> Box<dyn Future<Item = WireguardData, Error = RestError> + Send> + Send> - { + ) -> Box< + dyn FnMut() -> Pin< + Box<dyn Future<Output = std::result::Result<WireguardData, RestError>> + Send>, + > + Send, + > { let mut rpc = mullvad_rpc::WireguardKeyProxy::new(self.http_handle.clone()); let public_key = private_key.public_key(); let push_future = - move || -> Box<dyn Future<Item = WireguardData, Error = RestError> + Send> { + move || -> std::pin::Pin<Box<dyn Future<Output = std::result::Result<WireguardData, RestError>> + Send >> { let key = private_key.clone(); - Box::new( - rpc.push_wg_key(account.clone(), public_key.clone(), timeout) - .map(move |addresses| WireguardData { - private_key: key, - addresses, - created: Utc::now(), - }), - ) + let address_future = rpc + .push_wg_key(account.clone(), public_key.clone(), timeout); + Box::pin(async move { + let addresses = address_future.await?; + Ok(WireguardData { + private_key: key, + addresses, + created: Utc::now(), + }) + }) }; Box::new(push_future) } - fn replace_key_rpc( + async fn replace_key_rpc( http_handle: MullvadRestHandle, account: AccountToken, old_key: PublicKey, new_key: PrivateKey, - ) -> impl Future<Item = WireguardData, Error = Error> + Send { + ) -> Result<WireguardData> { let mut rpc = mullvad_rpc::WireguardKeyProxy::new(http_handle); let new_public_key = new_key.public_key(); - rpc.replace_wg_key(account, old_key.key, new_public_key) - .map_err(Self::map_rpc_error) - .map(move |addresses| WireguardData { - private_key: new_key, - addresses, - created: Utc::now(), - }) + let addresses = rpc + .replace_wg_key(account, old_key.key, new_public_key) + .await + .map_err(Self::map_rpc_error)?; + Ok(WireguardData { + private_key: new_key, + addresses, + created: Utc::now(), + }) } fn map_rpc_error(err: mullvad_rpc::rest::Error) -> Error { @@ -312,100 +289,89 @@ impl KeyManager { } } - fn create_rotation_check( - key: PublicKey, - rotation_interval_secs: u64, - ) -> impl Future<Item = (), Error = Error> + Send { - tokio_timer::wheel() - .build() - .interval(KEY_CHECK_INTERVAL) - .map_err(Error::RotationScheduleError) - .take_while(move |_| { - Ok( - (Utc::now().signed_duration_since(key.created)).num_seconds() as u64 - <= rotation_interval_secs, - ) - }) - .for_each(|_| Ok(())) + async fn key_rotation_timer(key: PublicKey, rotation_interval_secs: u64) { + let mut interval = tokio02::time::interval(KEY_CHECK_INTERVAL); + loop { + interval.tick().await; + if (Utc::now().signed_duration_since(key.created)).num_seconds() as u64 + >= rotation_interval_secs + { + return; + } + } } - fn next_automatic_rotation( + async fn next_automatic_rotation( daemon_tx: DaemonEventSender, http_handle: MullvadRestHandle, public_key: PublicKey, rotation_interval_secs: u64, account_token: AccountToken, - ) -> impl Future<Item = PublicKey, Error = Error> + Send { - let expiration_timer = - Self::create_rotation_check(public_key.clone(), rotation_interval_secs); + ) -> Result<PublicKey> { let account_token_copy = account_token.clone(); + Self::key_rotation_timer(public_key.clone(), rotation_interval_secs).await; - expiration_timer - .and_then(move |_| { - log::info!("Replacing WireGuard key"); - - let private_key = PrivateKey::new_from_random(); - Self::replace_key_rpc(http_handle, account_token, public_key, private_key) - }) - .then(move |rpc_result| { - match rpc_result { - Ok(data) => { - // Update account data - let _ = daemon_tx.send(InternalDaemonEvent::WgKeyEvent(( - account_token_copy, - Ok(data.clone()), - ))); - Ok(data.get_public_key()) - } - Err(Error::TooManyKeys) => { - let _ = daemon_tx.send(InternalDaemonEvent::WgKeyEvent(( - account_token_copy, - Err(Error::TooManyKeys), - ))); - Err(Error::TooManyKeys) - } - Err(unknown_err) => Err(unknown_err), - } - }) + let private_key = PrivateKey::new_from_random(); + let rpc_result = + Self::replace_key_rpc(http_handle, account_token, public_key, private_key).await; + match rpc_result { + Ok(data) => { + // Update account data + let _ = daemon_tx.send(InternalDaemonEvent::WgKeyEvent(( + account_token_copy, + Ok(data.clone()), + ))); + Ok(data.get_public_key()) + } + Err(Error::TooManyKeys) => { + let _ = daemon_tx.send(InternalDaemonEvent::WgKeyEvent(( + account_token_copy, + Err(Error::TooManyKeys), + ))); + Err(Error::TooManyKeys) + } + Err(unknown_err) => Err(unknown_err), + } } - fn create_automatic_rotation( + async fn create_automatic_rotation( daemon_tx: DaemonEventSender, http_handle: MullvadRestHandle, - public_key: PublicKey, + mut public_key: PublicKey, rotation_interval_secs: u64, account_token: AccountToken, - ) -> impl Future<Item = (), Error = ()> + Send { - tokio_timer::wheel() - .build() - .interval(AUTOMATIC_ROTATION_RETRY_DELAY) - .map_err(Error::RotationScheduleError) - .fold(public_key, move |old_public_key, _| { - let fut = Self::next_automatic_rotation( - daemon_tx.clone(), - http_handle.clone(), - old_public_key.clone(), - rotation_interval_secs, - account_token.clone(), - ); - fut.then(|result| match result { - Ok(new_public_key) => Ok(new_public_key), - Err(Error::TooManyKeys) => { - log::error!("Account has too many keys, stopping automatic rotation"); - Err(Error::TooManyKeys) - } - Err(e) => { - log::error!( - "{}. Retrying in {} seconds", - e.display_chain_with_msg("Key rotation failed:"), - AUTOMATIC_ROTATION_RETRY_DELAY.as_secs(), - ); - Ok(old_public_key) - } - }) - }) - .map_err(|_| ()) - .map(|_| ()) + ) { + let mut interval = tokio02::time::interval_at( + (Instant::now() + AUTOMATIC_ROTATION_RETRY_DELAY).into(), + AUTOMATIC_ROTATION_RETRY_DELAY, + ); + + loop { + let daemon_tx = daemon_tx.clone(); + interval.tick().await; + let new_key_result = Self::next_automatic_rotation( + daemon_tx, + http_handle.clone(), + public_key.clone(), + rotation_interval_secs, + account_token.clone(), + ) + .await; + match new_key_result { + Ok(new_key) => public_key = new_key, + Err(Error::TooManyKeys) => { + log::error!("Account has too many keys, stopping automatic rotation"); + return; + } + Err(err) => { + log::error!( + "{}. Retrying in {} seconds", + err.display_chain_with_msg("Key rotation failed:"), + AUTOMATIC_ROTATION_RETRY_DELAY.as_secs(), + ); + } + } + } } fn run_automatic_rotation(&mut self, account_token: AccountToken, public_key: PublicKey) { @@ -425,11 +391,9 @@ impl KeyManager { self.auto_rotation_interval.as_secs(), account_token, ); - let (fut, cancel_handle) = Cancellable::new(fut); + let (cancellable, cancel_handle) = Cancellable::new(Box::pin(fut)); - if let Err(e) = self.tokio_remote.execute(fut.map_err(|_| ())) { - log::error!("Failed to execute auto key rotation: {:?}", e.kind()); - } + self.http_handle.service().spawn(cancellable); self.abort_scheduler_tx = Some(cancel_handle); } @@ -440,54 +404,3 @@ impl KeyManager { } } } - -pub enum CancelErr<E> { - Cancelled, - Inner(E), -} - -pub struct Cancellable<T, E, F: Future<Item = T, Error = E>> { - rx: oneshot::Receiver<()>, - f: F, -} - -pub struct CancelHandle { - tx: oneshot::Sender<()>, -} - -impl CancelHandle { - fn cancel(self) { - let _ = self.tx.send(()); - } -} - - -impl<T, E, F> Cancellable<T, E, F> -where - F: Future<Item = T, Error = E>, -{ - fn new(f: F) -> (Self, CancelHandle) { - let (tx, rx) = oneshot::channel(); - (Self { f, rx }, CancelHandle { tx }) - } -} - -impl<T, E, F> Future for Cancellable<T, E, F> -where - F: Future<Item = T, Error = E>, -{ - type Item = T; - type Error = CancelErr<E>; - - fn poll(&mut self) -> Poll<T, CancelErr<E>> { - match self.rx.poll() { - Ok(Async::Ready(_)) | Err(_) => return Err(CancelErr::Cancelled), - Ok(Async::NotReady) => (), - }; - - match self.f.poll() { - Ok(v) => Ok(v), - Err(e) => Err(CancelErr::Inner(e)), - } - } -} diff --git a/mullvad-rpc/src/lib.rs b/mullvad-rpc/src/lib.rs index bc6a71a5ba..793a0f241c 100644 --- a/mullvad-rpc/src/lib.rs +++ b/mullvad-rpc/src/lib.rs @@ -315,6 +315,7 @@ impl AppVersionProxy { /// Error code for when an account has too many keys. Returned when trying to push a new key. pub const KEY_LIMIT_REACHED: &str = "KEY_LIMIT_REACHED"; +#[derive(Clone)] pub struct WireguardKeyProxy { handle: rest::MullvadRestHandle, } @@ -325,13 +326,12 @@ impl WireguardKeyProxy { Self { handle } } - pub fn push_wg_key( &mut self, account_token: AccountToken, public_key: wireguard::PublicKey, timeout: Option<std::time::Duration>, - ) -> impl Future01<Item = mullvad_types::wireguard::AssociatedAddresses, Error = rest::Error> + ) -> impl Future<Output = Result<mullvad_types::wireguard::AssociatedAddresses, rest::Error>> + 'static { #[derive(serde::Serialize)] struct PublishRequest { @@ -342,30 +342,24 @@ impl WireguardKeyProxy { let body = PublishRequest { pubkey: public_key }; let request = self.handle.factory.post_json(&"/v1/wireguard-keys", &body); - - let future = async move { + async move { let mut request = request?; if let Some(timeout) = timeout { request.set_timeout(timeout); } request.set_auth(Some(account_token))?; let response = service.request(request).await?; - rest::parse_rest_response(response, StatusCode::CREATED).await - }; - - - self.handle - .service - .compat_spawn(async move { rest::deserialize_body(future.await?).await }) + rest::deserialize_body(rest::parse_rest_response(response, StatusCode::CREATED).await?) + .await + } } - pub fn replace_wg_key( + pub async fn replace_wg_key( &mut self, account_token: AccountToken, old: wireguard::PublicKey, new: wireguard::PublicKey, - ) -> impl Future01<Item = mullvad_types::wireguard::AssociatedAddresses, Error = rest::Error> - { + ) -> Result<mullvad_types::wireguard::AssociatedAddresses, rest::Error> { #[derive(serde::Serialize)] struct ReplacementRequest { old: wireguard::PublicKey, @@ -375,29 +369,27 @@ impl WireguardKeyProxy { let service = self.handle.service.clone(); let body = ReplacementRequest { old, new }; - let request = rest::post_request_with_json( + let response = rest::post_request_with_json( &self.handle.factory, service, &"/v1/replace-wireguard-key", &body, Some(account_token), StatusCode::CREATED, - ); + ) + .await?; - self.handle - .service - .compat_spawn(async move { rest::deserialize_body(request.await?).await }) + rest::deserialize_body(response).await } - pub fn get_wireguard_key( + pub async fn get_wireguard_key( &mut self, account_token: AccountToken, key: &wireguard::PublicKey, - ) -> impl Future01<Item = mullvad_types::wireguard::AssociatedAddresses, Error = rest::Error> - { + ) -> Result<mullvad_types::wireguard::AssociatedAddresses, rest::Error> { let service = self.handle.service.clone(); - let request = rest::send_request( + let response = rest::send_request( &self.handle.factory, service, &format!( @@ -407,20 +399,20 @@ impl WireguardKeyProxy { Method::GET, Some(account_token), StatusCode::OK, - ); - self.handle - .service - .compat_spawn(async move { rest::deserialize_body(request.await?).await }) + ) + .await?; + + rest::deserialize_body(response).await } - pub fn remove_wireguard_key( + pub async fn remove_wireguard_key( &mut self, account_token: AccountToken, key: &wireguard::PublicKey, - ) -> impl Future01<Item = (), Error = rest::Error> { + ) -> Result<(), rest::Error> { let service = self.handle.service.clone(); - let request = rest::send_request( + let _ = rest::send_request( &self.handle.factory, service, &format!( @@ -430,11 +422,8 @@ impl WireguardKeyProxy { Method::DELETE, Some(account_token), StatusCode::NO_CONTENT, - ); - - self.handle.service.compat_spawn(async move { - let _ = request.await?; - Ok(()) - }) + ) + .await?; + Ok(()) } } diff --git a/mullvad-rpc/src/rest.rs b/mullvad-rpc/src/rest.rs index d29feee9f0..c2677c7212 100644 --- a/mullvad-rpc/src/rest.rs +++ b/mullvad-rpc/src/rest.rs @@ -1,9 +1,8 @@ use futures::{ channel::{mpsc, oneshot}, - future::{self, Either}, sink::SinkExt, stream::StreamExt, - TryFutureExt, + FutureExt, TryFutureExt, }; use futures01::Future as OldFuture; use hyper::{ @@ -11,7 +10,16 @@ use hyper::{ header::{self, HeaderValue}, Method, Uri, }; -use std::{collections::BTreeMap, future::Future, mem, net::IpAddr, str::FromStr, time::Duration}; +use std::{ + collections::BTreeMap, + future::Future, + mem, + net::IpAddr, + pin::Pin, + str::FromStr, + task::{Context, Poll}, + time::Duration, +}; use tokio::runtime::Handle; pub use hyper::StatusCode; @@ -180,14 +188,10 @@ impl RequestServiceHandle { /// Resets the corresponding RequestService, dropping all in-flight requests. pub fn reset(&self) { let mut tx = self.tx.clone(); - let (done_tx, done_rx) = oneshot::channel(); - self.handle.spawn(async move { + self.handle.block_on(async move { let _ = tx.send(RequestCommand::Reset).await; - let _ = done_tx.send(()); }); - - let _ = futures::executor::block_on(done_rx); } /// Submits a `RestRequest` for exectuion to the request service. @@ -222,6 +226,13 @@ impl RequestServiceHandle { pub fn spawn<T: Send + 'static>(&self, future: impl Future<Output = T> + Send + 'static) { let _ = self.handle.spawn(future); } + + pub fn block_on<T: Send + 'static>( + &self, + future: impl Future<Output = T> + Send + 'static, + ) -> T { + self.handle.block_on(future) + } } #[derive(Debug)] @@ -427,7 +438,7 @@ pub struct CancelHandle { } impl CancelHandle { - fn cancel(self) { + pub fn cancel(self) { let _ = self.tx.send(()); } } @@ -435,18 +446,41 @@ impl CancelHandle { impl<F> Cancellable<F> where - F: Future + Unpin, + F: Future, { - fn new(f: F) -> (Self, CancelHandle) { + pub fn new(f: F) -> (Self, CancelHandle) { let (tx, rx) = oneshot::channel(); (Self { f, rx }, CancelHandle { tx }) } async fn into_future(self) -> std::result::Result<F::Output, CancelErr> { - match future::select(self.rx, self.f).await { - Either::Left(_) => Err(CancelErr(())), - Either::Right((value, _)) => Ok(value), + futures::select! { + _cancelled = self.rx.fuse() => { + Err(CancelErr(())) + }, + value = self.f.fuse() => { + Ok(value) + } + } + } +} + + +impl<F: Future<Output = T> + Unpin, T: Unpin> Future for Cancellable<F> { + type Output = std::result::Result<T, CancelErr>; + + fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> { + let inner = self.get_mut(); + + if let Poll::Ready(ready) = inner.f.poll_unpin(cx) { + return Poll::Ready(Ok(ready)); + } + + if let Poll::Ready(_) = inner.rx.poll_unpin(cx) { + return Poll::Ready(Err(CancelErr(()))); } + + Poll::Pending } } |
