diff options
Diffstat (limited to 'android/lib')
8 files changed, 290 insertions, 178 deletions
diff --git a/android/lib/common/src/main/kotlin/net/mullvad/mullvadvpn/lib/common/util/VpnServiceUtils.kt b/android/lib/common/src/main/kotlin/net/mullvad/mullvadvpn/lib/common/util/VpnServiceUtils.kt index 59833cb396..06c862936b 100644 --- a/android/lib/common/src/main/kotlin/net/mullvad/mullvadvpn/lib/common/util/VpnServiceUtils.kt +++ b/android/lib/common/src/main/kotlin/net/mullvad/mullvadvpn/lib/common/util/VpnServiceUtils.kt @@ -2,10 +2,14 @@ package net.mullvad.mullvadvpn.lib.common.util import android.content.Context import android.content.Intent +import android.net.VpnService import android.net.VpnService.prepare +import android.os.ParcelFileDescriptor import arrow.core.Either -import arrow.core.flatten +import arrow.core.flatMap import arrow.core.left +import arrow.core.raise.either +import arrow.core.raise.ensureNotNull import arrow.core.right import co.touchlab.kermit.Logger import net.mullvad.mullvadvpn.lib.common.util.SdkUtils.getInstalledPackagesList @@ -13,6 +17,8 @@ import net.mullvad.mullvadvpn.lib.model.PrepareError import net.mullvad.mullvadvpn.lib.model.Prepared /** + * Prepare to establish a VPN connection safely. + * * Invoking VpnService.prepare() can result in 3 out comes: * 1. IllegalStateException - There is a legacy VPN profile marked as always on * 2. Intent @@ -34,7 +40,7 @@ fun Context.prepareVpnSafe(): Either<PrepareError, Prepared> = else -> throw it } } - .map { intent -> + .flatMap { intent -> if (intent == null) { Prepared.right() } else { @@ -46,7 +52,6 @@ fun Context.prepareVpnSafe(): Either<PrepareError, Prepared> = } } } - .flatten() fun Context.getAlwaysOnVpnAppName(): String? { return resolveAlwaysOnVpnPackageName() @@ -59,3 +64,38 @@ fun Context.getAlwaysOnVpnAppName(): String? { ?.loadLabel(packageManager) ?.toString() } + +/** + * Establish a VPN connection safely. + * + * This function wraps the [VpnService.Builder.establish] function and catches any exceptions that + * may be thrown and type them to a more specific error. + * + * @return [ParcelFileDescriptor] if successful, [EstablishError] otherwise + */ +fun VpnService.Builder.establishSafe(): Either<EstablishError, ParcelFileDescriptor> = either { + val vpnInterfaceFd = + Either.catch { establish() } + .mapLeft { + when (it) { + is IllegalStateException -> EstablishError.ParameterNotApplied(it) + is IllegalArgumentException -> EstablishError.ParameterNotAccepted(it) + else -> EstablishError.UnknownError(it) + } + } + .bind() + + ensureNotNull(vpnInterfaceFd) { EstablishError.NullVpnInterface } + + vpnInterfaceFd +} + +sealed interface EstablishError { + data class ParameterNotApplied(val exception: IllegalStateException) : EstablishError + + data class ParameterNotAccepted(val exception: IllegalArgumentException) : EstablishError + + data object NullVpnInterface : EstablishError + + data class UnknownError(val error: Throwable) : EstablishError +} diff --git a/android/lib/daemon-grpc/src/main/kotlin/net/mullvad/mullvadvpn/lib/daemon/grpc/mapper/ToDomain.kt b/android/lib/daemon-grpc/src/main/kotlin/net/mullvad/mullvadvpn/lib/daemon/grpc/mapper/ToDomain.kt index daa04fc8d9..fe4cf11881 100644 --- a/android/lib/daemon-grpc/src/main/kotlin/net/mullvad/mullvadvpn/lib/daemon/grpc/mapper/ToDomain.kt +++ b/android/lib/daemon-grpc/src/main/kotlin/net/mullvad/mullvadvpn/lib/daemon/grpc/mapper/ToDomain.kt @@ -36,9 +36,6 @@ import net.mullvad.mullvadvpn.lib.model.DnsState import net.mullvad.mullvadvpn.lib.model.Endpoint import net.mullvad.mullvadvpn.lib.model.ErrorState import net.mullvad.mullvadvpn.lib.model.ErrorStateCause -import net.mullvad.mullvadvpn.lib.model.ErrorStateCause.AuthFailed -import net.mullvad.mullvadvpn.lib.model.ErrorStateCause.OtherAlwaysOnApp -import net.mullvad.mullvadvpn.lib.model.ErrorStateCause.TunnelParameterError import net.mullvad.mullvadvpn.lib.model.FeatureIndicator import net.mullvad.mullvadvpn.lib.model.GeoIpLocation import net.mullvad.mullvadvpn.lib.model.GeoLocationId @@ -125,7 +122,7 @@ private fun ManagementInterface.TunnelState.Error.toDomain(): TunnelState.Error val otherAlwaysOnAppError = errorState.let { if (it.hasOtherAlwaysOnAppError()) { - OtherAlwaysOnApp(it.otherAlwaysOnAppError.appName) + ErrorStateCause.OtherAlwaysOnApp(it.otherAlwaysOnAppError.appName) } else { null } @@ -238,7 +235,7 @@ internal fun ManagementInterface.ErrorState.toDomain( cause = when (cause!!) { ManagementInterface.ErrorState.Cause.AUTH_FAILED -> - AuthFailed(authFailedError.toDomain()) + ErrorStateCause.AuthFailed(authFailedError.toDomain()) ManagementInterface.ErrorState.Cause.IPV6_UNAVAILABLE -> ErrorStateCause.Ipv6Unavailable ManagementInterface.ErrorState.Cause.SET_FIREWALL_POLICY_ERROR -> @@ -247,7 +244,7 @@ internal fun ManagementInterface.ErrorState.toDomain( ManagementInterface.ErrorState.Cause.START_TUNNEL_ERROR -> ErrorStateCause.StartTunnelError ManagementInterface.ErrorState.Cause.TUNNEL_PARAMETER_ERROR -> - TunnelParameterError(parameterError.toDomain()) + ErrorStateCause.TunnelParameterError(parameterError.toDomain()) ManagementInterface.ErrorState.Cause.IS_OFFLINE -> ErrorStateCause.IsOffline ManagementInterface.ErrorState.Cause.SPLIT_TUNNEL_ERROR -> ErrorStateCause.StartTunnelError @@ -255,7 +252,6 @@ internal fun ManagementInterface.ErrorState.toDomain( ManagementInterface.ErrorState.Cause.NEED_FULL_DISK_PERMISSIONS, ManagementInterface.ErrorState.Cause.CREATE_TUNNEL_DEVICE -> throw IllegalArgumentException("Unrecognized error state cause") - ManagementInterface.ErrorState.Cause.NOT_PREPARED -> ErrorStateCause.NotPrepared ManagementInterface.ErrorState.Cause.OTHER_ALWAYS_ON_APP -> otherAlwaysOnApp!! ManagementInterface.ErrorState.Cause.OTHER_LEGACY_ALWAYS_ON_VPN -> diff --git a/android/lib/talpid/src/main/kotlin/net/mullvad/talpid/ConnectivityListener.kt b/android/lib/talpid/src/main/kotlin/net/mullvad/talpid/ConnectivityListener.kt index 86b27e3ba8..fdee5039ad 100644 --- a/android/lib/talpid/src/main/kotlin/net/mullvad/talpid/ConnectivityListener.kt +++ b/android/lib/talpid/src/main/kotlin/net/mullvad/talpid/ConnectivityListener.kt @@ -7,34 +7,48 @@ import android.net.NetworkCapabilities import android.net.NetworkRequest import co.touchlab.kermit.Logger import java.net.InetAddress +import kotlin.collections.ArrayList import kotlinx.coroutines.CoroutineScope import kotlinx.coroutines.flow.Flow import kotlinx.coroutines.flow.SharingStarted import kotlinx.coroutines.flow.StateFlow import kotlinx.coroutines.flow.distinctUntilChanged -import kotlinx.coroutines.flow.filterIsInstance import kotlinx.coroutines.flow.map import kotlinx.coroutines.flow.onEach import kotlinx.coroutines.flow.scan import kotlinx.coroutines.flow.stateIn +import net.mullvad.talpid.model.NetworkState import net.mullvad.talpid.util.NetworkEvent -import net.mullvad.talpid.util.defaultNetworkFlow -import net.mullvad.talpid.util.networkFlow +import net.mullvad.talpid.util.RawNetworkState +import net.mullvad.talpid.util.defaultRawNetworkStateFlow +import net.mullvad.talpid.util.networkEvents -class ConnectivityListener(val connectivityManager: ConnectivityManager) { +class ConnectivityListener(private val connectivityManager: ConnectivityManager) { private lateinit var _isConnected: StateFlow<Boolean> // Used by JNI val isConnected get() = _isConnected.value - private lateinit var _currentDnsServers: StateFlow<List<InetAddress>> + private lateinit var _currentNetworkState: StateFlow<NetworkState?> + + // Used by JNI + val currentDefaultNetworkState: NetworkState? + get() = _currentNetworkState.value + // Used by JNI - val currentDnsServers - get() = ArrayList(_currentDnsServers.value) + val currentDnsServers: ArrayList<InetAddress> + get() = _currentNetworkState.value?.dnsServers ?: ArrayList() fun register(scope: CoroutineScope) { - _currentDnsServers = - dnsServerChanges().stateIn(scope, SharingStarted.Eagerly, currentDnsServers()) + // Consider implementing retry logic for the flows below, because registering a listener on + // the default network may fail if the network on Android 11 + // https://issuetracker.google.com/issues/175055271?pli=1 + _currentNetworkState = + connectivityManager + .defaultRawNetworkStateFlow() + .map { it?.toNetworkState() } + .onEach { notifyDefaultNetworkChange(it) } + .stateIn(scope, SharingStarted.Eagerly, null) _isConnected = hasInternetCapability() @@ -42,18 +56,6 @@ class ConnectivityListener(val connectivityManager: ConnectivityManager) { .stateIn(scope, SharingStarted.Eagerly, false) } - private fun dnsServerChanges(): Flow<List<InetAddress>> = - connectivityManager - .defaultNetworkFlow() - .filterIsInstance<NetworkEvent.LinkPropertiesChanged>() - .onEach { Logger.d("Link properties changed") } - .map { it.linkProperties.dnsServersWithoutFallback() } - - private fun currentDnsServers(): List<InetAddress> = - connectivityManager - .getLinkProperties(connectivityManager.activeNetwork) - ?.dnsServersWithoutFallback() ?: emptyList() - private fun LinkProperties.dnsServersWithoutFallback(): List<InetAddress> = dnsServers.filter { it.hostAddress != TalpidVpnService.FALLBACK_DUMMY_DNS_SERVER } @@ -65,7 +67,7 @@ class ConnectivityListener(val connectivityManager: ConnectivityManager) { .build() return connectivityManager - .networkFlow(request) + .networkEvents(request) .scan(setOf<Network>()) { networks, event -> when (event) { is NetworkEvent.Available -> { @@ -87,5 +89,14 @@ class ConnectivityListener(val connectivityManager: ConnectivityManager) { .distinctUntilChanged() } + private fun RawNetworkState.toNetworkState(): NetworkState = + NetworkState( + network.networkHandle, + linkProperties?.routes, + linkProperties?.dnsServersWithoutFallback(), + ) + private external fun notifyConnectivityChange(isConnected: Boolean) + + private external fun notifyDefaultNetworkChange(networkState: NetworkState?) } diff --git a/android/lib/talpid/src/main/kotlin/net/mullvad/talpid/TalpidVpnService.kt b/android/lib/talpid/src/main/kotlin/net/mullvad/talpid/TalpidVpnService.kt index 74d44005cd..a143df6132 100644 --- a/android/lib/talpid/src/main/kotlin/net/mullvad/talpid/TalpidVpnService.kt +++ b/android/lib/talpid/src/main/kotlin/net/mullvad/talpid/TalpidVpnService.kt @@ -1,18 +1,29 @@ package net.mullvad.talpid import android.net.ConnectivityManager +import android.net.VpnService import android.os.ParcelFileDescriptor import androidx.annotation.CallSuper import androidx.core.content.getSystemService import androidx.lifecycle.lifecycleScope +import arrow.core.Either +import arrow.core.mapOrAccumulate +import arrow.core.merge +import arrow.core.raise.either import co.touchlab.kermit.Logger import java.net.Inet4Address import java.net.Inet6Address import java.net.InetAddress import kotlin.properties.Delegates.observable +import net.mullvad.mullvadvpn.lib.common.util.establishSafe import net.mullvad.mullvadvpn.lib.common.util.prepareVpnSafe import net.mullvad.mullvadvpn.lib.model.PrepareError import net.mullvad.talpid.model.CreateTunResult +import net.mullvad.talpid.model.CreateTunResult.EstablishError +import net.mullvad.talpid.model.CreateTunResult.InvalidDnsServers +import net.mullvad.talpid.model.CreateTunResult.NotPrepared +import net.mullvad.talpid.model.CreateTunResult.OtherAlwaysOnApp +import net.mullvad.talpid.model.CreateTunResult.OtherLegacyAlwaysOnVpn import net.mullvad.talpid.model.TunConfig import net.mullvad.talpid.util.TalpidSdkUtils.setMeteredIfSupported @@ -22,7 +33,7 @@ open class TalpidVpnService : LifecycleVpnService() { val oldTunFd = when (oldTunStatus) { is CreateTunResult.Success -> oldTunStatus.tunFd - is CreateTunResult.InvalidDnsServers -> oldTunStatus.tunFd + is InvalidDnsServers -> oldTunStatus.tunFd else -> null } @@ -43,26 +54,30 @@ open class TalpidVpnService : LifecycleVpnService() { connectivityListener.register(lifecycleScope) } - fun openTun(config: TunConfig): CreateTunResult { + // Used by JNI + fun openTun(config: TunConfig): CreateTunResult = synchronized(this) { val tunStatus = activeTunStatus if (config == currentTunConfig && tunStatus != null && tunStatus.isOpen) { - return tunStatus + tunStatus } else { - return openTunImpl(config) + openTunImpl(config) } } - } - fun openTunForced(config: TunConfig): CreateTunResult { - synchronized(this) { - return openTunImpl(config) - } - } + // Used by JNI + fun openTunForced(config: TunConfig): CreateTunResult = + synchronized(this) { openTunImpl(config) } + + // Used by JNI + fun closeTun(): Unit = synchronized(this) { activeTunStatus = null } + + // Used by JNI + fun bypass(socket: Int): Boolean = protect(socket) private fun openTunImpl(config: TunConfig): CreateTunResult { - val newTunStatus = createTun(config) + val newTunStatus = createTun(config).merge() currentTunConfig = config activeTunStatus = newTunStatus @@ -70,95 +85,76 @@ open class TalpidVpnService : LifecycleVpnService() { return newTunStatus } - fun closeTun() { - synchronized(this) { activeTunStatus = null } - } - - // DROID-1407 - // Function is to be cleaned up and lint suppression to be removed. - @Suppress("ReturnCount") - private fun createTun(config: TunConfig): CreateTunResult { - prepareVpnSafe() - .mapLeft { it.toCreateTunResult() } - .onLeft { - return it - } - - val invalidDnsServerAddresses = ArrayList<InetAddress>() + private fun createTun( + config: TunConfig + ): Either<CreateTunResult.Error, CreateTunResult.Success> = either { + prepareVpnSafe().mapLeft { it.toCreateTunError() }.bind() - val builder = - Builder().apply { - for (address in config.addresses) { - addAddress(address, address.prefixLength()) - } + val builder = Builder() + builder.setMtu(config.mtu) + builder.setBlocking(false) + builder.setMeteredIfSupported(false) - for (dnsServer in config.dnsServers) { - try { - addDnsServer(dnsServer) - } catch (exception: IllegalArgumentException) { - invalidDnsServerAddresses.add(dnsServer) - } - } + config.addresses.forEach { builder.addAddress(it, it.prefixLength()) } + config.routes.forEach { builder.addRoute(it.address, it.prefixLength.toInt()) } + config.excludedPackages.forEach { app -> builder.addDisallowedApplication(app) } - // Avoids creating a tunnel with no DNS servers or if all DNS servers was invalid, - // since apps then may leak DNS requests. - // https://issuetracker.google.com/issues/337961996 - if (invalidDnsServerAddresses.size == config.dnsServers.size) { - Logger.w( - "All DNS servers invalid or non set, using fallback DNS server to " + - "minimize leaks, dnsServers.isEmpty(): ${config.dnsServers.isEmpty()}" - ) - addDnsServer(FALLBACK_DUMMY_DNS_SERVER) - } - - for (route in config.routes) { - addRoute(route.address, route.prefixLength.toInt()) - } - - config.excludedPackages.forEach { app -> addDisallowedApplication(app) } - setMtu(config.mtu) - setBlocking(false) - setMeteredIfSupported(false) - } - - val vpnInterfaceFd = - try { - builder.establish() - } catch (e: IllegalStateException) { - Logger.e("Failed to establish, a parameter could not be applied", e) - return CreateTunResult.TunnelDeviceError - } catch (e: IllegalArgumentException) { - Logger.e("Failed to establish a parameter was not accepted", e) - return CreateTunResult.TunnelDeviceError + // We don't care if adding DNS servers fails at this point, since we can still create a + // tunnel to consume traffic and then notify daemon to later enter blocked state. + val dnsConfigureResult = + config.dnsServers.mapOrAccumulate { + builder.addDnsServerSafe(it).bind() + Unit } - if (vpnInterfaceFd == null) { - Logger.e("VpnInterface returned null") - return CreateTunResult.TunnelDeviceError + // Never create a tunnel where all DNS servers are invalid or if none was ever set, since + // apps then may leak DNS requests. + // https://issuetracker.google.com/issues/337961996 + val shouldAddFallbackDns = + dnsConfigureResult.fold( + { invalidDnsServers -> invalidDnsServers.size == config.dnsServers.size }, + { addedDnsServers -> addedDnsServers.isEmpty() }, + ) + if (shouldAddFallbackDns) { + Logger.w( + "All DNS servers invalid or non set, using fallback DNS server to " + + "minimize leaks, dnsServers.isEmpty(): ${config.dnsServers.isEmpty()}" + ) + builder.addDnsServer(FALLBACK_DUMMY_DNS_SERVER) } - val tunFd = vpnInterfaceFd.detachFd() - - waitForTunnelUp(tunFd, config.routes.any { route -> route.isIpv6 }) + val vpnInterfaceFd = + builder + .establishSafe() + .onLeft { Logger.w("Failed to establish tunnel $it") } + .mapLeft { EstablishError } + .bind() - if (invalidDnsServerAddresses.isNotEmpty()) { - return CreateTunResult.InvalidDnsServers(invalidDnsServerAddresses, tunFd) - } + val tunFd = vpnInterfaceFd.detachFd() - return CreateTunResult.Success(tunFd) - } + dnsConfigureResult.mapLeft { InvalidDnsServers(it, tunFd) }.bind() - fun bypass(socket: Int): Boolean { - return protect(socket) + CreateTunResult.Success(tunFd) } - private fun PrepareError.toCreateTunResult() = + private fun PrepareError.toCreateTunError() = when (this) { - is PrepareError.OtherLegacyAlwaysOnVpn -> CreateTunResult.OtherLegacyAlwaysOnVpn - is PrepareError.NotPrepared -> CreateTunResult.NotPrepared - is PrepareError.OtherAlwaysOnApp -> CreateTunResult.OtherAlwaysOnApp(appName) + is PrepareError.OtherLegacyAlwaysOnVpn -> OtherLegacyAlwaysOnVpn + is PrepareError.NotPrepared -> NotPrepared + is PrepareError.OtherAlwaysOnApp -> OtherAlwaysOnApp(appName) } + private fun Builder.addDnsServerSafe( + dnsServer: InetAddress + ): Either<InetAddress, VpnService.Builder> = + Either.catch { addDnsServer(dnsServer) } + .mapLeft { + when (it) { + is IllegalArgumentException -> dnsServer + else -> throw it + } + } + private fun InetAddress.prefixLength(): Int = when (this) { is Inet4Address -> IPV4_PREFIX_LENGTH @@ -166,8 +162,6 @@ open class TalpidVpnService : LifecycleVpnService() { else -> throw IllegalArgumentException("Invalid IP address (not IPv4 nor IPv6)") } - private external fun waitForTunnelUp(tunFd: Int, isIpv6Enabled: Boolean) - companion object { const val FALLBACK_DUMMY_DNS_SERVER = "192.0.2.1" diff --git a/android/lib/talpid/src/main/kotlin/net/mullvad/talpid/model/CreateTunResult.kt b/android/lib/talpid/src/main/kotlin/net/mullvad/talpid/model/CreateTunResult.kt index 3cd73685f7..ef10dcd2f3 100644 --- a/android/lib/talpid/src/main/kotlin/net/mullvad/talpid/model/CreateTunResult.kt +++ b/android/lib/talpid/src/main/kotlin/net/mullvad/talpid/model/CreateTunResult.kt @@ -1,29 +1,38 @@ package net.mullvad.talpid.model import java.net.InetAddress +import java.util.ArrayList -sealed class CreateTunResult { - open val isOpen - get() = false +sealed interface CreateTunResult { + val isOpen: Boolean - class Success(val tunFd: Int) : CreateTunResult() { - override val isOpen - get() = true + data class Success(val tunFd: Int) : CreateTunResult { + override val isOpen = true } - class InvalidDnsServers(val addresses: ArrayList<InetAddress>, val tunFd: Int) : - CreateTunResult() { - override val isOpen - get() = true + sealed interface Error : CreateTunResult + + // Prepare errors + data object OtherLegacyAlwaysOnVpn : Error { + override val isOpen: Boolean = false } - // Establish error - data object TunnelDeviceError : CreateTunResult() + data class OtherAlwaysOnApp(val appName: String) : Error { + override val isOpen: Boolean = false + } - // Prepare errors - data object OtherLegacyAlwaysOnVpn : CreateTunResult() + data object NotPrepared : Error { + override val isOpen: Boolean = false + } - data class OtherAlwaysOnApp(val appName: String) : CreateTunResult() + // Establish error + data object EstablishError : Error { + override val isOpen: Boolean = false + } - data object NotPrepared : CreateTunResult() + data class InvalidDnsServers(val addresses: ArrayList<InetAddress>, val tunFd: Int) : Error { + constructor(address: List<InetAddress>, tunFd: Int) : this(ArrayList(address), tunFd) + + override val isOpen = true + } } diff --git a/android/lib/talpid/src/main/kotlin/net/mullvad/talpid/model/NetworkState.kt b/android/lib/talpid/src/main/kotlin/net/mullvad/talpid/model/NetworkState.kt new file mode 100644 index 0000000000..ca0b6db7e2 --- /dev/null +++ b/android/lib/talpid/src/main/kotlin/net/mullvad/talpid/model/NetworkState.kt @@ -0,0 +1,19 @@ +package net.mullvad.talpid.model + +import java.net.InetAddress + +data class NetworkState( + val networkHandle: Long, + val routes: ArrayList<RouteInfo>?, + val dnsServers: ArrayList<InetAddress>?, +) { + constructor( + networkHandle: Long, + routes: List<AndroidRouteInfo>?, + dnsServers: List<InetAddress>?, + ) : this( + networkHandle = networkHandle, + routes = routes?.map { it.toRoute() }?.let { ArrayList(it) }, + dnsServers = dnsServers?.let { ArrayList(it) }, + ) +} diff --git a/android/lib/talpid/src/main/kotlin/net/mullvad/talpid/model/RouteInfo.kt b/android/lib/talpid/src/main/kotlin/net/mullvad/talpid/model/RouteInfo.kt new file mode 100644 index 0000000000..a2b63b3ca7 --- /dev/null +++ b/android/lib/talpid/src/main/kotlin/net/mullvad/talpid/model/RouteInfo.kt @@ -0,0 +1,18 @@ +package net.mullvad.talpid.model + +import java.net.InetAddress + +typealias AndroidRouteInfo = android.net.RouteInfo + +data class RouteInfo( + val destination: InetNetwork, + val gateway: InetAddress?, + val interfaceName: String?, +) + +fun AndroidRouteInfo.toRoute() = + RouteInfo( + destination = InetNetwork(destination.address, destination.prefixLength.toShort()), + gateway = gateway, + interfaceName = `interface`, + ) diff --git a/android/lib/talpid/src/main/kotlin/net/mullvad/talpid/util/ConnectivityManagerUtil.kt b/android/lib/talpid/src/main/kotlin/net/mullvad/talpid/util/ConnectivityManagerUtil.kt index daf155c6e8..fddaa6fb88 100644 --- a/android/lib/talpid/src/main/kotlin/net/mullvad/talpid/util/ConnectivityManagerUtil.kt +++ b/android/lib/talpid/src/main/kotlin/net/mullvad/talpid/util/ConnectivityManagerUtil.kt @@ -10,59 +10,56 @@ import kotlinx.coroutines.channels.awaitClose import kotlinx.coroutines.channels.trySendBlocking import kotlinx.coroutines.flow.Flow import kotlinx.coroutines.flow.callbackFlow +import kotlinx.coroutines.flow.scan -fun ConnectivityManager.defaultNetworkFlow(): Flow<NetworkEvent> = - callbackFlow<NetworkEvent> { - val callback = - object : NetworkCallback() { - override fun onLinkPropertiesChanged( - network: Network, - linkProperties: LinkProperties, - ) { - super.onLinkPropertiesChanged(network, linkProperties) - trySendBlocking(NetworkEvent.LinkPropertiesChanged(network, linkProperties)) - } +internal fun ConnectivityManager.defaultNetworkEvents(): Flow<NetworkEvent> = callbackFlow { + val callback = + object : NetworkCallback() { + override fun onLinkPropertiesChanged(network: Network, linkProperties: LinkProperties) { + super.onLinkPropertiesChanged(network, linkProperties) + trySendBlocking(NetworkEvent.LinkPropertiesChanged(network, linkProperties)) + } - override fun onAvailable(network: Network) { - super.onAvailable(network) - trySendBlocking(NetworkEvent.Available(network)) - } + override fun onAvailable(network: Network) { + super.onAvailable(network) + trySendBlocking(NetworkEvent.Available(network)) + } - override fun onCapabilitiesChanged( - network: Network, - networkCapabilities: NetworkCapabilities, - ) { - super.onCapabilitiesChanged(network, networkCapabilities) - trySendBlocking(NetworkEvent.CapabilitiesChanged(network, networkCapabilities)) - } + override fun onCapabilitiesChanged( + network: Network, + networkCapabilities: NetworkCapabilities, + ) { + super.onCapabilitiesChanged(network, networkCapabilities) + trySendBlocking(NetworkEvent.CapabilitiesChanged(network, networkCapabilities)) + } - override fun onBlockedStatusChanged(network: Network, blocked: Boolean) { - super.onBlockedStatusChanged(network, blocked) - trySendBlocking(NetworkEvent.BlockedStatusChanged(network, blocked)) - } + override fun onBlockedStatusChanged(network: Network, blocked: Boolean) { + super.onBlockedStatusChanged(network, blocked) + trySendBlocking(NetworkEvent.BlockedStatusChanged(network, blocked)) + } - override fun onLosing(network: Network, maxMsToLive: Int) { - super.onLosing(network, maxMsToLive) - trySendBlocking(NetworkEvent.Losing(network, maxMsToLive)) - } + override fun onLosing(network: Network, maxMsToLive: Int) { + super.onLosing(network, maxMsToLive) + trySendBlocking(NetworkEvent.Losing(network, maxMsToLive)) + } - override fun onLost(network: Network) { - super.onLost(network) - trySendBlocking(NetworkEvent.Lost(network)) - } + override fun onLost(network: Network) { + super.onLost(network) + trySendBlocking(NetworkEvent.Lost(network)) + } - override fun onUnavailable() { - super.onUnavailable() - trySendBlocking(NetworkEvent.Unavailable) - } + override fun onUnavailable() { + super.onUnavailable() + trySendBlocking(NetworkEvent.Unavailable) } - registerDefaultNetworkCallback(callback) + } + registerDefaultNetworkCallback(callback) - awaitClose { unregisterNetworkCallback(callback) } - } + awaitClose { unregisterNetworkCallback(callback) } +} -fun ConnectivityManager.networkFlow(networkRequest: NetworkRequest): Flow<NetworkEvent> = - callbackFlow<NetworkEvent> { +fun ConnectivityManager.networkEvents(networkRequest: NetworkRequest): Flow<NetworkEvent> = + callbackFlow { val callback = object : NetworkCallback() { override fun onLinkPropertiesChanged( @@ -111,6 +108,26 @@ fun ConnectivityManager.networkFlow(networkRequest: NetworkRequest): Flow<Networ awaitClose { unregisterNetworkCallback(callback) } } +internal fun ConnectivityManager.defaultRawNetworkStateFlow(): Flow<RawNetworkState?> = + defaultNetworkEvents() + .scan( + null as RawNetworkState?, + { state, event -> + return@scan when (event) { + is NetworkEvent.Available -> RawNetworkState(network = event.network) + is NetworkEvent.BlockedStatusChanged -> + state?.copy(blockedStatus = event.blocked) + is NetworkEvent.CapabilitiesChanged -> + state?.copy(networkCapabilities = event.networkCapabilities) + is NetworkEvent.LinkPropertiesChanged -> + state?.copy(linkProperties = event.linkProperties) + is NetworkEvent.Losing -> state?.copy(maxMsToLive = event.maxMsToLive) + is NetworkEvent.Lost -> null + NetworkEvent.Unavailable -> null + } + }, + ) + sealed interface NetworkEvent { data class Available(val network: Network) : NetworkEvent @@ -130,3 +147,11 @@ sealed interface NetworkEvent { data class Lost(val network: Network) : NetworkEvent } + +internal data class RawNetworkState( + val network: Network, + val linkProperties: LinkProperties? = null, + val networkCapabilities: NetworkCapabilities? = null, + val blockedStatus: Boolean = false, + val maxMsToLive: Int? = null, +) |
