summaryrefslogtreecommitdiffhomepage
path: root/ios/MullvadVPN/TunnelManager/TunnelStore.swift
blob: 0344175033679b21f81904a6b7594cad5fee9b70 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
//
//  TunnelStore.swift
//  MullvadVPN
//
//  Created by pronebird on 07/12/2022.
//  Copyright © 2025 Mullvad VPN AB. All rights reserved.
//

import Foundation
import MullvadLogging
import MullvadTypes
import NetworkExtension
import UIKit

protocol TunnelStoreProtocol: Sendable {
    associatedtype TunnelType: TunnelProtocol, Equatable
    func getPersistentTunnels() -> [TunnelType]
    func createNewTunnel() -> TunnelType
}

/// Wrapper around system VPN tunnels.
final class TunnelStore: TunnelStoreProtocol, TunnelStatusObserver, @unchecked Sendable {
    typealias TunnelType = Tunnel
    private let logger = Logger(label: "TunnelStore")
    private let lock = NSLock()
    private let application: BackgroundTaskProviding

    /// Persistent tunnels registered with the system.
    private var persistentTunnels: [TunnelType] = []

    /// Newly created tunnels, stored as collection of weak boxes.
    private var newTunnels: [WeakBox<TunnelType>] = []

    init(application: BackgroundTaskProviding) {
        self.application = application
        NotificationCenter.default.addObserver(
            self,
            selector: #selector(applicationDidBecomeActive(_:)),
            name: UIApplication.didBecomeActiveNotification,
            object: application
        )
    }

    func getPersistentTunnels() -> [TunnelType] {
        lock.lock()
        defer { lock.unlock() }

        return persistentTunnels
    }

    func loadPersistentTunnels(completion: @escaping (Error?) -> Void) {
        TunnelProviderManagerType.loadAllFromPreferences { managers, error in
            self.lock.lock()
            defer {
                self.lock.unlock()

                completion(error)
            }

            guard error == nil else { return }

            self.persistentTunnels.forEach { tunnel in
                tunnel.removeObserver(self)
            }

            self.persistentTunnels =
                managers?.map { manager in
                    let tunnel = Tunnel(tunnelProvider: manager, backgroundTaskProvider: self.application)
                    tunnel.addObserver(self)

                    self.logger.debug(
                        "Loaded persistent tunnel: \(tunnel.logFormat()) with status: \(tunnel.status)."
                    )

                    return tunnel
                } ?? []
        }
    }

    func createNewTunnel() -> TunnelType {
        lock.lock()
        defer { lock.unlock() }

        let tunnelProviderManager = TunnelProviderManagerType()
        let tunnel = TunnelType(tunnelProvider: tunnelProviderManager, backgroundTaskProvider: application)
        tunnel.addObserver(self)

        newTunnels = newTunnels.filter { $0.value != nil }
        newTunnels.append(WeakBox(tunnel))

        logger.debug("Create new tunnel: \(tunnel.logFormat()).")

        return tunnel
    }

    func tunnel(_ tunnel: any TunnelProtocol, didReceiveStatus status: NEVPNStatus) {
        lock.lock()
        defer { lock.unlock() }

        handleTunnelStatus(tunnel: tunnel as! TunnelType, status: status)
    }

    private func handleTunnelStatus(tunnel: TunnelType, status: NEVPNStatus) {
        if status == .invalid,
            let index = persistentTunnels.firstIndex(of: tunnel)
        {
            persistentTunnels.remove(at: index)
            logger.debug("Persistent tunnel was removed: \(tunnel.logFormat()).")
        }

        if status != .invalid,
            let index = newTunnels.compactMap({ $0.value }).firstIndex(where: { $0 == tunnel })
        {
            newTunnels.remove(at: index)
            persistentTunnels.append(tunnel)
            logger.debug("New tunnel became persistent: \(tunnel.logFormat()).")
        }
    }

    @objc private func applicationDidBecomeActive(_ notification: Notification) {
        refreshStatus()
    }

    private func refreshStatus() {
        lock.lock()
        defer { lock.unlock() }

        let allTunnels = persistentTunnels + newTunnels.compactMap { $0.value }

        for tunnel in allTunnels {
            handleTunnelStatus(tunnel: tunnel, status: tunnel.status)
        }
    }
}