summaryrefslogtreecommitdiffhomepage
path: root/mullvad-api/src/access.rs
diff options
context:
space:
mode:
authorDavid Lönnhager <david.l@mullvad.net>2023-10-18 19:23:03 +0200
committerDavid Lönnhager <david.l@mullvad.net>2023-10-19 16:51:59 +0200
commit4e5d08ebbc2cb49304ff0c2f4c77ecc11be30c79 (patch)
treed6458f600d151347b99f52a0082157ee52b7e37f /mullvad-api/src/access.rs
parent58b09e84af5a921af3b4db213d717bc2b0a03770 (diff)
downloadmullvadvpn-4e5d08ebbc2cb49304ff0c2f4c77ecc11be30c79.tar.xz
mullvadvpn-4e5d08ebbc2cb49304ff0c2f4c77ecc11be30c79.zip
Fold all access token requests into a single request
Diffstat (limited to 'mullvad-api/src/access.rs')
-rw-r--r--mullvad-api/src/access.rs231
1 files changed, 151 insertions, 80 deletions
diff --git a/mullvad-api/src/access.rs b/mullvad-api/src/access.rs
index a3bec3f725..67c83ac4da 100644
--- a/mullvad-api/src/access.rs
+++ b/mullvad-api/src/access.rs
@@ -2,109 +2,180 @@ use crate::{
rest,
rest::{RequestFactory, RequestServiceHandle},
};
+use futures::{
+ channel::{mpsc, oneshot},
+ StreamExt,
+};
use hyper::StatusCode;
use mullvad_types::account::{AccessToken, AccessTokenData, AccountToken};
-use std::{
- collections::HashMap,
- sync::{Arc, Mutex},
-};
-use talpid_types::ErrorExt;
+use std::collections::HashMap;
+use tokio::select;
pub const AUTH_URL_PREFIX: &str = "auth/v1";
#[derive(Clone)]
-pub struct AccessTokenProxy {
- service: RequestServiceHandle,
- factory: RequestFactory,
- access_from_account: Arc<Mutex<HashMap<AccountToken, AccessTokenData>>>,
+pub struct AccessTokenStore {
+ tx: mpsc::UnboundedSender<StoreAction>,
+}
+
+enum StoreAction {
+ /// Request an access token for `AccountToken`, or return a saved one if it's not expired.
+ GetAccessToken(
+ AccountToken,
+ oneshot::Sender<Result<AccessToken, rest::Error>>,
+ ),
+ /// Forget cached access token for `AccountToken`, and drop any in-flight requests
+ InvalidateToken(AccountToken),
+}
+
+#[derive(Default)]
+struct AccountState {
+ current_access_token: Option<AccessTokenData>,
+ inflight_request: Option<tokio::task::JoinHandle<()>>,
+ response_channels: Vec<oneshot::Sender<Result<AccessToken, rest::Error>>>,
}
-impl AccessTokenProxy {
+impl AccessTokenStore {
pub(crate) fn new(service: RequestServiceHandle, factory: RequestFactory) -> Self {
- Self {
- service,
- factory,
- access_from_account: Arc::new(Mutex::new(HashMap::new())),
+ let (tx, rx) = mpsc::unbounded();
+ tokio::spawn(Self::service_requests(rx, service, factory));
+ Self { tx }
+ }
+
+ async fn service_requests(
+ mut rx: mpsc::UnboundedReceiver<StoreAction>,
+ service: RequestServiceHandle,
+ factory: RequestFactory,
+ ) {
+ let mut account_states: HashMap<AccountToken, AccountState> = HashMap::new();
+
+ let (completed_tx, mut completed_rx) = mpsc::unbounded();
+
+ loop {
+ select! {
+ action = rx.next() => {
+ let Some(action) = action else {
+ // We're done
+ break;
+ };
+
+ match action {
+ StoreAction::GetAccessToken(account, response_tx) => {
+ let account_state = account_states
+ .entry(account.clone())
+ .or_default();
+
+ // If there is an unexpired access token, just return it.
+ // Otherwise, generate a new token
+ if let Some(ref access_token) = account_state.current_access_token {
+ if !access_token.is_expired() {
+ log::trace!("Using stored access token");
+ let _ = response_tx.send(Ok(access_token.access_token.clone()));
+ continue;
+ }
+
+ log::debug!("Replacing expired access token");
+ account_state.current_access_token = None;
+ }
+
+ // Begin requesting an access token if it's not already underway.
+ // If there's already an inflight request, just save `response_tx`
+ account_state
+ .inflight_request
+ .get_or_insert_with(|| {
+ let completed_tx = completed_tx.clone();
+ let account = account.clone();
+ let service = service.clone();
+ let factory = factory.clone();
+
+ log::debug!("Fetching access token for an account");
+
+ tokio::spawn(async move {
+ let result = fetch_access_token(service, factory, account.clone()).await;
+ let _ = completed_tx.unbounded_send((account, result));
+ })
+ });
+
+ // Save the channel to respond to later
+ account_state.response_channels.push(response_tx);
+ }
+ StoreAction::InvalidateToken(account) => {
+ let account_state = account_states
+ .entry(account)
+ .or_default();
+
+ // Drop in-flight requests for the account
+ // & forget any existing access token
+
+ log::debug!("Invalidating access token for an account");
+
+ if let Some(task) = account_state.inflight_request.take() {
+ task.abort();
+ let _ = task.await;
+ }
+
+ account_state.response_channels.clear();
+ account_state.current_access_token = None;
+ }
+ }
+ }
+
+ Some((account, result)) = completed_rx.next() => {
+ let account_state = account_states
+ .entry(account)
+ .or_default();
+
+ account_state.inflight_request = None;
+
+ // Send response to all channels
+ for tx in account_state.response_channels.drain(..) {
+ let _ = tx.send(result.clone().map(|data| data.access_token));
+ }
+
+ if let Ok(access_token) = result {
+ account_state.current_access_token = Some(access_token);
+ }
+ }
+ }
}
}
/// Obtain access token for an account, requesting a new one from the API if necessary.
pub async fn get_token(&self, account: &AccountToken) -> Result<AccessToken, rest::Error> {
- let existing_token = {
- self.access_from_account
- .lock()
- .unwrap()
- .get(account.as_str())
- .cloned()
- };
- if let Some(access_token) = existing_token {
- if access_token.is_expired() {
- log::debug!("Replacing expired access token");
- return self.request_new_token(account.clone()).await;
- }
- log::trace!("Using stored access token");
- return Ok(access_token.access_token.clone());
- }
- self.request_new_token(account.clone()).await
+ let (tx, rx) = oneshot::channel();
+ let _ = self
+ .tx
+ .unbounded_send(StoreAction::GetAccessToken(account.to_owned(), tx));
+ rx.await.map_err(|_| rest::Error::Aborted)?
}
/// Remove an access token if the API response calls for it.
- pub fn check_response<T>(&self, account: &AccessToken, response: &Result<T, rest::Error>) {
+ pub fn check_response<T>(&self, account: &AccountToken, response: &Result<T, rest::Error>) {
if let Err(rest::Error::ApiError(_status, code)) = response {
if code == crate::INVALID_ACCESS_TOKEN {
- log::debug!("Dropping invalid access token");
- self.remove_token(account);
+ let _ = self
+ .tx
+ .unbounded_send(StoreAction::InvalidateToken(account.to_owned()));
}
}
}
+}
- /// Removes a stored access token.
- fn remove_token(&self, account: &AccountToken) -> Option<AccessToken> {
- self.access_from_account
- .lock()
- .unwrap()
- .remove(account)
- .map(|v| v.access_token)
- }
-
- async fn request_new_token(&self, account: AccountToken) -> Result<AccessToken, rest::Error> {
- log::debug!("Fetching access token for an account");
- let access_token = self
- .fetch_access_token(account.clone())
- .await
- .map_err(|error| {
- log::error!(
- "{}",
- error.display_chain_with_msg("Failed to obtain access token")
- );
- error
- })?;
- self.access_from_account
- .lock()
- .unwrap()
- .insert(account, access_token.clone());
- Ok(access_token.access_token)
+async fn fetch_access_token(
+ service: RequestServiceHandle,
+ factory: RequestFactory,
+ account_token: AccountToken,
+) -> Result<AccessTokenData, rest::Error> {
+ #[derive(serde::Serialize)]
+ struct AccessTokenRequest {
+ account_number: String,
}
+ let request = AccessTokenRequest {
+ account_number: account_token,
+ };
- async fn fetch_access_token(
- &self,
- account_token: AccountToken,
- ) -> Result<AccessTokenData, rest::Error> {
- #[derive(serde::Serialize)]
- struct AccessTokenRequest {
- account_number: String,
- }
- let request = AccessTokenRequest {
- account_number: account_token,
- };
-
- let service = self.service.clone();
-
- let rest_request = self
- .factory
- .post_json(&format!("{AUTH_URL_PREFIX}/token"), &request)?;
- let response = service.request(rest_request).await?;
- let response = rest::parse_rest_response(response, &[StatusCode::OK]).await?;
- rest::deserialize_body(response).await
- }
+ let rest_request = factory.post_json(&format!("{AUTH_URL_PREFIX}/token"), &request)?;
+ let response = service.request(rest_request).await?;
+ let response = rest::parse_rest_response(response, &[StatusCode::OK]).await?;
+ rest::deserialize_body(response).await
}