summaryrefslogtreecommitdiffhomepage
path: root/ios/PacketTunnel/DeviceCheck/DeviceCheckOperation.swift
blob: faf430017f52bf569a84322d217ff4abaa33f042 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
//
//  DeviceCheckOperation.swift
//  PacketTunnel
//
//  Created by pronebird on 20/04/2023.
//  Copyright © 2025 Mullvad VPN AB. All rights reserved.
//

import Foundation
import MullvadLogging
import MullvadREST
import MullvadSettings
import MullvadTypes
import Operations
import PacketTunnelCore
import WireGuardKitTypes

/**
 An operation that is responsible for performing account and device diagnostics and key rotation from within packet
 tunnel process.

 Packet tunnel runs this operation immediately as it starts, with `rotateImmediatelyOnKeyMismatch` flag set to
 `true` which forces key rotation to happen immediately given that the key stored on server does not match the key
 stored on device. Unless the last rotation attempt took place less than 15 seconds ago in which case the key rotation
 is not performed.

 Other times, packet tunnel runs this operation with `rotateImmediatelyOnKeyMismatch` set to `false`, in which
 case it respects the 24 hour interval between key rotation retry attempts.
 */
final class DeviceCheckOperation: ResultOperation<DeviceCheck>, @unchecked Sendable {
    private let logger = Logger(label: "DeviceCheckOperation")

    private let remoteService: DeviceCheckRemoteServiceProtocol
    private let deviceStateAccessor: DeviceStateAccessorProtocol
    private let rotateImmediatelyOnKeyMismatch: Bool

    private var tasks: [Cancellable] = []

    init(
        dispatchQueue: DispatchQueue,
        remoteSevice: DeviceCheckRemoteServiceProtocol,
        deviceStateAccessor: DeviceStateAccessorProtocol,
        rotateImmediatelyOnKeyMismatch: Bool,
        completionHandler: CompletionHandler? = nil
    ) {
        self.remoteService = remoteSevice
        self.deviceStateAccessor = deviceStateAccessor
        self.rotateImmediatelyOnKeyMismatch = rotateImmediatelyOnKeyMismatch

        super.init(dispatchQueue: dispatchQueue, completionQueue: dispatchQueue, completionHandler: completionHandler)
    }

    override func main() {
        startFlow { result in
            self.finish(result: result)
        }
    }

    override func operationDidCancel() {
        tasks.forEach { $0.cancel() }
    }

    // MARK: - Flow

    /**
     Begins the flow by fetching device state and then fetching account and device data. Calls `didReceiveData()` with
     the received data when done.
     */
    private func startFlow(completion: @escaping @Sendable (Result<DeviceCheck, Error>) -> Void) {
        do {
            guard case let .loggedIn(accountData, deviceData) = try deviceStateAccessor.read() else {
                throw DeviceCheckError.invalidDeviceState
            }

            fetchData(
                accountNumber: accountData.number,
                deviceIdentifier: deviceData.identifier
            ) { [self] accountResult, deviceResult in
                didReceiveData(accountResult: accountResult, deviceResult: deviceResult, completion: completion)
            }
        } catch {
            completion(.failure(error))
        }
    }

    /**
     Handles received data results and initiates key rotation when the key stored on server does not match the key
     stored on device.
     */
    private func didReceiveData(
        accountResult: Result<Account, Error>,
        deviceResult: Result<Device, Error>,
        completion: @escaping @Sendable (Result<DeviceCheck, Error>) -> Void
    ) {
        do {
            let accountVerdict = try accountVerdict(from: accountResult)
            let deviceVerdict = try deviceVerdict(from: deviceResult)

            // Do not rotate the key if account is invalid even if the API successfully returns a device.
            if accountVerdict != .invalid, deviceVerdict == .keyMismatch {
                rotateKeyIfNeeded { rotationResult in
                    completion(
                        rotationResult.map { rotationStatus in
                            DeviceCheck(
                                accountVerdict: accountVerdict,
                                deviceVerdict: rotationStatus.isSucceeded ? .active : .keyMismatch,
                                keyRotationStatus: rotationStatus
                            )
                        })
                }
            } else {
                completion(
                    .success(
                        DeviceCheck(
                            accountVerdict: accountVerdict,
                            deviceVerdict: deviceVerdict,
                            keyRotationStatus: .noAction
                        )))
            }
        } catch {
            completion(.failure(error))
        }
    }

    // MARK: - Data fetch

    /// Fetch account and device data simultaneously, upon completion calls completion handler passing the results to
    /// it.
    private func fetchData(
        accountNumber: String, deviceIdentifier: String,
        completion: @escaping (Result<Account, Error>, Result<Device, Error>) -> Void
    ) {
        nonisolated(unsafe) var accountResult: Result<Account, Error> = .failure(OperationError.cancelled)
        nonisolated(unsafe) var deviceResult: Result<Device, Error> = .failure(OperationError.cancelled)

        let dispatchGroup = DispatchGroup()

        dispatchGroup.enter()
        let accountTask = remoteService.getAccountData(accountNumber: accountNumber) { result in
            accountResult = result
            dispatchGroup.leave()
        }

        dispatchGroup.enter()
        let deviceTask = remoteService.getDevice(accountNumber: accountNumber, identifier: deviceIdentifier) { result in
            deviceResult = result
            dispatchGroup.leave()
        }

        tasks.append(contentsOf: [accountTask, deviceTask])

        dispatchGroup.notify(queue: dispatchQueue) {
            completion(accountResult, deviceResult)
        }
    }

    // MARK: - Key rotation

    /**
     Checks if the key should be rotated by checking when the last rotation took place. If conditions are satisfied,
     then it rotate device key by marking the beginning of key rotation, updating device state and persisting before
     proceeding to rotate the key.
     */
    private func rotateKeyIfNeeded(completion: @escaping @Sendable (Result<KeyRotationStatus, Error>) -> Void) {
        let deviceState: DeviceState
        do {
            deviceState = try deviceStateAccessor.read()
        } catch {
            logger.error(error: error, message: "Failed to read device state before rotating the key.")
            completion(.failure(error))
            return
        }

        guard case let .loggedIn(accountData, deviceData) = deviceState else {
            logger.debug("Will not attempt to rotate the key as device is no longer logged in.")
            completion(.failure(DeviceCheckError.invalidDeviceState))
            return
        }

        var keyRotation = WgKeyRotation(data: deviceData)
        guard keyRotation.shouldRotateFromPacketTunnel(rotateImmediately: rotateImmediatelyOnKeyMismatch) else {
            completion(.success(.noAction))
            return
        }

        let publicKey = keyRotation.beginAttempt()

        do {
            try deviceStateAccessor.write(.loggedIn(accountData, keyRotation.data))
        } catch {
            logger.error(error: error, message: "Failed to persist updated device state before rotating the key.")
            completion(.failure(error))
            return
        }

        logger.debug("Rotate private key from packet tunnel.")

        let task = remoteService.rotateDeviceKey(
            accountNumber: accountData.number,
            identifier: deviceData.identifier,
            publicKey: publicKey
        ) { result in
            self.dispatchQueue.async {
                let returnResult = result.tryMap { device -> KeyRotationStatus in
                    try self.completeKeyRotation(device)
                    return .succeeded(Date())
                }
                .flatMapError { error in
                    self.logger.error(error: error, message: "Failed to rotate device key.")

                    if error.isOperationCancellationError {
                        return .failure(error)
                    } else {
                        return .success(.attempted(Date()))
                    }
                }

                completion(returnResult)
            }
        }

        tasks.append(task)
    }

    /**
     Updates device state with the new data received from `Device` and marks key rotation as completed by swapping the
     current private key and erasing information about the last key rotation attempt.
     */
    private func completeKeyRotation(_ device: Device) throws {
        logger.debug("Successfully rotated device key. Persisting device state...")

        let deviceState = try deviceStateAccessor.read()
        guard case let .loggedIn(accountData, deviceData) = deviceState else {
            logger.debug("Will not persist device state after rotating the key because device is no longer logged in.")
            throw DeviceCheckError.invalidDeviceState
        }

        var keyRotation = WgKeyRotation(data: deviceData)
        let isCompleted = keyRotation.setCompleted(with: device)

        if isCompleted {
            do {
                try deviceStateAccessor.write(.loggedIn(accountData, keyRotation.data))
            } catch {
                logger.error(error: error, message: "Failed to persist device state after rotating the key.")
                throw error
            }
        } else {
            logger.debug("Cannot complete key rotation due to rotation race.")

            throw DeviceCheckError.keyRotationRace
        }
    }

    // MARK: - Private helpers

    /// Converts account data result type into `AccountVerdict`.
    private func accountVerdict(from accountResult: Result<Account, Error>) throws -> AccountVerdict {
        do {
            let account = try accountResult.get()

            return account.expiry > Date() ? .active(account) : .expired(account)
        } catch let error as REST.Error where error.compareErrorCode(.invalidAccount) {
            return .invalid
        }
    }

    /// Converts device result type into `DeviceVerdict`.
    private func deviceVerdict(from deviceResult: Result<Device, Error>) throws -> DeviceVerdict {
        do {
            let deviceState = try deviceStateAccessor.read()
            guard let deviceData = deviceState.deviceData else { throw DeviceCheckError.invalidDeviceState }

            let device = try deviceResult.get()

            return deviceData.wgKeyData.privateKey.publicKey == device.pubkey ? .active : .keyMismatch
        } catch let error as REST.Error where error.compareErrorCode(.deviceNotFound) {
            return .revoked
        }
    }
}

/// An error used internally by `DeviceCheckOperation`.
public enum DeviceCheckError: LocalizedError, Equatable {
    /// Device is no longer logged in.
    case invalidDeviceState

    /// Main process has likely performed key rotation at the same time when packet tunnel was doing so.
    case keyRotationRace

    public var errorDescription: String? {
        switch self {
        case .invalidDeviceState:
            return "Cannot complete device check because device is no longer logged in."
        case .keyRotationRace:
            return "Detected key rotation race condition."
        }
    }
}