summaryrefslogtreecommitdiffhomepage
path: root/ios/MullvadREST/Transport/Socks5/Socks5EndpointReader.swift
blob: bf7c3244d96d3b86dc7e78e3f766cf1abc4469d6 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
//
//  Socks5EndpointReader.swift
//  MullvadTransport
//
//  Created by pronebird on 21/10/2023.
//

import Foundation
import MullvadTypes
import Network

/// The object reading the endpoint data from connection.
struct Socks5EndpointReader: Sendable {
    /// Connection to the socks proxy.
    let connection: NWConnection

    /// The expected address type.
    let addressType: Socks5AddressType

    /// Completion handler called upon success.
    let onComplete: @Sendable (Socks5Endpoint) -> Void

    /// Failure handler.
    let onFailure: @Sendable (Error) -> Void

    /// Start reading endpoint from connection.
    func perform() {
        // The length of IPv4 address in bytes.
        let ipv4AddressLength = 4

        // The length of IPv6 address in bytes.
        let ipv6AddressLength = 16

        switch addressType {
        case .ipv4:
            readBoundAddressAndPortInner(addressLength: ipv4AddressLength)

        case .ipv6:
            readBoundAddressAndPortInner(addressLength: ipv6AddressLength)

        case .domainName:
            readBoundDomainNameLength { [self] domainLength in
                readBoundAddressAndPortInner(addressLength: domainLength)
            }
        }
    }

    private func readBoundAddressAndPortInner(addressLength: Int) {
        // The length of port in bytes.
        let portLength = MemoryLayout<UInt16>.size

        // The entire length of address + port
        let byteSize = addressLength + portLength

        connection.receive(exactLength: byteSize) { [self] addressData, _, _, error in
            if let error {
                onFailure(Socks5Error.remoteConnectionFailure(error))
            } else if let addressData {
                do {
                    let endpoint = try parseEndpoint(addressData: addressData, addressLength: addressLength)

                    onComplete(endpoint)
                } catch {
                    onFailure(error)
                }
            } else {
                onFailure(Socks5Error.unexpectedEndOfStream)
            }
        }
    }

    private func readBoundDomainNameLength(completion: @escaping @Sendable (Int) -> Void) {
        // The length of domain length parameter in bytes.
        let domainLengthLength = MemoryLayout<UInt8>.size

        connection.receive(exactLength: domainLengthLength) { [self] data, _, _, error in
            if let error {
                onFailure(Socks5Error.remoteConnectionFailure(error))
            } else if let domainNameLength = data?.first {
                completion(Int(domainNameLength))
            } else {
                onFailure(Socks5Error.unexpectedEndOfStream)
            }
        }
    }

    private func parseEndpoint(addressData: Data, addressLength: Int) throws -> Socks5Endpoint {
        // The length of port in bytes.
        let portLength = MemoryLayout<UInt16>.size

        guard addressData.count == addressLength + portLength else { throw Socks5Error.unexpectedEndOfStream }

        // Read address bytes.
        let addressBytes = addressData[0..<addressLength]

        // Read port bytes.
        let port = addressData[addressLength...].withUnsafeBytes { buffer in
            let value = buffer.load(as: UInt16.self)

            // Port is passed in network byte order. Convert it to host order.
            return UInt16(bigEndian: value)
        }

        // Parse address into endpoint.
        switch addressType {
        case .ipv4:
            guard let ipAddress = IPv4Address(addressBytes) else { throw Socks5Error.parseIPv4Address }

            return .ipv4(IPv4Endpoint(ip: ipAddress, port: port))

        case .ipv6:
            guard let ipAddress = IPv6Address(addressBytes) else { throw Socks5Error.parseIPv6Address }

            return .ipv6(IPv6Endpoint(ip: ipAddress, port: port))

        case .domainName:
            guard let hostname = String(bytes: addressBytes, encoding: .utf8),
                let endpoint = Socks5HostEndpoint(hostname: hostname, port: port)
            else {
                throw Socks5Error.decodeDomainName
            }
            return .domain(endpoint)
        }
    }
}