use std::collections::HashSet;
use std::ops::ControlFlow;
use std::sync::Mutex;
use futures::channel::mpsc::{self, UnboundedReceiver, UnboundedSender};
use futures::channel::oneshot;
use futures::future::FutureExt;
use futures::select_biased;
use futures::stream::StreamExt;
use jnix::jni::objects::JValue;
use jnix::jni::{JNIEnv, objects::JObject};
use jnix::{FromJava, JnixEnv};
use talpid_types::android::{AndroidContext, NetworkState};
use crate::{Route, imp::RouteManagerCommand};
/// Stub error type for routing errors on Android.
/// Errors that occur while setting up VpnService tunnel.
#[derive(Debug, thiserror::Error)]
pub enum Error {
/// Timed out when waiting for network routes.
#[error("Timed out when waiting for network routes")]
RoutesTimedOut,
}
/// Internal errors that may only happen during the initial poll for [NetworkState].
#[derive(Debug, thiserror::Error)]
enum JvmError {
#[error("Failed to attach Java VM to tunnel thread")]
AttachJvmToThread(#[source] jnix::jni::errors::Error),
#[error("Failed to call Java method {0}")]
CallMethod(&'static str, #[source] jnix::jni::errors::Error),
#[error("Failed to create global reference to Java object")]
CreateGlobalRef(#[source] jnix::jni::errors::Error),
#[error("Received an invalid result from {0}.{1}: {2}")]
InvalidMethodResult(&'static str, &'static str, String),
}
/// The sender used by [Java_net_mullvad_talpid_ConnectivityListener_notifyDefaultNetworkChange]
/// to notify the route manager of changes to the network.
static ROUTE_UPDATES_TX: Mutex>>> = Mutex::new(None);
/// Android route manager actor.
#[derive(Debug)]
pub struct RouteManagerImpl {
/// The receiving channel for updates on changes to the network.
network_state_updates: UnboundedReceiver >,
/// Cached [NetworkState]. If no update events have been received yet, this value will be [None].
last_state: Option,
/// Clients waiting on response to [RouteManagerCommand::WaitForRoutes].
waiting_for_routes: Vec<(oneshot::Sender<()>, Vec)>,
}
impl RouteManagerImpl {
#[allow(clippy::unused_async)]
pub async fn new(android_context: AndroidContext) -> Result {
// Create a channel between the kotlin client and route manager
let (tx, rx) = futures::channel::mpsc::unbounded();
*ROUTE_UPDATES_TX.lock().unwrap() = Some(tx);
// Try to poll for the current network state at startup.
// This will most likely be null, but it covers the edge case where a NetworkState
// update has been emitted before anyone starts to listen for route updates some
// time in the future (when connecting).
let last_state = match current_network_state(android_context) {
Ok(initial_state) => initial_state,
Err(err) => {
log::error!("Failed while polling for initial NetworkState");
log::error!("{err}");
None
}
};
let route_manager = RouteManagerImpl {
network_state_updates: rx,
last_state,
waiting_for_routes: Default::default(),
};
Ok(route_manager)
}
pub(crate) async fn run(
mut self,
manage_rx: mpsc::UnboundedReceiver,
) -> Result<(), Error> {
let mut manage_rx = manage_rx.fuse();
loop {
select_biased! {
command = manage_rx.next().fuse() => {
let Some(command) = command else { break };
if self.handle_command(command).is_break() {
break;
}
}
network_state_update = self.network_state_updates.next().fuse() => {
// None means that the sender was dropped
let Some(network_state) = network_state_update else { break };
// update the last known NetworkState
self.last_state = network_state;
// notify waiting clients that routes exist
self.waiting_for_routes = self
.waiting_for_routes
.into_iter()
.filter_map(|(client, expected_routes)| {
if has_routes(self.last_state.as_ref(), expected_routes.clone()) {
let _ = client.send(());
None
} else {
Some((client, expected_routes))
}
})
.collect();
}
}
}
log::debug!("RouteManager exited");
Ok(())
}
fn handle_command(&mut self, command: RouteManagerCommand) -> ControlFlow<()> {
match command {
RouteManagerCommand::Shutdown(tx) => {
let _ = tx.send(());
return ControlFlow::Break(());
}
RouteManagerCommand::WaitForRoutes(response_tx, expected_routes) => {
// check if routes have already been configured on the Android system.
// otherwise, register a listener for network state changes.
// routes may come in at any moment in the future.
if has_routes(self.last_state.as_ref(), expected_routes.clone()) {
let _ = response_tx.send(());
} else {
self.waiting_for_routes.push((response_tx, expected_routes));
}
}
RouteManagerCommand::ClearRouteCache(tx) => {
self.clear_route_cache();
let _ = tx.send(());
}
}
ControlFlow::Continue(())
}
fn clear_route_cache(&mut self) {
self.last_state = None;
}
}
/// Check whether the [NetworkState] contains expected routes.
///
/// Matches the routes reported from Android and checks if all the routes we expect to be there is
/// present.
fn has_routes(state: Option<&NetworkState>, expected_routes: Vec) -> bool {
let Some(network_state) = state else {
return false;
};
let routes = configured_routes(network_state);
routes.is_superset(&HashSet::from_iter(expected_routes))
}
fn configured_routes(state: &NetworkState) -> HashSet {
match &state.routes {
None => Default::default(),
Some(route_info) => route_info.iter().map(Route::from).collect(),
}
}
/// Entry point for Android Java code to notify the current default network state.
#[unsafe(no_mangle)]
#[allow(non_snake_case)]
pub extern "system" fn Java_net_mullvad_talpid_ConnectivityListener_notifyDefaultNetworkChange(
env: JNIEnv<'_>,
_: JObject<'_>,
network_state: JObject<'_>,
) {
let env = JnixEnv::from(env);
let network_state: Option = FromJava::from_java(&env, network_state);
let Some(tx) = &*ROUTE_UPDATES_TX.lock().unwrap() else {
// No sender has been registered
log::error!("Received routes notification wíth no channel");
return;
};
log::trace!("Received network state update {:#?}", network_state);
if tx.unbounded_send(network_state).is_err() {
log::warn!("Failed to send offline change event");
}
}
/// Return the current NetworkState according to Android
fn current_network_state(
android_context: AndroidContext,
) -> Result, JvmError> {
let env = JnixEnv::from(
android_context
.jvm
.attach_current_thread_as_daemon()
.map_err(JvmError::AttachJvmToThread)?,
);
let result = env
.call_method(
android_context.vpn_service.as_obj(),
"getConnectivityListener",
"()Lnet/mullvad/talpid/ConnectivityListener;",
&[],
)
.map_err(|cause| JvmError::CallMethod("getConnectivityListener", cause))?;
let connectivity_listener = match result {
JValue::Object(object) => env
.new_global_ref(object)
.map_err(JvmError::CreateGlobalRef)?,
value => {
return Err(JvmError::InvalidMethodResult(
"MullvadVpnService",
"getConnectivityListener",
format!("{value:?}"),
));
}
};
let network_state = env
.call_method(
connectivity_listener.as_obj(),
"getCurrentDefaultNetworkState",
"()Lnet/mullvad/talpid/model/NetworkState;",
&[],
)
.map_err(|cause| JvmError::CallMethod("getCurrentDefaultNetworkState", cause))?;
let network_state: Option = FromJava::from_java(&env, network_state);
Ok(network_state)
}