summaryrefslogtreecommitdiffhomepage
path: root/feature/conn25/ippool.go
blob: 4ae8918d49397f9dfa9a3af6a69f12cc415afcf0 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
// Copyright (c) Tailscale Inc & contributors
// SPDX-License-Identifier: BSD-3-Clause

package conn25

import (
	"errors"
	"net/netip"

	"go4.org/netipx"
	"tailscale.com/util/set"
)

// errPoolExhausted is returned when there are no more addresses to iterate over.
var errPoolExhausted = errors.New("ip pool exhausted")

// errNotOurAddress is returned if a provided address is not from our pool
var errNotOurAddress = errors.New("not our address")

// errAddrExists is returned if a returned address is already in the returned pool.
var errAddrExists = errors.New("address already returned")

// ipSetIterator allows for round robin iteration over all the addresses within a netipx.IPSet.
// netipx.IPSet has a Ranges call that returns the "minimum and sorted set of IP ranges that covers [the set]".
// netipx.IPRange is "an inclusive range of IP addresses from the same address family.". So we can iterate over
// all the addresses in the set by keeping a track of the last address we returned, calling Next on the last address
// to get the new one, and if we run off the edge of the current range, starting on the next one, or back at the beginning.
type ipSetIterator struct {
	// ranges defines the addresses in the pool
	ranges []netipx.IPRange
	// last is internal tracking of which the last address provided was.
	last netip.Addr
	// rangeIdx is internal tracking of which netipx.IPRange from the IPSet we are currently on.
	rangeIdx int
}

// next returns the next address from the set.
func (ipsi *ipSetIterator) next() (netip.Addr, error) {
	if len(ipsi.ranges) == 0 {
		// ipset is empty
		return netip.Addr{}, errPoolExhausted
	}
	if !ipsi.last.IsValid() {
		// not initialized yet
		ipsi.last = ipsi.ranges[0].From()
		return ipsi.last, nil
	}
	currRange := ipsi.ranges[ipsi.rangeIdx]
	if ipsi.last == currRange.To() {
		// then we need to move to the next range
		ipsi.rangeIdx++
		if ipsi.rangeIdx >= len(ipsi.ranges) {
			// back to the beginning
			ipsi.rangeIdx = 0
		}
		ipsi.last = ipsi.ranges[ipsi.rangeIdx].From()
		return ipsi.last, nil
	}
	ipsi.last = ipsi.last.Next()
	return ipsi.last, nil
}

func newIPPool(ipset *netipx.IPSet) *ippool {
	if ipset == nil {
		return &ippool{}
	}
	return &ippool{
		ipSet:         ipset,
		ipSetIterator: &ipSetIterator{ranges: ipset.Ranges()},
		inUse:         &set.Set[netip.Addr]{},
	}
}

type ippool struct {
	ipSet         *netipx.IPSet
	ipSetIterator *ipSetIterator
	inUse         *set.Set[netip.Addr]
}

func (ipp *ippool) next() (netip.Addr, error) {
	a, err := ipp.ipSetIterator.next()
	if err != nil {
		return netip.Addr{}, err
	}
	startedAt := a
	for ipp.inUse.Contains(a) {
		a, err = ipp.ipSetIterator.next()
		if err != nil {
			return a, err
		}
		if a == startedAt {
			return netip.Addr{}, errPoolExhausted
		}
	}
	ipp.inUse.Add(a)
	return a, nil
}

func (ipp *ippool) returnAddr(a netip.Addr) error {
	if !ipp.ipSet.Contains(a) {
		return errNotOurAddress
	}
	if !ipp.inUse.Contains(a) {
		return errAddrExists
	}
	ipp.inUse.Delete(a)
	return nil
}