summaryrefslogtreecommitdiffhomepage
path: root/ios/MullvadREST/Transport/AccessMethodIterator.swift
blob: d1672ad8fb9d748cf4992a6f7da43d2156005ccd (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
//
//  AccessMethodIterator.swift
//  MullvadREST
//
//  Created by Mojgan on 2024-01-10.
//  Copyright © 2025 Mullvad VPN AB. All rights reserved.
//

import Combine
import Foundation
import MullvadLogging
import MullvadSettings
import MullvadTypes

final class AccessMethodIterator: @unchecked Sendable, SwiftConnectionModeProviding {
    private let logger = Logger(label: "AccessMethodIterator")

    private let dataSource: AccessMethodRepositoryDataSource

    private var index = 0
    private var cancellables = Set<Combine.AnyCancellable>()

    private var enabledConfigurations: [PersistentAccessMethod] {
        dataSource.fetchAll().filter { $0.isEnabled }
    }

    private var lastReachableApiAccessId: UUID? {
        dataSource.fetchLastReachable().id
    }

    public var domainName: String {
        REST.encryptedDNSHostname
    }

    init(dataSource: AccessMethodRepositoryDataSource) {
        self.dataSource = dataSource

        self.dataSource
            .accessMethodsPublisher
            .sink { [weak self] _ in
                guard let self else { return }
                self.refreshCacheIfNeeded()
            }
            .store(in: &cancellables)
    }

    private func refreshCacheIfNeeded() {
        // Validating the index of `lastReachableApiAccessCache` after any changes in `AccessMethodRepository`
        if let firstIndex = enabledConfigurations.firstIndex(where: { $0.id == lastReachableApiAccessId }) {
            index = firstIndex
        }

        let newAccessMethod = pick()
        dataSource.requestAccessMethod(newAccessMethod)
    }

    func rotate() {
        let (partial, isOverflow) = index.addingReportingOverflow(1)
        index = isOverflow ? 0 : partial
        dataSource.requestAccessMethod(pick())
    }

    func pick() -> PersistentAccessMethod {
        let configurations = enabledConfigurations
        if configurations.isEmpty {
            /// Returning `Default` strategy  when  all is disabled
            return dataSource.directAccess
        } else {
            /// Picking the next `Enabled` configuration in order they are added
            /// And starting from the beginning when it reaches end
            let circularIndex = index % configurations.count
            return configurations[circularIndex]
        }
    }

    func accessMethods() -> [PersistentAccessMethod] {
        dataSource.fetchAll()
    }
}