summaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorSebastian Holmin <sebastian.holmin@mullvad.net>2024-10-01 10:06:39 +0200
committerSebastian Holmin <sebastian.holmin@mullvad.net>2024-10-16 09:56:49 +0200
commitae1c1cefb5c048fcec021397a85ad886197913f7 (patch)
tree60075efd2ef0fddcdf2f4700519ce8639eb589de
parent088b1f68db4268e923d011ea79c05f6c46ca72f2 (diff)
downloadmullvadvpn-ae1c1cefb5c048fcec021397a85ad886197913f7.tar.xz
mullvadvpn-ae1c1cefb5c048fcec021397a85ad886197913f7.zip
Replace occurrences of old `Body` type in `rest` mod
Use `Empty<Bytes>` for outgoing, `Incoming` for responses and generic paras for our type wrapping `Request`.
-rw-r--r--mullvad-api/src/lib.rs3
-rw-r--r--mullvad-api/src/relay_list.rs4
-rw-r--r--mullvad-api/src/rest.rs223
-rw-r--r--mullvad-daemon/src/geoip.rs2
4 files changed, 154 insertions, 78 deletions
diff --git a/mullvad-api/src/lib.rs b/mullvad-api/src/lib.rs
index fc731ee6ea..6b3ac3c951 100644
--- a/mullvad-api/src/lib.rs
+++ b/mullvad-api/src/lib.rs
@@ -1,7 +1,6 @@
#![allow(rustdoc::private_intra_doc_links)]
#[cfg(target_os = "android")]
use futures::channel::mpsc;
-use hyper::Method;
#[cfg(target_os = "android")]
use mullvad_types::account::{PlayPurchase, PlayPurchasePaymentToken};
use mullvad_types::{
@@ -710,7 +709,7 @@ impl AppVersionProxy {
let service = self.handle.service.clone();
let path = format!("{APP_URL_PREFIX}/releases/{platform}/{app_version}");
- let request = self.handle.factory.request(&path, Method::GET);
+ let request = self.handle.factory.get(&path);
async move {
let request = request?
diff --git a/mullvad-api/src/relay_list.rs b/mullvad-api/src/relay_list.rs
index 5f2b2d6d81..f1375b5f6f 100644
--- a/mullvad-api/src/relay_list.rs
+++ b/mullvad-api/src/relay_list.rs
@@ -2,7 +2,7 @@
use crate::rest;
-use hyper::{header, Method, StatusCode};
+use hyper::{header, StatusCode};
use mullvad_types::{location, relay_list};
use talpid_types::net::wireguard;
@@ -34,7 +34,7 @@ impl RelayListProxy {
etag: Option<String>,
) -> impl Future<Output = Result<Option<relay_list::RelayList>, rest::Error>> {
let service = self.handle.service.clone();
- let request = self.handle.factory.request("app/v1/relays", Method::GET);
+ let request = self.handle.factory.get("app/v1/relays");
async move {
let mut request = request?
diff --git a/mullvad-api/src/rest.rs b/mullvad-api/src/rest.rs
index c8df2aea8e..0a2bed93b7 100644
--- a/mullvad-api/src/rest.rs
+++ b/mullvad-api/src/rest.rs
@@ -11,10 +11,13 @@ use futures::{
channel::{mpsc, oneshot},
stream::StreamExt,
};
+use http_body_util::{combinators::BoxBody, BodyExt, Empty, Full};
use hyper::{
- client::{connect::Connect, Client},
+ body::{Body, Bytes, Incoming},
+ // client::{connect::Connect, Client},
header::{self, HeaderValue},
- Method, Uri,
+ Method,
+ Uri,
};
use mullvad_types::account::AccountNumber;
use std::{
@@ -62,8 +65,8 @@ pub enum Error {
ApiError(StatusCode, String),
/// The string given was not a valid URI.
- #[error("Not a valid URI")]
- InvalidUri,
+ #[error("Not a valid URI {0}")]
+ InvalidUri(#[from] Arc<http::uri::InvalidUri>),
#[error("Set account number on factory with no access token store")]
NoAccessTokenStore,
@@ -119,7 +122,11 @@ pub(crate) struct RequestService<T: ConnectionModeProvider> {
command_tx: Weak<mpsc::UnboundedSender<RequestCommand>>,
command_rx: mpsc::UnboundedReceiver<RequestCommand>,
connector_handle: HttpsConnectorWithSniHandle,
- client: hyper::Client<HttpsConnectorWithSni, hyper::Body>,
+ // client: hyper_util::client::legacy::Client<
+ // HttpsConnectorWithSni,
+ // BoxBody<dyn hyper::body::Buf, Error>,
+ // >,
+ client: HttpsConnectorWithSni,
connection_mode_provider: T,
connection_mode_generation: usize,
api_availability: ApiAvailability,
@@ -144,7 +151,8 @@ impl<T: ConnectionModeProvider + 'static> RequestService<T> {
connector_handle.set_connection_mode(connection_mode_provider.initial());
let (command_tx, command_rx) = mpsc::unbounded();
- let client = Client::builder().build(connector);
+ // let client =
+ // hyper_util::client::legacy::Client::builder(TokioExecutor::new()).build(connector);
let command_tx = Arc::new(command_tx);
@@ -152,7 +160,7 @@ impl<T: ConnectionModeProvider + 'static> RequestService<T> {
command_tx: Arc::downgrade(&command_tx),
command_rx,
connector_handle,
- client,
+ client: connector,
connection_mode_provider,
connection_mode_generation: 0,
api_availability,
@@ -203,13 +211,15 @@ impl<T: ConnectionModeProvider + 'static> RequestService<T> {
fn handle_new_request(
&mut self,
- request: Request,
- completion_tx: oneshot::Sender<Result<Response>>,
+ request: Request<BoxBody<Bytes, Error>>,
+ completion_tx: oneshot::Sender<Result<Response<Incoming>>>,
) {
let tx = self.command_tx.upgrade();
let api_availability = self.api_availability.clone();
- let request_future = request.into_future(self.client.clone(), api_availability.clone());
+ let request_future = request
+ .map(|r| http::Request::map(r, BodyExt::boxed))
+ .into_future(self.client.clone(), api_availability.clone());
let connection_mode_generation = self.connection_mode_generation;
@@ -246,8 +256,14 @@ impl RequestServiceHandle {
}
/// Submits a `RestRequest` for execution to the request service.
- pub async fn request(&self, request: Request) -> Result<Response> {
+ pub async fn request<B>(&self, request: Request<B>) -> Result<Response<Incoming>>
+ where
+ B: Body + Send + Sync + 'static,
+ Error: From<B::Error>,
+ Bytes: From<B::Data>,
+ {
let (completion_tx, completion_rx) = oneshot::channel();
+ let request = request.map(|r| r.map(box_body));
self.tx
.unbounded_send(RequestCommand::NewRequest(request, completion_tx))
.map_err(|_| Error::RestServiceDown)?;
@@ -258,8 +274,8 @@ impl RequestServiceHandle {
#[derive(Debug)]
pub(crate) enum RequestCommand {
NewRequest(
- Request,
- oneshot::Sender<std::result::Result<Response, Error>>,
+ Request<BoxBody<Bytes, Error>>,
+ oneshot::Sender<std::result::Result<Response<Incoming>, Error>>,
),
Reset,
NextApiConfig(usize),
@@ -267,18 +283,18 @@ pub(crate) enum RequestCommand {
/// A REST request that is sent to the RequestService to be executed.
#[derive(Debug)]
-pub struct Request {
- request: hyper::Request<hyper::Body>,
+pub struct Request<B> {
+ request: hyper::Request<B>,
timeout: Duration,
access_token_store: Option<AccessTokenStore>,
account: Option<AccountNumber>,
expected_status: &'static [hyper::StatusCode],
}
-impl Request {
+// TODO: merge with `RequestFactory::get`
/// Constructs a GET request with the given URI. Returns an error if the URI is not valid.
- pub fn get(uri: &str) -> Result<Self> {
- let uri = hyper::Uri::from_str(uri).map_err(|_| Error::InvalidUri)?;
+pub fn get(uri: &str) -> Result<Request<Empty<Bytes>>> {
+ let uri = hyper::Uri::from_str(uri)?;
let mut builder = http::request::Builder::new()
.method(Method::GET)
@@ -287,18 +303,16 @@ impl Request {
if let Some(host) = uri.host() {
builder = builder.header(
header::HOST,
- HeaderValue::from_str(host).map_err(|_| Error::InvalidHeaderError)?,
+ HeaderValue::from_str(host).map_err(|_e| Error::InvalidHeaderError)?,
);
};
- let request = builder.uri(uri).body(hyper::Body::empty())?;
- Ok(Self::new(request, None))
+ let request = builder.uri(uri).body(Empty::<Bytes>::new())?;
+ Ok(Request::new(request, None))
}
- fn new(
- request: hyper::Request<hyper::Body>,
- access_token_store: Option<AccessTokenStore>,
- ) -> Self {
+impl<B: Body> Request<B> {
+ fn new(request: hyper::Request<B>, access_token_store: Option<AccessTokenStore>) -> Self {
Self {
request,
timeout: DEFAULT_TIMEOUT,
@@ -336,11 +350,64 @@ impl Request {
Ok(self)
}
+ /// Returns the URI of the request
+ pub fn uri(&self) -> &Uri {
+ self.request.uri()
+ }
+}
+impl<B> Request<B> {
+ /// Map the underlying [`hyper::Request`] type
+ fn map<F, B2>(self, f: F) -> Request<B2>
+ where
+ F: FnOnce(hyper::Request<B>) -> hyper::Request<B2>,
+ {
+ Request {
+ request: f(self.request),
+ timeout: self.timeout,
+ access_token_store: self.access_token_store,
+ account: self.account,
+ expected_status: self.expected_status,
+ }
+ }
+}
+
+fn box_body<B>(body: B) -> BoxBody<Bytes, Error>
+where
+ B: Body + Send + Sync + 'static,
+ Error: From<B::Error>,
+ Bytes: From<B::Data>,
+{
+ try_downcast(body).unwrap_or_else(|body| {
+ body.map_frame(|frame| frame.map_data(Bytes::from))
+ .map_err(Error::from)
+ .boxed()
+ })
+}
+
+pub(crate) fn try_downcast<T, K>(k: K) -> core::result::Result<T, K>
+where
+ T: 'static,
+ K: Send + 'static,
+{
+ let mut k = Some(k);
+ if let Some(k) = <dyn std::any::Any>::downcast_mut::<Option<T>>(&mut k) {
+ Ok(k.take().unwrap())
+ } else {
+ Err(k.unwrap())
+ }
+}
+
+impl<B> Request<B>
+where
+ B: Body + Send + 'static + Unpin,
+ B::Data: Send,
+ B::Error: Into<Box<dyn StdError + Send + Sync>>,
+{
async fn into_future<C: Connect + Clone + Send + Sync + 'static>(
self,
- hyper_client: hyper::Client<C>,
+ hyper_client: hyper_util::client::legacy::Client<C, B>,
api_availability: ApiAvailability,
- ) -> Result<Response> {
+ ) -> Result<Response<Incoming>> {
let timeout = self.timeout;
let inner_fut = self.into_future_without_timeout(hyper_client, api_availability);
tokio::time::timeout(timeout, inner_fut)
@@ -348,11 +415,14 @@ impl Request {
.map_err(|_| Error::TimeoutError)?
}
- async fn into_future_without_timeout<C: Connect + Clone + Send + Sync + 'static>(
+ async fn into_future_without_timeout<C>(
mut self,
- hyper_client: hyper::Client<C>,
+ hyper_client: hyper_util::client::legacy::Client<C, B>,
api_availability: ApiAvailability,
- ) -> Result<Response> {
+ ) -> Result<Response<Incoming>>
+ where
+ C: Connect + Clone + Send + Sync + 'static,
+ {
let _ = api_availability.wait_for_unsuspend().await;
// Obtain access token first
@@ -399,21 +469,19 @@ impl Request {
Ok(Response::new(response))
}
-
- /// Returns the URI of the request
- pub fn uri(&self) -> &Uri {
- self.request.uri()
- }
}
/// Successful result of a REST request
#[derive(Debug)]
-pub struct Response {
- response: hyper::Response<hyper::Body>,
+pub struct Response<B> {
+ response: hyper::Response<B>,
}
-impl Response {
- fn new(response: hyper::Response<hyper::Body>) -> Self {
+impl<B: Body> Response<B>
+where
+ Error: From<<B as Body>::Error>,
+{
+ fn new(response: hyper::Response<B>) -> Self {
Self { response }
}
@@ -426,8 +494,7 @@ impl Response {
}
pub async fn deserialize<T: serde::de::DeserializeOwned>(self) -> Result<T> {
- let body_length = get_body_length(&self.response);
- deserialize_body_inner(self.response, body_length).await
+ deserialize_body_inner(self.response).await
}
}
@@ -462,38 +529,46 @@ impl RequestFactory {
}
}
- pub fn request(&self, path: &str, method: Method) -> Result<Request> {
+ pub fn request<B: Body + Default>(&self, path: &str, method: Method) -> Result<Request<B>> {
Ok(
Request::new(self.hyper_request(path, method)?, self.token_store.clone())
.timeout(self.default_timeout),
)
}
- pub fn get(&self, path: &str) -> Result<Request> {
+ pub fn get(&self, path: &str) -> Result<Request<Empty<Bytes>>> {
self.request(path, Method::GET)
}
- pub fn post(&self, path: &str) -> Result<Request> {
+ pub fn post(&self, path: &str) -> Result<Request<Empty<Bytes>>> {
self.request(path, Method::POST)
}
- pub fn put(&self, path: &str) -> Result<Request> {
+ pub fn put(&self, path: &str) -> Result<Request<Empty<Bytes>>> {
self.request(path, Method::PUT)
}
- pub fn delete(&self, path: &str) -> Result<Request> {
+ pub fn delete(&self, path: &str) -> Result<Request<Empty<Bytes>>> {
self.request(path, Method::DELETE)
}
- pub fn head(&self, path: &str) -> Result<Request> {
+ pub fn head(&self, path: &str) -> Result<Request<Empty<Bytes>>> {
self.request(path, Method::HEAD)
}
- pub fn post_json<S: serde::Serialize>(&self, path: &str, body: &S) -> Result<Request> {
+ pub fn post_json<S: serde::Serialize>(
+ &self,
+ path: &str,
+ body: &S,
+ ) -> Result<Request<Full<Bytes>>> {
self.json_request(Method::POST, path, body)
}
- pub fn put_json<S: serde::Serialize>(&self, path: &str, body: &S) -> Result<Request> {
+ pub fn put_json<S: serde::Serialize>(
+ &self,
+ path: &str,
+ body: &S,
+ ) -> Result<Request<Full<Bytes>>> {
self.json_request(Method::PUT, path, body)
}
@@ -501,18 +576,17 @@ impl RequestFactory {
self.default_timeout = timeout;
self
}
-
fn json_request<S: serde::Serialize>(
&self,
method: Method,
path: &str,
body: &S,
- ) -> Result<Request> {
+ ) -> Result<Request<Full<Bytes>>> {
let mut request = self.hyper_request(path, method)?;
- let json_body = serde_json::to_string(&body)?;
- let body_length = json_body.as_bytes().len();
- *request.body_mut() = json_body.into_bytes().into();
+ let json_body = serde_json::to_vec(&body)?;
+ let body_length = json_body.len();
+ *request.body_mut() = Full::new(Bytes::from(json_body));
let headers = request.headers_mut();
headers.insert(header::CONTENT_LENGTH, HeaderValue::from(body_length));
@@ -524,7 +598,7 @@ impl RequestFactory {
Ok(Request::new(request, self.token_store.clone()).timeout(self.default_timeout))
}
- fn hyper_request(&self, path: &str, method: Method) -> Result<hyper::Request<hyper::Body>> {
+ fn hyper_request<B: Default>(&self, path: &str, method: Method) -> Result<http::Request<B>> {
let uri = self.get_uri(path)?;
let request = http::request::Builder::new()
.method(method)
@@ -536,17 +610,17 @@ impl RequestFactory {
HeaderValue::from_str(&self.hostname).map_err(|_| Error::InvalidHeaderError)?,
);
- let result = request.body(hyper::Body::empty())?;
+ let result = request.body(B::default())?;
Ok(result)
}
fn get_uri(&self, path: &str) -> Result<Uri> {
let uri = format!("https://{}/{}", self.hostname, path);
- hyper::Uri::from_str(&uri).map_err(|_| Error::InvalidUri)
+ Ok(hyper::Uri::from_str(&uri)?)
}
}
-fn get_body_length(response: &hyper::Response<hyper::Body>) -> usize {
+fn get_body_length<B>(response: &hyper::Response<B>) -> usize {
response
.headers()
.get(header::CONTENT_LENGTH)
@@ -555,20 +629,22 @@ fn get_body_length(response: &hyper::Response<hyper::Body>) -> usize {
.unwrap_or(0)
}
-async fn handle_error_response<T>(response: hyper::Response<hyper::Body>) -> Result<T> {
+async fn handle_error_response<T, B: Body>(response: hyper::Response<B>) -> Result<T>
+where
+ Error: From<B::Error>,
+{
let status = response.status();
let error_message = match status {
hyper::StatusCode::METHOD_NOT_ALLOWED => "Method not allowed",
status => match get_body_length(&response) {
0 => status.canonical_reason().unwrap_or("Unexpected error"),
- body_length => {
+ _length => {
return match response.headers().get("content-type") {
Some(content_type) if content_type == "application/problem+json" => {
// TODO: We should make sure we unify the new error format and the old
// error format so that they both produce the same Errors for the same
// problems after being processed.
- let err: NewErrorResponse =
- deserialize_body_inner(response, body_length).await?;
+ let err: NewErrorResponse = deserialize_body_inner(response).await?;
// The new error type replaces the `code` field with the `type` field.
// This is what is used to programmatically check the error.
Err(Error::ApiError(
@@ -578,8 +654,7 @@ async fn handle_error_response<T>(response: hyper::Response<hyper::Body>) -> Res
))
}
_ => {
- let err: OldErrorResponse =
- deserialize_body_inner(response, body_length).await?;
+ let err: OldErrorResponse = deserialize_body_inner(response).await?;
Err(Error::ApiError(status, err.code))
}
};
@@ -589,16 +664,17 @@ async fn handle_error_response<T>(response: hyper::Response<hyper::Body>) -> Res
Err(Error::ApiError(status, error_message.to_owned()))
}
-async fn deserialize_body_inner<T: serde::de::DeserializeOwned>(
- mut response: hyper::Response<hyper::Body>,
- body_length: usize,
-) -> Result<T> {
- let mut body: Vec<u8> = Vec::with_capacity(body_length);
- while let Some(chunk) = response.body_mut().next().await {
- body.extend(&chunk?);
- }
+async fn deserialize_body_inner<T, B>(response: hyper::Response<B>) -> Result<T>
+where
+ T: serde::de::DeserializeOwned,
+ B: Body,
+ Error: From<B::Error>,
+{
+ use http_body_util::BodyExt;
- serde_json::from_slice(&body).map_err(Error::from)
+ let collected = BodyExt::collect(response).await?;
+ let res = serde_json::from_slice(&collected.to_bytes())?;
+ Ok(res)
}
#[derive(Clone)]
@@ -639,3 +715,4 @@ macro_rules! impl_into_arc_err {
impl_into_arc_err!(hyper::Error);
impl_into_arc_err!(serde_json::Error);
impl_into_arc_err!(http::Error);
+impl_into_arc_err!(http::uri::InvalidUri);
diff --git a/mullvad-daemon/src/geoip.rs b/mullvad-daemon/src/geoip.rs
index da2fb3e8db..815b83b13f 100644
--- a/mullvad-daemon/src/geoip.rs
+++ b/mullvad-daemon/src/geoip.rs
@@ -154,7 +154,7 @@ async fn send_location_request_internal(
service: RequestServiceHandle,
) -> Result<AmIMullvad, Error> {
let future_service = service.clone();
- let request = mullvad_api::rest::Request::get(uri)?;
+ let request = mullvad_api::rest::get(uri)?;
future_service.request(request).await?.deserialize().await
}