summaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorDavid Lönnhager <david.l@mullvad.net>2022-01-21 13:23:27 +0100
committerDavid Lönnhager <david.l@mullvad.net>2022-03-01 15:30:22 +0100
commitbcf3278eeb1b63f2ff8fa6ee68ab4cc8bb8b76fd (patch)
tree95721793a135d304f73a8724a5f014960316df80
parent45d827f96b524f1183a2f3709700c0bea643faab (diff)
downloadmullvadvpn-bcf3278eeb1b63f2ff8fa6ee68ab4cc8bb8b76fd.tar.xz
mullvadvpn-bcf3278eeb1b63f2ff8fa6ee68ab4cc8bb8b76fd.zip
Add Shadowsocks support to HTTPS connector
-rw-r--r--mullvad-rpc/src/address_cache.rs139
-rw-r--r--mullvad-rpc/src/bin/relay_list.rs16
-rw-r--r--mullvad-rpc/src/https_client_with_sni.rs184
-rw-r--r--mullvad-rpc/src/lib.rs68
-rw-r--r--mullvad-rpc/src/proxy.rs171
-rw-r--r--mullvad-rpc/src/rest.rs93
6 files changed, 452 insertions, 219 deletions
diff --git a/mullvad-rpc/src/address_cache.rs b/mullvad-rpc/src/address_cache.rs
index 099ffbc075..3b6fcba074 100644
--- a/mullvad-rpc/src/address_cache.rs
+++ b/mullvad-rpc/src/address_cache.rs
@@ -1,13 +1,9 @@
-use std::{
- io,
- net::SocketAddr,
- ops::{Deref, DerefMut},
- path::Path,
- sync::{Arc, Mutex},
-};
+use super::API;
+use std::{io, net::SocketAddr, path::Path, sync::Arc};
use tokio::{
fs,
io::{AsyncReadExt, AsyncWriteExt},
+ sync::Mutex,
};
#[derive(err_derive::Error, Debug)]
@@ -27,99 +23,66 @@ pub enum Error {
#[error(display = "The address cache is empty")]
EmptyAddressCache,
-
- #[error(display = "The address change listener returned an error")]
- ChangeListenerError,
}
-pub type CurrentAddressChangeListener =
- dyn Fn(SocketAddr) -> Result<(), ()> + Send + Sync + 'static;
-
#[derive(Clone)]
pub struct AddressCache {
inner: Arc<Mutex<AddressCacheInner>>,
write_path: Option<Arc<Path>>,
- change_listener: Arc<Box<CurrentAddressChangeListener>>,
}
impl AddressCache {
- /// Initialize cache using the given list, and write changes to `write_path`.
- pub fn new(address: SocketAddr, write_path: Option<Box<Path>>) -> Result<Self, Error> {
+ /// Initialize cache using the hardcoded address, and write changes to `write_path`.
+ pub fn new(write_path: Option<Box<Path>>) -> Result<Self, Error> {
+ Self::new_inner(API.addr, write_path)
+ }
+
+ /// Initialize cache using `read_path`, and write changes to `write_path`.
+ pub async fn from_file(read_path: &Path, write_path: Option<Box<Path>>) -> Result<Self, Error> {
+ log::debug!("Loading API addresses from {}", read_path.display());
+ Self::new_inner(read_address_file(read_path).await?, write_path)
+ }
+
+ fn new_inner(address: SocketAddr, write_path: Option<Box<Path>>) -> Result<Self, Error> {
let cache = AddressCacheInner::from_address(address);
- log::trace!("API address cache: {:?}", cache.address);
- log::debug!("Using API address: {:?}", Self::get_address_inner(&cache));
+ log::debug!("Using API address: {}", cache.address);
let address_cache = Self {
inner: Arc::new(Mutex::new(cache)),
write_path: write_path.map(|cache| Arc::from(cache)),
- change_listener: Arc::new(Box::new(|_| Ok(()))),
};
Ok(address_cache)
}
- /// Initialize cache using `read_path`, and write changes to `write_path`.
- pub async fn from_file(read_path: &Path, write_path: Option<Box<Path>>) -> Result<Self, Error> {
- log::debug!("Loading API addresses from {}", read_path.display());
- Self::new(read_address_file(read_path).await?, write_path)
- }
-
- pub fn set_change_listener(&mut self, change_listener: Arc<Box<CurrentAddressChangeListener>>) {
- self.change_listener = change_listener;
+ /// Returns the address if the hostname equals `API.host`. Otherwise, returns `None`.
+ pub async fn resolve_hostname(&self, hostname: &str) -> Option<SocketAddr> {
+ if hostname.eq_ignore_ascii_case(&API.host) {
+ Some(self.get_address().await)
+ } else {
+ None
+ }
}
/// Returns the currently selected address.
- pub fn get_address(&self) -> SocketAddr {
- let inner = self.inner.lock().unwrap();
- Self::get_address_inner(&inner)
- }
-
- fn get_address_inner(inner: &AddressCacheInner) -> SocketAddr {
- inner.address
+ pub async fn get_address(&self) -> SocketAddr {
+ self.inner.lock().await.address
}
pub async fn set_address(&self, address: SocketAddr) -> io::Result<()> {
- let should_update = {
- let mut inner = self.inner.lock().unwrap();
- let mut transaction = AddressCacheTransaction::new(&mut inner);
-
- let current_address = transaction.address.clone();
-
- if address != current_address {
- transaction.address = address.clone();
- tokio::task::block_in_place(move || {
- if (*self.change_listener)(Self::get_address_inner(&transaction)).is_err() {
- log::error!("Failed to select new API endpoint");
- return Err(io::Error::new(
- io::ErrorKind::Other,
- "callback returned an error",
- ));
- }
- transaction.commit();
- Ok(())
- })?;
- true
- } else {
- false
- }
- };
- if should_update {
- log::trace!("API address cache: {}", address);
- self.save_to_disk().await?;
+ let mut inner = self.inner.lock().await;
+ if address != inner.address {
+ self.save_to_disk(&address).await?;
+ inner.address = address;
}
Ok(())
}
- async fn save_to_disk(&self) -> io::Result<()> {
+ async fn save_to_disk(&self, address: &SocketAddr) -> io::Result<()> {
let write_path = match self.write_path.as_ref() {
Some(write_path) => write_path,
None => return Ok(()),
};
- let address = {
- let inner = self.inner.lock().unwrap();
- inner.address.clone()
- };
-
let temp_path = write_path.with_file_name("api-cache.temp");
let mut file = fs::File::create(&temp_path).await?;
@@ -132,16 +95,6 @@ impl AddressCache {
}
}
-impl crate::rest::AddressProvider for AddressCache {
- fn get_address(&self) -> String {
- self.get_address().to_string()
- }
-
- fn clone_box(&self) -> Box<dyn crate::rest::AddressProvider> {
- Box::new(self.clone())
- }
-}
-
#[derive(Clone, PartialEq, Eq)]
struct AddressCacheInner {
address: SocketAddr,
@@ -153,38 +106,6 @@ impl AddressCacheInner {
}
}
-struct AddressCacheTransaction<'a> {
- current: &'a mut AddressCacheInner,
- working_cache: AddressCacheInner,
-}
-
-impl<'a> AddressCacheTransaction<'a> {
- fn new(cache: &'a mut AddressCacheInner) -> Self {
- Self {
- working_cache: cache.clone(),
- current: cache,
- }
- }
-
- fn commit(self) {
- *self.current = self.working_cache;
- }
-}
-
-impl<'a> Deref for AddressCacheTransaction<'a> {
- type Target = AddressCacheInner;
-
- fn deref(&self) -> &Self::Target {
- &self.working_cache
- }
-}
-
-impl<'a> DerefMut for AddressCacheTransaction<'a> {
- fn deref_mut(&mut self) -> &mut Self::Target {
- &mut self.working_cache
- }
-}
-
async fn read_address_file(path: &Path) -> Result<SocketAddr, Error> {
let mut file = fs::File::open(path)
.await
diff --git a/mullvad-rpc/src/bin/relay_list.rs b/mullvad-rpc/src/bin/relay_list.rs
index db7fc29854..66118d3ade 100644
--- a/mullvad-rpc/src/bin/relay_list.rs
+++ b/mullvad-rpc/src/bin/relay_list.rs
@@ -2,18 +2,24 @@
//! Used by the installer artifact packer to bundle the latest available
//! relay list at the time of creating the installer.
-use mullvad_rpc::{rest::Error as RestError, MullvadRpcRuntime, RelayListProxy};
+use mullvad_rpc::{
+ proxy::ApiConnectionMode, rest::Error as RestError, MullvadRpcRuntime, RelayListProxy,
+};
use std::process;
use talpid_types::ErrorExt;
#[tokio::main]
async fn main() {
- let mut runtime =
+ let runtime =
MullvadRpcRuntime::new(tokio::runtime::Handle::current()).expect("Failed to load runtime");
- let relay_list_request = RelayListProxy::new(runtime.mullvad_rest_handle())
- .relay_list(None)
- .await;
+ let relay_list_request = RelayListProxy::new(
+ runtime
+ .mullvad_rest_handle(ApiConnectionMode::Direct.into_repeat(), |_| async { true })
+ .await,
+ )
+ .relay_list(None)
+ .await;
let relay_list = match relay_list_request {
Ok(relay_list) => relay_list,
diff --git a/mullvad-rpc/src/https_client_with_sni.rs b/mullvad-rpc/src/https_client_with_sni.rs
index 4d8108f89b..8328ed93d9 100644
--- a/mullvad-rpc/src/https_client_with_sni.rs
+++ b/mullvad-rpc/src/https_client_with_sni.rs
@@ -1,8 +1,10 @@
use crate::{
abortable_stream::{AbortableStream, AbortableStreamHandle},
+ proxy::{ApiConnection, ApiConnectionMode, ProxyConfig},
tls_stream::TlsStream,
+ AddressCache,
};
-use futures::{channel::mpsc, StreamExt};
+use futures::{channel::mpsc, future, StreamExt};
#[cfg(target_os = "android")]
use futures::{channel::oneshot, sink::SinkExt};
use http::uri::Scheme;
@@ -11,6 +13,13 @@ use hyper::{
service::Service,
Uri,
};
+use shadowsocks::{
+ config::ServerType,
+ context::{Context as SsContext, SharedContext},
+ crypto::v1::CipherKind,
+ relay::tcprelay::ProxyClientStream,
+ ServerConfig,
+};
#[cfg(target_os = "android")]
use std::os::unix::io::{AsRawFd, RawFd};
use std::{
@@ -24,6 +33,7 @@ use std::{
task::{Context, Poll},
time::Duration,
};
+use talpid_types::ErrorExt;
#[cfg(target_os = "android")]
use tokio::net::TcpSocket;
@@ -33,13 +43,70 @@ const CONNECT_TIMEOUT: Duration = Duration::from_secs(5);
#[derive(Clone)]
pub struct HttpsConnectorWithSniHandle {
- tx: mpsc::UnboundedSender<()>,
+ tx: mpsc::UnboundedSender<HttpsConnectorRequest>,
}
impl HttpsConnectorWithSniHandle {
/// Stop all streams produced by this connector
pub fn reset(&self) {
- let _ = self.tx.unbounded_send(());
+ let _ = self.tx.unbounded_send(HttpsConnectorRequest::Reset);
+ }
+
+ /// Change the proxy settings for the connector
+ pub fn set_connection_mode(&self, proxy: ApiConnectionMode) {
+ let _ = self
+ .tx
+ .unbounded_send(HttpsConnectorRequest::SetConnectionMode(proxy));
+ }
+}
+
+enum HttpsConnectorRequest {
+ Reset,
+ SetConnectionMode(ApiConnectionMode),
+}
+
+#[derive(Clone)]
+enum InnerConnectionMode {
+ /// Connect directly to the target.
+ Direct,
+ /// Connect to the destination via a proxy.
+ Proxied(ParsedShadowsocksConfig),
+}
+
+#[derive(Clone)]
+struct ParsedShadowsocksConfig {
+ peer: SocketAddr,
+ password: String,
+ cipher: CipherKind,
+}
+
+impl From<ParsedShadowsocksConfig> for ServerConfig {
+ fn from(config: ParsedShadowsocksConfig) -> Self {
+ ServerConfig::new(config.peer, config.password, config.cipher)
+ }
+}
+
+#[derive(err_derive::Error, Debug)]
+enum ProxyConfigError {
+ #[error(display = "Unrecognized cipher selected: {}", _0)]
+ InvalidCipher(String),
+}
+
+impl TryFrom<ApiConnectionMode> for InnerConnectionMode {
+ type Error = ProxyConfigError;
+
+ fn try_from(config: ApiConnectionMode) -> Result<Self, Self::Error> {
+ Ok(match config {
+ ApiConnectionMode::Direct => InnerConnectionMode::Direct,
+ ApiConnectionMode::Proxied(ProxyConfig::Shadowsocks(config)) => {
+ InnerConnectionMode::Proxied(ParsedShadowsocksConfig {
+ peer: config.peer,
+ password: config.password,
+ cipher: CipherKind::from_str(&config.cipher)
+ .map_err(|_| ProxyConfigError::InvalidCipher(config.cipher))?,
+ })
+ }
+ })
}
}
@@ -48,12 +115,16 @@ impl HttpsConnectorWithSniHandle {
pub struct HttpsConnectorWithSni {
inner: Arc<Mutex<HttpsConnectorWithSniInner>>,
sni_hostname: Option<String>,
+ address_cache: AddressCache,
+ abort_notify: Arc<tokio::sync::Notify>,
+ proxy_context: SharedContext,
#[cfg(target_os = "android")]
socket_bypass_tx: Option<mpsc::Sender<SocketBypassRequest>>,
}
struct HttpsConnectorWithSniInner {
stream_handles: Vec<AbortableStreamHandle>,
+ proxy_config: InnerConnectionMode,
}
#[cfg(target_os = "android")]
@@ -63,24 +134,46 @@ impl HttpsConnectorWithSni {
pub fn new(
handle: Handle,
sni_hostname: Option<String>,
+ address_cache: AddressCache,
#[cfg(target_os = "android")] socket_bypass_tx: Option<mpsc::Sender<SocketBypassRequest>>,
) -> (Self, HttpsConnectorWithSniHandle) {
- let (tx, mut rx): (_, mpsc::UnboundedReceiver<()>) = mpsc::unbounded();
+ let (tx, mut rx) = mpsc::unbounded();
+ let abort_notify = Arc::new(tokio::sync::Notify::new());
let inner = Arc::new(Mutex::new(HttpsConnectorWithSniInner {
stream_handles: vec![],
+ proxy_config: InnerConnectionMode::Direct,
}));
let inner_copy = inner.clone();
+ let notify = abort_notify.clone();
handle.spawn(async move {
// Handle requests by `HttpsConnectorWithSniHandle`s
- while let Some(()) = rx.next().await {
+ while let Some(request) = rx.next().await {
let handles = {
let mut inner = inner_copy.lock().unwrap();
+
+ if let HttpsConnectorRequest::SetConnectionMode(config) = request {
+ match InnerConnectionMode::try_from(config) {
+ Ok(config) => {
+ inner.proxy_config = config;
+ }
+ Err(error) => {
+ log::error!(
+ "{}",
+ error.display_chain_with_msg(
+ "Failed to parse new API proxy config"
+ )
+ );
+ }
+ }
+ }
+
std::mem::take(&mut inner.stream_handles)
};
for handle in handles {
handle.close();
}
+ notify.notify_waiters();
}
});
@@ -88,6 +181,9 @@ impl HttpsConnectorWithSni {
HttpsConnectorWithSni {
inner,
sni_hostname,
+ address_cache,
+ abort_notify,
+ proxy_context: SsContext::new_shared(ServerType::Local),
#[cfg(target_os = "android")]
socket_bypass_tx,
},
@@ -125,17 +221,24 @@ impl HttpsConnectorWithSni {
.map_err(|err| io::Error::new(io::ErrorKind::TimedOut, err))?
}
- async fn resolve_address(uri: &Uri) -> io::Result<SocketAddr> {
+ async fn resolve_address(address_cache: AddressCache, uri: Uri) -> io::Result<SocketAddr> {
let hostname = uri.host().ok_or(io::Error::new(
io::ErrorKind::InvalidInput,
"invalid url, missing host",
))?;
let port = uri.port_u16().unwrap_or(443);
-
if let Some(addr) = hostname.parse::<IpAddr>().ok() {
return Ok(SocketAddr::new(addr, port));
}
+ // Preferentially, use cached address.
+ //
+ if let Some(addr) = address_cache.resolve_hostname(hostname).await {
+ return Ok(SocketAddr::new(addr.ip(), port));
+ }
+
+ // Use getaddrinfo as a fallback
+ //
let mut addrs = GaiResolver::new()
.call(
Name::from_str(&hostname)
@@ -157,7 +260,7 @@ impl fmt::Debug for HttpsConnectorWithSni {
}
impl Service<Uri> for HttpsConnectorWithSni {
- type Response = TlsStream<AbortableStream<TcpStream>>;
+ type Response = AbortableStream<ApiConnection>;
type Error = io::Error;
type Future =
Pin<Box<dyn Future<Output = Result<Self::Response, Self::Error>> + Send + 'static>>;
@@ -175,8 +278,11 @@ impl Service<Uri> for HttpsConnectorWithSni {
io::Error::new(io::ErrorKind::InvalidInput, "invalid url, missing host")
});
let inner = self.inner.clone();
+ let abort_notify = self.abort_notify.clone();
+ let proxy_context = self.proxy_context.clone();
#[cfg(target_os = "android")]
let socket_bypass_tx = self.socket_bypass_tx.clone();
+ let address_cache = self.address_cache.clone();
let fut = async move {
if uri.scheme() != Some(&Scheme::HTTPS) {
@@ -187,16 +293,62 @@ impl Service<Uri> for HttpsConnectorWithSni {
}
let hostname = sni_hostname?;
- let addr = Self::resolve_address(&uri).await?;
+ let addr = Self::resolve_address(address_cache, uri).await?;
- let tokio_connection = Self::open_socket(
- addr,
+ // Loop until we have established a connection. This starts over if a new endpoint
+ // is selected while connecting.
+ let stream = loop {
+ let config = { inner.lock().unwrap().proxy_config.clone() };
+ let hostname_copy = hostname.clone();
+ let addr_copy = addr.clone();
+ let context = proxy_context.clone();
#[cfg(target_os = "android")]
- socket_bypass_tx,
- )
- .await?;
+ let socket_bypass_tx_copy = socket_bypass_tx.clone();
+
+ let stream_fut: Pin<
+ Box<dyn Future<Output = Result<ApiConnection, io::Error>> + Send>,
+ > = Box::pin(async move {
+ match config {
+ InnerConnectionMode::Direct => {
+ let socket = Self::open_socket(
+ addr_copy,
+ #[cfg(target_os = "android")]
+ socket_bypass_tx_copy,
+ )
+ .await?;
+ let tls_stream =
+ TlsStream::connect_https(socket, &hostname_copy).await?;
+ Ok(ApiConnection::Direct(tls_stream))
+ }
+ InnerConnectionMode::Proxied(proxy_config) => {
+ let socket = Self::open_socket(
+ proxy_config.peer,
+ #[cfg(target_os = "android")]
+ socket_bypass_tx_copy,
+ )
+ .await?;
+ let proxy = ProxyClientStream::from_stream(
+ context,
+ socket,
+ &ServerConfig::from(proxy_config),
+ addr,
+ );
+ let tls_stream =
+ TlsStream::connect_https(proxy, &hostname_copy).await?;
+ Ok(ApiConnection::Proxied(tls_stream))
+ }
+ }
+ });
+
+ // Wait for connection. Abort and retry if we switched to a different server.
+ if let future::Either::Left((stream, _)) =
+ future::select(stream_fut, Box::pin(abort_notify.notified())).await
+ {
+ break stream?;
+ }
+ };
- let (tcp_stream, socket_handle) = AbortableStream::new(tokio_connection);
+ let (stream, socket_handle) = AbortableStream::new(stream);
{
let mut inner = inner.lock().unwrap();
@@ -204,7 +356,7 @@ impl Service<Uri> for HttpsConnectorWithSni {
inner.stream_handles.push(socket_handle);
}
- Ok(TlsStream::connect_https(tcp_stream, &hostname).await?)
+ Ok(stream)
};
Box::pin(fut)
diff --git a/mullvad-rpc/src/lib.rs b/mullvad-rpc/src/lib.rs
index 27cd3f87ae..e625dca376 100644
--- a/mullvad-rpc/src/lib.rs
+++ b/mullvad-rpc/src/lib.rs
@@ -3,17 +3,18 @@
use chrono::{offset::Utc, DateTime};
#[cfg(target_os = "android")]
use futures::channel::mpsc;
+use futures::Stream;
use hyper::Method;
use mullvad_types::{
account::{AccountToken, VoucherSubmission},
version::AppVersion,
};
+use proxy::ApiConnectionMode;
use std::{
collections::BTreeMap,
future::Future,
net::{IpAddr, Ipv4Addr, SocketAddr},
path::Path,
- sync::Arc,
};
use talpid_types::{net::wireguard, ErrorExt};
@@ -23,14 +24,14 @@ pub mod rest;
mod abortable_stream;
mod https_client_with_sni;
+pub mod proxy;
mod tls_stream;
-mod proxy;
#[cfg(target_os = "android")]
pub use crate::https_client_with_sni::SocketBypassRequest;
mod address_cache;
mod relay_list;
-pub use address_cache::{AddressCache, CurrentAddressChangeListener};
+pub use address_cache::AddressCache;
pub use hyper::StatusCode;
pub use relay_list::RelayListProxy;
@@ -151,7 +152,7 @@ impl MullvadRpcRuntime {
) -> Result<Self, Error> {
Ok(MullvadRpcRuntime {
handle,
- address_cache: AddressCache::new(API.addr, None)?,
+ address_cache: AddressCache::new(None)?,
api_availability: ApiAvailability::new(availability::State::default()),
#[cfg(target_os = "android")]
socket_bypass_tx,
@@ -192,7 +193,7 @@ impl MullvadRpcRuntime {
)
);
}
- AddressCache::new(API.addr, write_file)?
+ AddressCache::new(write_file)?
}
};
@@ -205,44 +206,52 @@ impl MullvadRpcRuntime {
})
}
- pub fn set_address_change_listener(
- &mut self,
- address_change_listener: impl Fn(SocketAddr) -> Result<(), ()> + Send + Sync + 'static,
- ) {
- self.address_cache
- .set_change_listener(Arc::new(Box::new(address_change_listener)));
- }
-
/// Creates a new request service and returns a handle to it.
- fn new_request_service(
- &mut self,
+ async fn new_request_service<
+ T: Stream<Item = ApiConnectionMode> + Unpin + Send + 'static,
+ AcceptedNewEndpoint: Future<Output = bool> + Send + 'static,
+ >(
+ &self,
sni_hostname: Option<String>,
+ proxy_provider: T,
+ new_address_callback: impl (Fn(SocketAddr) -> AcceptedNewEndpoint) + Send + Sync + 'static,
#[cfg(target_os = "android")] socket_bypass_tx: Option<mpsc::Sender<SocketBypassRequest>>,
) -> rest::RequestServiceHandle {
let service = rest::RequestService::new(
self.handle.clone(),
sni_hostname,
self.api_availability.handle(),
+ self.address_cache.clone(),
+ proxy_provider,
+ new_address_callback,
#[cfg(target_os = "android")]
socket_bypass_tx,
- );
+ )
+ .await;
let handle = service.handle();
self.handle.spawn(service.into_future());
handle
}
/// Returns a request factory initialized to create requests for the master API
- pub fn mullvad_rest_handle(&mut self) -> rest::MullvadRestHandle {
- let service = self.new_request_service(
- Some(API.host.clone()),
- #[cfg(target_os = "android")]
- self.socket_bypass_tx.clone(),
- );
- let factory = rest::RequestFactory::new(
- API.host.clone(),
- Box::new(self.address_cache.clone()),
- Some("app".to_owned()),
- );
+ pub async fn mullvad_rest_handle<
+ T: Stream<Item = ApiConnectionMode> + Unpin + Send + 'static,
+ AcceptedNewEndpoint: Future<Output = bool> + Send + 'static,
+ >(
+ &self,
+ proxy_provider: T,
+ new_address_callback: impl (Fn(SocketAddr) -> AcceptedNewEndpoint) + Send + Sync + 'static,
+ ) -> rest::MullvadRestHandle {
+ let service = self
+ .new_request_service(
+ Some(API.host.clone()),
+ proxy_provider,
+ new_address_callback,
+ #[cfg(target_os = "android")]
+ self.socket_bypass_tx.clone(),
+ )
+ .await;
+ let factory = rest::RequestFactory::new(API.host.clone(), Some("app".to_owned()));
rest::MullvadRestHandle::new(
service,
@@ -253,12 +262,15 @@ impl MullvadRpcRuntime {
}
/// Returns a new request service handle
- pub fn rest_handle(&mut self) -> rest::RequestServiceHandle {
+ pub async fn rest_handle(&mut self) -> rest::RequestServiceHandle {
self.new_request_service(
None,
+ ApiConnectionMode::Direct.into_repeat(),
+ |_| async { true },
#[cfg(target_os = "android")]
None,
)
+ .await
}
pub fn handle(&mut self) -> &mut tokio::runtime::Handle {
diff --git a/mullvad-rpc/src/proxy.rs b/mullvad-rpc/src/proxy.rs
index dd21df164e..009a1960dc 100644
--- a/mullvad-rpc/src/proxy.rs
+++ b/mullvad-rpc/src/proxy.rs
@@ -1,67 +1,204 @@
use crate::tls_stream::TlsStream;
+use futures::Stream;
use hyper::client::connect::{Connected, Connection};
+use rand::{distributions::Alphanumeric, Rng};
+use serde::{Deserialize, Serialize};
use shadowsocks::relay::tcprelay::ProxyClientStream;
use std::{
- io,
+ fmt, io,
+ net::SocketAddr,
+ path::Path,
pin::Pin,
task::{self, Poll},
};
+use talpid_types::{net::openvpn::ShadowsocksProxySettings, ErrorExt};
use tokio::{
- io::{AsyncRead, AsyncWrite, ReadBuf},
+ fs,
+ io::{AsyncRead, AsyncWrite, AsyncWriteExt, ReadBuf},
net::TcpStream,
};
+const CURRENT_CONFIG_FILENAME: &str = "api-endpoint.json";
+
+#[derive(Clone, Debug, PartialEq, Deserialize, Serialize)]
+pub enum ApiConnectionMode {
+ /// Connect directly to the target.
+ Direct,
+ /// Connect to the destination via a proxy.
+ Proxied(ProxyConfig),
+}
+
+impl fmt::Display for ApiConnectionMode {
+ fn fmt(&self, f: &mut fmt::Formatter<'_>) -> Result<(), fmt::Error> {
+ match self {
+ ApiConnectionMode::Direct => write!(f, "unproxied"),
+ ApiConnectionMode::Proxied(settings) => settings.fmt(f),
+ }
+ }
+}
+
+#[derive(Clone, Debug, PartialEq, Deserialize, Serialize)]
+pub enum ProxyConfig {
+ Shadowsocks(ShadowsocksProxySettings),
+}
+
+impl fmt::Display for ProxyConfig {
+ fn fmt(&self, f: &mut fmt::Formatter<'_>) -> Result<(), fmt::Error> {
+ match self {
+ // TODO: Do not hardcode TCP
+ ProxyConfig::Shadowsocks(ss) => write!(f, "Shadowsocks {}/TCP", ss.peer),
+ }
+ }
+}
+
+impl ApiConnectionMode {
+ /// Reads the proxy config from `CURRENT_CONFIG_FILENAME`.
+ /// This returns `ApiConnectionMode::Direct` if reading from disk fails for any reason.
+ pub async fn try_from_cache(cache_dir: &Path) -> Self {
+ Self::from_cache(cache_dir).await.unwrap_or_else(|error| {
+ log::error!(
+ "{}",
+ error.display_chain_with_msg("Failed to read API endpoint cache")
+ );
+ ApiConnectionMode::Direct
+ })
+ }
+
+ /// Reads the proxy config from `CURRENT_CONFIG_FILENAME`.
+ /// If the file does not exist, this returns `Ok(ApiConnectionMode::Direct)`.
+ async fn from_cache(cache_dir: &Path) -> io::Result<Self> {
+ let path = cache_dir.join(CURRENT_CONFIG_FILENAME);
+ match fs::read_to_string(path).await {
+ Ok(s) => serde_json::from_str(&s).map_err(|error| {
+ log::error!(
+ "{}",
+ error.display_chain_with_msg(&format!(
+ "Failed to deserialize \"{}\"",
+ CURRENT_CONFIG_FILENAME
+ ))
+ );
+ io::Error::new(io::ErrorKind::Other, "deserialization failed")
+ }),
+ Err(error) => {
+ if error.kind() == io::ErrorKind::NotFound {
+ Ok(ApiConnectionMode::Direct)
+ } else {
+ Err(error)
+ }
+ }
+ }
+ }
+
+ /// Stores this config to `CURRENT_CONFIG_FILENAME`.
+ /// The content is saved to a temporary file first, which ensures that
+ /// consumers of the file never end up with partial content.
+ pub async fn save(&self, cache_dir: &Path) -> io::Result<()> {
+ let path = cache_dir.join(CURRENT_CONFIG_FILENAME);
+ let mut temp_ext = String::from("temp");
+ temp_ext.push_str(
+ &rand::thread_rng()
+ .sample_iter(&Alphanumeric)
+ .take(5)
+ .map(char::from)
+ .collect::<String>(),
+ );
+ let temp_path = path.with_extension(temp_ext);
+
+ {
+ let mut file = fs::File::create(&temp_path).await?;
+ let json = serde_json::to_string_pretty(self)
+ .map_err(|_| io::Error::new(io::ErrorKind::Other, "serialization failed"))?;
+ file.write_all(json.as_bytes()).await?;
+ file.write_all(b"\n").await?;
+ file.sync_data().await?;
+ }
+
+ fs::rename(&temp_path, path).await
+ }
+
+ /// Attempts to remove `CURRENT_CONFIG_FILENAME`, if it exists.
+ pub async fn try_delete_cache(cache_dir: &Path) {
+ let path = cache_dir.join(CURRENT_CONFIG_FILENAME);
+ if let Err(err) = fs::remove_file(path).await {
+ if err.kind() != std::io::ErrorKind::NotFound {
+ log::error!(
+ "{}",
+ err.display_chain_with_msg("Failed to remove old API config")
+ );
+ }
+ }
+ }
+
+ /// Returns the remote address, or `None` for `ApiConnectionMode::Direct`.
+ pub fn get_endpoint(&self) -> Option<SocketAddr> {
+ match self {
+ ApiConnectionMode::Proxied(ProxyConfig::Shadowsocks(ss)) => Some(ss.peer),
+ ApiConnectionMode::Direct => None,
+ }
+ }
+
+ pub fn is_proxy(&self) -> bool {
+ *self != ApiConnectionMode::Direct
+ }
+
+ /// Convenience function that returns a stream that repeats
+ /// this config forever.
+ pub fn into_repeat(self) -> impl Stream<Item = ApiConnectionMode> {
+ futures::stream::repeat(self)
+ }
+}
+
/// Stream that is either a regular TLS stream or TLS via shadowsocks
-pub enum MaybeProxyStream {
- Tls(TlsStream<TcpStream>),
+pub enum ApiConnection {
+ Direct(TlsStream<TcpStream>),
Proxied(TlsStream<ProxyClientStream<TcpStream>>),
}
-impl AsyncRead for MaybeProxyStream {
+impl AsyncRead for ApiConnection {
fn poll_read(
self: Pin<&mut Self>,
cx: &mut task::Context<'_>,
buf: &mut ReadBuf<'_>,
) -> Poll<io::Result<()>> {
match Pin::get_mut(self) {
- MaybeProxyStream::Tls(s) => Pin::new(s).poll_read(cx, buf),
- MaybeProxyStream::Proxied(s) => Pin::new(s).poll_read(cx, buf),
+ ApiConnection::Direct(s) => Pin::new(s).poll_read(cx, buf),
+ ApiConnection::Proxied(s) => Pin::new(s).poll_read(cx, buf),
}
}
}
-impl AsyncWrite for MaybeProxyStream {
+impl AsyncWrite for ApiConnection {
fn poll_write(
self: Pin<&mut Self>,
cx: &mut task::Context<'_>,
buf: &[u8],
) -> Poll<io::Result<usize>> {
match Pin::get_mut(self) {
- MaybeProxyStream::Tls(s) => Pin::new(s).poll_write(cx, buf),
- MaybeProxyStream::Proxied(s) => Pin::new(s).poll_write(cx, buf),
+ ApiConnection::Direct(s) => Pin::new(s).poll_write(cx, buf),
+ ApiConnection::Proxied(s) => Pin::new(s).poll_write(cx, buf),
}
}
fn poll_flush(self: Pin<&mut Self>, cx: &mut task::Context<'_>) -> Poll<io::Result<()>> {
match Pin::get_mut(self) {
- MaybeProxyStream::Tls(s) => Pin::new(s).poll_flush(cx),
- MaybeProxyStream::Proxied(s) => Pin::new(s).poll_flush(cx),
+ ApiConnection::Direct(s) => Pin::new(s).poll_flush(cx),
+ ApiConnection::Proxied(s) => Pin::new(s).poll_flush(cx),
}
}
fn poll_shutdown(self: Pin<&mut Self>, cx: &mut task::Context<'_>) -> Poll<io::Result<()>> {
match Pin::get_mut(self) {
- MaybeProxyStream::Tls(s) => Pin::new(s).poll_shutdown(cx),
- MaybeProxyStream::Proxied(s) => Pin::new(s).poll_shutdown(cx),
+ ApiConnection::Direct(s) => Pin::new(s).poll_shutdown(cx),
+ ApiConnection::Proxied(s) => Pin::new(s).poll_shutdown(cx),
}
}
}
-impl Connection for MaybeProxyStream {
+impl Connection for ApiConnection {
fn connected(&self) -> Connected {
match self {
- MaybeProxyStream::Tls(s) => s.connected(),
- MaybeProxyStream::Proxied(s) => s.connected(),
+ ApiConnection::Direct(s) => s.connected(),
+ ApiConnection::Proxied(s) => s.connected(),
}
}
}
diff --git a/mullvad-rpc/src/rest.rs b/mullvad-rpc/src/rest.rs
index 014719b26b..3b7d0a0bc7 100644
--- a/mullvad-rpc/src/rest.rs
+++ b/mullvad-rpc/src/rest.rs
@@ -4,13 +4,14 @@ use crate::{
address_cache::AddressCache,
availability::ApiAvailabilityHandle,
https_client_with_sni::{HttpsConnectorWithSni, HttpsConnectorWithSniHandle},
+ proxy::ApiConnectionMode,
};
use futures::{
channel::{mpsc, oneshot},
future::{abortable, AbortHandle, Aborted},
sink::SinkExt,
stream::StreamExt,
- TryFutureExt,
+ Stream, TryFutureExt,
};
use hyper::{
client::Client,
@@ -21,7 +22,7 @@ use std::{
collections::BTreeMap,
future::Future,
mem,
- net::{IpAddr, SocketAddr},
+ net::SocketAddr,
str::FromStr,
time::{Duration, Instant},
};
@@ -88,7 +89,11 @@ impl Error {
/// A service that executes HTTP requests, allowing for on-demand termination of all in-flight
/// requests
-pub(crate) struct RequestService {
+pub(crate) struct RequestService<
+ T: Stream<Item = ApiConnectionMode>,
+ F: Fn(SocketAddr) -> AcceptedNewEndpoint,
+ AcceptedNewEndpoint: Future<Output = bool>,
+> {
command_tx: mpsc::Sender<RequestCommand>,
command_rx: mpsc::Receiver<RequestCommand>,
connector_handle: HttpsConnectorWithSniHandle,
@@ -96,24 +101,41 @@ pub(crate) struct RequestService {
handle: Handle,
next_id: u64,
in_flight_requests: BTreeMap<u64, AbortHandle>,
+ proxy_config_provider: T,
+ new_address_callback: F,
+ address_cache: AddressCache,
api_availability: ApiAvailabilityHandle,
}
-impl RequestService {
+impl<
+ T: Stream<Item = ApiConnectionMode> + Unpin + Send + 'static,
+ F: (Fn(SocketAddr) -> AcceptedNewEndpoint) + Send + Sync + 'static,
+ AcceptedNewEndpoint: Future<Output = bool> + Send + 'static,
+ > RequestService<T, F, AcceptedNewEndpoint>
+{
/// Constructs a new request service.
- pub fn new(
+ pub async fn new(
handle: Handle,
sni_hostname: Option<String>,
api_availability: ApiAvailabilityHandle,
+ address_cache: AddressCache,
+ mut proxy_config_provider: T,
+ new_address_callback: F,
#[cfg(target_os = "android")] socket_bypass_tx: Option<mpsc::Sender<SocketBypassRequest>>,
- ) -> RequestService {
+ ) -> RequestService<T> {
let (connector, connector_handle) = HttpsConnectorWithSni::new(
handle.clone(),
sni_hostname,
+ address_cache.clone(),
#[cfg(target_os = "android")]
socket_bypass_tx.clone(),
);
+ proxy_config_provider
+ .next()
+ .await
+ .map(|config| connector_handle.set_connection_mode(config));
+
let (command_tx, command_rx) = mpsc::channel(1);
let client = Client::builder().build(connector);
@@ -125,6 +147,9 @@ impl RequestService {
handle,
in_flight_requests: BTreeMap::new(),
next_id: 0,
+ proxy_config_provider,
+ new_address_callback,
+ address_cache,
api_availability,
}
}
@@ -137,7 +162,7 @@ impl RequestService {
}
}
- fn process_command(&mut self, command: RequestCommand) {
+ async fn process_command(&mut self, command: RequestCommand) {
match command {
RequestCommand::NewRequest(request, completion_tx) => {
let id = self.id();
@@ -166,10 +191,7 @@ impl RequestService {
if let Err(err) = &response {
if err.is_network_error() && !api_availability.get_state().is_offline() {
log::error!("{}", err.display_chain_with_msg("HTTP request failed"));
-
- // TODO: ask provider for new proxy config
- // TODO: notify connector handle of this new config
- // TODO: pass proxy config to tunnel state machine
+ let _ = tx.send(RequestCommand::NextApiConfig).await;
}
}
@@ -191,6 +213,18 @@ impl RequestService {
self.reset();
let _ = tx.send(());
}
+ RequestCommand::NextApiConfig => {
+ if let Some(new_config) = self.proxy_config_provider.next().await {
+ let endpoint = match new_config.get_endpoint() {
+ Some(endpoint) => endpoint,
+ None => self.address_cache.get_address().await,
+ };
+ // Switch to new connection mode unless rejected by address change callback
+ if (self.new_address_callback)(endpoint).await {
+ self.connector_handle.set_connection_mode(new_config);
+ }
+ }
+ }
}
}
@@ -211,7 +245,7 @@ impl RequestService {
pub async fn into_future(mut self) {
while let Some(command) = self.command_rx.next().await {
- self.process_command(command);
+ self.process_command(command).await;
}
self.reset();
}
@@ -259,6 +293,7 @@ pub(crate) enum RequestCommand {
),
RequestFinished(u64),
Reset(oneshot::Sender<()>),
+ NextApiConfig,
}
/// A REST request that is sent to the RequestService to be executed.
@@ -358,20 +393,14 @@ pub struct ErrorResponse {
#[derive(Clone)]
pub struct RequestFactory {
hostname: String,
- address_provider: Box<dyn AddressProvider>,
path_prefix: Option<String>,
pub timeout: Duration,
}
impl RequestFactory {
- pub fn new(
- hostname: String,
- address_provider: Box<dyn AddressProvider>,
- path_prefix: Option<String>,
- ) -> Self {
+ pub fn new(hostname: String, path_prefix: Option<String>) -> Self {
Self {
hostname,
- address_provider,
path_prefix,
timeout: DEFAULT_TIMEOUT,
}
@@ -432,9 +461,8 @@ impl RequestFactory {
}
fn get_uri(&self, path: &str) -> Result<Uri> {
- let host = self.address_provider.get_address();
let prefix = self.path_prefix.as_ref().map(AsRef::as_ref).unwrap_or("");
- let uri = format!("https://{}/{}{}", host, prefix, path);
+ let uri = format!("https://{}/{}{}", self.hostname, prefix, path);
hyper::Uri::from_str(&uri).map_err(Error::UriError)
}
@@ -444,29 +472,6 @@ impl RequestFactory {
}
}
-pub trait AddressProvider: Send + Sync {
- /// Must return a string that represents either a host or a host with port
- fn get_address(&self) -> String;
- fn clone_box(&self) -> Box<dyn AddressProvider>;
-}
-
-impl Clone for Box<dyn AddressProvider> {
- fn clone(&self) -> Self {
- self.clone_box()
- }
-}
-
-impl AddressProvider for IpAddr {
- /// Must return a string that represents either a host or a host with port
- fn get_address(&self) -> String {
- self.to_string()
- }
-
- fn clone_box(&self) -> Box<dyn AddressProvider> {
- Box::new(*self)
- }
-}
-
pub fn get_request<T: serde::de::DeserializeOwned>(
factory: &RequestFactory,
service: RequestServiceHandle,