summaryrefslogtreecommitdiffhomepage
path: root/ios/MullvadRustRuntime/EphemeralPeerExchangeActor.swift
blob: 43ea7b5b32b1b52ed587691680285c329d9af001 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
//
//  EphemeralPeerExchangeActor.swift
//  PacketTunnel
//
//  Created by Marco Nikic on 2024-04-12.
//  Copyright © 2025 Mullvad VPN AB. All rights reserved.
//

import Foundation
import MullvadRustRuntimeProxy
import MullvadTypes
import NetworkExtension
import WireGuardKitTypes

public protocol EphemeralPeerExchangeActorProtocol {
    func startNegotiation(with privateKey: PrivateKey, enablePostQuantum: Bool, enableDaita: Bool)
    func endCurrentNegotiation()
    func reset()
}

public class EphemeralPeerExchangeActor: EphemeralPeerExchangeActorProtocol {
    struct Negotiation {
        var negotiator: EphemeralPeerNegotiating

        func cancel() {
            negotiator.cancelKeyNegotiation()
        }
    }

    unowned let packetTunnel: any TunnelProvider
    internal var negotiation: Negotiation?
    private var timer: DispatchSourceTimer?
    private var keyExchangeRetriesIterator: AnyIterator<Duration>!
    private let iteratorProvider: () -> AnyIterator<Duration>
    private let negotiationProvider: EphemeralPeerNegotiating.Type

    // Callback in the event of the negotiation failing on startup
    var onFailure: () -> Void

    public init(
        packetTunnel: any TunnelProvider,
        onFailure: @escaping (() -> Void),
        negotiationProvider: EphemeralPeerNegotiating.Type = EphemeralPeerNegotiator.self,
        iteratorProvider: @escaping () -> AnyIterator<Duration>
    ) {
        self.packetTunnel = packetTunnel
        self.onFailure = onFailure
        self.negotiationProvider = negotiationProvider
        self.iteratorProvider = iteratorProvider
        self.keyExchangeRetriesIterator = iteratorProvider()
    }

    /// Starts a new key exchange.
    ///
    /// Any ongoing key negotiation is stopped before starting a new one.
    /// An exponential backoff timer is used to stop the exchange if it takes too long,
    /// or if the TCP connection takes too long to become ready.
    /// It is reset after every successful key exchange.
    ///
    /// - Parameter privateKey: The device's current private key
    public func startNegotiation(with privateKey: PrivateKey, enablePostQuantum: Bool, enableDaita: Bool) {
        endCurrentNegotiation()
        let negotiator = negotiationProvider.init()

        // This will become the new private key of the device
        let ephemeralSharedKey = PrivateKey()

        let tcpConnectionTimeout = keyExchangeRetriesIterator.next() ?? .seconds(10)
        // If the connection never becomes viable, force a reconnection after 10 seconds
        let peerParameters = EphemeralPeerParameters(
            peer_exchange_timeout: UInt64(tcpConnectionTimeout.timeInterval),
            enable_post_quantum: enablePostQuantum,
            enable_daita: enableDaita,
            funcs: mapWgFunctions(functions: packetTunnel.wgFunctions())
        )

        if !negotiator.startNegotiation(
            devicePublicKey: privateKey.publicKey,
            presharedKey: ephemeralSharedKey,
            peerReceiver: packetTunnel,
            ephemeralPeerParams: peerParameters
        ) {
            // Cancel the negotiation to shut down any remaining use of the TCP connection on the Rust side
            self.negotiation?.cancel()
            self.negotiation = nil
            self.onFailure()
        }

        negotiation = Negotiation(
            negotiator: negotiator
        )
    }

    private func mapWgFunctions(functions: WgFunctionPointers) -> WgTcpConnectionFunctions {
        var mappedFunctions = WgTcpConnectionFunctions()

        mappedFunctions.close_fn = functions.close
        mappedFunctions.open_fn = functions.open
        mappedFunctions.send_fn = functions.send
        mappedFunctions.recv_fn = functions.receive

        return mappedFunctions
    }

    /// Cancels the ongoing key exchange.
    public func endCurrentNegotiation() {
        negotiation?.cancel()
        negotiation = nil
    }

    /// Resets the exponential timeout for successful key exchanges, and ends the current key exchange.
    public func reset() {
        keyExchangeRetriesIterator = iteratorProvider()
        endCurrentNegotiation()
    }
}