diff options
| author | David Göransson <david.goransson@mullvad.net> | 2025-01-13 11:51:09 +0100 |
|---|---|---|
| committer | David Göransson <david.goransson@mullvad.net> | 2025-02-06 11:02:59 +0100 |
| commit | 341c10ba38752bc36151b8998064e706f70d9ea6 (patch) | |
| tree | afb60c53e267eda0b033f346b64afd9035d7495a /android | |
| parent | 612aad8d8d2ae779a4e5e01e85b2848b4fc7de3c (diff) | |
| download | mullvadvpn-341c10ba38752bc36151b8998064e706f70d9ea6.tar.xz mullvadvpn-341c10ba38752bc36151b8998064e706f70d9ea6.zip | |
Replace old waitForTunnelUp function
After invoking VpnService.establish() we will get a tunnel file
descriptor that corresponds to the interface that was created. However,
this has no guarantee of the routing table beeing up to date, and we
might thus send traffic outside the tunnel. Previously this was done
through looking at the tunFd to see that traffic is sent to verify that
the routing table has changed. If no traffic is seen some traffic is
induced to a random IP address to ensure traffic can be seen. This new
implementation is slower but won't risk sending UDP traffic to a random
public address at the internet.
Diffstat (limited to 'android')
9 files changed, 436 insertions, 178 deletions
diff --git a/android/app/src/test/kotlin/net/mullvad/talpid/TalpidVpnServiceFallbackDnsTest.kt b/android/app/src/test/kotlin/net/mullvad/talpid/TalpidVpnServiceFallbackDnsTest.kt new file mode 100644 index 0000000000..27e7658a11 --- /dev/null +++ b/android/app/src/test/kotlin/net/mullvad/talpid/TalpidVpnServiceFallbackDnsTest.kt @@ -0,0 +1,146 @@ +package net.mullvad.talpid + +import android.net.VpnService +import android.os.ParcelFileDescriptor +import arrow.core.right +import io.mockk.MockKAnnotations +import io.mockk.coVerify +import io.mockk.every +import io.mockk.mockk +import io.mockk.mockkConstructor +import io.mockk.mockkStatic +import io.mockk.spyk +import java.net.InetAddress +import net.mullvad.mullvadvpn.lib.common.test.assertLists +import net.mullvad.mullvadvpn.lib.common.util.prepareVpnSafe +import net.mullvad.mullvadvpn.lib.model.Prepared +import net.mullvad.talpid.model.CreateTunResult +import net.mullvad.talpid.model.InetNetwork +import net.mullvad.talpid.model.TunConfig +import org.junit.jupiter.api.BeforeEach +import org.junit.jupiter.api.Test +import org.junit.jupiter.api.assertInstanceOf + +class TalpidVpnServiceFallbackDnsTest { + lateinit var talpidVpnService: TalpidVpnService + var builderMockk = mockk<VpnService.Builder>() + + @BeforeEach + fun setup() { + MockKAnnotations.init(this) + mockkStatic(VPN_SERVICE_EXTENSION) + + talpidVpnService = spyk<TalpidVpnService>(recordPrivateCalls = true) + every { talpidVpnService.prepareVpnSafe() } returns Prepared.right() + builderMockk = mockk<VpnService.Builder>() + + mockkConstructor(VpnService.Builder::class) + every { anyConstructed<VpnService.Builder>().setMtu(any()) } returns builderMockk + every { anyConstructed<VpnService.Builder>().setBlocking(any()) } returns builderMockk + every { anyConstructed<VpnService.Builder>().addAddress(any<InetAddress>(), any()) } returns + builderMockk + every { anyConstructed<VpnService.Builder>().addRoute(any<InetAddress>(), any()) } returns + builderMockk + every { + anyConstructed<VpnService.Builder>() + .addDnsServer(TalpidVpnService.FALLBACK_DUMMY_DNS_SERVER) + } returns builderMockk + val parcelFileDescriptor: ParcelFileDescriptor = mockk() + every { anyConstructed<VpnService.Builder>().establish() } returns parcelFileDescriptor + every { parcelFileDescriptor.detachFd() } returns 1 + } + + @Test + fun `opening tun with no DnsServers should add fallback DNS server`() { + val tunConfig = baseTunConfig.copy(dnsServers = arrayListOf()) + + val result = talpidVpnService.openTun(tunConfig) + + assertInstanceOf<CreateTunResult.Success>(result) + + // Fallback DNS server should be added if no DNS servers are provided + coVerify(exactly = 1) { + anyConstructed<VpnService.Builder>() + .addDnsServer(TalpidVpnService.FALLBACK_DUMMY_DNS_SERVER) + } + } + + @Test + fun `opening tun with all bad DnsServers should return InvalidDnsServers and add fallback`() { + val badDns1 = InetAddress.getByName("0.0.0.0") + val badDns2 = InetAddress.getByName("255.255.255.255") + every { anyConstructed<VpnService.Builder>().addDnsServer(badDns1) } throws + IllegalArgumentException() + every { anyConstructed<VpnService.Builder>().addDnsServer(badDns2) } throws + IllegalArgumentException() + + val tunConfig = baseTunConfig.copy(dnsServers = arrayListOf(badDns1, badDns2)) + val result = talpidVpnService.openTun(tunConfig) + + assertInstanceOf<CreateTunResult.InvalidDnsServers>(result) + assertLists(tunConfig.dnsServers, result.addresses) + // Fallback DNS server should be added if no valid DNS servers are provided + coVerify(exactly = 1) { + anyConstructed<VpnService.Builder>() + .addDnsServer(TalpidVpnService.FALLBACK_DUMMY_DNS_SERVER) + } + } + + @Test + fun `opening tun with 1 good and 1 bad DnsServers should return InvalidDnsServers`() { + val goodDnsServer = InetAddress.getByName("1.1.1.1") + val badDns = InetAddress.getByName("255.255.255.255") + every { anyConstructed<VpnService.Builder>().addDnsServer(goodDnsServer) } returns + builderMockk + every { anyConstructed<VpnService.Builder>().addDnsServer(badDns) } throws + IllegalArgumentException() + + val tunConfig = baseTunConfig.copy(dnsServers = arrayListOf(goodDnsServer, badDns)) + val result = talpidVpnService.openTun(tunConfig) + + assertInstanceOf<CreateTunResult.InvalidDnsServers>(result) + assertLists(arrayListOf(badDns), result.addresses) + + // Fallback DNS server should not be added since we have 1 good DNS server + coVerify(exactly = 0) { + anyConstructed<VpnService.Builder>() + .addDnsServer(TalpidVpnService.FALLBACK_DUMMY_DNS_SERVER) + } + } + + @Test + fun `providing good dns servers should not add the fallback dns and return success`() { + val goodDnsServer = InetAddress.getByName("1.1.1.1") + every { anyConstructed<VpnService.Builder>().addDnsServer(goodDnsServer) } returns + builderMockk + + val tunConfig = baseTunConfig.copy(dnsServers = arrayListOf(goodDnsServer)) + val result = talpidVpnService.openTun(tunConfig) + + assertInstanceOf<CreateTunResult.Success>(result) + + // Fallback DNS server should not be added since we have good DNS servers. + coVerify(exactly = 0) { + anyConstructed<VpnService.Builder>() + .addDnsServer(TalpidVpnService.FALLBACK_DUMMY_DNS_SERVER) + } + } + + companion object { + private const val VPN_SERVICE_EXTENSION = + "net.mullvad.mullvadvpn.lib.common.util.VpnServiceUtilsKt" + + val baseTunConfig = + TunConfig( + addresses = arrayListOf(InetAddress.getByName("45.83.223.209")), + dnsServers = arrayListOf(), + routes = + arrayListOf( + InetNetwork(InetAddress.getByName("0.0.0.0"), 0), + InetNetwork(InetAddress.getByName("::"), 0), + ), + mtu = 1280, + excludedPackages = arrayListOf(), + ) + } +} diff --git a/android/lib/common/src/main/kotlin/net/mullvad/mullvadvpn/lib/common/util/VpnServiceUtils.kt b/android/lib/common/src/main/kotlin/net/mullvad/mullvadvpn/lib/common/util/VpnServiceUtils.kt index 59833cb396..06c862936b 100644 --- a/android/lib/common/src/main/kotlin/net/mullvad/mullvadvpn/lib/common/util/VpnServiceUtils.kt +++ b/android/lib/common/src/main/kotlin/net/mullvad/mullvadvpn/lib/common/util/VpnServiceUtils.kt @@ -2,10 +2,14 @@ package net.mullvad.mullvadvpn.lib.common.util import android.content.Context import android.content.Intent +import android.net.VpnService import android.net.VpnService.prepare +import android.os.ParcelFileDescriptor import arrow.core.Either -import arrow.core.flatten +import arrow.core.flatMap import arrow.core.left +import arrow.core.raise.either +import arrow.core.raise.ensureNotNull import arrow.core.right import co.touchlab.kermit.Logger import net.mullvad.mullvadvpn.lib.common.util.SdkUtils.getInstalledPackagesList @@ -13,6 +17,8 @@ import net.mullvad.mullvadvpn.lib.model.PrepareError import net.mullvad.mullvadvpn.lib.model.Prepared /** + * Prepare to establish a VPN connection safely. + * * Invoking VpnService.prepare() can result in 3 out comes: * 1. IllegalStateException - There is a legacy VPN profile marked as always on * 2. Intent @@ -34,7 +40,7 @@ fun Context.prepareVpnSafe(): Either<PrepareError, Prepared> = else -> throw it } } - .map { intent -> + .flatMap { intent -> if (intent == null) { Prepared.right() } else { @@ -46,7 +52,6 @@ fun Context.prepareVpnSafe(): Either<PrepareError, Prepared> = } } } - .flatten() fun Context.getAlwaysOnVpnAppName(): String? { return resolveAlwaysOnVpnPackageName() @@ -59,3 +64,38 @@ fun Context.getAlwaysOnVpnAppName(): String? { ?.loadLabel(packageManager) ?.toString() } + +/** + * Establish a VPN connection safely. + * + * This function wraps the [VpnService.Builder.establish] function and catches any exceptions that + * may be thrown and type them to a more specific error. + * + * @return [ParcelFileDescriptor] if successful, [EstablishError] otherwise + */ +fun VpnService.Builder.establishSafe(): Either<EstablishError, ParcelFileDescriptor> = either { + val vpnInterfaceFd = + Either.catch { establish() } + .mapLeft { + when (it) { + is IllegalStateException -> EstablishError.ParameterNotApplied(it) + is IllegalArgumentException -> EstablishError.ParameterNotAccepted(it) + else -> EstablishError.UnknownError(it) + } + } + .bind() + + ensureNotNull(vpnInterfaceFd) { EstablishError.NullVpnInterface } + + vpnInterfaceFd +} + +sealed interface EstablishError { + data class ParameterNotApplied(val exception: IllegalStateException) : EstablishError + + data class ParameterNotAccepted(val exception: IllegalArgumentException) : EstablishError + + data object NullVpnInterface : EstablishError + + data class UnknownError(val error: Throwable) : EstablishError +} diff --git a/android/lib/daemon-grpc/src/main/kotlin/net/mullvad/mullvadvpn/lib/daemon/grpc/mapper/ToDomain.kt b/android/lib/daemon-grpc/src/main/kotlin/net/mullvad/mullvadvpn/lib/daemon/grpc/mapper/ToDomain.kt index daa04fc8d9..fe4cf11881 100644 --- a/android/lib/daemon-grpc/src/main/kotlin/net/mullvad/mullvadvpn/lib/daemon/grpc/mapper/ToDomain.kt +++ b/android/lib/daemon-grpc/src/main/kotlin/net/mullvad/mullvadvpn/lib/daemon/grpc/mapper/ToDomain.kt @@ -36,9 +36,6 @@ import net.mullvad.mullvadvpn.lib.model.DnsState import net.mullvad.mullvadvpn.lib.model.Endpoint import net.mullvad.mullvadvpn.lib.model.ErrorState import net.mullvad.mullvadvpn.lib.model.ErrorStateCause -import net.mullvad.mullvadvpn.lib.model.ErrorStateCause.AuthFailed -import net.mullvad.mullvadvpn.lib.model.ErrorStateCause.OtherAlwaysOnApp -import net.mullvad.mullvadvpn.lib.model.ErrorStateCause.TunnelParameterError import net.mullvad.mullvadvpn.lib.model.FeatureIndicator import net.mullvad.mullvadvpn.lib.model.GeoIpLocation import net.mullvad.mullvadvpn.lib.model.GeoLocationId @@ -125,7 +122,7 @@ private fun ManagementInterface.TunnelState.Error.toDomain(): TunnelState.Error val otherAlwaysOnAppError = errorState.let { if (it.hasOtherAlwaysOnAppError()) { - OtherAlwaysOnApp(it.otherAlwaysOnAppError.appName) + ErrorStateCause.OtherAlwaysOnApp(it.otherAlwaysOnAppError.appName) } else { null } @@ -238,7 +235,7 @@ internal fun ManagementInterface.ErrorState.toDomain( cause = when (cause!!) { ManagementInterface.ErrorState.Cause.AUTH_FAILED -> - AuthFailed(authFailedError.toDomain()) + ErrorStateCause.AuthFailed(authFailedError.toDomain()) ManagementInterface.ErrorState.Cause.IPV6_UNAVAILABLE -> ErrorStateCause.Ipv6Unavailable ManagementInterface.ErrorState.Cause.SET_FIREWALL_POLICY_ERROR -> @@ -247,7 +244,7 @@ internal fun ManagementInterface.ErrorState.toDomain( ManagementInterface.ErrorState.Cause.START_TUNNEL_ERROR -> ErrorStateCause.StartTunnelError ManagementInterface.ErrorState.Cause.TUNNEL_PARAMETER_ERROR -> - TunnelParameterError(parameterError.toDomain()) + ErrorStateCause.TunnelParameterError(parameterError.toDomain()) ManagementInterface.ErrorState.Cause.IS_OFFLINE -> ErrorStateCause.IsOffline ManagementInterface.ErrorState.Cause.SPLIT_TUNNEL_ERROR -> ErrorStateCause.StartTunnelError @@ -255,7 +252,6 @@ internal fun ManagementInterface.ErrorState.toDomain( ManagementInterface.ErrorState.Cause.NEED_FULL_DISK_PERMISSIONS, ManagementInterface.ErrorState.Cause.CREATE_TUNNEL_DEVICE -> throw IllegalArgumentException("Unrecognized error state cause") - ManagementInterface.ErrorState.Cause.NOT_PREPARED -> ErrorStateCause.NotPrepared ManagementInterface.ErrorState.Cause.OTHER_ALWAYS_ON_APP -> otherAlwaysOnApp!! ManagementInterface.ErrorState.Cause.OTHER_LEGACY_ALWAYS_ON_VPN -> diff --git a/android/lib/talpid/src/main/kotlin/net/mullvad/talpid/ConnectivityListener.kt b/android/lib/talpid/src/main/kotlin/net/mullvad/talpid/ConnectivityListener.kt index 86b27e3ba8..fdee5039ad 100644 --- a/android/lib/talpid/src/main/kotlin/net/mullvad/talpid/ConnectivityListener.kt +++ b/android/lib/talpid/src/main/kotlin/net/mullvad/talpid/ConnectivityListener.kt @@ -7,34 +7,48 @@ import android.net.NetworkCapabilities import android.net.NetworkRequest import co.touchlab.kermit.Logger import java.net.InetAddress +import kotlin.collections.ArrayList import kotlinx.coroutines.CoroutineScope import kotlinx.coroutines.flow.Flow import kotlinx.coroutines.flow.SharingStarted import kotlinx.coroutines.flow.StateFlow import kotlinx.coroutines.flow.distinctUntilChanged -import kotlinx.coroutines.flow.filterIsInstance import kotlinx.coroutines.flow.map import kotlinx.coroutines.flow.onEach import kotlinx.coroutines.flow.scan import kotlinx.coroutines.flow.stateIn +import net.mullvad.talpid.model.NetworkState import net.mullvad.talpid.util.NetworkEvent -import net.mullvad.talpid.util.defaultNetworkFlow -import net.mullvad.talpid.util.networkFlow +import net.mullvad.talpid.util.RawNetworkState +import net.mullvad.talpid.util.defaultRawNetworkStateFlow +import net.mullvad.talpid.util.networkEvents -class ConnectivityListener(val connectivityManager: ConnectivityManager) { +class ConnectivityListener(private val connectivityManager: ConnectivityManager) { private lateinit var _isConnected: StateFlow<Boolean> // Used by JNI val isConnected get() = _isConnected.value - private lateinit var _currentDnsServers: StateFlow<List<InetAddress>> + private lateinit var _currentNetworkState: StateFlow<NetworkState?> + + // Used by JNI + val currentDefaultNetworkState: NetworkState? + get() = _currentNetworkState.value + // Used by JNI - val currentDnsServers - get() = ArrayList(_currentDnsServers.value) + val currentDnsServers: ArrayList<InetAddress> + get() = _currentNetworkState.value?.dnsServers ?: ArrayList() fun register(scope: CoroutineScope) { - _currentDnsServers = - dnsServerChanges().stateIn(scope, SharingStarted.Eagerly, currentDnsServers()) + // Consider implementing retry logic for the flows below, because registering a listener on + // the default network may fail if the network on Android 11 + // https://issuetracker.google.com/issues/175055271?pli=1 + _currentNetworkState = + connectivityManager + .defaultRawNetworkStateFlow() + .map { it?.toNetworkState() } + .onEach { notifyDefaultNetworkChange(it) } + .stateIn(scope, SharingStarted.Eagerly, null) _isConnected = hasInternetCapability() @@ -42,18 +56,6 @@ class ConnectivityListener(val connectivityManager: ConnectivityManager) { .stateIn(scope, SharingStarted.Eagerly, false) } - private fun dnsServerChanges(): Flow<List<InetAddress>> = - connectivityManager - .defaultNetworkFlow() - .filterIsInstance<NetworkEvent.LinkPropertiesChanged>() - .onEach { Logger.d("Link properties changed") } - .map { it.linkProperties.dnsServersWithoutFallback() } - - private fun currentDnsServers(): List<InetAddress> = - connectivityManager - .getLinkProperties(connectivityManager.activeNetwork) - ?.dnsServersWithoutFallback() ?: emptyList() - private fun LinkProperties.dnsServersWithoutFallback(): List<InetAddress> = dnsServers.filter { it.hostAddress != TalpidVpnService.FALLBACK_DUMMY_DNS_SERVER } @@ -65,7 +67,7 @@ class ConnectivityListener(val connectivityManager: ConnectivityManager) { .build() return connectivityManager - .networkFlow(request) + .networkEvents(request) .scan(setOf<Network>()) { networks, event -> when (event) { is NetworkEvent.Available -> { @@ -87,5 +89,14 @@ class ConnectivityListener(val connectivityManager: ConnectivityManager) { .distinctUntilChanged() } + private fun RawNetworkState.toNetworkState(): NetworkState = + NetworkState( + network.networkHandle, + linkProperties?.routes, + linkProperties?.dnsServersWithoutFallback(), + ) + private external fun notifyConnectivityChange(isConnected: Boolean) + + private external fun notifyDefaultNetworkChange(networkState: NetworkState?) } diff --git a/android/lib/talpid/src/main/kotlin/net/mullvad/talpid/TalpidVpnService.kt b/android/lib/talpid/src/main/kotlin/net/mullvad/talpid/TalpidVpnService.kt index 74d44005cd..a143df6132 100644 --- a/android/lib/talpid/src/main/kotlin/net/mullvad/talpid/TalpidVpnService.kt +++ b/android/lib/talpid/src/main/kotlin/net/mullvad/talpid/TalpidVpnService.kt @@ -1,18 +1,29 @@ package net.mullvad.talpid import android.net.ConnectivityManager +import android.net.VpnService import android.os.ParcelFileDescriptor import androidx.annotation.CallSuper import androidx.core.content.getSystemService import androidx.lifecycle.lifecycleScope +import arrow.core.Either +import arrow.core.mapOrAccumulate +import arrow.core.merge +import arrow.core.raise.either import co.touchlab.kermit.Logger import java.net.Inet4Address import java.net.Inet6Address import java.net.InetAddress import kotlin.properties.Delegates.observable +import net.mullvad.mullvadvpn.lib.common.util.establishSafe import net.mullvad.mullvadvpn.lib.common.util.prepareVpnSafe import net.mullvad.mullvadvpn.lib.model.PrepareError import net.mullvad.talpid.model.CreateTunResult +import net.mullvad.talpid.model.CreateTunResult.EstablishError +import net.mullvad.talpid.model.CreateTunResult.InvalidDnsServers +import net.mullvad.talpid.model.CreateTunResult.NotPrepared +import net.mullvad.talpid.model.CreateTunResult.OtherAlwaysOnApp +import net.mullvad.talpid.model.CreateTunResult.OtherLegacyAlwaysOnVpn import net.mullvad.talpid.model.TunConfig import net.mullvad.talpid.util.TalpidSdkUtils.setMeteredIfSupported @@ -22,7 +33,7 @@ open class TalpidVpnService : LifecycleVpnService() { val oldTunFd = when (oldTunStatus) { is CreateTunResult.Success -> oldTunStatus.tunFd - is CreateTunResult.InvalidDnsServers -> oldTunStatus.tunFd + is InvalidDnsServers -> oldTunStatus.tunFd else -> null } @@ -43,26 +54,30 @@ open class TalpidVpnService : LifecycleVpnService() { connectivityListener.register(lifecycleScope) } - fun openTun(config: TunConfig): CreateTunResult { + // Used by JNI + fun openTun(config: TunConfig): CreateTunResult = synchronized(this) { val tunStatus = activeTunStatus if (config == currentTunConfig && tunStatus != null && tunStatus.isOpen) { - return tunStatus + tunStatus } else { - return openTunImpl(config) + openTunImpl(config) } } - } - fun openTunForced(config: TunConfig): CreateTunResult { - synchronized(this) { - return openTunImpl(config) - } - } + // Used by JNI + fun openTunForced(config: TunConfig): CreateTunResult = + synchronized(this) { openTunImpl(config) } + + // Used by JNI + fun closeTun(): Unit = synchronized(this) { activeTunStatus = null } + + // Used by JNI + fun bypass(socket: Int): Boolean = protect(socket) private fun openTunImpl(config: TunConfig): CreateTunResult { - val newTunStatus = createTun(config) + val newTunStatus = createTun(config).merge() currentTunConfig = config activeTunStatus = newTunStatus @@ -70,95 +85,76 @@ open class TalpidVpnService : LifecycleVpnService() { return newTunStatus } - fun closeTun() { - synchronized(this) { activeTunStatus = null } - } - - // DROID-1407 - // Function is to be cleaned up and lint suppression to be removed. - @Suppress("ReturnCount") - private fun createTun(config: TunConfig): CreateTunResult { - prepareVpnSafe() - .mapLeft { it.toCreateTunResult() } - .onLeft { - return it - } - - val invalidDnsServerAddresses = ArrayList<InetAddress>() + private fun createTun( + config: TunConfig + ): Either<CreateTunResult.Error, CreateTunResult.Success> = either { + prepareVpnSafe().mapLeft { it.toCreateTunError() }.bind() - val builder = - Builder().apply { - for (address in config.addresses) { - addAddress(address, address.prefixLength()) - } + val builder = Builder() + builder.setMtu(config.mtu) + builder.setBlocking(false) + builder.setMeteredIfSupported(false) - for (dnsServer in config.dnsServers) { - try { - addDnsServer(dnsServer) - } catch (exception: IllegalArgumentException) { - invalidDnsServerAddresses.add(dnsServer) - } - } + config.addresses.forEach { builder.addAddress(it, it.prefixLength()) } + config.routes.forEach { builder.addRoute(it.address, it.prefixLength.toInt()) } + config.excludedPackages.forEach { app -> builder.addDisallowedApplication(app) } - // Avoids creating a tunnel with no DNS servers or if all DNS servers was invalid, - // since apps then may leak DNS requests. - // https://issuetracker.google.com/issues/337961996 - if (invalidDnsServerAddresses.size == config.dnsServers.size) { - Logger.w( - "All DNS servers invalid or non set, using fallback DNS server to " + - "minimize leaks, dnsServers.isEmpty(): ${config.dnsServers.isEmpty()}" - ) - addDnsServer(FALLBACK_DUMMY_DNS_SERVER) - } - - for (route in config.routes) { - addRoute(route.address, route.prefixLength.toInt()) - } - - config.excludedPackages.forEach { app -> addDisallowedApplication(app) } - setMtu(config.mtu) - setBlocking(false) - setMeteredIfSupported(false) - } - - val vpnInterfaceFd = - try { - builder.establish() - } catch (e: IllegalStateException) { - Logger.e("Failed to establish, a parameter could not be applied", e) - return CreateTunResult.TunnelDeviceError - } catch (e: IllegalArgumentException) { - Logger.e("Failed to establish a parameter was not accepted", e) - return CreateTunResult.TunnelDeviceError + // We don't care if adding DNS servers fails at this point, since we can still create a + // tunnel to consume traffic and then notify daemon to later enter blocked state. + val dnsConfigureResult = + config.dnsServers.mapOrAccumulate { + builder.addDnsServerSafe(it).bind() + Unit } - if (vpnInterfaceFd == null) { - Logger.e("VpnInterface returned null") - return CreateTunResult.TunnelDeviceError + // Never create a tunnel where all DNS servers are invalid or if none was ever set, since + // apps then may leak DNS requests. + // https://issuetracker.google.com/issues/337961996 + val shouldAddFallbackDns = + dnsConfigureResult.fold( + { invalidDnsServers -> invalidDnsServers.size == config.dnsServers.size }, + { addedDnsServers -> addedDnsServers.isEmpty() }, + ) + if (shouldAddFallbackDns) { + Logger.w( + "All DNS servers invalid or non set, using fallback DNS server to " + + "minimize leaks, dnsServers.isEmpty(): ${config.dnsServers.isEmpty()}" + ) + builder.addDnsServer(FALLBACK_DUMMY_DNS_SERVER) } - val tunFd = vpnInterfaceFd.detachFd() - - waitForTunnelUp(tunFd, config.routes.any { route -> route.isIpv6 }) + val vpnInterfaceFd = + builder + .establishSafe() + .onLeft { Logger.w("Failed to establish tunnel $it") } + .mapLeft { EstablishError } + .bind() - if (invalidDnsServerAddresses.isNotEmpty()) { - return CreateTunResult.InvalidDnsServers(invalidDnsServerAddresses, tunFd) - } + val tunFd = vpnInterfaceFd.detachFd() - return CreateTunResult.Success(tunFd) - } + dnsConfigureResult.mapLeft { InvalidDnsServers(it, tunFd) }.bind() - fun bypass(socket: Int): Boolean { - return protect(socket) + CreateTunResult.Success(tunFd) } - private fun PrepareError.toCreateTunResult() = + private fun PrepareError.toCreateTunError() = when (this) { - is PrepareError.OtherLegacyAlwaysOnVpn -> CreateTunResult.OtherLegacyAlwaysOnVpn - is PrepareError.NotPrepared -> CreateTunResult.NotPrepared - is PrepareError.OtherAlwaysOnApp -> CreateTunResult.OtherAlwaysOnApp(appName) + is PrepareError.OtherLegacyAlwaysOnVpn -> OtherLegacyAlwaysOnVpn + is PrepareError.NotPrepared -> NotPrepared + is PrepareError.OtherAlwaysOnApp -> OtherAlwaysOnApp(appName) } + private fun Builder.addDnsServerSafe( + dnsServer: InetAddress + ): Either<InetAddress, VpnService.Builder> = + Either.catch { addDnsServer(dnsServer) } + .mapLeft { + when (it) { + is IllegalArgumentException -> dnsServer + else -> throw it + } + } + private fun InetAddress.prefixLength(): Int = when (this) { is Inet4Address -> IPV4_PREFIX_LENGTH @@ -166,8 +162,6 @@ open class TalpidVpnService : LifecycleVpnService() { else -> throw IllegalArgumentException("Invalid IP address (not IPv4 nor IPv6)") } - private external fun waitForTunnelUp(tunFd: Int, isIpv6Enabled: Boolean) - companion object { const val FALLBACK_DUMMY_DNS_SERVER = "192.0.2.1" diff --git a/android/lib/talpid/src/main/kotlin/net/mullvad/talpid/model/CreateTunResult.kt b/android/lib/talpid/src/main/kotlin/net/mullvad/talpid/model/CreateTunResult.kt index 3cd73685f7..ef10dcd2f3 100644 --- a/android/lib/talpid/src/main/kotlin/net/mullvad/talpid/model/CreateTunResult.kt +++ b/android/lib/talpid/src/main/kotlin/net/mullvad/talpid/model/CreateTunResult.kt @@ -1,29 +1,38 @@ package net.mullvad.talpid.model import java.net.InetAddress +import java.util.ArrayList -sealed class CreateTunResult { - open val isOpen - get() = false +sealed interface CreateTunResult { + val isOpen: Boolean - class Success(val tunFd: Int) : CreateTunResult() { - override val isOpen - get() = true + data class Success(val tunFd: Int) : CreateTunResult { + override val isOpen = true } - class InvalidDnsServers(val addresses: ArrayList<InetAddress>, val tunFd: Int) : - CreateTunResult() { - override val isOpen - get() = true + sealed interface Error : CreateTunResult + + // Prepare errors + data object OtherLegacyAlwaysOnVpn : Error { + override val isOpen: Boolean = false } - // Establish error - data object TunnelDeviceError : CreateTunResult() + data class OtherAlwaysOnApp(val appName: String) : Error { + override val isOpen: Boolean = false + } - // Prepare errors - data object OtherLegacyAlwaysOnVpn : CreateTunResult() + data object NotPrepared : Error { + override val isOpen: Boolean = false + } - data class OtherAlwaysOnApp(val appName: String) : CreateTunResult() + // Establish error + data object EstablishError : Error { + override val isOpen: Boolean = false + } - data object NotPrepared : CreateTunResult() + data class InvalidDnsServers(val addresses: ArrayList<InetAddress>, val tunFd: Int) : Error { + constructor(address: List<InetAddress>, tunFd: Int) : this(ArrayList(address), tunFd) + + override val isOpen = true + } } diff --git a/android/lib/talpid/src/main/kotlin/net/mullvad/talpid/model/NetworkState.kt b/android/lib/talpid/src/main/kotlin/net/mullvad/talpid/model/NetworkState.kt new file mode 100644 index 0000000000..ca0b6db7e2 --- /dev/null +++ b/android/lib/talpid/src/main/kotlin/net/mullvad/talpid/model/NetworkState.kt @@ -0,0 +1,19 @@ +package net.mullvad.talpid.model + +import java.net.InetAddress + +data class NetworkState( + val networkHandle: Long, + val routes: ArrayList<RouteInfo>?, + val dnsServers: ArrayList<InetAddress>?, +) { + constructor( + networkHandle: Long, + routes: List<AndroidRouteInfo>?, + dnsServers: List<InetAddress>?, + ) : this( + networkHandle = networkHandle, + routes = routes?.map { it.toRoute() }?.let { ArrayList(it) }, + dnsServers = dnsServers?.let { ArrayList(it) }, + ) +} diff --git a/android/lib/talpid/src/main/kotlin/net/mullvad/talpid/model/RouteInfo.kt b/android/lib/talpid/src/main/kotlin/net/mullvad/talpid/model/RouteInfo.kt new file mode 100644 index 0000000000..a2b63b3ca7 --- /dev/null +++ b/android/lib/talpid/src/main/kotlin/net/mullvad/talpid/model/RouteInfo.kt @@ -0,0 +1,18 @@ +package net.mullvad.talpid.model + +import java.net.InetAddress + +typealias AndroidRouteInfo = android.net.RouteInfo + +data class RouteInfo( + val destination: InetNetwork, + val gateway: InetAddress?, + val interfaceName: String?, +) + +fun AndroidRouteInfo.toRoute() = + RouteInfo( + destination = InetNetwork(destination.address, destination.prefixLength.toShort()), + gateway = gateway, + interfaceName = `interface`, + ) diff --git a/android/lib/talpid/src/main/kotlin/net/mullvad/talpid/util/ConnectivityManagerUtil.kt b/android/lib/talpid/src/main/kotlin/net/mullvad/talpid/util/ConnectivityManagerUtil.kt index daf155c6e8..fddaa6fb88 100644 --- a/android/lib/talpid/src/main/kotlin/net/mullvad/talpid/util/ConnectivityManagerUtil.kt +++ b/android/lib/talpid/src/main/kotlin/net/mullvad/talpid/util/ConnectivityManagerUtil.kt @@ -10,59 +10,56 @@ import kotlinx.coroutines.channels.awaitClose import kotlinx.coroutines.channels.trySendBlocking import kotlinx.coroutines.flow.Flow import kotlinx.coroutines.flow.callbackFlow +import kotlinx.coroutines.flow.scan -fun ConnectivityManager.defaultNetworkFlow(): Flow<NetworkEvent> = - callbackFlow<NetworkEvent> { - val callback = - object : NetworkCallback() { - override fun onLinkPropertiesChanged( - network: Network, - linkProperties: LinkProperties, - ) { - super.onLinkPropertiesChanged(network, linkProperties) - trySendBlocking(NetworkEvent.LinkPropertiesChanged(network, linkProperties)) - } +internal fun ConnectivityManager.defaultNetworkEvents(): Flow<NetworkEvent> = callbackFlow { + val callback = + object : NetworkCallback() { + override fun onLinkPropertiesChanged(network: Network, linkProperties: LinkProperties) { + super.onLinkPropertiesChanged(network, linkProperties) + trySendBlocking(NetworkEvent.LinkPropertiesChanged(network, linkProperties)) + } - override fun onAvailable(network: Network) { - super.onAvailable(network) - trySendBlocking(NetworkEvent.Available(network)) - } + override fun onAvailable(network: Network) { + super.onAvailable(network) + trySendBlocking(NetworkEvent.Available(network)) + } - override fun onCapabilitiesChanged( - network: Network, - networkCapabilities: NetworkCapabilities, - ) { - super.onCapabilitiesChanged(network, networkCapabilities) - trySendBlocking(NetworkEvent.CapabilitiesChanged(network, networkCapabilities)) - } + override fun onCapabilitiesChanged( + network: Network, + networkCapabilities: NetworkCapabilities, + ) { + super.onCapabilitiesChanged(network, networkCapabilities) + trySendBlocking(NetworkEvent.CapabilitiesChanged(network, networkCapabilities)) + } - override fun onBlockedStatusChanged(network: Network, blocked: Boolean) { - super.onBlockedStatusChanged(network, blocked) - trySendBlocking(NetworkEvent.BlockedStatusChanged(network, blocked)) - } + override fun onBlockedStatusChanged(network: Network, blocked: Boolean) { + super.onBlockedStatusChanged(network, blocked) + trySendBlocking(NetworkEvent.BlockedStatusChanged(network, blocked)) + } - override fun onLosing(network: Network, maxMsToLive: Int) { - super.onLosing(network, maxMsToLive) - trySendBlocking(NetworkEvent.Losing(network, maxMsToLive)) - } + override fun onLosing(network: Network, maxMsToLive: Int) { + super.onLosing(network, maxMsToLive) + trySendBlocking(NetworkEvent.Losing(network, maxMsToLive)) + } - override fun onLost(network: Network) { - super.onLost(network) - trySendBlocking(NetworkEvent.Lost(network)) - } + override fun onLost(network: Network) { + super.onLost(network) + trySendBlocking(NetworkEvent.Lost(network)) + } - override fun onUnavailable() { - super.onUnavailable() - trySendBlocking(NetworkEvent.Unavailable) - } + override fun onUnavailable() { + super.onUnavailable() + trySendBlocking(NetworkEvent.Unavailable) } - registerDefaultNetworkCallback(callback) + } + registerDefaultNetworkCallback(callback) - awaitClose { unregisterNetworkCallback(callback) } - } + awaitClose { unregisterNetworkCallback(callback) } +} -fun ConnectivityManager.networkFlow(networkRequest: NetworkRequest): Flow<NetworkEvent> = - callbackFlow<NetworkEvent> { +fun ConnectivityManager.networkEvents(networkRequest: NetworkRequest): Flow<NetworkEvent> = + callbackFlow { val callback = object : NetworkCallback() { override fun onLinkPropertiesChanged( @@ -111,6 +108,26 @@ fun ConnectivityManager.networkFlow(networkRequest: NetworkRequest): Flow<Networ awaitClose { unregisterNetworkCallback(callback) } } +internal fun ConnectivityManager.defaultRawNetworkStateFlow(): Flow<RawNetworkState?> = + defaultNetworkEvents() + .scan( + null as RawNetworkState?, + { state, event -> + return@scan when (event) { + is NetworkEvent.Available -> RawNetworkState(network = event.network) + is NetworkEvent.BlockedStatusChanged -> + state?.copy(blockedStatus = event.blocked) + is NetworkEvent.CapabilitiesChanged -> + state?.copy(networkCapabilities = event.networkCapabilities) + is NetworkEvent.LinkPropertiesChanged -> + state?.copy(linkProperties = event.linkProperties) + is NetworkEvent.Losing -> state?.copy(maxMsToLive = event.maxMsToLive) + is NetworkEvent.Lost -> null + NetworkEvent.Unavailable -> null + } + }, + ) + sealed interface NetworkEvent { data class Available(val network: Network) : NetworkEvent @@ -130,3 +147,11 @@ sealed interface NetworkEvent { data class Lost(val network: Network) : NetworkEvent } + +internal data class RawNetworkState( + val network: Network, + val linkProperties: LinkProperties? = null, + val networkCapabilities: NetworkCapabilities? = null, + val blockedStatus: Boolean = false, + val maxMsToLive: Int? = null, +) |
