summaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
-rw-r--r--feature/conn25/ippool.go93
-rw-r--r--feature/conn25/ippool_test.go87
2 files changed, 156 insertions, 24 deletions
diff --git a/feature/conn25/ippool.go b/feature/conn25/ippool.go
index e50186d88..4ae8918d4 100644
--- a/feature/conn25/ippool.go
+++ b/feature/conn25/ippool.go
@@ -8,17 +8,24 @@ import (
"net/netip"
"go4.org/netipx"
+ "tailscale.com/util/set"
)
// errPoolExhausted is returned when there are no more addresses to iterate over.
var errPoolExhausted = errors.New("ip pool exhausted")
-// ippool allows for iteration over all the addresses within a netipx.IPSet.
+// errNotOurAddress is returned if a provided address is not from our pool
+var errNotOurAddress = errors.New("not our address")
+
+// errAddrExists is returned if a returned address is already in the returned pool.
+var errAddrExists = errors.New("address already returned")
+
+// ipSetIterator allows for round robin iteration over all the addresses within a netipx.IPSet.
// netipx.IPSet has a Ranges call that returns the "minimum and sorted set of IP ranges that covers [the set]".
// netipx.IPRange is "an inclusive range of IP addresses from the same address family.". So we can iterate over
// all the addresses in the set by keeping a track of the last address we returned, calling Next on the last address
-// to get the new one, and if we run off the edge of the current range, starting on the next one.
-type ippool struct {
+// to get the new one, and if we run off the edge of the current range, starting on the next one, or back at the beginning.
+type ipSetIterator struct {
// ranges defines the addresses in the pool
ranges []netipx.IPRange
// last is internal tracking of which the last address provided was.
@@ -27,35 +34,75 @@ type ippool struct {
rangeIdx int
}
+// next returns the next address from the set.
+func (ipsi *ipSetIterator) next() (netip.Addr, error) {
+ if len(ipsi.ranges) == 0 {
+ // ipset is empty
+ return netip.Addr{}, errPoolExhausted
+ }
+ if !ipsi.last.IsValid() {
+ // not initialized yet
+ ipsi.last = ipsi.ranges[0].From()
+ return ipsi.last, nil
+ }
+ currRange := ipsi.ranges[ipsi.rangeIdx]
+ if ipsi.last == currRange.To() {
+ // then we need to move to the next range
+ ipsi.rangeIdx++
+ if ipsi.rangeIdx >= len(ipsi.ranges) {
+ // back to the beginning
+ ipsi.rangeIdx = 0
+ }
+ ipsi.last = ipsi.ranges[ipsi.rangeIdx].From()
+ return ipsi.last, nil
+ }
+ ipsi.last = ipsi.last.Next()
+ return ipsi.last, nil
+}
+
func newIPPool(ipset *netipx.IPSet) *ippool {
if ipset == nil {
return &ippool{}
}
- return &ippool{ranges: ipset.Ranges()}
+ return &ippool{
+ ipSet: ipset,
+ ipSetIterator: &ipSetIterator{ranges: ipset.Ranges()},
+ inUse: &set.Set[netip.Addr]{},
+ }
+}
+
+type ippool struct {
+ ipSet *netipx.IPSet
+ ipSetIterator *ipSetIterator
+ inUse *set.Set[netip.Addr]
}
-// next returns the next address from the set, or errPoolExhausted if we have
-// iterated over the whole set.
func (ipp *ippool) next() (netip.Addr, error) {
- if ipp.rangeIdx >= len(ipp.ranges) {
- // ipset is empty or we have iterated off the end
- return netip.Addr{}, errPoolExhausted
- }
- if !ipp.last.IsValid() {
- // not initialized yet
- ipp.last = ipp.ranges[0].From()
- return ipp.last, nil
+ a, err := ipp.ipSetIterator.next()
+ if err != nil {
+ return netip.Addr{}, err
}
- currRange := ipp.ranges[ipp.rangeIdx]
- if ipp.last == currRange.To() {
- // then we need to move to the next range
- ipp.rangeIdx++
- if ipp.rangeIdx >= len(ipp.ranges) {
+ startedAt := a
+ for ipp.inUse.Contains(a) {
+ a, err = ipp.ipSetIterator.next()
+ if err != nil {
+ return a, err
+ }
+ if a == startedAt {
return netip.Addr{}, errPoolExhausted
}
- ipp.last = ipp.ranges[ipp.rangeIdx].From()
- return ipp.last, nil
}
- ipp.last = ipp.last.Next()
- return ipp.last, nil
+ ipp.inUse.Add(a)
+ return a, nil
+}
+
+func (ipp *ippool) returnAddr(a netip.Addr) error {
+ if !ipp.ipSet.Contains(a) {
+ return errNotOurAddress
+ }
+ if !ipp.inUse.Contains(a) {
+ return errAddrExists
+ }
+ ipp.inUse.Delete(a)
+ return nil
}
diff --git a/feature/conn25/ippool_test.go b/feature/conn25/ippool_test.go
index ccfaad3eb..431ea6998 100644
--- a/feature/conn25/ippool_test.go
+++ b/feature/conn25/ippool_test.go
@@ -13,7 +13,7 @@ import (
)
func TestNext(t *testing.T) {
- a := ippool{}
+ a := ipSetIterator{}
_, err := a.next()
if !errors.Is(err, errPoolExhausted) {
t.Fatalf("expected errPoolExhausted, got %v", err)
@@ -58,3 +58,88 @@ func TestNext(t *testing.T) {
t.Fatalf("expected errPoolExhausted, got %v", err)
}
}
+
+// TestReturnAddr tests that if a pool is exhausted, an address can be returned to the
+// pool, and then that address will be handed out again.
+func TestReturnAddr(t *testing.T) {
+ addrString := "192.168.0.0"
+ // There's an IPPool with one address in it.
+ var isb netipx.IPSetBuilder
+ isb.AddRange(netipx.IPRangeFrom(netip.MustParseAddr(addrString), netip.MustParseAddr(addrString)))
+ ipset := must.Get(isb.IPSet())
+ ipp := newIPPool(ipset)
+ // The first time we call next we get the address.
+ addr, err := ipp.next()
+ if err != nil {
+ t.Fatalf("expected nil error, got: %v", err)
+ }
+ if addr != netip.MustParseAddr(addrString) {
+ t.Fatalf("want %v, got %v", addrString, addr)
+ }
+ // The second time we call next we get errPoolExhausted
+ _, err = ipp.next()
+ if !errors.Is(err, errPoolExhausted) {
+ t.Fatalf("expected errPoolExhausted, got %v", err)
+ }
+ // Return the addr to the pool
+ err = ipp.returnAddr(netip.MustParseAddr(addrString))
+ if err != nil {
+ t.Fatal(err)
+ }
+ // It's not possible to return addresses that are already in the pool.
+ err = ipp.returnAddr(netip.MustParseAddr(addrString))
+ if !errors.Is(err, errAddrExists) {
+ t.Fatalf("want errAddrExists, got: %v", err)
+ }
+ // When we call next we get the returned addr
+ addrAfterReturn, err := ipp.next()
+ if err != nil {
+ t.Fatalf("expected nil error, got: %v", err)
+ }
+ if addrAfterReturn != netip.MustParseAddr(addrString) {
+ t.Fatalf("want %v, got %v", addrString, addrAfterReturn)
+ }
+ // You can't return addresses that aren't from the pool.
+ err = ipp.returnAddr(netip.MustParseAddr("100.100.100.0"))
+ if !errors.Is(err, errNotOurAddress) {
+ t.Fatalf("want errNotOurAddress, got: %v", err)
+ }
+}
+
+// TestGettingReturnedAddresses tests that when addresses are returned to the IP Pool
+// they are then handed out in the order they were returned.
+func TestGettingReturnedAddresses(t *testing.T) {
+ var isb netipx.IPSetBuilder
+ isb.AddRange(netipx.IPRangeFrom(netip.MustParseAddr("192.168.0.0"), netip.MustParseAddr("192.168.0.4")))
+ ipset := must.Get(isb.IPSet())
+ ipp := newIPPool(ipset)
+ expectAddrNext := func(addrString string) {
+ t.Helper()
+ got, err := ipp.next()
+ if err != nil {
+ t.Fatalf("expected nil error, got: %v", err)
+ }
+ want := netip.MustParseAddr(addrString)
+ if want != got {
+ t.Fatalf("want %v; got %v", want, got)
+ }
+ }
+ expectErrPoolExhaustedNext := func() {
+ t.Helper()
+ _, err := ipp.next()
+ if !errors.Is(err, errPoolExhausted) {
+ t.Fatalf("expected errPoolExhausted; got %v", err)
+ }
+ }
+ expectAddrNext("192.168.0.0")
+ expectAddrNext("192.168.0.1")
+ expectAddrNext("192.168.0.2")
+ expectAddrNext("192.168.0.3")
+ expectAddrNext("192.168.0.4")
+ expectErrPoolExhaustedNext()
+ ipp.returnAddr(netip.MustParseAddr("192.168.0.2"))
+ ipp.returnAddr(netip.MustParseAddr("192.168.0.4"))
+ expectAddrNext("192.168.0.2")
+ expectAddrNext("192.168.0.4")
+ expectErrPoolExhaustedNext()
+}