diff options
Diffstat (limited to 'android/lib/billing/src')
14 files changed, 1289 insertions, 0 deletions
diff --git a/android/lib/billing/src/androidTest/kotlin/net/mullvad/mullvadvpn/lib/billing/BillingRepositoryTest.kt b/android/lib/billing/src/androidTest/kotlin/net/mullvad/mullvadvpn/lib/billing/BillingRepositoryTest.kt new file mode 100644 index 0000000000..85982007b8 --- /dev/null +++ b/android/lib/billing/src/androidTest/kotlin/net/mullvad/mullvadvpn/lib/billing/BillingRepositoryTest.kt @@ -0,0 +1,388 @@ +package net.mullvad.mullvadvpn.lib.billing + +import android.app.Activity +import android.content.Context +import app.cash.turbine.test +import com.android.billingclient.api.BillingClient +import com.android.billingclient.api.BillingClient.BillingResponseCode +import com.android.billingclient.api.BillingClientStateListener +import com.android.billingclient.api.BillingFlowParams +import com.android.billingclient.api.BillingResult +import com.android.billingclient.api.ProductDetails +import com.android.billingclient.api.ProductDetailsResult +import com.android.billingclient.api.Purchase +import com.android.billingclient.api.PurchasesResult +import com.android.billingclient.api.PurchasesUpdatedListener +import com.android.billingclient.api.QueryPurchasesParams +import com.android.billingclient.api.queryProductDetails +import com.android.billingclient.api.queryPurchasesAsync +import io.mockk.CapturingSlot +import io.mockk.coEvery +import io.mockk.coVerify +import io.mockk.every +import io.mockk.mockk +import io.mockk.mockkStatic +import io.mockk.unmockkAll +import io.mockk.verify +import kotlin.test.assertEquals +import kotlin.test.assertIs +import kotlinx.coroutines.ExperimentalCoroutinesApi +import kotlinx.coroutines.launch +import kotlinx.coroutines.test.UnconfinedTestDispatcher +import kotlinx.coroutines.test.runTest +import net.mullvad.mullvadvpn.lib.billing.model.BillingException +import net.mullvad.mullvadvpn.lib.billing.model.PurchaseEvent +import net.mullvad.mullvadvpn.lib.common.test.TestCoroutineRule +import net.mullvad.mullvadvpn.lib.common.test.assertLists +import org.junit.After +import org.junit.Before +import org.junit.Rule +import org.junit.Test +import org.koin.core.context.startKoin +import org.koin.core.context.stopKoin +import org.koin.dsl.module + +class BillingRepositoryTest { + @get:Rule val testCoroutineRule = TestCoroutineRule() + + private val mockContext: Context = mockk() + private lateinit var billingRepository: BillingRepository + + private val mockBillingClientBuilder: BillingClient.Builder = mockk(relaxed = true) + private val mockBillingClient: BillingClient = mockk() + + private val purchaseUpdatedListenerSlot: CapturingSlot<PurchasesUpdatedListener> = + CapturingSlot() + + @Before + fun setUp() { + startKoin { modules(module { single { mockk<Activity>() } }) } + + mockkStatic(BILLING_CLIENT_CLASS) + mockkStatic(BILLING_CLIENT_KOTLIN_CLASS) + mockkStatic(BILLING_FLOW_PARAMS) + + every { BillingClient.newBuilder(any()) } returns mockBillingClientBuilder + every { mockBillingClientBuilder.enablePendingPurchases() } returns mockBillingClientBuilder + every { mockBillingClientBuilder.setListener(capture(purchaseUpdatedListenerSlot)) } returns + mockBillingClientBuilder + every { mockBillingClientBuilder.build() } returns mockBillingClient + + billingRepository = BillingRepository(mockContext) + } + + @After + fun tearDown() { + unmockkAll() + stopKoin() + } + + @Test + fun testQueryProductsOk() = runTest { + // Arrange + val mockBillingResult: BillingResult = mockk() + val mockProductDetails: ProductDetails = mockk() + val expectedProductDetailsResult: ProductDetailsResult = mockk() + val productId = "TEST" + val price = "44.4" + + every { mockBillingResult.responseCode } returns BillingResponseCode.OK + every { mockBillingClient.isReady } returns true + every { mockBillingClient.connectionState } returns BillingClient.ConnectionState.CONNECTED + coEvery { mockBillingClient.queryProductDetails(any()) } returns + expectedProductDetailsResult + every { expectedProductDetailsResult.billingResult } returns mockBillingResult + every { expectedProductDetailsResult.productDetailsList } returns listOf(mockProductDetails) + every { mockProductDetails.productId } returns productId + every { mockProductDetails.oneTimePurchaseOfferDetails?.formattedPrice } returns price + + // Act + val result = billingRepository.queryProducts(listOf(productId)) + + // Assert + assertEquals(expectedProductDetailsResult, result) + } + + @Test + fun testQueryProductsItemUnavailable() = runTest { + // Arrange + val mockBillingResult: BillingResult = mockk() + val mockProductDetailsResult: ProductDetailsResult = mockk() + + every { mockBillingResult.responseCode } returns BillingResponseCode.ITEM_UNAVAILABLE + every { mockBillingClient.isReady } returns true + every { mockBillingClient.connectionState } returns BillingClient.ConnectionState.CONNECTED + coEvery { mockBillingClient.queryProductDetails(any()) } returns mockProductDetailsResult + every { mockProductDetailsResult.billingResult } returns mockBillingResult + every { mockProductDetailsResult.productDetailsList } returns emptyList() + + // Act + val result = billingRepository.queryProducts(listOf("TEST")) + + // Assert + assertEquals(mockProductDetailsResult, result) + } + + @Test + fun testQueryProductsBillingUnavailable() = runTest { + // Arrange + val mockBillingResult: BillingResult = mockk() + val mockProductDetailsResult: ProductDetailsResult = mockk() + + every { mockBillingResult.responseCode } returns BillingResponseCode.BILLING_UNAVAILABLE + every { mockBillingClient.isReady } returns true + every { mockBillingClient.connectionState } returns BillingClient.ConnectionState.CONNECTED + coEvery { mockBillingClient.queryProductDetails(any()) } returns mockProductDetailsResult + every { mockProductDetailsResult.billingResult } returns mockBillingResult + every { mockProductDetailsResult.productDetailsList } returns emptyList() + + // Act + val result = billingRepository.queryProducts(listOf("TEST")) + + // Assert + assertEquals(mockProductDetailsResult, result) + } + + @Test + fun testStartPurchaseFlowOk() = runTest { + // Arrange + val mockProductBillingResult: BillingResult = mockk() + val mockBillingResult: BillingResult = mockk() + val transactionId = "MOCK22" + val mockProductDetails: ProductDetails = mockk(relaxed = true) + val mockActivityProvider: () -> Activity = mockk() + every { mockBillingResult.responseCode } returns BillingResponseCode.OK + every { mockBillingClient.isReady } returns true + every { mockBillingClient.connectionState } returns BillingClient.ConnectionState.CONNECTED + every { mockBillingClient.launchBillingFlow(any(), any()) } returns mockBillingResult + every { BillingFlowParams.newBuilder() } returns mockk(relaxed = true) + every { mockProductBillingResult.responseCode } returns BillingResponseCode.OK + every { mockActivityProvider() } returns mockk() + + // Act + val result = + billingRepository.startPurchaseFlow( + mockProductDetails, + transactionId, + mockActivityProvider + ) + + // Assert + assertEquals(mockBillingResult, result) + } + + @Test + fun testStartPurchaseFlowBillingUnavailable() = runTest { + // Arrange + val mockBillingResult: BillingResult = mockk() + val transactionId = "MOCK22" + val mockProductDetails: ProductDetails = mockk(relaxed = true) + val mockActivityProvider: () -> Activity = mockk() + every { mockBillingResult.responseCode } returns BillingResponseCode.BILLING_UNAVAILABLE + every { mockBillingClient.isReady } returns true + every { mockBillingClient.connectionState } returns BillingClient.ConnectionState.CONNECTED + every { mockBillingClient.launchBillingFlow(any(), any()) } returns mockBillingResult + every { BillingFlowParams.newBuilder() } returns mockk(relaxed = true) + every { mockActivityProvider() } returns mockk() + + // Act + val result = + billingRepository.startPurchaseFlow( + mockProductDetails, + transactionId, + mockActivityProvider + ) + + // Assert + assertEquals(mockBillingResult, result) + } + + @Test + fun testQueryPurchasesFoundPurchases() = runTest { + // Arrange + val mockResult: PurchasesResult = mockk() + val mockPurchase: Purchase = mockk() + every { mockResult.billingResult.responseCode } returns BillingResponseCode.OK + every { mockResult.purchasesList } returns listOf(mockPurchase) + every { mockBillingClient.isReady } returns true + every { mockBillingClient.connectionState } returns BillingClient.ConnectionState.CONNECTED + coEvery { mockBillingClient.queryPurchasesAsync(any<QueryPurchasesParams>()) } returns + mockResult + every { BillingFlowParams.newBuilder() } returns mockk(relaxed = true) + + // Act + val result = billingRepository.queryPurchases() + + // Assert + assertEquals(mockResult, result) + } + + @Test + fun testQueryPurchasesNoPurchaseFound() = runTest { + // Arrange + val mockResult: PurchasesResult = mockk() + every { mockResult.billingResult.responseCode } returns BillingResponseCode.OK + every { mockResult.purchasesList } returns emptyList() + every { mockBillingClient.isReady } returns true + every { mockBillingClient.connectionState } returns BillingClient.ConnectionState.CONNECTED + coEvery { mockBillingClient.queryPurchasesAsync(any<QueryPurchasesParams>()) } returns + mockResult + every { BillingFlowParams.newBuilder() } returns mockk(relaxed = true) + + // Act + val result = billingRepository.queryPurchases() + + // Assert + assertEquals(mockResult, result) + } + + @Test + fun testQueryPurchasesError() = runTest { + // Arrange + val responseCode = BillingResponseCode.ITEM_UNAVAILABLE + val message = "ERROR" + val expectedError = BillingException(responseCode, message) + val mockResult: PurchasesResult = mockk() + every { mockResult.billingResult.responseCode } returns responseCode + every { mockResult.billingResult.debugMessage } returns message + every { mockResult.purchasesList } returns emptyList() + every { mockBillingClient.isReady } returns true + every { mockBillingClient.connectionState } returns BillingClient.ConnectionState.CONNECTED + coEvery { mockBillingClient.queryPurchasesAsync(any<QueryPurchasesParams>()) } returns + mockResult + every { BillingFlowParams.newBuilder() } returns mockk(relaxed = true) + + // Act + val result = billingRepository.queryPurchases() + + // Assert + assertEquals( + expectedError.toBillingResult().responseCode, + result.billingResult.responseCode + ) + assertEquals(expectedError.message, result.billingResult.debugMessage) + } + + @Test + fun testPurchaseEventPurchaseComplete() = runTest { + // Arrange + val mockPurchase: Purchase = mockk() + val mockPurchaseList = listOf(mockPurchase) + val mockBillingResult: BillingResult = mockk() + every { mockBillingResult.responseCode } returns BillingResponseCode.OK + + // Act, Assert + billingRepository.purchaseEvents.test { + purchaseUpdatedListenerSlot.captured.onPurchasesUpdated( + mockBillingResult, + mockPurchaseList + ) + val result = awaitItem() + assertIs<PurchaseEvent.Completed>(result) + assertLists(mockPurchaseList, result.purchases) + } + } + + @Test + fun testPurchaseEventUserCanceled() = runTest { + // Arrange + val mockBillingResult: BillingResult = mockk() + val mockResponseCode: Int = BillingResponseCode.USER_CANCELED + every { mockBillingResult.responseCode } returns mockResponseCode + + // Act, Assert + billingRepository.purchaseEvents.test { + purchaseUpdatedListenerSlot.captured.onPurchasesUpdated(mockBillingResult, null) + val result = awaitItem() + assertIs<PurchaseEvent.UserCanceled>(result) + } + } + + @Test + fun testPurchaseEventError() = runTest { + // Arrange + val mockDebugMessage = "ERROR" + val mockBillingResult: BillingResult = mockk() + val mockResponseCode: Int = BillingResponseCode.ERROR + val expectedError = + BillingException(responseCode = mockResponseCode, message = mockDebugMessage) + every { mockBillingResult.responseCode } returns mockResponseCode + every { mockBillingResult.debugMessage } returns mockDebugMessage + + // Act, Assert + billingRepository.purchaseEvents.test { + purchaseUpdatedListenerSlot.captured.onPurchasesUpdated(mockBillingResult, null) + val result = awaitItem() + assertIs<PurchaseEvent.Error>(result) + assertEquals(expectedError.message, result.exception.message) + } + } + + @Test + fun testEnsureConnectedStartConnection() = runTest { + // Arrange + val mockStartConnectionResult: BillingResult = mockk() + every { mockBillingClient.isReady } returns false + every { mockBillingClient.connectionState } returns + BillingClient.ConnectionState.DISCONNECTED + every { mockBillingClient.startConnection(any()) } answers + { + firstArg<BillingClientStateListener>() + .onBillingSetupFinished(mockStartConnectionResult) + } + every { mockStartConnectionResult.responseCode } returns BillingResponseCode.OK + coEvery { mockBillingClient.queryPurchasesAsync(any<QueryPurchasesParams>()) } returns + mockk(relaxed = true) + + // Act + billingRepository.queryPurchases() + + // Assert + verify { mockBillingClient.startConnection(any()) } + coVerify { mockBillingClient.queryPurchasesAsync(any<QueryPurchasesParams>()) } + } + + @OptIn(ExperimentalCoroutinesApi::class) + @Test + fun testEnsureConnectedOnlyOneSuccessfulConnection() = + runTest(UnconfinedTestDispatcher()) { + // Arrange + var hasConnected = false + val mockStartConnectionResult: BillingResult = mockk() + every { mockBillingClient.isReady } answers { hasConnected } + every { mockBillingClient.connectionState } answers + { + if (hasConnected) { + BillingClient.ConnectionState.CONNECTED + } else { + BillingClient.ConnectionState.DISCONNECTED + } + } + every { mockBillingClient.startConnection(any()) } answers + { + hasConnected = true + firstArg<BillingClientStateListener>() + .onBillingSetupFinished(mockStartConnectionResult) + } + every { mockStartConnectionResult.responseCode } returns BillingResponseCode.OK + coEvery { mockBillingClient.queryPurchasesAsync(any<QueryPurchasesParams>()) } returns + mockk(relaxed = true) + coEvery { mockBillingClient.queryProductDetails(any()) } returns mockk(relaxed = true) + + // Act + launch { billingRepository.queryPurchases() } + launch { billingRepository.queryProducts(listOf("MOCK")) } + + // Assert + verify(exactly = 1) { mockBillingClient.startConnection(any()) } + coVerify { mockBillingClient.queryPurchasesAsync(any<QueryPurchasesParams>()) } + coVerify { mockBillingClient.queryProductDetails(any()) } + } + + companion object { + private const val BILLING_CLIENT_CLASS = "com.android.billingclient.api.BillingClient" + private const val BILLING_CLIENT_KOTLIN_CLASS = + "com.android.billingclient.api.BillingClientKotlinKt" + private const val BILLING_FLOW_PARAMS = "com.android.billingclient.api.BillingFlowParams" + } +} diff --git a/android/lib/billing/src/main/AndroidManifest.xml b/android/lib/billing/src/main/AndroidManifest.xml new file mode 100644 index 0000000000..b2d3ea1235 --- /dev/null +++ b/android/lib/billing/src/main/AndroidManifest.xml @@ -0,0 +1,2 @@ +<?xml version="1.0" encoding="utf-8"?> +<manifest xmlns:android="http://schemas.android.com/apk/res/android" /> diff --git a/android/lib/billing/src/main/kotlin/net/mullvad/mullvadvpn/lib/billing/BillingPaymentRepository.kt b/android/lib/billing/src/main/kotlin/net/mullvad/mullvadvpn/lib/billing/BillingPaymentRepository.kt new file mode 100644 index 0000000000..76df623ada --- /dev/null +++ b/android/lib/billing/src/main/kotlin/net/mullvad/mullvadvpn/lib/billing/BillingPaymentRepository.kt @@ -0,0 +1,167 @@ +package net.mullvad.mullvadvpn.lib.billing + +import android.app.Activity +import com.android.billingclient.api.BillingClient.BillingResponseCode +import com.android.billingclient.api.Purchase +import kotlinx.coroutines.flow.Flow +import kotlinx.coroutines.flow.firstOrNull +import kotlinx.coroutines.flow.flow +import net.mullvad.mullvadvpn.lib.billing.extension.getProductDetails +import net.mullvad.mullvadvpn.lib.billing.extension.nonPendingPurchases +import net.mullvad.mullvadvpn.lib.billing.extension.responseCode +import net.mullvad.mullvadvpn.lib.billing.extension.toBillingException +import net.mullvad.mullvadvpn.lib.billing.extension.toPaymentAvailability +import net.mullvad.mullvadvpn.lib.billing.extension.toPaymentStatus +import net.mullvad.mullvadvpn.lib.billing.extension.toPurchaseResult +import net.mullvad.mullvadvpn.lib.billing.model.BillingException +import net.mullvad.mullvadvpn.lib.billing.model.PurchaseEvent +import net.mullvad.mullvadvpn.lib.payment.PaymentRepository +import net.mullvad.mullvadvpn.lib.payment.ProductIds +import net.mullvad.mullvadvpn.lib.payment.model.PaymentAvailability +import net.mullvad.mullvadvpn.lib.payment.model.ProductId +import net.mullvad.mullvadvpn.lib.payment.model.PurchaseResult +import net.mullvad.mullvadvpn.lib.payment.model.VerificationResult +import net.mullvad.mullvadvpn.model.PlayPurchase +import net.mullvad.mullvadvpn.model.PlayPurchaseInitResult +import net.mullvad.mullvadvpn.model.PlayPurchaseVerifyResult + +class BillingPaymentRepository( + private val billingRepository: BillingRepository, + private val playPurchaseRepository: PlayPurchaseRepository +) : PaymentRepository { + + override fun queryPaymentAvailability(): Flow<PaymentAvailability> = flow { + emit(PaymentAvailability.Loading) + val purchases = billingRepository.queryPurchases() + val productIdToPaymentStatus = + purchases.purchasesList + .filter { it.products.isNotEmpty() } + .associate { it.products.first() to it.purchaseState.toPaymentStatus() } + emit( + billingRepository + .queryProducts(listOf(ProductIds.OneMonth)) + .toPaymentAvailability(productIdToPaymentStatus) + ) + } + + override fun purchaseProduct( + productId: ProductId, + activityProvider: () -> Activity + ): Flow<PurchaseResult> = flow { + emit(PurchaseResult.FetchingProducts) + + val productDetailsResult = billingRepository.queryProducts(listOf(productId.value)) + + val productDetails = + when (productDetailsResult.responseCode()) { + BillingResponseCode.OK -> { + productDetailsResult.getProductDetails(productId.value) + ?: run { + emit(PurchaseResult.Error.NoProductFound(productId)) + return@flow + } + } + else -> { + emit( + PurchaseResult.Error.FetchProductsError( + productId, + productDetailsResult.toBillingException() + ) + ) + return@flow + } + } + + // Get transaction id + emit(PurchaseResult.FetchingObfuscationId) + val obfuscatedId: String = + when (val result = initialisePurchase()) { + is PlayPurchaseInitResult.Ok -> result.obfuscatedId + else -> { + emit(PurchaseResult.Error.TransactionIdError(productId, null)) + return@flow + } + } + + val result = + billingRepository.startPurchaseFlow( + productDetails = productDetails, + obfuscatedId = obfuscatedId, + activityProvider = activityProvider + ) + + if (result.responseCode == BillingResponseCode.OK) { + emit(PurchaseResult.BillingFlowStarted) + } else { + emit( + PurchaseResult.Error.BillingError( + BillingException(result.responseCode, result.debugMessage) + ) + ) + return@flow + } + + // Wait for a callback from the billing library + when (val event = billingRepository.purchaseEvents.firstOrNull()) { + is PurchaseEvent.Error -> emit(event.toPurchaseResult()) + is PurchaseEvent.Completed -> { + val purchase = + event.purchases.firstOrNull() + ?: run { + emit(PurchaseResult.Error.BillingError(null)) + return@flow + } + if (purchase.purchaseState == Purchase.PurchaseState.PENDING) { + emit(PurchaseResult.Completed.Pending) + } else { + emit(PurchaseResult.VerificationStarted) + if (verifyPurchase(event.purchases.first()) == PlayPurchaseVerifyResult.Ok) { + emit(PurchaseResult.Completed.Success) + } else { + emit(PurchaseResult.Error.VerificationError(null)) + } + } + } + PurchaseEvent.UserCanceled -> emit(event.toPurchaseResult()) + else -> emit(PurchaseResult.Error.BillingError(null)) + } + } + + override fun verifyPurchases(): Flow<VerificationResult> = flow { + emit(VerificationResult.FetchingUnfinishedPurchases) + val purchasesResult = billingRepository.queryPurchases() + when (purchasesResult.responseCode()) { + BillingResponseCode.OK -> { + val purchases = purchasesResult.nonPendingPurchases() + if (purchases.isNotEmpty()) { + emit(VerificationResult.VerificationStarted) + val verificationResult = verifyPurchase(purchases.first()) + emit( + when (verificationResult) { + is PlayPurchaseVerifyResult.Error -> + VerificationResult.Error.VerificationError(null) + PlayPurchaseVerifyResult.Ok -> VerificationResult.Success + } + ) + } else { + emit(VerificationResult.NothingToVerify) + } + } + else -> + emit(VerificationResult.Error.BillingError(purchasesResult.toBillingException())) + } + } + + private suspend fun initialisePurchase(): PlayPurchaseInitResult { + return playPurchaseRepository.initializePlayPurchase() + } + + private suspend fun verifyPurchase(purchase: Purchase): PlayPurchaseVerifyResult { + return playPurchaseRepository.verifyPlayPurchase( + PlayPurchase( + productId = purchase.products.first(), + purchaseToken = purchase.purchaseToken, + ) + ) + } +} diff --git a/android/lib/billing/src/main/kotlin/net/mullvad/mullvadvpn/lib/billing/BillingRepository.kt b/android/lib/billing/src/main/kotlin/net/mullvad/mullvadvpn/lib/billing/BillingRepository.kt new file mode 100644 index 0000000000..6274f8cb6f --- /dev/null +++ b/android/lib/billing/src/main/kotlin/net/mullvad/mullvadvpn/lib/billing/BillingRepository.kt @@ -0,0 +1,194 @@ +package net.mullvad.mullvadvpn.lib.billing + +import android.app.Activity +import android.content.Context +import com.android.billingclient.api.BillingClient +import com.android.billingclient.api.BillingClient.BillingResponseCode +import com.android.billingclient.api.BillingClientStateListener +import com.android.billingclient.api.BillingFlowParams +import com.android.billingclient.api.BillingResult +import com.android.billingclient.api.ProductDetails +import com.android.billingclient.api.ProductDetailsResult +import com.android.billingclient.api.PurchasesResult +import com.android.billingclient.api.PurchasesUpdatedListener +import com.android.billingclient.api.QueryProductDetailsParams +import com.android.billingclient.api.QueryProductDetailsParams.Product +import com.android.billingclient.api.QueryPurchasesParams +import com.android.billingclient.api.queryProductDetails +import com.android.billingclient.api.queryPurchasesAsync +import kotlin.coroutines.Continuation +import kotlin.coroutines.resume +import kotlin.coroutines.resumeWithException +import kotlin.coroutines.suspendCoroutine +import kotlinx.coroutines.flow.MutableSharedFlow +import kotlinx.coroutines.flow.asSharedFlow +import kotlinx.coroutines.sync.Mutex +import kotlinx.coroutines.sync.withLock +import net.mullvad.mullvadvpn.lib.billing.model.BillingException +import net.mullvad.mullvadvpn.lib.billing.model.PurchaseEvent + +class BillingRepository(context: Context) { + + private val billingClient: BillingClient + + private val purchaseUpdateListener: PurchasesUpdatedListener = + PurchasesUpdatedListener { result, purchases -> + when (result.responseCode) { + BillingResponseCode.OK -> { + _purchaseEvents.tryEmit( + PurchaseEvent.Completed(purchases?.toList() ?: emptyList()) + ) + } + BillingResponseCode.USER_CANCELED -> { + _purchaseEvents.tryEmit(PurchaseEvent.UserCanceled) + } + else -> { + _purchaseEvents.tryEmit( + PurchaseEvent.Error( + exception = + BillingException( + responseCode = result.responseCode, + message = result.debugMessage + ) + ) + ) + } + } + } + + private val _purchaseEvents = MutableSharedFlow<PurchaseEvent>(extraBufferCapacity = 1) + val purchaseEvents = _purchaseEvents.asSharedFlow() + + init { + billingClient = + BillingClient.newBuilder(context) + .enablePendingPurchases() + .setListener(purchaseUpdateListener) + .build() + } + + private val ensureConnectedMutex = Mutex() + + private suspend fun ensureConnected() = + ensureConnectedMutex.withLock { + suspendCoroutine { + if ( + billingClient.isReady && + billingClient.connectionState == BillingClient.ConnectionState.CONNECTED + ) { + it.resume(Unit) + } else { + startConnection(it) + } + } + } + + private fun startConnection(continuation: Continuation<Unit>) { + billingClient.startConnection( + object : BillingClientStateListener { + override fun onBillingServiceDisconnected() { + // Maybe do something here? + continuation.resumeWithException( + BillingException( + BillingResponseCode.SERVICE_DISCONNECTED, + "Billing service disconnected" + ) + ) + } + + override fun onBillingSetupFinished(result: BillingResult) { + if (result.responseCode == BillingResponseCode.OK) { + continuation.resume(Unit) + } else { + continuation.resumeWithException( + BillingException(result.responseCode, result.debugMessage) + ) + } + } + } + ) + } + + suspend fun queryProducts(productIds: List<String>): ProductDetailsResult { + return queryProductDetails(productIds) + } + + suspend fun startPurchaseFlow( + productDetails: ProductDetails, + obfuscatedId: String, + activityProvider: () -> Activity + ): BillingResult { + return try { + ensureConnected() + + val productDetailsParamsList = + listOf( + BillingFlowParams.ProductDetailsParams.newBuilder() + .setProductDetails(productDetails) + .build() + ) + + val billingFlowParams = + BillingFlowParams.newBuilder() + .setProductDetailsParamsList(productDetailsParamsList) + .setObfuscatedAccountId(obfuscatedId) + .build() + + val activity = activityProvider() + // Launch the billing flow + billingClient.launchBillingFlow(activity, billingFlowParams) + } catch (t: Throwable) { + if (t is BillingException) { + t.toBillingResult() + } else { + throw t + } + } + } + + suspend fun queryPurchases(): PurchasesResult { + return try { + ensureConnected() + + val queryPurchaseHistoryParams: QueryPurchasesParams = + QueryPurchasesParams.newBuilder() + .setProductType(BillingClient.ProductType.INAPP) + .build() + + billingClient.queryPurchasesAsync(queryPurchaseHistoryParams) + } catch (t: Throwable) { + if (t is BillingException) { + t.toPurchasesResult() + } else { + throw t + } + } + } + + private suspend fun queryProductDetails(productIds: List<String>): ProductDetailsResult { + return try { + ensureConnected() + + val productList = + productIds.map { productId -> + Product.newBuilder() + .setProductId(productId) + .setProductType(BillingClient.ProductType.INAPP) + .build() + } + val params = QueryProductDetailsParams.newBuilder() + params.setProductList(productList) + + billingClient.queryProductDetails(params.build()) + } catch (t: Throwable) { + if (t is BillingException) { + return ProductDetailsResult(t.toBillingResult(), null) + } else { + return ProductDetailsResult( + BillingResult.newBuilder().setResponseCode(BillingResponseCode.ERROR).build(), + null + ) + } + } + } +} diff --git a/android/lib/billing/src/main/kotlin/net/mullvad/mullvadvpn/lib/billing/PlayPurchaseRepository.kt b/android/lib/billing/src/main/kotlin/net/mullvad/mullvadvpn/lib/billing/PlayPurchaseRepository.kt new file mode 100644 index 0000000000..ac71372f76 --- /dev/null +++ b/android/lib/billing/src/main/kotlin/net/mullvad/mullvadvpn/lib/billing/PlayPurchaseRepository.kt @@ -0,0 +1,33 @@ +package net.mullvad.mullvadvpn.lib.billing + +import kotlinx.coroutines.flow.first +import net.mullvad.mullvadvpn.lib.ipc.Event +import net.mullvad.mullvadvpn.lib.ipc.MessageHandler +import net.mullvad.mullvadvpn.lib.ipc.Request +import net.mullvad.mullvadvpn.lib.ipc.events +import net.mullvad.mullvadvpn.model.PlayPurchase +import net.mullvad.mullvadvpn.model.PlayPurchaseInitError +import net.mullvad.mullvadvpn.model.PlayPurchaseInitResult +import net.mullvad.mullvadvpn.model.PlayPurchaseVerifyError +import net.mullvad.mullvadvpn.model.PlayPurchaseVerifyResult + +class PlayPurchaseRepository(private val messageHandler: MessageHandler) { + suspend fun initializePlayPurchase(): PlayPurchaseInitResult { + val result = messageHandler.trySendRequest(Request.InitPlayPurchase) + + return if (result) { + messageHandler.events<Event.PlayPurchaseInitResultEvent>().first().result + } else { + PlayPurchaseInitResult.Error(PlayPurchaseInitError.OtherError) + } + } + + suspend fun verifyPlayPurchase(purchase: PlayPurchase): PlayPurchaseVerifyResult { + val result = messageHandler.trySendRequest(Request.VerifyPlayPurchase(purchase)) + return if (result) { + messageHandler.events<Event.PlayPurchaseVerifyResultEvent>().first().result + } else { + PlayPurchaseVerifyResult.Error(PlayPurchaseVerifyError.OtherError) + } + } +} diff --git a/android/lib/billing/src/main/kotlin/net/mullvad/mullvadvpn/lib/billing/extension/ProductDetailsResultExtensions.kt b/android/lib/billing/src/main/kotlin/net/mullvad/mullvadvpn/lib/billing/extension/ProductDetailsResultExtensions.kt new file mode 100644 index 0000000000..3e4aee180a --- /dev/null +++ b/android/lib/billing/src/main/kotlin/net/mullvad/mullvadvpn/lib/billing/extension/ProductDetailsResultExtensions.kt @@ -0,0 +1,13 @@ +package net.mullvad.mullvadvpn.lib.billing.extension + +import com.android.billingclient.api.ProductDetails +import com.android.billingclient.api.ProductDetailsResult +import net.mullvad.mullvadvpn.lib.billing.model.BillingException + +fun ProductDetailsResult.getProductDetails(productId: String): ProductDetails? = + this.productDetailsList?.firstOrNull { it.productId == productId } + +fun ProductDetailsResult.responseCode(): Int = this.billingResult.responseCode + +fun ProductDetailsResult.toBillingException(): BillingException = + BillingException(responseCode = this.responseCode(), message = this.billingResult.debugMessage) diff --git a/android/lib/billing/src/main/kotlin/net/mullvad/mullvadvpn/lib/billing/extension/ProductDetailsResultToPaymentAvailability.kt b/android/lib/billing/src/main/kotlin/net/mullvad/mullvadvpn/lib/billing/extension/ProductDetailsResultToPaymentAvailability.kt new file mode 100644 index 0000000000..37cc701724 --- /dev/null +++ b/android/lib/billing/src/main/kotlin/net/mullvad/mullvadvpn/lib/billing/extension/ProductDetailsResultToPaymentAvailability.kt @@ -0,0 +1,37 @@ +package net.mullvad.mullvadvpn.lib.billing.extension + +import com.android.billingclient.api.BillingClient +import com.android.billingclient.api.ProductDetailsResult +import net.mullvad.mullvadvpn.lib.billing.model.BillingException +import net.mullvad.mullvadvpn.lib.payment.model.PaymentAvailability +import net.mullvad.mullvadvpn.lib.payment.model.PaymentStatus + +fun ProductDetailsResult.toPaymentAvailability( + productIdToPaymentStatus: Map<String, PaymentStatus?> +) = + when (this.billingResult.responseCode) { + BillingClient.BillingResponseCode.OK -> { + val productDetailsList = this.productDetailsList + if (productDetailsList?.isNotEmpty() == true) { + PaymentAvailability.ProductsAvailable( + productDetailsList.toPaymentProducts(productIdToPaymentStatus) + ) + } else { + PaymentAvailability.NoProductsFounds + } + } + BillingClient.BillingResponseCode.BILLING_UNAVAILABLE -> + PaymentAvailability.Error.BillingUnavailable + BillingClient.BillingResponseCode.SERVICE_UNAVAILABLE -> + PaymentAvailability.Error.ServiceUnavailable + BillingClient.BillingResponseCode.DEVELOPER_ERROR -> + PaymentAvailability.Error.DeveloperError + BillingClient.BillingResponseCode.FEATURE_NOT_SUPPORTED -> + PaymentAvailability.Error.FeatureNotSupported + BillingClient.BillingResponseCode.ITEM_UNAVAILABLE -> + PaymentAvailability.Error.ItemUnavailable + else -> + PaymentAvailability.Error.Other( + BillingException(this.billingResult.responseCode, this.billingResult.debugMessage) + ) + } diff --git a/android/lib/billing/src/main/kotlin/net/mullvad/mullvadvpn/lib/billing/extension/ProductDetailsToPaymentProduct.kt b/android/lib/billing/src/main/kotlin/net/mullvad/mullvadvpn/lib/billing/extension/ProductDetailsToPaymentProduct.kt new file mode 100644 index 0000000000..fa9a20613f --- /dev/null +++ b/android/lib/billing/src/main/kotlin/net/mullvad/mullvadvpn/lib/billing/extension/ProductDetailsToPaymentProduct.kt @@ -0,0 +1,17 @@ +package net.mullvad.mullvadvpn.lib.billing.extension + +import com.android.billingclient.api.ProductDetails +import net.mullvad.mullvadvpn.lib.payment.model.PaymentProduct +import net.mullvad.mullvadvpn.lib.payment.model.PaymentStatus +import net.mullvad.mullvadvpn.lib.payment.model.ProductId +import net.mullvad.mullvadvpn.lib.payment.model.ProductPrice + +fun ProductDetails.toPaymentProduct(productIdToStatus: Map<String, PaymentStatus?>) = + PaymentProduct( + productId = ProductId(this.productId), + price = ProductPrice(this.oneTimePurchaseOfferDetails?.formattedPrice ?: ""), + productIdToStatus[this.productId] + ) + +fun List<ProductDetails>.toPaymentProducts(productIdToStatus: Map<String, PaymentStatus?>) = + this.map { it.toPaymentProduct(productIdToStatus) } diff --git a/android/lib/billing/src/main/kotlin/net/mullvad/mullvadvpn/lib/billing/extension/PurchaseEventToPurchaseResult.kt b/android/lib/billing/src/main/kotlin/net/mullvad/mullvadvpn/lib/billing/extension/PurchaseEventToPurchaseResult.kt new file mode 100644 index 0000000000..e0e4bf0a77 --- /dev/null +++ b/android/lib/billing/src/main/kotlin/net/mullvad/mullvadvpn/lib/billing/extension/PurchaseEventToPurchaseResult.kt @@ -0,0 +1,11 @@ +package net.mullvad.mullvadvpn.lib.billing.extension + +import net.mullvad.mullvadvpn.lib.billing.model.PurchaseEvent +import net.mullvad.mullvadvpn.lib.payment.model.PurchaseResult + +fun PurchaseEvent.toPurchaseResult() = + when (this) { + is PurchaseEvent.Error -> PurchaseResult.Error.BillingError(this.exception) + is PurchaseEvent.Completed -> PurchaseResult.VerificationStarted + PurchaseEvent.UserCanceled -> PurchaseResult.Completed.Cancelled + } diff --git a/android/lib/billing/src/main/kotlin/net/mullvad/mullvadvpn/lib/billing/extension/PurchaseStateToPaymentStatus.kt b/android/lib/billing/src/main/kotlin/net/mullvad/mullvadvpn/lib/billing/extension/PurchaseStateToPaymentStatus.kt new file mode 100644 index 0000000000..701e5fde3d --- /dev/null +++ b/android/lib/billing/src/main/kotlin/net/mullvad/mullvadvpn/lib/billing/extension/PurchaseStateToPaymentStatus.kt @@ -0,0 +1,11 @@ +package net.mullvad.mullvadvpn.lib.billing.extension + +import com.android.billingclient.api.Purchase +import net.mullvad.mullvadvpn.lib.payment.model.PaymentStatus + +internal fun Int.toPaymentStatus(): PaymentStatus? = + when (this) { + Purchase.PurchaseState.PURCHASED -> PaymentStatus.VERIFICATION_IN_PROGRESS + Purchase.PurchaseState.PENDING -> PaymentStatus.PENDING + else -> null + } diff --git a/android/lib/billing/src/main/kotlin/net/mullvad/mullvadvpn/lib/billing/extension/PurchasesResultExtensions.kt b/android/lib/billing/src/main/kotlin/net/mullvad/mullvadvpn/lib/billing/extension/PurchasesResultExtensions.kt new file mode 100644 index 0000000000..d76d1a8b7e --- /dev/null +++ b/android/lib/billing/src/main/kotlin/net/mullvad/mullvadvpn/lib/billing/extension/PurchasesResultExtensions.kt @@ -0,0 +1,13 @@ +package net.mullvad.mullvadvpn.lib.billing.extension + +import com.android.billingclient.api.Purchase +import com.android.billingclient.api.PurchasesResult +import net.mullvad.mullvadvpn.lib.billing.model.BillingException + +fun PurchasesResult.nonPendingPurchases(): List<Purchase> = + this.purchasesList.filter { it.purchaseState != Purchase.PurchaseState.PENDING } + +fun PurchasesResult.responseCode(): Int = this.billingResult.responseCode + +fun PurchasesResult.toBillingException(): BillingException = + BillingException(responseCode = this.responseCode(), message = this.billingResult.debugMessage) diff --git a/android/lib/billing/src/main/kotlin/net/mullvad/mullvadvpn/lib/billing/model/BillingException.kt b/android/lib/billing/src/main/kotlin/net/mullvad/mullvadvpn/lib/billing/model/BillingException.kt new file mode 100644 index 0000000000..08f6a89cca --- /dev/null +++ b/android/lib/billing/src/main/kotlin/net/mullvad/mullvadvpn/lib/billing/model/BillingException.kt @@ -0,0 +1,15 @@ +package net.mullvad.mullvadvpn.lib.billing.model + +import com.android.billingclient.api.BillingResult +import com.android.billingclient.api.PurchasesResult + +class BillingException(private val responseCode: Int, message: String) : Throwable(message) { + + fun toBillingResult(): BillingResult = + BillingResult.newBuilder() + .setResponseCode(responseCode) + .setDebugMessage(message ?: "") + .build() + + fun toPurchasesResult(): PurchasesResult = PurchasesResult(toBillingResult(), emptyList()) +} diff --git a/android/lib/billing/src/main/kotlin/net/mullvad/mullvadvpn/lib/billing/model/PurchaseEvent.kt b/android/lib/billing/src/main/kotlin/net/mullvad/mullvadvpn/lib/billing/model/PurchaseEvent.kt new file mode 100644 index 0000000000..b88f31cae6 --- /dev/null +++ b/android/lib/billing/src/main/kotlin/net/mullvad/mullvadvpn/lib/billing/model/PurchaseEvent.kt @@ -0,0 +1,11 @@ +package net.mullvad.mullvadvpn.lib.billing.model + +import com.android.billingclient.api.Purchase + +sealed interface PurchaseEvent { + data object UserCanceled : PurchaseEvent + + data class Error(val exception: BillingException) : PurchaseEvent + + data class Completed(val purchases: List<Purchase>) : PurchaseEvent +} diff --git a/android/lib/billing/src/test/kotlin/net/mullvad/mullvadvpn/lib/billing/BillingPaymentRepositoryTest.kt b/android/lib/billing/src/test/kotlin/net/mullvad/mullvadvpn/lib/billing/BillingPaymentRepositoryTest.kt new file mode 100644 index 0000000000..fe25457e49 --- /dev/null +++ b/android/lib/billing/src/test/kotlin/net/mullvad/mullvadvpn/lib/billing/BillingPaymentRepositoryTest.kt @@ -0,0 +1,377 @@ +package net.mullvad.mullvadvpn.lib.billing + +import app.cash.turbine.test +import com.android.billingclient.api.BillingClient.BillingResponseCode +import com.android.billingclient.api.BillingResult +import com.android.billingclient.api.ProductDetails +import com.android.billingclient.api.ProductDetailsResult +import com.android.billingclient.api.Purchase +import io.mockk.coEvery +import io.mockk.every +import io.mockk.mockk +import io.mockk.mockkStatic +import kotlin.test.assertEquals +import kotlin.test.assertIs +import kotlinx.coroutines.flow.MutableSharedFlow +import kotlinx.coroutines.test.runTest +import net.mullvad.mullvadvpn.lib.billing.extension.toPaymentProduct +import net.mullvad.mullvadvpn.lib.billing.model.PurchaseEvent +import net.mullvad.mullvadvpn.lib.common.test.TestCoroutineRule +import net.mullvad.mullvadvpn.lib.payment.model.PaymentAvailability +import net.mullvad.mullvadvpn.lib.payment.model.PaymentProduct +import net.mullvad.mullvadvpn.lib.payment.model.ProductId +import net.mullvad.mullvadvpn.lib.payment.model.PurchaseResult +import net.mullvad.mullvadvpn.model.PlayPurchaseInitError +import net.mullvad.mullvadvpn.model.PlayPurchaseInitResult +import net.mullvad.mullvadvpn.model.PlayPurchaseVerifyError +import net.mullvad.mullvadvpn.model.PlayPurchaseVerifyResult +import org.junit.Before +import org.junit.Rule +import org.junit.Test + +class BillingPaymentRepositoryTest { + @get:Rule val testCoroutineRule = TestCoroutineRule() + + private val mockBillingRepository: BillingRepository = mockk() + private val mockPlayPurchaseRepository: PlayPurchaseRepository = mockk() + + private val purchaseEventFlow = MutableSharedFlow<PurchaseEvent>(extraBufferCapacity = 1) + + private lateinit var paymentRepository: BillingPaymentRepository + + @Before + fun setUp() { + mockkStatic(PRODUCT_DETAILS_TO_PAYMENT_PRODUCT_EXT) + + every { mockBillingRepository.purchaseEvents } returns purchaseEventFlow + + paymentRepository = + BillingPaymentRepository( + billingRepository = mockBillingRepository, + playPurchaseRepository = mockPlayPurchaseRepository + ) + } + + @Test + fun testQueryAvailablePaymentProductsAvailable() = runTest { + // Arrange + val expectedProduct: PaymentProduct = mockk() + val mockProduct: ProductDetails = mockk() + val mockResult: ProductDetailsResult = mockk() + coEvery { mockBillingRepository.queryPurchases() } returns mockk(relaxed = true) + coEvery { mockBillingRepository.queryProducts(any()) } returns mockResult + every { mockProduct.toPaymentProduct(any()) } returns expectedProduct + every { mockResult.billingResult.responseCode } returns BillingResponseCode.OK + every { mockResult.productDetailsList } returns listOf(mockProduct) + + // Act, Assert + paymentRepository.queryPaymentAvailability().test { + // Loading + awaitItem() + val result = awaitItem() + assertIs<PaymentAvailability.ProductsAvailable>(result) + assertEquals(expectedProduct, result.products.first()) + awaitComplete() + } + } + + @Test + fun testQueryAvailablePaymentProductsUnavailable() = runTest { + // Arrange + val mockResult: ProductDetailsResult = mockk() + every { mockResult.billingResult.responseCode } returns BillingResponseCode.OK + every { mockResult.productDetailsList } returns emptyList() + coEvery { mockBillingRepository.queryPurchases() } returns mockk(relaxed = true) + coEvery { mockBillingRepository.queryProducts(any()) } returns mockResult + + // Act, Assert + paymentRepository.queryPaymentAvailability().test { + // Loading + awaitItem() + val result = awaitItem() + assertIs<PaymentAvailability.NoProductsFounds>(result) + awaitComplete() + } + } + + @Test + fun testQueryAvailablePaymentBillingUnavailableError() = runTest { + // Arrange + val mockResult: ProductDetailsResult = mockk() + every { mockResult.billingResult.responseCode } returns + BillingResponseCode.BILLING_UNAVAILABLE + coEvery { mockBillingRepository.queryPurchases() } returns mockk(relaxed = true) + coEvery { mockBillingRepository.queryProducts(any()) } returns mockResult + + // Act, Assert + paymentRepository.queryPaymentAvailability().test { + // Loading + awaitItem() + val result = awaitItem() + assertIs<PaymentAvailability.Error.BillingUnavailable>(result) + awaitComplete() + } + } + + @Test + fun testPurchaseBillingProductStartPurchaseFetchProductsError() = runTest { + // Arrange + val mockProductId = ProductId("MOCK") + val mockProductDetailsResult = mockk<ProductDetailsResult>() + every { mockProductDetailsResult.billingResult.responseCode } returns + BillingResponseCode.BILLING_UNAVAILABLE + every { mockProductDetailsResult.billingResult.debugMessage } returns "ERROR" + coEvery { mockBillingRepository.queryProducts(listOf(mockProductId.value)) } returns + mockProductDetailsResult + + // Act, Assert + paymentRepository.purchaseProduct(mockProductId, mockk()).test { + assertIs<PurchaseResult.FetchingProducts>(awaitItem()) + val result = awaitItem() + assertIs<PurchaseResult.Error.FetchProductsError>(result) + awaitComplete() + } + } + + @Test + fun testPurchaseBillingProductStartPurchaseNoProductsFoundError() = runTest { + // Arrange + val mockProductId = ProductId("MOCK") + val mockProductDetailsResult = mockk<ProductDetailsResult>() + every { mockProductDetailsResult.billingResult.responseCode } returns BillingResponseCode.OK + every { mockProductDetailsResult.productDetailsList } returns emptyList() + coEvery { mockBillingRepository.queryProducts(listOf(mockProductId.value)) } returns + mockProductDetailsResult + + // Act, Assert + paymentRepository.purchaseProduct(mockProductId, mockk()).test { + assertIs<PurchaseResult.FetchingProducts>(awaitItem()) + val result = awaitItem() + assertIs<PurchaseResult.Error.NoProductFound>(result) + awaitComplete() + } + } + + @Test + fun testPurchaseBillingProductStartPurchaseTransactionIdError() = runTest { + // Arrange + val mockProductId = ProductId("MOCK") + val mockProductDetailsResult = mockk<ProductDetailsResult>() + val mockProductDetails: ProductDetails = mockk() + every { mockProductDetails.productId } returns mockProductId.value + every { mockProductDetailsResult.billingResult.responseCode } returns BillingResponseCode.OK + every { mockProductDetailsResult.productDetailsList } returns listOf(mockProductDetails) + coEvery { mockBillingRepository.queryProducts(listOf(mockProductId.value)) } returns + mockProductDetailsResult + coEvery { mockPlayPurchaseRepository.initializePlayPurchase() } returns + PlayPurchaseInitResult.Error(PlayPurchaseInitError.OtherError) + + // Act, Assert + paymentRepository.purchaseProduct(mockProductId, mockk()).test { + assertIs<PurchaseResult.FetchingProducts>(awaitItem()) + assertIs<PurchaseResult.FetchingObfuscationId>(awaitItem()) + val result = awaitItem() + assertIs<PurchaseResult.Error.TransactionIdError>(result) + awaitComplete() + } + } + + @Test + fun testPurchaseBillingProductStartPurchaseFlowBillingError() = runTest { + // Arrange + val mockProductId = ProductId("MOCK") + val mockProductDetailsResult = mockk<ProductDetailsResult>() + val mockProductDetails: ProductDetails = mockk() + every { mockProductDetails.productId } returns mockProductId.value + every { mockProductDetailsResult.billingResult.responseCode } returns BillingResponseCode.OK + every { mockProductDetailsResult.productDetailsList } returns listOf(mockProductDetails) + coEvery { mockBillingRepository.queryProducts(listOf(mockProductId.value)) } returns + mockProductDetailsResult + val mockBillingResult: BillingResult = mockk() + every { mockBillingResult.responseCode } returns BillingResponseCode.BILLING_UNAVAILABLE + every { mockBillingResult.debugMessage } returns "Mock error" + coEvery { + mockBillingRepository.startPurchaseFlow( + productDetails = any(), + obfuscatedId = any(), + activityProvider = any() + ) + } returns mockBillingResult + coEvery { mockPlayPurchaseRepository.initializePlayPurchase() } returns + PlayPurchaseInitResult.Ok("MOCK") + + // Act, Assert + paymentRepository.purchaseProduct(mockProductId, mockk()).test { + // Purchase started + assertIs<PurchaseResult.FetchingProducts>(awaitItem()) + assertIs<PurchaseResult.FetchingObfuscationId>(awaitItem()) + val result = awaitItem() + assertIs<PurchaseResult.Error.BillingError>(result) + awaitComplete() + } + } + + @Test + fun testPurchaseBillingProductPurchaseCanceled() = runTest { + // Arrange + val mockProductId = ProductId("MOCK") + val mockProductDetailsResult = mockk<ProductDetailsResult>() + val mockProductDetails: ProductDetails = mockk() + every { mockProductDetails.productId } returns mockProductId.value + every { mockProductDetailsResult.billingResult.responseCode } returns BillingResponseCode.OK + every { mockProductDetailsResult.productDetailsList } returns listOf(mockProductDetails) + coEvery { mockBillingRepository.queryProducts(listOf(mockProductId.value)) } returns + mockProductDetailsResult + val mockObfuscatedId = "MOCK-ID" + val mockBillingResult: BillingResult = mockk() + every { mockBillingResult.responseCode } returns BillingResponseCode.OK + coEvery { + mockBillingRepository.startPurchaseFlow( + productDetails = any(), + obfuscatedId = mockObfuscatedId, + activityProvider = any() + ) + } returns mockBillingResult + coEvery { mockPlayPurchaseRepository.initializePlayPurchase() } returns + PlayPurchaseInitResult.Ok(mockObfuscatedId) + + // Act, Assert + paymentRepository.purchaseProduct(mockProductId, mockk()).test { + assertIs<PurchaseResult.FetchingProducts>(awaitItem()) + assertIs<PurchaseResult.FetchingObfuscationId>(awaitItem()) + assertIs<PurchaseResult.BillingFlowStarted>(awaitItem()) + purchaseEventFlow.tryEmit(PurchaseEvent.UserCanceled) + val result = awaitItem() + assertIs<PurchaseResult.Completed.Cancelled>(result) + awaitComplete() + } + } + + @Test + fun testPurchaseBillingProductVerificationError() = runTest { + // Arrange + val mockProductId = ProductId("MOCK") + val mockProductDetailsResult = mockk<ProductDetailsResult>() + val mockProductDetails: ProductDetails = mockk() + every { mockProductDetails.productId } returns mockProductId.value + every { mockProductDetailsResult.billingResult.responseCode } returns BillingResponseCode.OK + every { mockProductDetailsResult.productDetailsList } returns listOf(mockProductDetails) + coEvery { mockBillingRepository.queryProducts(listOf(mockProductId.value)) } returns + mockProductDetailsResult + val mockPurchaseToken = "TOKEN" + val mockBillingPurchase: Purchase = mockk() + val mockBillingResult: BillingResult = mockk() + every { mockBillingPurchase.purchaseState } returns Purchase.PurchaseState.PURCHASED + every { mockBillingResult.responseCode } returns BillingResponseCode.OK + every { mockBillingPurchase.products } returns listOf(mockProductId.value) + every { mockBillingPurchase.purchaseToken } returns mockPurchaseToken + coEvery { + mockBillingRepository.startPurchaseFlow( + productDetails = any(), + obfuscatedId = any(), + activityProvider = any() + ) + } returns mockBillingResult + coEvery { mockPlayPurchaseRepository.initializePlayPurchase() } returns + PlayPurchaseInitResult.Ok("MOCK-ID") + coEvery { mockPlayPurchaseRepository.verifyPlayPurchase(any()) } returns + PlayPurchaseVerifyResult.Error(PlayPurchaseVerifyError.OtherError) + + // Act, Assert + paymentRepository.purchaseProduct(mockProductId, mockk()).test { + assertIs<PurchaseResult.FetchingProducts>(awaitItem()) + assertIs<PurchaseResult.FetchingObfuscationId>(awaitItem()) + assertIs<PurchaseResult.BillingFlowStarted>(awaitItem()) + purchaseEventFlow.tryEmit(PurchaseEvent.Completed(listOf(mockBillingPurchase))) + assertIs<PurchaseResult.VerificationStarted>(awaitItem()) + val result = awaitItem() + assertIs<PurchaseResult.Error.VerificationError>(result) + awaitComplete() + } + } + + @Test + fun testPurchaseBillingProductPurchaseCompleted() = runTest { + // Arrange + val mockProductId = ProductId("MOCK") + val mockProductDetailsResult = mockk<ProductDetailsResult>() + val mockProductDetails: ProductDetails = mockk() + every { mockProductDetails.productId } returns mockProductId.value + every { mockProductDetailsResult.billingResult.responseCode } returns BillingResponseCode.OK + every { mockProductDetailsResult.productDetailsList } returns listOf(mockProductDetails) + coEvery { mockBillingRepository.queryProducts(listOf(mockProductId.value)) } returns + mockProductDetailsResult + val mockPurchaseToken = "TOKEN" + val mockBillingPurchase: Purchase = mockk() + val mockBillingResult: BillingResult = mockk() + every { mockBillingPurchase.purchaseState } returns Purchase.PurchaseState.PURCHASED + every { mockBillingResult.responseCode } returns BillingResponseCode.OK + every { mockBillingPurchase.products } returns listOf(mockProductId.value) + every { mockBillingPurchase.purchaseToken } returns mockPurchaseToken + coEvery { + mockBillingRepository.startPurchaseFlow( + productDetails = any(), + obfuscatedId = any(), + activityProvider = any() + ) + } returns mockBillingResult + coEvery { mockPlayPurchaseRepository.initializePlayPurchase() } returns + PlayPurchaseInitResult.Ok("MOCK") + coEvery { mockPlayPurchaseRepository.verifyPlayPurchase(any()) } returns + PlayPurchaseVerifyResult.Ok + + // Act, Assert + paymentRepository.purchaseProduct(mockProductId, mockk()).test { + assertIs<PurchaseResult.FetchingProducts>(awaitItem()) + assertIs<PurchaseResult.FetchingObfuscationId>(awaitItem()) + assertIs<PurchaseResult.BillingFlowStarted>(awaitItem()) + purchaseEventFlow.tryEmit(PurchaseEvent.Completed(listOf(mockBillingPurchase))) + assertIs<PurchaseResult.VerificationStarted>(awaitItem()) + val result = awaitItem() + assertIs<PurchaseResult.Completed.Success>(result) + awaitComplete() + } + } + + @Test + fun testPurchaseBillingProductPurchasePending() = runTest { + // Arrange + val mockProductId = ProductId("MOCK") + val mockProductDetailsResult = mockk<ProductDetailsResult>() + val mockProductDetails: ProductDetails = mockk() + every { mockProductDetails.productId } returns mockProductId.value + every { mockProductDetailsResult.billingResult.responseCode } returns BillingResponseCode.OK + every { mockProductDetailsResult.productDetailsList } returns listOf(mockProductDetails) + coEvery { mockBillingRepository.queryProducts(listOf(mockProductId.value)) } returns + mockProductDetailsResult + val mockBillingPurchase: Purchase = mockk() + val mockBillingResult: BillingResult = mockk() + every { mockBillingPurchase.purchaseState } returns Purchase.PurchaseState.PENDING + every { mockBillingResult.responseCode } returns BillingResponseCode.OK + coEvery { + mockBillingRepository.startPurchaseFlow( + productDetails = any(), + obfuscatedId = any(), + activityProvider = any() + ) + } returns mockBillingResult + coEvery { mockPlayPurchaseRepository.initializePlayPurchase() } returns + PlayPurchaseInitResult.Ok("MOCK") + + // Act, Assert + paymentRepository.purchaseProduct(mockProductId, mockk()).test { + assertIs<PurchaseResult.FetchingProducts>(awaitItem()) + assertIs<PurchaseResult.FetchingObfuscationId>(awaitItem()) + assertIs<PurchaseResult.BillingFlowStarted>(awaitItem()) + purchaseEventFlow.tryEmit(PurchaseEvent.Completed(listOf(mockBillingPurchase))) + val result = awaitItem() + assertIs<PurchaseResult.Completed.Pending>(result) + awaitComplete() + } + } + + companion object { + private const val PRODUCT_DETAILS_TO_PAYMENT_PRODUCT_EXT = + "net.mullvad.mullvadvpn.lib.billing.extension.ProductDetailsToPaymentProductKt" + } +} |
