diff options
| author | David Göransson <david.goransson@mullvad.net> | 2025-02-25 12:02:18 +0100 |
|---|---|---|
| committer | David Göransson <david.goransson@mullvad.net> | 2025-02-25 12:02:18 +0100 |
| commit | c41c9d846f8670886816c8883a7cbff833f95e26 (patch) | |
| tree | 1bf584764ee1b73430ecb05fb4b5602cd3838db1 /android | |
| parent | e4724d612354963ea00f0796f2c5cdf11c15c9d7 (diff) | |
| parent | fbc7f28235fc527ea9cd83ff2dc51feadfee1b3e (diff) | |
| download | mullvadvpn-c41c9d846f8670886816c8883a7cbff833f95e26.tar.xz mullvadvpn-c41c9d846f8670886816c8883a7cbff833f95e26.zip | |
Merge branch 'fix-wait-for-routes'
Diffstat (limited to 'android')
4 files changed, 72 insertions, 69 deletions
diff --git a/android/app/src/test/kotlin/net/mullvad/talpid/TalpidVpnServiceFallbackDnsTest.kt b/android/app/src/test/kotlin/net/mullvad/talpid/TalpidVpnServiceFallbackDnsTest.kt index 27e7658a11..e3faaf3884 100644 --- a/android/app/src/test/kotlin/net/mullvad/talpid/TalpidVpnServiceFallbackDnsTest.kt +++ b/android/app/src/test/kotlin/net/mullvad/talpid/TalpidVpnServiceFallbackDnsTest.kt @@ -34,6 +34,9 @@ class TalpidVpnServiceFallbackDnsTest { every { talpidVpnService.prepareVpnSafe() } returns Prepared.right() builderMockk = mockk<VpnService.Builder>() + every { talpidVpnService getProperty "connectivityListener" } returns + mockk<ConnectivityListener>(relaxed = true) + mockkConstructor(VpnService.Builder::class) every { anyConstructed<VpnService.Builder>().setMtu(any()) } returns builderMockk every { anyConstructed<VpnService.Builder>().setBlocking(any()) } returns builderMockk diff --git a/android/lib/talpid/src/main/kotlin/net/mullvad/talpid/ConnectivityListener.kt b/android/lib/talpid/src/main/kotlin/net/mullvad/talpid/ConnectivityListener.kt index fdee5039ad..b702a39a6e 100644 --- a/android/lib/talpid/src/main/kotlin/net/mullvad/talpid/ConnectivityListener.kt +++ b/android/lib/talpid/src/main/kotlin/net/mullvad/talpid/ConnectivityListener.kt @@ -2,21 +2,24 @@ package net.mullvad.talpid import android.net.ConnectivityManager import android.net.LinkProperties -import android.net.Network import android.net.NetworkCapabilities import android.net.NetworkRequest import co.touchlab.kermit.Logger import java.net.InetAddress import kotlin.collections.ArrayList import kotlinx.coroutines.CoroutineScope +import kotlinx.coroutines.channels.Channel import kotlinx.coroutines.flow.Flow +import kotlinx.coroutines.flow.MutableStateFlow import kotlinx.coroutines.flow.SharingStarted import kotlinx.coroutines.flow.StateFlow -import kotlinx.coroutines.flow.distinctUntilChanged import kotlinx.coroutines.flow.map +import kotlinx.coroutines.flow.merge import kotlinx.coroutines.flow.onEach +import kotlinx.coroutines.flow.receiveAsFlow import kotlinx.coroutines.flow.scan import kotlinx.coroutines.flow.stateIn +import kotlinx.coroutines.launch import net.mullvad.talpid.model.NetworkState import net.mullvad.talpid.util.NetworkEvent import net.mullvad.talpid.util.RawNetworkState @@ -29,66 +32,86 @@ class ConnectivityListener(private val connectivityManager: ConnectivityManager) val isConnected get() = _isConnected.value - private lateinit var _currentNetworkState: StateFlow<NetworkState?> + private val _mutableNetworkState = MutableStateFlow<NetworkState?>(null) + private val resetNetworkState: Channel<Unit> = Channel() // Used by JNI val currentDefaultNetworkState: NetworkState? - get() = _currentNetworkState.value + get() = _mutableNetworkState.value // Used by JNI val currentDnsServers: ArrayList<InetAddress> - get() = _currentNetworkState.value?.dnsServers ?: ArrayList() + get() = _mutableNetworkState.value?.dnsServers ?: ArrayList() fun register(scope: CoroutineScope) { // Consider implementing retry logic for the flows below, because registering a listener on // the default network may fail if the network on Android 11 // https://issuetracker.google.com/issues/175055271?pli=1 - _currentNetworkState = - connectivityManager - .defaultRawNetworkStateFlow() + scope.launch { + merge( + connectivityManager.defaultRawNetworkStateFlow(), + resetNetworkState.receiveAsFlow().map { null }, + ) .map { it?.toNetworkState() } .onEach { notifyDefaultNetworkChange(it) } - .stateIn(scope, SharingStarted.Eagerly, null) + .collect(_mutableNetworkState) + } _isConnected = hasInternetCapability() .onEach { notifyConnectivityChange(it) } - .stateIn(scope, SharingStarted.Eagerly, false) + .stateIn( + scope, + SharingStarted.Eagerly, + true, // Assume we have internet until we know otherwise + ) + } + + /** + * Invalidates the network state cache. E.g when the VPN is connected or disconnected, and we + * know the last known values not to be correct anymore. + */ + fun invalidateNetworkStateCache() { + _mutableNetworkState.value = null } private fun LinkProperties.dnsServersWithoutFallback(): List<InetAddress> = dnsServers.filter { it.hostAddress != TalpidVpnService.FALLBACK_DUMMY_DNS_SERVER } - private fun hasInternetCapability(): Flow<Boolean> { - val request = - NetworkRequest.Builder() - .addCapability(NetworkCapabilities.NET_CAPABILITY_INTERNET) - .addCapability(NetworkCapabilities.NET_CAPABILITY_NOT_VPN) - .build() + private val nonVPNNetworksRequest = + NetworkRequest.Builder().addCapability(NetworkCapabilities.NET_CAPABILITY_NOT_VPN).build() + private fun hasInternetCapability(): Flow<Boolean> { + @Suppress("DEPRECATION") return connectivityManager - .networkEvents(request) - .scan(setOf<Network>()) { networks, event -> + .networkEvents(nonVPNNetworksRequest) + .scan( + connectivityManager.allNetworks.associateWith { + connectivityManager.getNetworkCapabilities(it) + } + ) { networks, event -> when (event) { - is NetworkEvent.Available -> { - Logger.d("Network available ${event.network}") - (networks + event.network).also { - Logger.d("Number of networks: ${it.size}") - } - } is NetworkEvent.Lost -> { Logger.d("Network lost ${event.network}") (networks - event.network).also { Logger.d("Number of networks: ${it.size}") } } + is NetworkEvent.CapabilitiesChanged -> { + Logger.d("Network capabilities changed ${event.network}") + (networks + (event.network to event.networkCapabilities)).also { + Logger.d("Number of networks: ${it.size}") + } + } else -> networks } } - .map { it.isNotEmpty() } - .distinctUntilChanged() + .map { it.any { it.value.hasInternetCapability() } } } + private fun NetworkCapabilities?.hasInternetCapability(): Boolean = + this?.hasCapability(NetworkCapabilities.NET_CAPABILITY_INTERNET) == true + private fun RawNetworkState.toNetworkState(): NetworkState = NetworkState( network.networkHandle, diff --git a/android/lib/talpid/src/main/kotlin/net/mullvad/talpid/TalpidVpnService.kt b/android/lib/talpid/src/main/kotlin/net/mullvad/talpid/TalpidVpnService.kt index a143df6132..a227c9a770 100644 --- a/android/lib/talpid/src/main/kotlin/net/mullvad/talpid/TalpidVpnService.kt +++ b/android/lib/talpid/src/main/kotlin/net/mullvad/talpid/TalpidVpnService.kt @@ -42,8 +42,6 @@ open class TalpidVpnService : LifecycleVpnService() { } } - private var currentTunConfig: TunConfig? = null - // Used by JNI lateinit var connectivityListener: ConnectivityListener @@ -56,35 +54,18 @@ open class TalpidVpnService : LifecycleVpnService() { // Used by JNI fun openTun(config: TunConfig): CreateTunResult = - synchronized(this) { - val tunStatus = activeTunStatus - - if (config == currentTunConfig && tunStatus != null && tunStatus.isOpen) { - tunStatus - } else { - openTunImpl(config) - } - } - - // Used by JNI - fun openTunForced(config: TunConfig): CreateTunResult = - synchronized(this) { openTunImpl(config) } + synchronized(this) { createTun(config).merge().also { activeTunStatus = it } } // Used by JNI - fun closeTun(): Unit = synchronized(this) { activeTunStatus = null } + fun closeTun(): Unit = + synchronized(this) { + connectivityListener.invalidateNetworkStateCache() + activeTunStatus = null + } // Used by JNI fun bypass(socket: Int): Boolean = protect(socket) - private fun openTunImpl(config: TunConfig): CreateTunResult { - val newTunStatus = createTun(config).merge() - - currentTunConfig = config - activeTunStatus = newTunStatus - - return newTunStatus - } - private fun createTun( config: TunConfig ): Either<CreateTunResult.Error, CreateTunResult.Success> = either { @@ -123,6 +104,7 @@ open class TalpidVpnService : LifecycleVpnService() { builder.addDnsServer(FALLBACK_DUMMY_DNS_SERVER) } + connectivityListener.invalidateNetworkStateCache() val vpnInterfaceFd = builder .establishSafe() diff --git a/android/lib/talpid/src/main/kotlin/net/mullvad/talpid/util/ConnectivityManagerUtil.kt b/android/lib/talpid/src/main/kotlin/net/mullvad/talpid/util/ConnectivityManagerUtil.kt index fddaa6fb88..f3297a995e 100644 --- a/android/lib/talpid/src/main/kotlin/net/mullvad/talpid/util/ConnectivityManagerUtil.kt +++ b/android/lib/talpid/src/main/kotlin/net/mullvad/talpid/util/ConnectivityManagerUtil.kt @@ -109,24 +109,19 @@ fun ConnectivityManager.networkEvents(networkRequest: NetworkRequest): Flow<Netw } internal fun ConnectivityManager.defaultRawNetworkStateFlow(): Flow<RawNetworkState?> = - defaultNetworkEvents() - .scan( - null as RawNetworkState?, - { state, event -> - return@scan when (event) { - is NetworkEvent.Available -> RawNetworkState(network = event.network) - is NetworkEvent.BlockedStatusChanged -> - state?.copy(blockedStatus = event.blocked) - is NetworkEvent.CapabilitiesChanged -> - state?.copy(networkCapabilities = event.networkCapabilities) - is NetworkEvent.LinkPropertiesChanged -> - state?.copy(linkProperties = event.linkProperties) - is NetworkEvent.Losing -> state?.copy(maxMsToLive = event.maxMsToLive) - is NetworkEvent.Lost -> null - NetworkEvent.Unavailable -> null - } - }, - ) + defaultNetworkEvents().scan(null as RawNetworkState?) { state, event -> state.reduce(event) } + +private fun RawNetworkState?.reduce(event: NetworkEvent): RawNetworkState? = + when (event) { + is NetworkEvent.Available -> RawNetworkState(network = event.network) + is NetworkEvent.BlockedStatusChanged -> this?.copy(blockedStatus = event.blocked) + is NetworkEvent.CapabilitiesChanged -> + this?.copy(networkCapabilities = event.networkCapabilities) + is NetworkEvent.LinkPropertiesChanged -> this?.copy(linkProperties = event.linkProperties) + is NetworkEvent.Losing -> this?.copy(maxMsToLive = event.maxMsToLive) + is NetworkEvent.Lost -> null + NetworkEvent.Unavailable -> null + } sealed interface NetworkEvent { data class Available(val network: Network) : NetworkEvent |
