diff options
| author | Janito Vaqueiro Ferreira Filho <janito@mullvad.net> | 2021-05-14 18:10:40 -0300 |
|---|---|---|
| committer | Janito Vaqueiro Ferreira Filho <janito@mullvad.net> | 2021-05-14 18:10:40 -0300 |
| commit | af9f0236984d5e41b2517a3fd52c417290094c82 (patch) | |
| tree | db8e01306d284f9eed7eaab9784e44c9bac3aff9 /android | |
| parent | 052d6cfe31191da136fdbab748284cabe5923491 (diff) | |
| parent | ddda7d310f73ea16fe00d62f9704e556e6c4fc97 (diff) | |
| download | mullvadvpn-af9f0236984d5e41b2517a3fd52c417290094c82.tar.xz mullvadvpn-af9f0236984d5e41b2517a3fd52c417290094c82.zip | |
Merge branch 'use-flow-for-tile-ipc'
Diffstat (limited to 'android')
10 files changed, 349 insertions, 48 deletions
diff --git a/android/build.gradle b/android/build.gradle index 19035b857d..8a7d7e2a16 100644 --- a/android/build.gradle +++ b/android/build.gradle @@ -162,6 +162,7 @@ dependencies { androidTestImplementation "androidx.test.ext:junit:1.1.2" androidTestImplementation "io.mockk:mockk-android:$mockkVersion" androidTestImplementation "org.koin:koin-test:$koinVersion" + androidTestImplementation "org.jetbrains.kotlin:kotlin-test:$kotlinVersion" // debugImplementation because LeakCanary should only run in debug builds. // debugImplementation 'com.squareup.leakcanary:leakcanary-android:2.6' } diff --git a/android/src/androidTest/kotlin/net/mullvad/mullvadvpn/ipc/HandlerFlowTest.kt b/android/src/androidTest/kotlin/net/mullvad/mullvadvpn/ipc/HandlerFlowTest.kt new file mode 100644 index 0000000000..709f330b0d --- /dev/null +++ b/android/src/androidTest/kotlin/net/mullvad/mullvadvpn/ipc/HandlerFlowTest.kt @@ -0,0 +1,48 @@ +package net.mullvad.mullvadvpn.ipc + +import android.os.Bundle +import android.os.Looper +import android.os.Message +import android.os.Parcelable +import kotlin.test.assertEquals +import kotlinx.coroutines.flow.take +import kotlinx.coroutines.flow.toList +import kotlinx.coroutines.runBlocking +import kotlinx.parcelize.Parcelize +import org.junit.Test + +class HandlerFlowTest { + val looper by lazy { Looper.getMainLooper() } + + val handler: HandlerFlow<Data?> by lazy { + HandlerFlow(looper) { message -> + message.data.getParcelable(DATA_KEY) + } + } + + @Test + fun test_message_extraction() { + sendMessage(Data(1)) + sendMessage(Data(2)) + sendMessage(Data(3)) + + val extractedData = runBlocking { handler.take(3).toList() } + + assertEquals(listOf(Data(1), Data(2), Data(3)), extractedData) + } + + private fun sendMessage(messageData: Data) { + val message = Message().apply { + data = Bundle().apply { putParcelable(DATA_KEY, messageData) } + } + + handler.handleMessage(message) + } + + companion object { + const val DATA_KEY = "data" + + @Parcelize + data class Data(val id: Int) : Parcelable + } +} diff --git a/android/src/main/kotlin/net/mullvad/mullvadvpn/ipc/Event.kt b/android/src/main/kotlin/net/mullvad/mullvadvpn/ipc/Event.kt index 6a197224ed..faaba05d89 100644 --- a/android/src/main/kotlin/net/mullvad/mullvadvpn/ipc/Event.kt +++ b/android/src/main/kotlin/net/mullvad/mullvadvpn/ipc/Event.kt @@ -1,6 +1,7 @@ package net.mullvad.mullvadvpn.ipc import android.os.Message as RawMessage +import android.os.Messenger import kotlinx.parcelize.Parcelize import net.mullvad.mullvadvpn.model.AppVersionInfo as AppVersionInfoData import net.mullvad.mullvadvpn.model.GeoIpLocation @@ -28,7 +29,7 @@ sealed class Event : Message.EventMessage() { data class CurrentVersion(val version: String?) : Event() @Parcelize - object ListenerReady : Event() + data class ListenerReady(val connection: Messenger, val listenerId: Int) : Event() @Parcelize data class LoginStatus(val status: LoginStatusData?) : Event() diff --git a/android/src/main/kotlin/net/mullvad/mullvadvpn/ipc/HandlerFlow.kt b/android/src/main/kotlin/net/mullvad/mullvadvpn/ipc/HandlerFlow.kt new file mode 100644 index 0000000000..943c55eeff --- /dev/null +++ b/android/src/main/kotlin/net/mullvad/mullvadvpn/ipc/HandlerFlow.kt @@ -0,0 +1,45 @@ +package net.mullvad.mullvadvpn.ipc + +import android.os.Handler +import android.os.Looper +import android.os.Message +import android.util.Log +import kotlinx.coroutines.CancellationException +import kotlinx.coroutines.InternalCoroutinesApi +import kotlinx.coroutines.channels.Channel +import kotlinx.coroutines.channels.ClosedSendChannelException +import kotlinx.coroutines.channels.sendBlocking +import kotlinx.coroutines.flow.Flow +import kotlinx.coroutines.flow.FlowCollector +import kotlinx.coroutines.flow.collect +import kotlinx.coroutines.flow.consumeAsFlow +import kotlinx.coroutines.flow.onCompletion + +class HandlerFlow<T>( + looper: Looper, + private val extractor: (Message) -> T +) : Handler(looper), Flow<T> { + private val channel = Channel<T>(Channel.UNLIMITED) + private val flow = channel.consumeAsFlow().onCompletion { + removeCallbacksAndMessages(null) + } + + @InternalCoroutinesApi + override suspend fun collect(collector: FlowCollector<T>) = flow.collect(collector) + + override fun handleMessage(message: Message) { + val extractedData = extractor(message) + + try { + channel.sendBlocking(extractedData) + } catch (exception: Exception) { + when (exception) { + is ClosedSendChannelException, is CancellationException -> { + Log.w("mullvad", "Received a message after HandlerFlow was closed", exception) + removeCallbacksAndMessages(null) + } + else -> throw exception + } + } + } +} diff --git a/android/src/main/kotlin/net/mullvad/mullvadvpn/ipc/Request.kt b/android/src/main/kotlin/net/mullvad/mullvadvpn/ipc/Request.kt index 0f64daece7..8093cac415 100644 --- a/android/src/main/kotlin/net/mullvad/mullvadvpn/ipc/Request.kt +++ b/android/src/main/kotlin/net/mullvad/mullvadvpn/ipc/Request.kt @@ -90,6 +90,9 @@ sealed class Request : Message.RequestMessage() { data class SubmitVoucher(val voucher: String) : Request() @Parcelize + data class UnregisterListener(val listenerId: Int) : Request() + + @Parcelize data class VpnPermissionResponse(val isGranted: Boolean) : Request() @Parcelize diff --git a/android/src/main/kotlin/net/mullvad/mullvadvpn/ipc/ServiceConnection.kt b/android/src/main/kotlin/net/mullvad/mullvadvpn/ipc/ServiceConnection.kt new file mode 100644 index 0000000000..751bcc9bc3 --- /dev/null +++ b/android/src/main/kotlin/net/mullvad/mullvadvpn/ipc/ServiceConnection.kt @@ -0,0 +1,103 @@ +package net.mullvad.mullvadvpn.ipc + +import android.content.Context +import android.content.Intent +import android.os.IBinder +import android.os.Looper +import android.os.Messenger +import kotlin.reflect.KClass +import kotlinx.coroutines.CoroutineScope +import kotlinx.coroutines.flow.MutableStateFlow +import kotlinx.coroutines.flow.SharingStarted +import kotlinx.coroutines.flow.StateFlow +import kotlinx.coroutines.flow.collect +import kotlinx.coroutines.flow.filterNotNull +import kotlinx.coroutines.flow.map +import kotlinx.coroutines.flow.onCompletion +import kotlinx.coroutines.flow.stateIn +import kotlinx.coroutines.launch +import net.mullvad.mullvadvpn.model.TunnelState +import net.mullvad.mullvadvpn.service.MullvadVpnService +import net.mullvad.mullvadvpn.util.DispatchingFlow +import net.mullvad.mullvadvpn.util.bindServiceFlow +import net.mullvad.mullvadvpn.util.dispatchTo + +class ServiceConnection(context: Context, scope: CoroutineScope) { + private val activeListeners = MutableStateFlow<Pair<Messenger, Int>?>(null) + private val handler = HandlerFlow(Looper.getMainLooper(), Event::fromMessage) + private val listener = Messenger(handler) + private val listenerId = MutableStateFlow<Int?>(null) + + private lateinit var listenerRegistrations: StateFlow<Pair<Messenger, Int>?> + + lateinit var tunnelState: StateFlow<TunnelState> + private set + + init { + val dispatcher = handler + .filterNotNull() + .dispatchTo { + listenerRegistrations = subscribeToState(Event.ListenerReady::class, scope) { + Pair(connection, listenerId) + } + + tunnelState = subscribeToState( + Event.TunnelStateChange::class, + scope, + TunnelState.Disconnected + ) { tunnelState } + } + + scope.launch { connect(context) } + scope.launch { dispatcher.collect() } + scope.launch { unregisterOldListeners() } + scope.launch { listenerRegistrations.collect { activeListeners.value = it } } + } + + private suspend fun connect(context: Context) { + val intent = Intent(context, MullvadVpnService::class.java) + + context.bindServiceFlow(intent).collect { binder -> + activeListeners.value = null + binder?.let(::registerListener) + } + } + + private fun registerListener(binder: IBinder) { + val request = Request.RegisterListener(listener) + val messenger = Messenger(binder) + + messenger.send(request.message) + } + + private suspend fun unregisterOldListeners() { + var oldListener: Pair<Messenger, Int>? = null + + activeListeners + .onCompletion { oldListener?.let(::unregisterListener) } + .collect { newListener -> + oldListener?.let(::unregisterListener) + oldListener = newListener + } + } + + private fun unregisterListener(registration: Pair<Messenger, Int>) { + val (messenger, listenerId) = registration + val request = Request.UnregisterListener(listenerId) + + messenger.send(request.message) + } + + private fun <V : Any, D> DispatchingFlow<in V>.subscribeToState( + event: KClass<V>, + scope: CoroutineScope, + dataExtractor: suspend V.() -> D + ) = subscribe(event).map(dataExtractor).stateIn(scope, SharingStarted.Lazily, null) + + private fun <V : Any, D> DispatchingFlow<in V>.subscribeToState( + event: KClass<V>, + scope: CoroutineScope, + initialValue: D, + dataExtractor: suspend V.() -> D + ) = subscribe(event).map(dataExtractor).stateIn(scope, SharingStarted.Lazily, initialValue) +} diff --git a/android/src/main/kotlin/net/mullvad/mullvadvpn/service/MullvadTileService.kt b/android/src/main/kotlin/net/mullvad/mullvadvpn/service/MullvadTileService.kt index db9662f5d6..b487e49a0e 100644 --- a/android/src/main/kotlin/net/mullvad/mullvadvpn/service/MullvadTileService.kt +++ b/android/src/main/kotlin/net/mullvad/mullvadvpn/service/MullvadTileService.kt @@ -1,46 +1,32 @@ package net.mullvad.mullvadvpn.service -import android.content.ComponentName import android.content.Intent import android.graphics.drawable.Icon import android.os.Build -import android.os.IBinder -import android.os.Messenger import android.service.quicksettings.Tile import android.service.quicksettings.TileService import kotlin.properties.Delegates.observable +import kotlinx.coroutines.CoroutineScope +import kotlinx.coroutines.FlowPreview +import kotlinx.coroutines.MainScope +import kotlinx.coroutines.cancel +import kotlinx.coroutines.flow.collect +import kotlinx.coroutines.flow.debounce +import kotlinx.coroutines.launch import net.mullvad.mullvadvpn.R +import net.mullvad.mullvadvpn.ipc.ServiceConnection import net.mullvad.mullvadvpn.model.TunnelState -import net.mullvad.mullvadvpn.ui.serviceconnection.ServiceConnection import net.mullvad.talpid.tunnel.ActionAfterDisconnect class MullvadTileService : TileService() { - private val serviceConnectionManager = object : android.content.ServiceConnection { - override fun onServiceConnected(className: ComponentName, binder: IBinder) { - serviceConnection = ServiceConnection(Messenger(binder)) - } - - override fun onServiceDisconnected(className: ComponentName) { - serviceConnection = null - } - } - - private var serviceConnection by observable<ServiceConnection?>( - null - ) { _, oldConnection, newConnection -> - oldConnection?.onDestroy() - - newConnection?.connectionProxy?.run { - onStateChange.subscribe(this@MullvadTileService, ::updateTunnelState) - } - } - private var secured by observable(false) { _, wasSecured, isSecured -> if (wasSecured != isSecured) { updateTileState() } } + private lateinit var scope: CoroutineScope + private lateinit var securedIcon: Icon private lateinit var unsecuredIcon: Icon @@ -54,11 +40,9 @@ class MullvadTileService : TileService() { override fun onStartListening() { super.onStartListening() - val intent = Intent(this, MullvadVpnService::class.java) - - bindService(intent, serviceConnectionManager, BIND_IMPORTANT) + scope = MainScope() - updateTileState() + scope.launch { listenToTunnelState() } } override fun onClick() { @@ -80,12 +64,18 @@ class MullvadTileService : TileService() { } override fun onStopListening() { - unbindService(serviceConnectionManager) - serviceConnection = null - + scope.cancel() super.onStopListening() } + @OptIn(FlowPreview::class) + private suspend fun listenToTunnelState() { + ServiceConnection(this@MullvadTileService, scope) + .tunnelState + .debounce(300L) + .collect(::updateTunnelState) + } + private fun updateTunnelState(tunnelState: TunnelState) { secured = when (tunnelState) { is TunnelState.Disconnected -> false diff --git a/android/src/main/kotlin/net/mullvad/mullvadvpn/service/endpoint/ServiceEndpoint.kt b/android/src/main/kotlin/net/mullvad/mullvadvpn/service/endpoint/ServiceEndpoint.kt index 93c7a8a9cf..0a0c41b42e 100644 --- a/android/src/main/kotlin/net/mullvad/mullvadvpn/service/endpoint/ServiceEndpoint.kt +++ b/android/src/main/kotlin/net/mullvad/mullvadvpn/service/endpoint/ServiceEndpoint.kt @@ -25,13 +25,22 @@ class ServiceEndpoint( val connectivityListener: ConnectivityListener, context: Context ) { - private val listeners = mutableSetOf<Messenger>() - private val registrationQueue: SendChannel<Messenger> = startRegistrator() + companion object { + sealed class Command { + data class RegisterListener(val listener: Messenger) : Command() + data class UnregisterListener(val listenerId: Int) : Command() + } + } + + private val listeners = mutableMapOf<Int, Messenger>() + private val commands: SendChannel<Command> = startRegistrator() internal val dispatcher = DispatchingHandler(looper) { message -> Request.fromMessage(message) } + private var listenerIdCounter = 0 + val messenger = Messenger(dispatcher) val vpnPermission = VpnPermission(context, this) @@ -50,14 +59,20 @@ class ServiceEndpoint( val voucherRedeemer = VoucherRedeemer(this) init { - dispatcher.registerHandler(Request.RegisterListener::class) { request -> - registrationQueue.sendBlocking(request.listener) + dispatcher.apply { + registerHandler(Request.RegisterListener::class) { request -> + commands.sendBlocking(Command.RegisterListener(request.listener)) + } + + registerHandler(Request.UnregisterListener::class) { request -> + commands.sendBlocking(Command.UnregisterListener(request.listenerId)) + } } } fun onDestroy() { dispatcher.onDestroy() - registrationQueue.close() + commands.close() accountCache.onDestroy() appVersionInfoCache.onDestroy() @@ -74,13 +89,13 @@ class ServiceEndpoint( internal fun sendEvent(event: Event) { synchronized(this) { - val deadListeners = mutableSetOf<Messenger>() + val deadListeners = mutableSetOf<Int>() - for (listener in listeners) { + for ((id, listener) in listeners) { try { listener.send(event.message) } catch (_: DeadObjectException) { - deadListeners.add(listener) + deadListeners.add(id) } } @@ -88,17 +103,20 @@ class ServiceEndpoint( } } - private fun startRegistrator() = GlobalScope.actor<Messenger>( + private fun startRegistrator() = GlobalScope.actor<Command>( Dispatchers.Default, Channel.UNLIMITED ) { try { - while (true) { - val listener = channel.receive() + for (command in channel) { + when (command) { + is Command.RegisterListener -> { + intermittentDaemon.await() - intermittentDaemon.await() - - registerListener(listener) + registerListener(command.listener) + } + is Command.UnregisterListener -> unregisterListener(command.listenerId) + } } } catch (exception: ClosedReceiveChannelException) { // Registration queue closed; stop registrator @@ -107,7 +125,9 @@ class ServiceEndpoint( private fun registerListener(listener: Messenger) { synchronized(this) { - listeners.add(listener) + val listenerId = newListenerId() + + listeners.put(listenerId, listener) val initialEvents = mutableListOf( Event.TunnelStateChange(connectionProxy.state), @@ -121,7 +141,7 @@ class ServiceEndpoint( Event.AppVersionInfo(appVersionInfoCache.appVersionInfo), Event.NewRelayList(relayListListener.relayList), Event.AuthToken(authTokenCache.authToken), - Event.ListenerReady + Event.ListenerReady(messenger, listenerId) ) if (vpnPermission.waitingForResponse) { @@ -133,4 +153,18 @@ class ServiceEndpoint( } } } + + private fun unregisterListener(listenerId: Int) { + synchronized(this) { + listeners.remove(listenerId) + } + } + + private fun newListenerId(): Int { + val listenerId = listenerIdCounter + + listenerIdCounter += 1 + + return listenerId + } } diff --git a/android/src/main/kotlin/net/mullvad/mullvadvpn/util/DispatchingFlow.kt b/android/src/main/kotlin/net/mullvad/mullvadvpn/util/DispatchingFlow.kt new file mode 100644 index 0000000000..af66a092ba --- /dev/null +++ b/android/src/main/kotlin/net/mullvad/mullvadvpn/util/DispatchingFlow.kt @@ -0,0 +1,49 @@ +package net.mullvad.mullvadvpn.util + +import java.util.concurrent.ConcurrentHashMap +import kotlin.reflect.KClass +import kotlinx.coroutines.InternalCoroutinesApi +import kotlinx.coroutines.channels.Channel +import kotlinx.coroutines.channels.ClosedSendChannelException +import kotlinx.coroutines.channels.SendChannel +import kotlinx.coroutines.flow.Flow +import kotlinx.coroutines.flow.FlowCollector +import kotlinx.coroutines.flow.collect +import kotlinx.coroutines.flow.consumeAsFlow + +class DispatchingFlow<T : Any>(private val upstream: Flow<T>) : Flow<T> { + private val subscribers = ConcurrentHashMap<KClass<out T>, SendChannel<T>>() + + fun <V : T> subscribe( + variant: KClass<V>, + capacity: Int = Channel.CONFLATED + ): Flow<V> { + val channel = Channel<V>(capacity) + + // This is safe because `collect` will only send to this channel if the instance class is V + @Suppress("UNCHECKED_CAST") + subscribers[variant] = channel as SendChannel<T> + + return channel.consumeAsFlow() + } + + fun <V : T> unsubscribe(variant: KClass<V>) = subscribers.remove(variant) + + @InternalCoroutinesApi + override suspend fun collect(collector: FlowCollector<T>) { + upstream.collect { event -> + try { + subscribers[event::class]?.send(event) + } catch (closedException: ClosedSendChannelException) { + subscribers.remove(event::class) + } + + collector.emit(event) + } + + subscribers.clear() + } +} + +fun <T : Any> Flow<T>.dispatchTo(configureSubscribers: DispatchingFlow<T>.() -> Unit) = + DispatchingFlow(this).also(configureSubscribers) diff --git a/android/src/main/kotlin/net/mullvad/mullvadvpn/util/FlowUtils.kt b/android/src/main/kotlin/net/mullvad/mullvadvpn/util/FlowUtils.kt index 45c45f4e16..71ce51a005 100644 --- a/android/src/main/kotlin/net/mullvad/mullvadvpn/util/FlowUtils.kt +++ b/android/src/main/kotlin/net/mullvad/mullvadvpn/util/FlowUtils.kt @@ -1,5 +1,10 @@ package net.mullvad.mullvadvpn.util +import android.content.ComponentName +import android.content.Context +import android.content.Intent +import android.content.ServiceConnection +import android.os.IBinder import android.view.animation.Animation import kotlin.coroutines.EmptyCoroutineContext import kotlinx.coroutines.Dispatchers @@ -26,3 +31,25 @@ fun Animation.transitionFinished(): Flow<Unit> = callbackFlow<Unit> { } } }.take(1) + +fun Context.bindServiceFlow(intent: Intent, flags: Int = 0): Flow<IBinder?> = callbackFlow { + val connectionCallback = object : ServiceConnection { + override fun onServiceConnected(className: ComponentName, binder: IBinder) { + safeOffer(binder) + } + + override fun onServiceDisconnected(className: ComponentName) { + safeOffer(null) + } + } + + bindService(intent, connectionCallback, flags) + + awaitClose { + safeOffer(null) + + Dispatchers.Default.dispatch(EmptyCoroutineContext) { + unbindService(connectionCallback) + } + } +} |
