summaryrefslogtreecommitdiffhomepage
path: root/android/lib
diff options
context:
space:
mode:
authorDavid Göransson <david.goransson@mullvad.net>2024-11-20 08:45:25 +0100
committerDavid Lönnhager <david.l@mullvad.net>2024-11-22 13:38:22 +0100
commit133845955492ecafb6447eaa9ceba34cb972f488 (patch)
tree814a03fe2be83ea9a8921410850de9fbf321d2de /android/lib
parent168c9afb19e9bec61b40ecfb5ab12ed7983f35e0 (diff)
downloadmullvadvpn-133845955492ecafb6447eaa9ceba34cb972f488.tar.xz
mullvadvpn-133845955492ecafb6447eaa9ceba34cb972f488.zip
Refactor ConnectivityListener
Diffstat (limited to 'android/lib')
-rw-r--r--android/lib/talpid/build.gradle.kts1
-rw-r--r--android/lib/talpid/src/main/kotlin/net/mullvad/talpid/ConnectivityListener.kt127
-rw-r--r--android/lib/talpid/src/main/kotlin/net/mullvad/talpid/TalpidVpnService.kt8
-rw-r--r--android/lib/talpid/src/main/kotlin/net/mullvad/talpid/util/ConnectivityManagerUtil.kt132
4 files changed, 207 insertions, 61 deletions
diff --git a/android/lib/talpid/build.gradle.kts b/android/lib/talpid/build.gradle.kts
index a5cd613de1..c53c2add28 100644
--- a/android/lib/talpid/build.gradle.kts
+++ b/android/lib/talpid/build.gradle.kts
@@ -31,6 +31,7 @@ android {
dependencies {
implementation(projects.lib.model)
+ implementation(libs.androidx.ktx)
implementation(libs.androidx.lifecycle.service)
implementation(libs.kermit)
implementation(libs.kotlin.stdlib)
diff --git a/android/lib/talpid/src/main/kotlin/net/mullvad/talpid/ConnectivityListener.kt b/android/lib/talpid/src/main/kotlin/net/mullvad/talpid/ConnectivityListener.kt
index f1fe3ca807..a37cf18578 100644
--- a/android/lib/talpid/src/main/kotlin/net/mullvad/talpid/ConnectivityListener.kt
+++ b/android/lib/talpid/src/main/kotlin/net/mullvad/talpid/ConnectivityListener.kt
@@ -1,86 +1,95 @@
package net.mullvad.talpid
-import android.content.Context
import android.net.ConnectivityManager
-import android.net.ConnectivityManager.NetworkCallback
import android.net.LinkProperties
import android.net.Network
import android.net.NetworkCapabilities
import android.net.NetworkRequest
-import co.touchlab.kermit.Logger
import java.net.InetAddress
-import kotlin.properties.Delegates.observable
+import kotlinx.coroutines.CoroutineScope
+import kotlinx.coroutines.flow.Flow
+import kotlinx.coroutines.flow.SharingStarted
+import kotlinx.coroutines.flow.StateFlow
+import kotlinx.coroutines.flow.distinctUntilChanged
+import kotlinx.coroutines.flow.filterIsInstance
+import kotlinx.coroutines.flow.map
+import kotlinx.coroutines.flow.onEach
+import kotlinx.coroutines.flow.scan
+import kotlinx.coroutines.flow.stateIn
+import net.mullvad.talpid.util.NetworkEvent
+import net.mullvad.talpid.util.defaultNetworkFlow
+import net.mullvad.talpid.util.networkFlow
-class ConnectivityListener {
- private val availableNetworks = HashSet<Network>()
-
- private val callback =
- object : NetworkCallback() {
- override fun onAvailable(network: Network) {
- availableNetworks.add(network)
- isConnected = true
- }
-
- override fun onLost(network: Network) {
- availableNetworks.remove(network)
- isConnected = availableNetworks.isNotEmpty()
- }
- }
-
- private val defaultNetworkCallback =
- object : NetworkCallback() {
- override fun onLinkPropertiesChanged(network: Network, linkProperties: LinkProperties) {
- super.onLinkPropertiesChanged(network, linkProperties)
- currentDnsServers = ArrayList(linkProperties.dnsServers)
+class ConnectivityListener(val connectivityManager: ConnectivityManager) {
+ // Used by JNI
+ var senderAddress = 0L
+ set(value) {
+ if (value == 0L) {
+ destroySender(field)
}
+ field = value
}
- private lateinit var connectivityManager: ConnectivityManager
+ private lateinit var _isConnected: StateFlow<Boolean>
+ // Used by JNI
+ val isConnected
+ get() = _isConnected.value
+ private lateinit var _currentDnsServers: StateFlow<List<InetAddress>>
// Used by JNI
- var isConnected by
- observable(false) { _, oldValue, newValue ->
- if (newValue != oldValue) {
- if (senderAddress != 0L) {
- notifyConnectivityChange(newValue, senderAddress)
+ val currentDnsServers
+ get() = ArrayList(_currentDnsServers.value)
+
+ fun register(scope: CoroutineScope) {
+ _currentDnsServers =
+ dnsServerChanges().stateIn(scope, SharingStarted.Eagerly, currentDnsServers())
+
+ _isConnected =
+ hasInternetCapability()
+ .onEach {
+ if (senderAddress != 0L) {
+ notifyConnectivityChange(it, senderAddress)
+ }
}
- }
- }
+ .stateIn(scope, SharingStarted.Eagerly, false)
+ }
- var currentDnsServers: ArrayList<InetAddress> = ArrayList()
- private set(value) {
- field = ArrayList(value.filter { it.hostAddress != TalpidVpnService.FALLBACK_DUMMY_DNS_SERVER })
- Logger.d("New currentDnsServers: $field")
- }
+ fun unregister() {
+ senderAddress = 0L
+ }
- var senderAddress = 0L
+ private fun dnsServerChanges(): Flow<List<InetAddress>> =
+ connectivityManager
+ .defaultNetworkFlow()
+ .filterIsInstance<NetworkEvent.LinkPropertiesChanged>()
+ .map { it.linkProperties.dnsServersWithoutFallback() }
- fun register(context: Context) {
+ private fun currentDnsServers(): List<InetAddress> =
+ connectivityManager
+ .getLinkProperties(connectivityManager.activeNetwork)
+ ?.dnsServersWithoutFallback() ?: emptyList()
+
+ private fun LinkProperties.dnsServersWithoutFallback(): List<InetAddress> =
+ dnsServers.filter { it.hostAddress != TalpidVpnService.FALLBACK_DUMMY_DNS_SERVER }
+
+ private fun hasInternetCapability(): Flow<Boolean> {
val request =
NetworkRequest.Builder()
.addCapability(NetworkCapabilities.NET_CAPABILITY_INTERNET)
.addCapability(NetworkCapabilities.NET_CAPABILITY_NOT_VPN)
.build()
- connectivityManager =
- context.getSystemService(Context.CONNECTIVITY_SERVICE) as ConnectivityManager
-
- connectivityManager.registerNetworkCallback(request, callback)
- currentDnsServers =
- connectivityManager.getLinkProperties(connectivityManager.activeNetwork)?.dnsServers?.let { ArrayList(it) }
- ?: ArrayList()
- connectivityManager.registerDefaultNetworkCallback(defaultNetworkCallback)
- }
-
- fun unregister() {
- connectivityManager.unregisterNetworkCallback(callback)
- connectivityManager.unregisterNetworkCallback(defaultNetworkCallback)
-
- if (senderAddress != 0L) {
- var oldSender = senderAddress
- senderAddress = 0L
- destroySender(oldSender)
- }
+ return connectivityManager
+ .networkFlow(request)
+ .scan(setOf<Network>()) { networks, event ->
+ when (event) {
+ is NetworkEvent.Available -> networks + event.network
+ is NetworkEvent.Lost -> networks - event.network
+ else -> networks
+ }
+ }
+ .map { it.isNotEmpty() }
+ .distinctUntilChanged()
}
private external fun notifyConnectivityChange(isConnected: Boolean, senderAddress: Long)
diff --git a/android/lib/talpid/src/main/kotlin/net/mullvad/talpid/TalpidVpnService.kt b/android/lib/talpid/src/main/kotlin/net/mullvad/talpid/TalpidVpnService.kt
index 61c0be2ccf..dfd6699b1e 100644
--- a/android/lib/talpid/src/main/kotlin/net/mullvad/talpid/TalpidVpnService.kt
+++ b/android/lib/talpid/src/main/kotlin/net/mullvad/talpid/TalpidVpnService.kt
@@ -1,7 +1,10 @@
package net.mullvad.talpid
+import android.net.ConnectivityManager
import android.os.ParcelFileDescriptor
import androidx.annotation.CallSuper
+import androidx.core.content.getSystemService
+import androidx.lifecycle.lifecycleScope
import co.touchlab.kermit.Logger
import java.net.Inet4Address
import java.net.Inet6Address
@@ -29,12 +32,13 @@ open class TalpidVpnService : LifecycleVpnService() {
private var currentTunConfig: TunConfig? = null
// Used by JNI
- val connectivityListener = ConnectivityListener()
+ lateinit var connectivityListener: ConnectivityListener
@CallSuper
override fun onCreate() {
super.onCreate()
- connectivityListener.register(this)
+ connectivityListener = ConnectivityListener(getSystemService<ConnectivityManager>()!!)
+ connectivityListener.register(lifecycleScope)
}
@CallSuper
diff --git a/android/lib/talpid/src/main/kotlin/net/mullvad/talpid/util/ConnectivityManagerUtil.kt b/android/lib/talpid/src/main/kotlin/net/mullvad/talpid/util/ConnectivityManagerUtil.kt
new file mode 100644
index 0000000000..daf155c6e8
--- /dev/null
+++ b/android/lib/talpid/src/main/kotlin/net/mullvad/talpid/util/ConnectivityManagerUtil.kt
@@ -0,0 +1,132 @@
+package net.mullvad.talpid.util
+
+import android.net.ConnectivityManager
+import android.net.ConnectivityManager.NetworkCallback
+import android.net.LinkProperties
+import android.net.Network
+import android.net.NetworkCapabilities
+import android.net.NetworkRequest
+import kotlinx.coroutines.channels.awaitClose
+import kotlinx.coroutines.channels.trySendBlocking
+import kotlinx.coroutines.flow.Flow
+import kotlinx.coroutines.flow.callbackFlow
+
+fun ConnectivityManager.defaultNetworkFlow(): Flow<NetworkEvent> =
+ callbackFlow<NetworkEvent> {
+ val callback =
+ object : NetworkCallback() {
+ override fun onLinkPropertiesChanged(
+ network: Network,
+ linkProperties: LinkProperties,
+ ) {
+ super.onLinkPropertiesChanged(network, linkProperties)
+ trySendBlocking(NetworkEvent.LinkPropertiesChanged(network, linkProperties))
+ }
+
+ override fun onAvailable(network: Network) {
+ super.onAvailable(network)
+ trySendBlocking(NetworkEvent.Available(network))
+ }
+
+ override fun onCapabilitiesChanged(
+ network: Network,
+ networkCapabilities: NetworkCapabilities,
+ ) {
+ super.onCapabilitiesChanged(network, networkCapabilities)
+ trySendBlocking(NetworkEvent.CapabilitiesChanged(network, networkCapabilities))
+ }
+
+ override fun onBlockedStatusChanged(network: Network, blocked: Boolean) {
+ super.onBlockedStatusChanged(network, blocked)
+ trySendBlocking(NetworkEvent.BlockedStatusChanged(network, blocked))
+ }
+
+ override fun onLosing(network: Network, maxMsToLive: Int) {
+ super.onLosing(network, maxMsToLive)
+ trySendBlocking(NetworkEvent.Losing(network, maxMsToLive))
+ }
+
+ override fun onLost(network: Network) {
+ super.onLost(network)
+ trySendBlocking(NetworkEvent.Lost(network))
+ }
+
+ override fun onUnavailable() {
+ super.onUnavailable()
+ trySendBlocking(NetworkEvent.Unavailable)
+ }
+ }
+ registerDefaultNetworkCallback(callback)
+
+ awaitClose { unregisterNetworkCallback(callback) }
+ }
+
+fun ConnectivityManager.networkFlow(networkRequest: NetworkRequest): Flow<NetworkEvent> =
+ callbackFlow<NetworkEvent> {
+ val callback =
+ object : NetworkCallback() {
+ override fun onLinkPropertiesChanged(
+ network: Network,
+ linkProperties: LinkProperties,
+ ) {
+ super.onLinkPropertiesChanged(network, linkProperties)
+ trySendBlocking(NetworkEvent.LinkPropertiesChanged(network, linkProperties))
+ }
+
+ override fun onAvailable(network: Network) {
+ super.onAvailable(network)
+ trySendBlocking(NetworkEvent.Available(network))
+ }
+
+ override fun onCapabilitiesChanged(
+ network: Network,
+ networkCapabilities: NetworkCapabilities,
+ ) {
+ super.onCapabilitiesChanged(network, networkCapabilities)
+ trySendBlocking(NetworkEvent.CapabilitiesChanged(network, networkCapabilities))
+ }
+
+ override fun onBlockedStatusChanged(network: Network, blocked: Boolean) {
+ super.onBlockedStatusChanged(network, blocked)
+ trySendBlocking(NetworkEvent.BlockedStatusChanged(network, blocked))
+ }
+
+ override fun onLosing(network: Network, maxMsToLive: Int) {
+ super.onLosing(network, maxMsToLive)
+ trySendBlocking(NetworkEvent.Losing(network, maxMsToLive))
+ }
+
+ override fun onLost(network: Network) {
+ super.onLost(network)
+ trySendBlocking(NetworkEvent.Lost(network))
+ }
+
+ override fun onUnavailable() {
+ super.onUnavailable()
+ trySendBlocking(NetworkEvent.Unavailable)
+ }
+ }
+ registerNetworkCallback(networkRequest, callback)
+
+ awaitClose { unregisterNetworkCallback(callback) }
+ }
+
+sealed interface NetworkEvent {
+ data class Available(val network: Network) : NetworkEvent
+
+ data object Unavailable : NetworkEvent
+
+ data class LinkPropertiesChanged(val network: Network, val linkProperties: LinkProperties) :
+ NetworkEvent
+
+ data class CapabilitiesChanged(
+ val network: Network,
+ val networkCapabilities: NetworkCapabilities,
+ ) : NetworkEvent
+
+ data class BlockedStatusChanged(val network: Network, val blocked: Boolean) : NetworkEvent
+
+ data class Losing(val network: Network, val maxMsToLive: Int) : NetworkEvent
+
+ data class Lost(val network: Network) : NetworkEvent
+}