diff options
| author | David Lönnhager <david.l@mullvad.net> | 2023-10-18 19:23:03 +0200 |
|---|---|---|
| committer | David Lönnhager <david.l@mullvad.net> | 2023-10-19 16:51:59 +0200 |
| commit | 4e5d08ebbc2cb49304ff0c2f4c77ecc11be30c79 (patch) | |
| tree | d6458f600d151347b99f52a0082157ee52b7e37f /mullvad-api/src/access.rs | |
| parent | 58b09e84af5a921af3b4db213d717bc2b0a03770 (diff) | |
| download | mullvadvpn-4e5d08ebbc2cb49304ff0c2f4c77ecc11be30c79.tar.xz mullvadvpn-4e5d08ebbc2cb49304ff0c2f4c77ecc11be30c79.zip | |
Fold all access token requests into a single request
Diffstat (limited to 'mullvad-api/src/access.rs')
| -rw-r--r-- | mullvad-api/src/access.rs | 231 |
1 files changed, 151 insertions, 80 deletions
diff --git a/mullvad-api/src/access.rs b/mullvad-api/src/access.rs index a3bec3f725..67c83ac4da 100644 --- a/mullvad-api/src/access.rs +++ b/mullvad-api/src/access.rs @@ -2,109 +2,180 @@ use crate::{ rest, rest::{RequestFactory, RequestServiceHandle}, }; +use futures::{ + channel::{mpsc, oneshot}, + StreamExt, +}; use hyper::StatusCode; use mullvad_types::account::{AccessToken, AccessTokenData, AccountToken}; -use std::{ - collections::HashMap, - sync::{Arc, Mutex}, -}; -use talpid_types::ErrorExt; +use std::collections::HashMap; +use tokio::select; pub const AUTH_URL_PREFIX: &str = "auth/v1"; #[derive(Clone)] -pub struct AccessTokenProxy { - service: RequestServiceHandle, - factory: RequestFactory, - access_from_account: Arc<Mutex<HashMap<AccountToken, AccessTokenData>>>, +pub struct AccessTokenStore { + tx: mpsc::UnboundedSender<StoreAction>, +} + +enum StoreAction { + /// Request an access token for `AccountToken`, or return a saved one if it's not expired. + GetAccessToken( + AccountToken, + oneshot::Sender<Result<AccessToken, rest::Error>>, + ), + /// Forget cached access token for `AccountToken`, and drop any in-flight requests + InvalidateToken(AccountToken), +} + +#[derive(Default)] +struct AccountState { + current_access_token: Option<AccessTokenData>, + inflight_request: Option<tokio::task::JoinHandle<()>>, + response_channels: Vec<oneshot::Sender<Result<AccessToken, rest::Error>>>, } -impl AccessTokenProxy { +impl AccessTokenStore { pub(crate) fn new(service: RequestServiceHandle, factory: RequestFactory) -> Self { - Self { - service, - factory, - access_from_account: Arc::new(Mutex::new(HashMap::new())), + let (tx, rx) = mpsc::unbounded(); + tokio::spawn(Self::service_requests(rx, service, factory)); + Self { tx } + } + + async fn service_requests( + mut rx: mpsc::UnboundedReceiver<StoreAction>, + service: RequestServiceHandle, + factory: RequestFactory, + ) { + let mut account_states: HashMap<AccountToken, AccountState> = HashMap::new(); + + let (completed_tx, mut completed_rx) = mpsc::unbounded(); + + loop { + select! { + action = rx.next() => { + let Some(action) = action else { + // We're done + break; + }; + + match action { + StoreAction::GetAccessToken(account, response_tx) => { + let account_state = account_states + .entry(account.clone()) + .or_default(); + + // If there is an unexpired access token, just return it. + // Otherwise, generate a new token + if let Some(ref access_token) = account_state.current_access_token { + if !access_token.is_expired() { + log::trace!("Using stored access token"); + let _ = response_tx.send(Ok(access_token.access_token.clone())); + continue; + } + + log::debug!("Replacing expired access token"); + account_state.current_access_token = None; + } + + // Begin requesting an access token if it's not already underway. + // If there's already an inflight request, just save `response_tx` + account_state + .inflight_request + .get_or_insert_with(|| { + let completed_tx = completed_tx.clone(); + let account = account.clone(); + let service = service.clone(); + let factory = factory.clone(); + + log::debug!("Fetching access token for an account"); + + tokio::spawn(async move { + let result = fetch_access_token(service, factory, account.clone()).await; + let _ = completed_tx.unbounded_send((account, result)); + }) + }); + + // Save the channel to respond to later + account_state.response_channels.push(response_tx); + } + StoreAction::InvalidateToken(account) => { + let account_state = account_states + .entry(account) + .or_default(); + + // Drop in-flight requests for the account + // & forget any existing access token + + log::debug!("Invalidating access token for an account"); + + if let Some(task) = account_state.inflight_request.take() { + task.abort(); + let _ = task.await; + } + + account_state.response_channels.clear(); + account_state.current_access_token = None; + } + } + } + + Some((account, result)) = completed_rx.next() => { + let account_state = account_states + .entry(account) + .or_default(); + + account_state.inflight_request = None; + + // Send response to all channels + for tx in account_state.response_channels.drain(..) { + let _ = tx.send(result.clone().map(|data| data.access_token)); + } + + if let Ok(access_token) = result { + account_state.current_access_token = Some(access_token); + } + } + } } } /// Obtain access token for an account, requesting a new one from the API if necessary. pub async fn get_token(&self, account: &AccountToken) -> Result<AccessToken, rest::Error> { - let existing_token = { - self.access_from_account - .lock() - .unwrap() - .get(account.as_str()) - .cloned() - }; - if let Some(access_token) = existing_token { - if access_token.is_expired() { - log::debug!("Replacing expired access token"); - return self.request_new_token(account.clone()).await; - } - log::trace!("Using stored access token"); - return Ok(access_token.access_token.clone()); - } - self.request_new_token(account.clone()).await + let (tx, rx) = oneshot::channel(); + let _ = self + .tx + .unbounded_send(StoreAction::GetAccessToken(account.to_owned(), tx)); + rx.await.map_err(|_| rest::Error::Aborted)? } /// Remove an access token if the API response calls for it. - pub fn check_response<T>(&self, account: &AccessToken, response: &Result<T, rest::Error>) { + pub fn check_response<T>(&self, account: &AccountToken, response: &Result<T, rest::Error>) { if let Err(rest::Error::ApiError(_status, code)) = response { if code == crate::INVALID_ACCESS_TOKEN { - log::debug!("Dropping invalid access token"); - self.remove_token(account); + let _ = self + .tx + .unbounded_send(StoreAction::InvalidateToken(account.to_owned())); } } } +} - /// Removes a stored access token. - fn remove_token(&self, account: &AccountToken) -> Option<AccessToken> { - self.access_from_account - .lock() - .unwrap() - .remove(account) - .map(|v| v.access_token) - } - - async fn request_new_token(&self, account: AccountToken) -> Result<AccessToken, rest::Error> { - log::debug!("Fetching access token for an account"); - let access_token = self - .fetch_access_token(account.clone()) - .await - .map_err(|error| { - log::error!( - "{}", - error.display_chain_with_msg("Failed to obtain access token") - ); - error - })?; - self.access_from_account - .lock() - .unwrap() - .insert(account, access_token.clone()); - Ok(access_token.access_token) +async fn fetch_access_token( + service: RequestServiceHandle, + factory: RequestFactory, + account_token: AccountToken, +) -> Result<AccessTokenData, rest::Error> { + #[derive(serde::Serialize)] + struct AccessTokenRequest { + account_number: String, } + let request = AccessTokenRequest { + account_number: account_token, + }; - async fn fetch_access_token( - &self, - account_token: AccountToken, - ) -> Result<AccessTokenData, rest::Error> { - #[derive(serde::Serialize)] - struct AccessTokenRequest { - account_number: String, - } - let request = AccessTokenRequest { - account_number: account_token, - }; - - let service = self.service.clone(); - - let rest_request = self - .factory - .post_json(&format!("{AUTH_URL_PREFIX}/token"), &request)?; - let response = service.request(rest_request).await?; - let response = rest::parse_rest_response(response, &[StatusCode::OK]).await?; - rest::deserialize_body(response).await - } + let rest_request = factory.post_json(&format!("{AUTH_URL_PREFIX}/token"), &request)?; + let response = service.request(rest_request).await?; + let response = rest::parse_rest_response(response, &[StatusCode::OK]).await?; + rest::deserialize_body(response).await } |
