summaryrefslogtreecommitdiffhomepage
path: root/ios/MullvadVPN/TunnelManager/RotateKeyOperation.swift
blob: 632dd89a0b3ba0f51c80d1f10b597b3a8b431022 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
//
//  RotateKeyOperation.swift
//  MullvadVPN
//
//  Created by pronebird on 15/12/2021.
//  Copyright © 2025 Mullvad VPN AB. All rights reserved.
//

import Foundation
import MullvadLogging
import MullvadREST
import MullvadSettings
import MullvadTypes
import Operations
import WireGuardKitTypes

class RotateKeyOperation: ResultOperation<Void>, @unchecked Sendable {
    private let logger = Logger(label: "RotateKeyOperation")
    private let interactor: TunnelInteractor
    private let devicesProxy: DeviceHandling
    private var task: Cancellable?

    init(dispatchQueue: DispatchQueue, interactor: TunnelInteractor, devicesProxy: DeviceHandling) {
        self.interactor = interactor
        self.devicesProxy = devicesProxy

        super.init(dispatchQueue: dispatchQueue, completionQueue: nil, completionHandler: nil)
    }

    override func main() {
        // Extract login metadata.
        guard case let .loggedIn(accountData, deviceData) = interactor.deviceState else {
            finish(result: .failure(InvalidDeviceStateError()))
            return
        }

        // Create key rotation.
        nonisolated(unsafe) var keyRotation = WgKeyRotation(data: deviceData)

        // Check if key rotation can take place.
        guard keyRotation.shouldRotate else {
            logger.debug("Throttle private key rotation.")
            finish(result: .success(()))
            return
        }

        logger.debug("Private key is old enough, rotate right away.")

        // Mark the beginning of key rotation and receive the public key to push to backend.
        let publicKey = keyRotation.beginAttempt()

        // Persist mutated device data.
        interactor.setDeviceState(.loggedIn(accountData, keyRotation.data), persist: true)

        // Send REST request to rotate the device key.
        logger.debug("Replacing old key with new key on server...")

        task = devicesProxy.rotateDeviceKey(
            accountNumber: accountData.number,
            identifier: deviceData.identifier,
            publicKey: publicKey,
            retryStrategy: .default
        ) { [self] result in
            dispatchQueue.async { [self] in
                switch result {
                case let .success(device):
                    handleSuccess(accountData: accountData, fetchedDevice: device, keyRotation: keyRotation)
                case let .failure(error):
                    handleError(error)
                }
            }
        }
    }

    override func operationDidCancel() {
        task?.cancel()
        task = nil
    }

    private func handleSuccess(accountData: StoredAccountData, fetchedDevice: Device, keyRotation: WgKeyRotation) {
        logger.debug("Successfully rotated device key. Persisting device state...")

        var keyRotation = keyRotation

        // Mark key rotation completed.
        _ = keyRotation.setCompleted(with: fetchedDevice)

        // Persist changes.
        interactor.setDeviceState(.loggedIn(accountData, keyRotation.data), persist: true)

        // Notify the tunnel that key rotation took place and that it should reload VPN configuration.
        if let tunnel = interactor.tunnel {
            _ = tunnel.notifyKeyRotation { [weak self] _ in
                self?.finish(result: .success(()))
            }
        } else {
            finish(result: .success(()))
        }
    }

    private func handleError(_ error: Error) {
        if !error.isOperationCancellationError {
            logger.error(error: error, message: "Failed to rotate device key.")
        }

        interactor.handleRestError(error)
        finish(result: .failure(error))
    }
}