summaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorEmīls <emils@mullvad.net>2021-01-25 15:29:56 +0000
committerEmīls <emils@mullvad.net>2021-01-25 15:29:56 +0000
commit2b055cc39c4eee4b8af8bb193143ddd4055b39ac (patch)
treedf2c586a211247a022cfb4618a4abbc8423027a0
parent631b690b32a183b03f2fc7ec2fb2ecf0d7e40aae (diff)
parent493aeeb995c0027afd92e2e5d26de64d0e0d88d9 (diff)
downloadmullvadvpn-2b055cc39c4eee4b8af8bb193143ddd4055b39ac.tar.xz
mullvadvpn-2b055cc39c4eee4b8af8bb193143ddd4055b39ac.zip
Merge branch 'android-socket-passthrough'
-rw-r--r--CHANGELOG.md1
-rw-r--r--Cargo.lock2
-rw-r--r--mullvad-daemon/src/lib.rs91
-rw-r--r--mullvad-problem-report/src/lib.rs2
-rw-r--r--mullvad-rpc/Cargo.toml4
-rw-r--r--mullvad-rpc/src/https_client_with_sni.rs183
-rw-r--r--mullvad-rpc/src/lib.rs25
-rw-r--r--mullvad-rpc/src/rest.rs61
-rw-r--r--mullvad-rpc/src/tcp_stream.rs124
-rw-r--r--talpid-core/src/tunnel/tun_provider/android/mod.rs38
-rw-r--r--talpid-core/src/tunnel_state_machine/connected_state.rs5
-rw-r--r--talpid-core/src/tunnel_state_machine/connecting_state.rs5
-rw-r--r--talpid-core/src/tunnel_state_machine/disconnected_state.rs5
-rw-r--r--talpid-core/src/tunnel_state_machine/disconnecting_state.rs15
-rw-r--r--talpid-core/src/tunnel_state_machine/error_state.rs6
-rw-r--r--talpid-core/src/tunnel_state_machine/mod.rs13
16 files changed, 500 insertions, 80 deletions
diff --git a/CHANGELOG.md b/CHANGELOG.md
index 78dacd3c01..87b49e3166 100644
--- a/CHANGELOG.md
+++ b/CHANGELOG.md
@@ -32,6 +32,7 @@ Line wrap the file at 100 chars. Th
#### Android
- Allow to configure the tunnel to use custom DNS servers.
- Show only applications that has INTERNET permission on split tunnel screen.
+- Allow reaching the API server when connecting, disconnecting or in a blocked state.
#### Linux
- Improved compatiblitiy with newer versions of systemd-resolved.
diff --git a/Cargo.lock b/Cargo.lock
index a0cf1de768..5cdb346ab5 100644
--- a/Cargo.lock
+++ b/Cargo.lock
@@ -1354,6 +1354,7 @@ dependencies = [
name = "mullvad-rpc"
version = "0.1.0"
dependencies = [
+ "bytes",
"chrono",
"err-derive",
"filetime",
@@ -1368,6 +1369,7 @@ dependencies = [
"regex",
"serde",
"serde_json",
+ "socket2",
"talpid-types",
"tempfile",
"tokio",
diff --git a/mullvad-daemon/src/lib.rs b/mullvad-daemon/src/lib.rs
index abf5d9962a..42f067a4f3 100644
--- a/mullvad-daemon/src/lib.rs
+++ b/mullvad-daemon/src/lib.rs
@@ -41,6 +41,8 @@ use mullvad_types::{
wireguard::KeygenEvent,
};
use settings::SettingsPersister;
+#[cfg(target_os = "android")]
+use std::os::unix::io::RawFd;
#[cfg(not(target_os = "android"))]
use std::path::Path;
use std::{
@@ -62,7 +64,7 @@ use talpid_core::{
#[cfg(target_os = "android")]
use talpid_types::android::AndroidContext;
use talpid_types::{
- net::{openvpn, Endpoint, TransportProtocol, TunnelParameters, TunnelType},
+ net::{openvpn, Endpoint, TransportProtocol, TunnelEndpoint, TunnelParameters, TunnelType},
tunnel::{ErrorStateCause, ParameterGenerationError, TunnelStateTransition},
ErrorExt,
};
@@ -231,6 +233,8 @@ pub enum DaemonCommand {
/// Saves the target tunnel state and enters a blocking state. The state is restored
/// upon restart.
PrepareRestart,
+ #[cfg(target_os = "android")]
+ BypassSocket(RawFd, oneshot::Sender<()>),
}
/// All events that can happen in the daemon. Sent from various threads and exposed interfaces.
@@ -513,6 +517,8 @@ where
result_rx.await.map_err(|_| ())
})
},
+ #[cfg(target_os = "android")]
+ Self::create_bypass_tx(&internal_event_tx),
)
.await
.map_err(Error::InitRpcFactory)?;
@@ -522,6 +528,7 @@ where
let on_relay_list_update = move |relay_list: &RelayList| {
relay_list_listener.notify_relay_list(relay_list.clone());
};
+
let mut relay_selector = relays::RelaySelector::new(
rpc_handle.clone(),
on_relay_list_update,
@@ -794,17 +801,8 @@ where
&mut self,
tunnel_state_transition: TunnelStateTransition,
) {
- match &tunnel_state_transition {
- TunnelStateTransition::Disconnected
- | TunnelStateTransition::Connected(_)
- | TunnelStateTransition::Error(_) => {
- // Reset the RPCs so that they fail immediately after the underlying socket gets
- // invalidated due to the tunnel either coming up or breaking.
- self.rpc_handle.service().reset().await;
- }
- _ => (),
- };
-
+ self.reset_rpc_sockets_on_tunnel_state_transition(&tunnel_state_transition)
+ .await;
let tunnel_state = match tunnel_state_transition {
TunnelStateTransition::Disconnected => TunnelState::Disconnected,
TunnelStateTransition::Connecting(endpoint) => TunnelState::Connecting {
@@ -851,6 +849,19 @@ where
self.event_listener.notify_new_state(tunnel_state);
}
+ async fn reset_rpc_sockets_on_tunnel_state_transition(
+ &mut self,
+ tunnel_state_transition: &TunnelStateTransition,
+ ) {
+ match (&self.tunnel_state, &tunnel_state_transition) {
+ // only reset the API sockets if when connected or leaving the connected state
+ (&TunnelState::Connected { .. }, _) | (_, &TunnelStateTransition::Connected(_)) => {
+ self.rpc_handle.service().reset().await;
+ }
+ _ => (),
+ };
+ }
+
async fn handle_generate_tunnel_parameters(
&mut self,
tunnel_parameters_tx: &sync_mpsc::Sender<
@@ -1112,6 +1123,8 @@ where
ClearSplitTunnelProcesses(tx) => self.on_clear_split_tunnel_processes(tx),
Shutdown => self.trigger_shutdown_event(),
PrepareRestart => self.on_prepare_restart(),
+ #[cfg(target_os = "android")]
+ BypassSocket(fd, tx) => self.on_bypass_socket(fd, tx),
}
}
@@ -1288,10 +1301,9 @@ where
}
fn get_geo_location(&mut self) -> impl Future<Output = Result<GeoIpLocation, ()>> {
- let https_handle = self.rpc_runtime.rest_handle();
-
+ let rpc_service = self.rpc_runtime.rest_handle();
async {
- geoip::send_location_request(https_handle)
+ geoip::send_location_request(rpc_service)
.await
.map_err(|e| {
warn!("Unable to fetch GeoIP location: {}", e.display_chain());
@@ -1865,7 +1877,7 @@ where
.map_err(|e| {
format!("Failed to add new wireguard key to account data: {}", e)
})?;
- if let Some(TunnelType::Wireguard) = self.get_connected_tunnel_type() {
+ if let Some(TunnelType::Wireguard) = self.get_target_tunnel_type() {
self.reconnect_tunnel();
}
let keygen_event = KeygenEvent::NewKey(public_key);
@@ -1974,6 +1986,34 @@ where
self.clean_up_target_cache = false;
}
+ #[cfg(target_os = "android")]
+ fn on_bypass_socket(&mut self, fd: RawFd, tx: oneshot::Sender<()>) {
+ match self.tunnel_state {
+ // When connected, the API connection shouldn't be bypassed.
+ TunnelState::Connected { .. } => (),
+ _ => {
+ self.send_tunnel_command(TunnelCommand::BypassSocket(fd, tx));
+ }
+ }
+ }
+
+ #[cfg(target_os = "android")]
+ fn create_bypass_tx(
+ event_sender: &DaemonEventSender,
+ ) -> Option<mpsc::Sender<mullvad_rpc::SocketBypassRequest>> {
+ let (bypass_tx, mut bypass_rx) = mpsc::channel(1);
+ let daemon_tx = event_sender.to_specialized_sender();
+ tokio::runtime::Handle::current().spawn(async move {
+ while let Some((raw_fd, done_tx)) = bypass_rx.next().await {
+ if let Err(_) = daemon_tx.send(DaemonCommand::BypassSocket(raw_fd, done_tx)) {
+ log::error!("Can't send socket bypass request to daemon");
+ break;
+ }
+ }
+ });
+ Some(bypass_tx)
+ }
+
/// Set the target state of the client. If it changed trigger the operations needed to
/// progress towards that state.
/// Returns a bool representing whether or not a state change was initiated.
@@ -2026,10 +2066,7 @@ where
}
fn get_connected_tunnel_type(&self) -> Option<TunnelType> {
- use talpid_types::net::TunnelEndpoint;
- use TunnelState::Connected;
-
- if let Connected {
+ if let TunnelState::Connected {
endpoint: TunnelEndpoint { tunnel_type, .. },
..
} = self.tunnel_state
@@ -2040,6 +2077,20 @@ where
}
}
+ fn get_target_tunnel_type(&self) -> Option<TunnelType> {
+ match self.tunnel_state {
+ TunnelState::Connected {
+ endpoint: TunnelEndpoint { tunnel_type, .. },
+ ..
+ }
+ | TunnelState::Connecting {
+ endpoint: TunnelEndpoint { tunnel_type, .. },
+ ..
+ } => Some(tunnel_type),
+ _ => None,
+ }
+ }
+
fn send_tunnel_command(&mut self, command: TunnelCommand) {
self.tunnel_command_tx
.unbounded_send(command)
diff --git a/mullvad-problem-report/src/lib.rs b/mullvad-problem-report/src/lib.rs
index 9bd1fba185..6b820f58f3 100644
--- a/mullvad-problem-report/src/lib.rs
+++ b/mullvad-problem-report/src/lib.rs
@@ -282,6 +282,8 @@ pub fn send_problem_report(
cache_dir,
false,
|_| Ok(()),
+ #[cfg(target_os = "android")]
+ None,
))
.map_err(Error::CreateRpcClientError)?;
let rpc_client = mullvad_rpc::ProblemReportProxy::new(rpc_manager.mullvad_rest_handle());
diff --git a/mullvad-rpc/Cargo.toml b/mullvad-rpc/Cargo.toml
index cf27d2e287..680a880608 100644
--- a/mullvad-rpc/Cargo.toml
+++ b/mullvad-rpc/Cargo.toml
@@ -8,6 +8,7 @@ edition = "2018"
publish = false
[dependencies]
+bytes = "0.5"
chrono = { version = "0.4", features = ["serde"] }
err-derive = "0.2.1"
futures = "0.3"
@@ -32,6 +33,9 @@ talpid-types = { path = "../talpid-types" }
filetime = "0.2"
tempfile = "3.0"
+[target.'cfg(target_os="android")'.dependencies]
+socket2 = "0.3"
+
[[bin]]
name = "relay_list"
diff --git a/mullvad-rpc/src/https_client_with_sni.rs b/mullvad-rpc/src/https_client_with_sni.rs
index 48b90e76e7..b80b2db95b 100644
--- a/mullvad-rpc/src/https_client_with_sni.rs
+++ b/mullvad-rpc/src/https_client_with_sni.rs
@@ -1,15 +1,30 @@
+use crate::{rest::RequestCommand, tcp_stream::TcpStream};
+use futures::{
+ channel::{mpsc, oneshot},
+ sink::SinkExt,
+};
use http::uri::Scheme;
-use hyper::{client::HttpConnector, service::Service, Uri};
+use hyper::{
+ client::connect::dns::{GaiResolver, Name},
+ service::Service,
+ Uri,
+};
use hyper_rustls::MaybeHttpsStream;
+#[cfg(target_os = "android")]
+use std::os::unix::io::{AsRawFd, RawFd};
use std::{
fmt,
future::Future,
io::{self, BufReader},
+ net::{IpAddr, SocketAddr},
pin::Pin,
- str,
+ str::{self, FromStr},
sync::Arc,
task::{Context, Poll},
+ time::Duration,
};
+
+use tokio::{net::TcpStream as TokioTcpStream, runtime::Handle, time::timeout};
use tokio_rustls::rustls;
use webpki::DNSNameRef;
@@ -18,14 +33,23 @@ const OLD_ROOT_CERT: &[u8] = include_bytes!("../old_le_root_cert.pem");
// New LetsEncrypt root certificate
const NEW_ROOT_CERT: &[u8] = include_bytes!("../new_le_root_cert.pem");
+const CONNECT_TIMEOUT: Duration = Duration::from_secs(5);
+
/// A Connector for the `https` scheme.
#[derive(Clone)]
pub struct HttpsConnectorWithSni {
+ next_socket_id: usize,
+ handle: Handle,
sni_hostname: Option<String>,
- http: HttpConnector,
+ service_tx: Option<mpsc::Sender<RequestCommand>>,
+ #[cfg(target_os = "android")]
+ socket_bypass_tx: Option<mpsc::Sender<SocketBypassRequest>>,
tls: Arc<rustls::ClientConfig>,
}
+#[cfg(target_os = "android")]
+pub type SocketBypassRequest = (RawFd, oneshot::Sender<()>);
+
impl HttpsConnectorWithSni {
/// Construct a new HttpsConnectorWithSni.
///
@@ -33,15 +57,24 @@ impl HttpsConnectorWithSni {
///
/// This uses hyper's default `HttpConnector`, and default `TlsConnector`.
/// If you wish to use something besides the defaults, use `From::from`.
- pub fn new() -> Self {
- let mut http = HttpConnector::new();
- http.enforce_http(false);
-
+ pub fn new(
+ handle: Handle,
+ sni_hostname: Option<String>,
+ #[cfg(target_os = "android")] socket_bypass_tx: Option<mpsc::Sender<SocketBypassRequest>>,
+ ) -> Self {
let mut config = rustls::ClientConfig::new();
config.enable_sni = true;
config.root_store = Self::read_cert_store();
- HttpsConnectorWithSni::from((http, config))
+ HttpsConnectorWithSni {
+ next_socket_id: 0,
+ handle,
+ sni_hostname,
+ #[cfg(target_os = "android")]
+ socket_bypass_tx,
+ service_tx: None,
+ tls: Arc::new(config),
+ }
}
fn read_cert_store() -> rustls::RootCertStore {
@@ -65,22 +98,75 @@ impl HttpsConnectorWithSni {
}
- /// Configure a hostname to use with SNI.
- ///
- /// Configures the TLS connection handshake to request a certificate for a given domain,
- /// instead of the domain obtained from the URI. Use `None` to use the domain from the URI.
- pub fn set_sni_hostname(&mut self, hostname: Option<String>) {
- self.sni_hostname = hostname;
+ /// Set a channel to register sockets with the request service.
+ pub(crate) fn set_service_tx(&mut self, service_tx: mpsc::Sender<RequestCommand>) {
+ self.service_tx = Some(service_tx);
}
-}
-impl From<(HttpConnector, rustls::ClientConfig)> for HttpsConnectorWithSni {
- fn from(args: (HttpConnector, rustls::ClientConfig)) -> HttpsConnectorWithSni {
- HttpsConnectorWithSni {
- sni_hostname: None,
- http: args.0,
- tls: Arc::new(args.1),
+ fn next_id(&mut self) -> usize {
+ let next_id = self.next_socket_id;
+ self.next_socket_id = self.next_socket_id.wrapping_add(1);
+ next_id
+ }
+
+ #[cfg(not(target_os = "android"))]
+ async fn open_socket(addr: SocketAddr) -> std::io::Result<TokioTcpStream> {
+ timeout(CONNECT_TIMEOUT, TokioTcpStream::connect(addr))
+ .await
+ .map_err(|err| io::Error::new(io::ErrorKind::TimedOut, err))?
+ }
+
+ #[cfg(target_os = "android")]
+ async fn open_socket(
+ addr: SocketAddr,
+ socket_bypass_tx: Option<mpsc::Sender<SocketBypassRequest>>,
+ ) -> std::io::Result<TokioTcpStream> {
+ use socket2::{Domain, Protocol, Socket, Type};
+ let domain = match addr {
+ SocketAddr::V4(_) => Domain::ipv4(),
+ SocketAddr::V6(_) => Domain::ipv6(),
+ };
+ let socket = Socket::new(domain, Type::stream(), Some(Protocol::tcp()))?.into_tcp_stream();
+
+ if let Some(mut tx) = socket_bypass_tx {
+ let (done_tx, done_rx) = oneshot::channel();
+ let _ = tx.send((socket.as_raw_fd(), done_tx)).await;
+ if let Err(_) = done_rx.await {
+ log::error!("Failed to bypass socket, connection might fail");
+ }
}
+
+ timeout(CONNECT_TIMEOUT, TokioTcpStream::connect_std(socket, &addr))
+ .await
+ .map_err(|err| io::Error::new(io::ErrorKind::TimedOut, err))?
+ }
+
+ async fn resolve_address(hostname: &str) -> io::Result<SocketAddr> {
+ match Self::parse_addr(&hostname) {
+ Some(addr) => Ok(addr),
+ None => {
+ let mut addrs = GaiResolver::new()
+ .call(
+ Name::from_str(&hostname)
+ .map_err(|err| io::Error::new(io::ErrorKind::InvalidInput, err))?,
+ )
+ .await
+ .map_err(|err| io::Error::new(io::ErrorKind::Other, err))?;
+ let addr = addrs
+ .next()
+ .ok_or(io::Error::new(io::ErrorKind::Other, "Empty DNS response"))?;
+ Ok(SocketAddr::new(addr, 443))
+ }
+ }
+ }
+
+
+ fn parse_addr(hostname: &str) -> Option<SocketAddr> {
+ if let Ok(addr) = hostname.parse::<SocketAddr>() {
+ return Some(addr);
+ }
+ let ip = hostname.parse::<IpAddr>().ok()?;
+ Some(SocketAddr::new(ip, 443))
}
}
@@ -90,8 +176,9 @@ impl fmt::Debug for HttpsConnectorWithSni {
}
}
+
impl Service<Uri> for HttpsConnectorWithSni {
- type Response = MaybeHttpsStream<tokio::net::TcpStream>;
+ type Response = MaybeHttpsStream<TcpStream>;
type Error = io::Error;
type Future =
Pin<Box<dyn Future<Output = Result<Self::Response, Self::Error>> + Send + 'static>>;
@@ -102,7 +189,6 @@ impl Service<Uri> for HttpsConnectorWithSni {
fn call(&mut self, uri: Uri) -> Self::Future {
let tls_connector: tokio_rustls::TlsConnector = self.tls.clone().into();
- let mut http = self.http.clone();
let sni_hostname = self
.sni_hostname
.clone()
@@ -110,7 +196,12 @@ impl Service<Uri> for HttpsConnectorWithSni {
.ok_or_else(|| {
io::Error::new(io::ErrorKind::InvalidInput, "invalid url, missing host")
});
+ let service_tx = self.service_tx.clone();
+ let socket_id = self.next_id();
+ let handle = self.handle.clone();
+ #[cfg(target_os = "android")]
+ let socket_bypass_tx = self.socket_bypass_tx.clone();
let fut = async move {
if uri.scheme() != Some(&Scheme::HTTPS) {
@@ -119,14 +210,49 @@ impl Service<Uri> for HttpsConnectorWithSni {
"invalid url, not https",
));
}
+ let host_addr = uri.host().ok_or(io::Error::new(
+ io::ErrorKind::InvalidInput,
+ "invalid url, missing host",
+ ))?;
let hostname = sni_hostname?;
let host = DNSNameRef::try_from_ascii_str(&hostname)
.map_err(|_| io::Error::new(io::ErrorKind::InvalidInput, "invalid hostname"))?;
- let connection = http
- .call(uri)
- .await
- .map_err(|err| io::Error::new(io::ErrorKind::Other, err))?;
- let tls_connection = tls_connector.connect(host, connection).await?;
+ let addr = Self::resolve_address(host_addr).await?;
+
+ let tokio_connection = Self::open_socket(
+ addr,
+ #[cfg(target_os = "android")]
+ socket_bypass_tx,
+ )
+ .await?;
+
+ let (socket_shutdown_tx, socket_shutdown_rx) = oneshot::channel();
+
+
+ let (tcp_stream, socket_handle) =
+ TcpStream::new(tokio_connection, socket_id, Some(socket_shutdown_tx));
+ if let Some(mut service_tx) = service_tx {
+ if service_tx
+ .send(RequestCommand::SocketOpened(socket_id, socket_handle))
+ .await
+ .is_err()
+ {
+ log::error!("Failed to submit new socket to request service");
+ }
+ handle.spawn(async move {
+ let _ = socket_shutdown_rx.await;
+ if service_tx
+ .send(RequestCommand::SocketClosed(socket_id))
+ .await
+ .is_err()
+ {
+ log::error!("Failed to send socket closure command to request service");
+ }
+ });
+ }
+
+
+ let tls_connection = tls_connector.connect(host, tcp_stream).await?;
Ok(MaybeHttpsStream::Https(tls_connection))
};
@@ -136,6 +262,7 @@ impl Service<Uri> for HttpsConnectorWithSni {
}
}
+
#[cfg(test)]
mod test {
use super::HttpsConnectorWithSni;
diff --git a/mullvad-rpc/src/lib.rs b/mullvad-rpc/src/lib.rs
index e211a89c2e..7fd3a2480f 100644
--- a/mullvad-rpc/src/lib.rs
+++ b/mullvad-rpc/src/lib.rs
@@ -1,6 +1,8 @@
#![deny(rust_2018_idioms)]
use chrono::{offset::Utc, DateTime};
+#[cfg(target_os = "android")]
+use futures::channel::mpsc;
use hyper::Method;
use mullvad_types::{
account::{AccountToken, VoucherSubmission},
@@ -20,6 +22,9 @@ pub mod rest;
mod https_client_with_sni;
use crate::https_client_with_sni::HttpsConnectorWithSni;
+#[cfg(target_os = "android")]
+pub use crate::https_client_with_sni::SocketBypassRequest;
+mod tcp_stream;
mod address_cache;
mod relay_list;
@@ -41,9 +46,10 @@ const API_ADDRESS: (IpAddr, u16) = (crate::API_IP, 443);
/// A type that helps with the creation of RPC connections.
pub struct MullvadRpcRuntime {
- https_connector: HttpsConnectorWithSni,
handle: tokio::runtime::Handle,
pub address_cache: AddressCache,
+ #[cfg(target_os = "android")]
+ socket_bypass_tx: Option<mpsc::Sender<SocketBypassRequest>>,
}
#[derive(err_derive::Error, Debug)]
@@ -59,13 +65,14 @@ impl MullvadRpcRuntime {
/// Create a new `MullvadRpcRuntime`.
pub fn new(handle: tokio::runtime::Handle) -> Result<Self, Error> {
Ok(MullvadRpcRuntime {
- https_connector: HttpsConnectorWithSni::new(),
handle,
address_cache: AddressCache::new(
vec![API_ADDRESS.into()],
None,
Arc::new(Box::new(|_| Ok(()))),
)?,
+ #[cfg(target_os = "android")]
+ socket_bypass_tx: None,
})
}
@@ -78,6 +85,7 @@ impl MullvadRpcRuntime {
cache_dir: &Path,
write_changes: bool,
address_change_listener: impl Fn(SocketAddr) -> Result<(), ()> + Send + Sync + 'static,
+ #[cfg(target_os = "android")] socket_bypass_tx: Option<mpsc::Sender<SocketBypassRequest>>,
) -> Result<Self, Error> {
let cache_file = cache_dir.join(API_IP_CACHE_FILENAME);
let write_file = if write_changes {
@@ -125,19 +133,22 @@ impl MullvadRpcRuntime {
}
};
- let https_connector = HttpsConnectorWithSni::new();
-
Ok(MullvadRpcRuntime {
- https_connector,
handle,
address_cache,
+ #[cfg(target_os = "android")]
+ socket_bypass_tx,
})
}
/// Creates a new request service and returns a handle to it.
fn new_request_service(&mut self, sni_hostname: Option<String>) -> rest::RequestServiceHandle {
- let mut https_connector = self.https_connector.clone();
- https_connector.set_sni_hostname(sni_hostname);
+ let https_connector = HttpsConnectorWithSni::new(
+ self.handle.clone(),
+ sni_hostname,
+ #[cfg(target_os = "android")]
+ self.socket_bypass_tx.clone(),
+ );
let service = rest::RequestService::new(
https_connector,
diff --git a/mullvad-rpc/src/rest.rs b/mullvad-rpc/src/rest.rs
index 82a6271c74..07746a3150 100644
--- a/mullvad-rpc/src/rest.rs
+++ b/mullvad-rpc/src/rest.rs
@@ -1,4 +1,7 @@
-use crate::address_cache::AddressCache;
+use crate::{
+ address_cache::AddressCache, https_client_with_sni::HttpsConnectorWithSni,
+ tcp_stream::TcpStreamHandle,
+};
use futures::{
channel::{mpsc, oneshot},
future::{abortable, AbortHandle, Aborted},
@@ -7,7 +10,7 @@ use futures::{
TryFutureExt,
};
use hyper::{
- client::{connect::Connect, Client},
+ client::Client,
header::{self, HeaderValue},
Method, Uri,
};
@@ -73,30 +76,37 @@ pub enum Error {
/// A service that executes HTTP requests, allowing for on-demand termination of all in-flight
/// requests
-pub(crate) struct RequestService<C> {
+pub(crate) struct RequestService {
command_tx: mpsc::Sender<RequestCommand>,
command_rx: mpsc::Receiver<RequestCommand>,
- client: hyper::Client<C, hyper::Body>,
- connector: C,
+ sockets: BTreeMap<usize, TcpStreamHandle>,
+ client: hyper::Client<HttpsConnectorWithSni, hyper::Body>,
handle: Handle,
next_id: u64,
in_flight_requests: BTreeMap<u64, AbortHandle>,
address_cache: AddressCache,
}
-impl<C: Connect + Clone + Send + Sync + 'static> RequestService<C> {
+impl RequestService {
/// Constructs a new request service.
- pub fn new(connector: C, handle: Handle, address_cache: AddressCache) -> RequestService<C> {
- let client = Self::new_client(connector.clone());
-
+ pub fn new(
+ mut connector: HttpsConnectorWithSni,
+ handle: Handle,
+ address_cache: AddressCache,
+ ) -> RequestService {
let (command_tx, command_rx) = mpsc::channel(1);
+
+ connector.set_service_tx(command_tx.clone());
+ let client = Client::builder().build(connector);
+
+
Self {
command_tx,
command_rx,
+ sockets: BTreeMap::new(),
client,
in_flight_requests: BTreeMap::new(),
next_id: 0,
- connector,
handle,
address_cache,
}
@@ -110,10 +120,6 @@ impl<C: Connect + Clone + Send + Sync + 'static> RequestService<C> {
}
}
- fn new_client(connector: C) -> Client<C, hyper::Body> {
- Client::builder().pool_max_idle_per_host(0).build(connector)
- }
-
fn process_command(&mut self, command: RequestCommand) {
match command {
RequestCommand::NewRequest(request, completion_tx) => {
@@ -173,12 +179,19 @@ impl<C: Connect + Clone + Send + Sync + 'static> RequestService<C> {
self.in_flight_requests.insert(id, abort_handle);
}
+ RequestCommand::SocketOpened(id, socket) => {
+ self.sockets.insert(id, socket);
+ }
+ RequestCommand::SocketClosed(id) => {
+ self.sockets.remove(&id);
+ }
RequestCommand::RequestFinished(id) => {
self.in_flight_requests.remove(&id);
}
- RequestCommand::Reset => {
+ RequestCommand::Reset(tx) => {
self.reset();
+ let _ = tx.send(());
}
}
}
@@ -188,7 +201,12 @@ impl<C: Connect + Clone + Send + Sync + 'static> RequestService<C> {
for (_, abort_handle) in old_requests.into_iter() {
abort_handle.abort();
}
- let _ = mem::replace(&mut self.client, Self::new_client(self.connector.clone()));
+
+ let old_sockets = mem::replace(&mut self.sockets, BTreeMap::new());
+ for (_, socket) in old_sockets.into_iter() {
+ socket.close();
+ }
+
self.next_id = 0;
}
@@ -202,6 +220,7 @@ impl<C: Connect + Clone + Send + Sync + 'static> RequestService<C> {
while let Some(command) = self.command_rx.next().await {
self.process_command(command);
}
+ self.reset();
}
}
@@ -229,8 +248,10 @@ impl RequestServiceHandle {
/// Resets the corresponding RequestService, dropping all in-flight requests.
pub async fn reset(&self) {
let mut tx = self.tx.clone();
+ let (done_tx, done_rx) = oneshot::channel();
- let _ = tx.send(RequestCommand::Reset).await;
+ let _ = tx.send(RequestCommand::Reset(done_tx)).await;
+ let _ = done_rx.await;
}
/// Submits a `RestRequest` for exectuion to the request service.
@@ -252,13 +273,15 @@ impl RequestServiceHandle {
}
#[derive(Debug)]
-enum RequestCommand {
+pub(crate) enum RequestCommand {
NewRequest(
RestRequest,
oneshot::Sender<std::result::Result<Response, Error>>,
),
RequestFinished(u64),
- Reset,
+ SocketOpened(usize, TcpStreamHandle),
+ SocketClosed(usize),
+ Reset(oneshot::Sender<()>),
}
diff --git a/mullvad-rpc/src/tcp_stream.rs b/mullvad-rpc/src/tcp_stream.rs
new file mode 100644
index 0000000000..7bf89a47b2
--- /dev/null
+++ b/mullvad-rpc/src/tcp_stream.rs
@@ -0,0 +1,124 @@
+use bytes::buf::Buf;
+use futures::channel::oneshot;
+use hyper::client::connect::{Connected, Connection};
+use std::{
+ io,
+ net::Shutdown,
+ pin::Pin,
+ sync::{Arc, Mutex, Weak},
+ task::{Context, Poll},
+};
+use tokio::{
+ io::{AsyncRead, AsyncWrite},
+ net::TcpStream as TokioTcpStream,
+};
+
+#[derive(Debug)]
+pub struct TcpStreamHandle {
+ inner: Weak<Mutex<StreamInner>>,
+}
+
+impl TcpStreamHandle {
+ pub fn close(self) {
+ if let Some(inner_lock) = self.inner.upgrade() {
+ if let Ok(mut inner) = inner_lock.lock() {
+ if let Err(err) = inner.stream.shutdown(Shutdown::Both) {
+ log::error!("Failed to shut down TCP socket: {}", err);
+ }
+ let _ = inner.shutdown_tx.take();
+ }
+ }
+ }
+}
+
+
+pub struct TcpStream {
+ inner: Arc<Mutex<StreamInner>>,
+}
+
+impl TcpStream {
+ pub fn new(
+ stream: TokioTcpStream,
+ id: usize,
+ shutdown_tx: Option<oneshot::Sender<()>>,
+ ) -> (Self, TcpStreamHandle) {
+ let inner = Arc::new(Mutex::new(StreamInner {
+ id,
+ stream,
+ shutdown_tx,
+ }));
+ (
+ Self {
+ inner: inner.clone(),
+ },
+ TcpStreamHandle {
+ inner: Arc::downgrade(&inner),
+ },
+ )
+ }
+
+ fn do_stream<T>(&self, mut stream_fn: impl FnMut(&mut TokioTcpStream) -> T) -> T {
+ let mut inner = self.inner.lock().expect("TCP lock poisoned");
+ stream_fn(&mut inner.stream)
+ }
+}
+
+impl Drop for TcpStream {
+ fn drop(&mut self) {
+ if let Ok(mut inner) = self.inner.lock() {
+ if let Some(tx) = inner.shutdown_tx.take() {
+ let _ = tx.send(());
+ }
+ }
+ }
+}
+
+
+impl AsyncWrite for TcpStream {
+ fn poll_write(
+ self: Pin<&mut Self>,
+ cx: &mut Context<'_>,
+ buf: &[u8],
+ ) -> Poll<Result<usize, io::Error>> {
+ self.do_stream(|stream| Pin::new(stream).poll_write(cx, buf))
+ }
+
+ fn poll_write_buf<B: Buf>(
+ self: Pin<&mut Self>,
+ cx: &mut Context<'_>,
+ buf: &mut B,
+ ) -> Poll<Result<usize, io::Error>> {
+ self.do_stream(|stream| Pin::new(stream).poll_write_buf(cx, buf))
+ }
+
+ fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), io::Error>> {
+ self.do_stream(|stream| Pin::new(stream).poll_flush(cx))
+ }
+
+ fn poll_shutdown(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), io::Error>> {
+ self.do_stream(|stream| Pin::new(stream).poll_shutdown(cx))
+ }
+}
+
+impl AsyncRead for TcpStream {
+ fn poll_read(
+ self: Pin<&mut Self>,
+ cx: &mut Context<'_>,
+ buf: &mut [u8],
+ ) -> Poll<Result<usize, io::Error>> {
+ self.do_stream(|stream| Pin::new(stream).poll_read(cx, buf))
+ }
+}
+
+impl Connection for TcpStream {
+ fn connected(&self) -> Connected {
+ Connected::new()
+ }
+}
+
+#[derive(Debug)]
+struct StreamInner {
+ id: usize,
+ stream: TokioTcpStream,
+ shutdown_tx: Option<oneshot::Sender<()>>,
+}
diff --git a/talpid-core/src/tunnel/tun_provider/android/mod.rs b/talpid-core/src/tunnel/tun_provider/android/mod.rs
index fa48f115b9..575f05f4c2 100644
--- a/talpid-core/src/tunnel/tun_provider/android/mod.rs
+++ b/talpid-core/src/tunnel/tun_provider/android/mod.rs
@@ -136,8 +136,8 @@ impl AndroidTunProvider {
})
}
- /// Open a tunnel device that routes everything but `allowed_endpoint`, custom DNS, and (potentially)
- /// LAN routes via the tunnel device.
+ /// Open a tunnel device that routes everything but `allowed_endpoint`, custom DNS, and
+ /// (potentially) LAN routes via the tunnel device.
///
/// Will open a new tunnel if there is already an active tunnel. The previous tunnel will be
/// closed.
@@ -321,6 +321,33 @@ impl AndroidTunProvider {
}
}
+ /// Allow a socket to bypass the tunnel.
+ pub fn bypass(&mut self, socket: RawFd) -> Result<(), Error> {
+ let env = JnixEnv::from(
+ self.jvm
+ .attach_current_thread_as_daemon()
+ .map_err(|cause| Error::AttachJvmToThread(cause))?,
+ );
+ let create_tun_method = env
+ .get_method_id(&self.class, "bypass", "(I)Z")
+ .map_err(|cause| Error::FindMethod("bypass", cause))?;
+
+ let result = env
+ .call_method_unchecked(
+ self.object.as_obj(),
+ create_tun_method,
+ JavaType::Primitive(Primitive::Boolean),
+ &[JValue::Int(socket)],
+ )
+ .map_err(|cause| Error::CallMethod("bypass", cause))?;
+
+ match result {
+ JValue::Bool(0) => Err(Error::Bypass),
+ JValue::Bool(_) => Ok(()),
+ value => Err(Error::InvalidMethodResult("bypass", format!("{:?}", value))),
+ }
+ }
+
fn call_method(
&self,
name: &'static str,
@@ -385,11 +412,10 @@ impl VpnServiceTun {
)
.map_err(|cause| Error::CallMethod("bypass", cause))?;
- match result {
- JValue::Bool(0) => Err(Error::Bypass),
- JValue::Bool(_) => Ok(()),
- value => Err(Error::InvalidMethodResult("bypass", format!("{:?}", value))),
+ if !bool::from_java(&env, result) {
+ return Err(Error::Bypass);
}
+ Ok(())
}
}
diff --git a/talpid-core/src/tunnel_state_machine/connected_state.rs b/talpid-core/src/tunnel_state_machine/connected_state.rs
index 0c305de9a7..2f1d53fd9a 100644
--- a/talpid-core/src/tunnel_state_machine/connected_state.rs
+++ b/talpid-core/src/tunnel_state_machine/connected_state.rs
@@ -258,6 +258,11 @@ impl ConnectedState {
Some(TunnelCommand::Block(reason)) => {
self.disconnect(shared_values, AfterDisconnect::Block(reason))
}
+ #[cfg(target_os = "android")]
+ Some(TunnelCommand::BypassSocket(fd, done_tx)) => {
+ shared_values.bypass_socket(fd, done_tx);
+ SameState(self.into())
+ }
}
}
diff --git a/talpid-core/src/tunnel_state_machine/connecting_state.rs b/talpid-core/src/tunnel_state_machine/connecting_state.rs
index 0b03ceeca1..7ec0aaf55a 100644
--- a/talpid-core/src/tunnel_state_machine/connecting_state.rs
+++ b/talpid-core/src/tunnel_state_machine/connecting_state.rs
@@ -284,6 +284,11 @@ impl ConnectingState {
Some(TunnelCommand::Block(reason)) => {
self.disconnect(shared_values, AfterDisconnect::Block(reason))
}
+ #[cfg(target_os = "android")]
+ Some(TunnelCommand::BypassSocket(fd, done_tx)) => {
+ shared_values.bypass_socket(fd, done_tx);
+ SameState(self.into())
+ }
}
}
diff --git a/talpid-core/src/tunnel_state_machine/disconnected_state.rs b/talpid-core/src/tunnel_state_machine/disconnected_state.rs
index 922eb69c88..7a02a9e550 100644
--- a/talpid-core/src/tunnel_state_machine/disconnected_state.rs
+++ b/talpid-core/src/tunnel_state_machine/disconnected_state.rs
@@ -110,6 +110,11 @@ impl TunnelState for DisconnectedState {
Some(TunnelCommand::Block(reason)) => {
NewState(ErrorState::enter(shared_values, reason))
}
+ #[cfg(target_os = "android")]
+ Some(TunnelCommand::BypassSocket(fd, done_tx)) => {
+ shared_values.bypass_socket(fd, done_tx);
+ SameState(self.into())
+ }
Some(_) => SameState(self.into()),
None => Finished,
}
diff --git a/talpid-core/src/tunnel_state_machine/disconnecting_state.rs b/talpid-core/src/tunnel_state_machine/disconnecting_state.rs
index 48a83a6dc3..12e0bebe5e 100644
--- a/talpid-core/src/tunnel_state_machine/disconnecting_state.rs
+++ b/talpid-core/src/tunnel_state_machine/disconnecting_state.rs
@@ -54,6 +54,11 @@ impl DisconnectingState {
Some(TunnelCommand::Connect) => AfterDisconnect::Reconnect(0),
Some(TunnelCommand::Disconnect) | None => AfterDisconnect::Nothing,
Some(TunnelCommand::Block(reason)) => AfterDisconnect::Block(reason),
+ #[cfg(target_os = "android")]
+ Some(TunnelCommand::BypassSocket(fd, done_tx)) => {
+ shared_values.bypass_socket(fd, done_tx);
+ AfterDisconnect::Nothing
+ }
},
AfterDisconnect::Block(reason) => match command {
Some(TunnelCommand::AllowLan(allow_lan)) => {
@@ -86,6 +91,11 @@ impl DisconnectingState {
Some(TunnelCommand::Connect) => AfterDisconnect::Reconnect(0),
Some(TunnelCommand::Disconnect) => AfterDisconnect::Nothing,
Some(TunnelCommand::Block(new_reason)) => AfterDisconnect::Block(new_reason),
+ #[cfg(target_os = "android")]
+ Some(TunnelCommand::BypassSocket(fd, done_tx)) => {
+ shared_values.bypass_socket(fd, done_tx);
+ AfterDisconnect::Block(reason)
+ }
None => AfterDisconnect::Block(reason),
},
AfterDisconnect::Reconnect(retry_attempt) => match command {
@@ -119,6 +129,11 @@ impl DisconnectingState {
Some(TunnelCommand::Connect) => AfterDisconnect::Reconnect(retry_attempt),
Some(TunnelCommand::Disconnect) | None => AfterDisconnect::Nothing,
Some(TunnelCommand::Block(reason)) => AfterDisconnect::Block(reason),
+ #[cfg(target_os = "android")]
+ Some(TunnelCommand::BypassSocket(fd, done_tx)) => {
+ shared_values.bypass_socket(fd, done_tx);
+ AfterDisconnect::Reconnect(retry_attempt)
+ }
},
};
diff --git a/talpid-core/src/tunnel_state_machine/error_state.rs b/talpid-core/src/tunnel_state_machine/error_state.rs
index 51159d274f..6674cca13e 100644
--- a/talpid-core/src/tunnel_state_machine/error_state.rs
+++ b/talpid-core/src/tunnel_state_machine/error_state.rs
@@ -151,6 +151,12 @@ impl TunnelState for ErrorState {
Some(TunnelCommand::Block(reason)) => {
NewState(ErrorState::enter(shared_values, reason))
}
+
+ #[cfg(target_os = "android")]
+ Some(TunnelCommand::BypassSocket(fd, done_tx)) => {
+ shared_values.bypass_socket(fd, done_tx);
+ SameState(self.into())
+ }
}
}
}
diff --git a/talpid-core/src/tunnel_state_machine/mod.rs b/talpid-core/src/tunnel_state_machine/mod.rs
index 5a86722c01..90012a8296 100644
--- a/talpid-core/src/tunnel_state_machine/mod.rs
+++ b/talpid-core/src/tunnel_state_machine/mod.rs
@@ -23,6 +23,8 @@ use futures::{
channel::{mpsc, oneshot},
stream, StreamExt,
};
+#[cfg(target_os = "android")]
+use std::os::unix::io::RawFd;
use std::{
collections::HashSet,
io,
@@ -171,6 +173,9 @@ pub enum TunnelCommand {
Disconnect,
/// Disconnect any open tunnel and block all network access
Block(ErrorStateCause),
+ /// Bypass a socket, allowing traffic to flow through outside the tunnel.
+ #[cfg(target_os = "android")]
+ BypassSocket(RawFd, oneshot::Sender<()>),
}
type TunnelCommandReceiver = stream::Fuse<mpsc::UnboundedReceiver<TunnelCommand>>;
@@ -405,6 +410,14 @@ impl SharedTunnelStateValues {
log::trace!("Connectivity check wasn't disabled by the daemon");
}
}
+
+ #[cfg(target_os = "android")]
+ pub fn bypass_socket(&mut self, fd: RawFd, tx: oneshot::Sender<()>) {
+ if let Err(err) = self.tun_provider.bypass(fd) {
+ log::error!("Failed to bypass socket {}", err);
+ }
+ let _ = tx.send(());
+ }
}
/// Asynchronous result of an attempt to progress a state.