diff options
| -rw-r--r-- | ios/MullvadVPN/REST/ServerRelaysResponse.swift | 2 | ||||
| -rw-r--r-- | ios/MullvadVPN/RelaySelector.swift | 40 |
2 files changed, 32 insertions, 10 deletions
diff --git a/ios/MullvadVPN/REST/ServerRelaysResponse.swift b/ios/MullvadVPN/REST/ServerRelaysResponse.swift index 1604e15ee2..1f93954d33 100644 --- a/ios/MullvadVPN/REST/ServerRelaysResponse.swift +++ b/ios/MullvadVPN/REST/ServerRelaysResponse.swift @@ -24,7 +24,7 @@ extension REST { let owned: Bool let location: String let provider: String - let weight: Int32 + let weight: UInt64 let ipv4AddrIn: IPv4Address let ipv6AddrIn: IPv6Address let publicKey: Data diff --git a/ios/MullvadVPN/RelaySelector.swift b/ios/MullvadVPN/RelaySelector.swift index a3f9f7e474..f0bb12ce7f 100644 --- a/ios/MullvadVPN/RelaySelector.swift +++ b/ios/MullvadVPN/RelaySelector.swift @@ -36,16 +36,11 @@ enum RelaySelector {} extension RelaySelector { static func evaluate(relays: REST.ServerRelaysResponse, constraints: RelayConstraints) -> RelaySelectorResult? { - let filteredRelays = Self.applyConstraints(constraints, relays: Self.parseRelaysResponse(relays)) - let totalWeight = filteredRelays.reduce(0) { $0 + $1.relay.weight } + let filteredRelays = applyConstraints(constraints, relays: Self.parseRelaysResponse(relays)) - guard totalWeight > 0 else { return nil } - guard var i = (0...totalWeight).randomElement() else { return nil } - - let relayWithLocation = filteredRelays.first { (relayWithLocation) -> Bool in - i -= relayWithLocation.relay.weight - return i <= 0 - }.unsafelyUnwrapped + guard let relayWithLocation = pickRandomRelay(relays: filteredRelays) else { + return nil + } guard let port = relays.wireguard.portRanges.randomElement()?.randomElement() else { return nil @@ -93,6 +88,33 @@ extension RelaySelector { } } + private static func pickRandomRelay(relays: [RelayWithLocation]) -> RelayWithLocation? { + let totalWeight = relays.reduce(0) { accummulatedWeight, relayWithLocation in + return accummulatedWeight + relayWithLocation.relay.weight + } + + // Return random relay when all relays within the list have zero weight. + guard totalWeight > 0 else { + return relays.randomElement() + } + + // Pick a random number in the range 1 - totalWeight. This choses the relay with a + // non-zero weight. + var i = (1...totalWeight).randomElement()! + + let randomRelay = relays.first { (relayWithLocation) -> Bool in + let (result, isOverflow) = i.subtractingReportingOverflow(relayWithLocation.relay.weight) + + i = isOverflow ? 0 : result + + return i == 0 + } + + precondition(randomRelay != nil, "At least one relay must've had a weight above 0") + + return randomRelay + } + private static func parseRelaysResponse(_ response: REST.ServerRelaysResponse) -> [RelayWithLocation] { return response.wireguard.relays.compactMap { (serverRelay) -> RelayWithLocation? in guard let serverLocation = response.locations[serverRelay.location] else { return nil } |
