summaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorAndrew Dunham <andrew@du.nham.ca>2023-08-14 21:06:38 -0700
committerAndrew Dunham <andrew@du.nham.ca>2023-08-15 14:06:42 -0700
commit95d776bd8cdbbd2b4acc40e2fbccc65607b2818a (patch)
treea3f32e1aee95f286ad878353fee73228c37afc0b
parent9c4364e0b764e39084f207347b5cfd0dfba5a59d (diff)
downloadtailscale-95d776bd8cdbbd2b4acc40e2fbccc65607b2818a.tar.xz
tailscale-95d776bd8cdbbd2b4acc40e2fbccc65607b2818a.zip
wgengine/magicsock: only cache N most recent endpoints per-Addr
If a node is flapping or otherwise generating lots of STUN endpoints, we can end up caching a ton of useless values and sending them to peers. Instead, let's apply a fixed per-Addr limit of endpoints that we cache, so that we're only sending peers up to the N most recent. Updates tailscale/corp#13890 Signed-off-by: Andrew Dunham <andrew@du.nham.ca> Change-Id: I8079a05b44220c46da55016c0e5fc96dd2135ef8
-rw-r--r--cmd/tailscaled/depaware.txt3
-rw-r--r--wgengine/magicsock/endpoint_tracker.go248
-rw-r--r--wgengine/magicsock/endpoint_tracker_test.go187
-rw-r--r--wgengine/magicsock/magicsock.go79
-rw-r--r--wgengine/magicsock/magicsock_test.go112
5 files changed, 438 insertions, 191 deletions
diff --git a/cmd/tailscaled/depaware.txt b/cmd/tailscaled/depaware.txt
index a20495608..0a685d61f 100644
--- a/cmd/tailscaled/depaware.txt
+++ b/cmd/tailscaled/depaware.txt
@@ -292,6 +292,7 @@ tailscale.com/cmd/tailscaled dependencies: (generated by github.com/tailscale/de
tailscale.com/tailcfg from tailscale.com/client/tailscale/apitype+
💣 tailscale.com/tempfork/device from tailscale.com/net/tstun/table
LD tailscale.com/tempfork/gliderlabs/ssh from tailscale.com/ssh/tailssh
+ tailscale.com/tempfork/heap from tailscale.com/wgengine/magicsock
tailscale.com/tka from tailscale.com/ipn/ipnlocal+
W tailscale.com/tsconst from tailscale.com/net/interfaces
tailscale.com/tsd from tailscale.com/cmd/tailscaled+
@@ -411,6 +412,7 @@ tailscale.com/cmd/tailscaled dependencies: (generated by github.com/tailscale/de
golang.org/x/time/rate from gvisor.dev/gvisor/pkg/tcpip/stack+
bufio from compress/flate+
bytes from bufio+
+ cmp from slices
compress/flate from compress/gzip+
compress/gzip from golang.org/x/net/http2+
W compress/zlib from debug/pe
@@ -495,6 +497,7 @@ tailscale.com/cmd/tailscaled dependencies: (generated by github.com/tailscale/de
runtime/debug from github.com/klauspost/compress/zstd+
runtime/pprof from tailscale.com/log/logheap+
runtime/trace from net/http/pprof
+ slices from tailscale.com/wgengine/magicsock
sort from compress/flate+
strconv from compress/flate+
strings from bufio+
diff --git a/wgengine/magicsock/endpoint_tracker.go b/wgengine/magicsock/endpoint_tracker.go
new file mode 100644
index 000000000..5caddd1a0
--- /dev/null
+++ b/wgengine/magicsock/endpoint_tracker.go
@@ -0,0 +1,248 @@
+// Copyright (c) Tailscale Inc & AUTHORS
+// SPDX-License-Identifier: BSD-3-Clause
+
+package magicsock
+
+import (
+ "net/netip"
+ "slices"
+ "sync"
+ "time"
+
+ "tailscale.com/tailcfg"
+ "tailscale.com/tempfork/heap"
+ "tailscale.com/util/mak"
+ "tailscale.com/util/set"
+)
+
+const (
+ // endpointTrackerLifetime is how long we continue advertising an
+ // endpoint after we last see it. This is intentionally chosen to be
+ // slightly longer than a full netcheck period.
+ endpointTrackerLifetime = 5*time.Minute + 10*time.Second
+
+ // endpointTrackerMaxPerAddr is how many cached addresses we track for
+ // a given netip.Addr. This allows e.g. restricting the number of STUN
+ // endpoints we cache (which usually have the same netip.Addr but
+ // different ports).
+ //
+ // The value of 6 is chosen because we can advertise up to 3 endpoints
+ // based on the STUN IP:
+ // 1. The STUN endpoint itself (EndpointSTUN)
+ // 2. The STUN IP with the local Tailscale port (EndpointSTUN4LocalPort)
+ // 3. The STUN IP with a portmapped port (EndpointPortmapped)
+ //
+ // Storing 6 endpoints in the cache means we can store up to 2 previous
+ // sets of endpoints.
+ endpointTrackerMaxPerAddr = 6
+)
+
+// endpointTrackerEntry is an entry in an endpointHeap that stores the state of
+// a given cached endpoint.
+type endpointTrackerEntry struct {
+ // endpoint is the cached endpoint.
+ endpoint tailcfg.Endpoint
+ // until is the time until which this endpoint is being cached.
+ until time.Time
+ // index is the index within the containing endpointHeap.
+ index int
+}
+
+// endpointHeap is an ordered heap of endpointTrackerEntry structs, ordered in
+// ascending order by the 'until' expiry time (i.e. oldest first).
+type endpointHeap []*endpointTrackerEntry
+
+var _ heap.Interface[*endpointTrackerEntry] = (*endpointHeap)(nil)
+
+// Len implements heap.Interface.
+func (eh endpointHeap) Len() int { return len(eh) }
+
+// Less implements heap.Interface.
+func (eh endpointHeap) Less(i, j int) bool {
+ // We want to store items so that the lowest item in the heap is the
+ // oldest, so that heap.Pop()-ing from the endpointHeap will remove the
+ // oldest entry.
+ return eh[i].until.Before(eh[j].until)
+}
+
+// Swap implements heap.Interface.
+func (eh endpointHeap) Swap(i, j int) {
+ eh[i], eh[j] = eh[j], eh[i]
+ eh[i].index = i
+ eh[j].index = j
+}
+
+// Push implements heap.Interface.
+func (eh *endpointHeap) Push(item *endpointTrackerEntry) {
+ n := len(*eh)
+ item.index = n
+ *eh = append(*eh, item)
+}
+
+// Pop implements heap.Interface.
+func (eh *endpointHeap) Pop() *endpointTrackerEntry {
+ old := *eh
+ n := len(old)
+ item := old[n-1]
+ old[n-1] = nil // avoid memory leak
+ item.index = -1 // for safety
+ *eh = old[0 : n-1]
+ return item
+}
+
+// Min returns a pointer to the minimum element in the heap, without removing
+// it. Since this is a min-heap ordered by the 'until' field, this returns the
+// chronologically "earliest" element in the heap.
+//
+// Len() must be non-zero.
+func (eh endpointHeap) Min() *endpointTrackerEntry {
+ return eh[0]
+}
+
+// endpointTracker caches endpoints that are advertised to peers. This allows
+// peers to still reach this node if there's a temporary endpoint flap; rather
+// than withdrawing an endpoint and then re-advertising it the next time we run
+// a netcheck, we keep advertising the endpoint until it's not present for a
+// defined timeout.
+//
+// See tailscale/tailscale#7877 for more information.
+type endpointTracker struct {
+ mu sync.Mutex
+ endpoints map[netip.Addr]*endpointHeap
+}
+
+// update takes as input the current sent of discovered endpoints and the
+// current time, and returns the set of endpoints plus any previous-cached and
+// non-expired endpoints that should be advertised to peers.
+func (et *endpointTracker) update(now time.Time, eps []tailcfg.Endpoint) (epsPlusCached []tailcfg.Endpoint) {
+ var inputEps set.Slice[netip.AddrPort]
+ for _, ep := range eps {
+ inputEps.Add(ep.Addr)
+ }
+
+ et.mu.Lock()
+ defer et.mu.Unlock()
+
+ // Extend endpoints that already exist in the cache. We do this before
+ // we remove expired endpoints, below, so we don't remove something
+ // that would otherwise have survived by extending.
+ until := now.Add(endpointTrackerLifetime)
+ for _, ep := range eps {
+ et.extendLocked(ep, until)
+ }
+
+ // Now that we've extended existing endpoints, remove everything that
+ // has expired.
+ et.removeExpiredLocked(now)
+
+ // Add entries from the input set of endpoints into the cache; we do
+ // this after removing expired ones so that we can store as many as
+ // possible, with space freed by the entries removed after expiry.
+ for _, ep := range eps {
+ et.addLocked(now, ep, until)
+ }
+
+ // Finally, add entries to the return array that aren't already there.
+ epsPlusCached = eps
+ for _, heap := range et.endpoints {
+ for _, ep := range *heap {
+ // If the endpoint was in the input list, or has expired, skip it.
+ if inputEps.Contains(ep.endpoint.Addr) {
+ continue
+ } else if now.After(ep.until) {
+ // Defense-in-depth; should never happen since
+ // we removed expired entries above, but ignore
+ // it anyway.
+ continue
+ }
+
+ // We haven't seen this endpoint; add to the return array
+ epsPlusCached = append(epsPlusCached, ep.endpoint)
+ }
+ }
+
+ return epsPlusCached
+}
+
+// extendLocked will update the expiry time of the provided endpoint in the
+// cache, if it is present. If it is not present, nothing will be done.
+//
+// et.mu must be held.
+func (et *endpointTracker) extendLocked(ep tailcfg.Endpoint, until time.Time) {
+ key := ep.Addr.Addr()
+ epHeap, found := et.endpoints[key]
+ if !found {
+ return
+ }
+
+ // Find the entry for this exact address; this loop is quick since we
+ // bound the number of items in the heap.
+ //
+ // TODO(andrew): this means we iterate over the entire heap once per
+ // endpoint; even if the heap is small, if we have a lot of input
+ // endpoints this can be expensive?
+ for i, entry := range *epHeap {
+ if entry.endpoint == ep {
+ entry.until = until
+ heap.Fix(epHeap, i)
+ return
+ }
+ }
+}
+
+// addLocked will store the provided endpoint(s) in the cache for a fixed
+// period of time, ensuring that the size of the endpoint cache remains below
+// the maximum.
+//
+// et.mu must be held.
+func (et *endpointTracker) addLocked(now time.Time, ep tailcfg.Endpoint, until time.Time) {
+ key := ep.Addr.Addr()
+
+ // Create or get the heap for this endpoint's addr
+ epHeap := et.endpoints[key]
+ if epHeap == nil {
+ epHeap = new(endpointHeap)
+ mak.Set(&et.endpoints, key, epHeap)
+ }
+
+ // Find the entry for this exact address; this loop is quick
+ // since we bound the number of items in the heap.
+ found := slices.ContainsFunc(*epHeap, func(v *endpointTrackerEntry) bool {
+ return v.endpoint == ep
+ })
+ if !found {
+ // Add address to heap; either the endpoint is new, or the heap
+ // was newly-created and thus empty.
+ heap.Push(epHeap, &endpointTrackerEntry{endpoint: ep, until: until})
+ }
+
+ // Now that we've added everything, pop from our heap until we're below
+ // the limit. This is a min-heap, so popping removes the lowest (and
+ // thus oldest) endpoint.
+ for epHeap.Len() > endpointTrackerMaxPerAddr {
+ heap.Pop(epHeap)
+ }
+}
+
+// removeExpired will remove all expired entries from the cache.
+//
+// et.mu must be held.
+func (et *endpointTracker) removeExpiredLocked(now time.Time) {
+ for k, epHeap := range et.endpoints {
+ // The minimum element is oldest/earliest endpoint; repeatedly
+ // pop from the heap while it's in the past.
+ for epHeap.Len() > 0 {
+ minElem := epHeap.Min()
+ if now.After(minElem.until) {
+ heap.Pop(epHeap)
+ } else {
+ break
+ }
+ }
+
+ if epHeap.Len() == 0 {
+ // Free up space in the map by removing the empty heap.
+ delete(et.endpoints, k)
+ }
+ }
+}
diff --git a/wgengine/magicsock/endpoint_tracker_test.go b/wgengine/magicsock/endpoint_tracker_test.go
new file mode 100644
index 000000000..b6a2699c1
--- /dev/null
+++ b/wgengine/magicsock/endpoint_tracker_test.go
@@ -0,0 +1,187 @@
+// Copyright (c) Tailscale Inc & AUTHORS
+// SPDX-License-Identifier: BSD-3-Clause
+
+package magicsock
+
+import (
+ "net/netip"
+ "reflect"
+ "slices"
+ "strings"
+ "testing"
+ "time"
+
+ "tailscale.com/tailcfg"
+)
+
+func TestEndpointTracker(t *testing.T) {
+ local := tailcfg.Endpoint{
+ Addr: netip.MustParseAddrPort("192.168.1.1:12345"),
+ Type: tailcfg.EndpointLocal,
+ }
+
+ stun4_1 := tailcfg.Endpoint{
+ Addr: netip.MustParseAddrPort("1.2.3.4:12345"),
+ Type: tailcfg.EndpointSTUN,
+ }
+ stun4_2 := tailcfg.Endpoint{
+ Addr: netip.MustParseAddrPort("5.6.7.8:12345"),
+ Type: tailcfg.EndpointSTUN,
+ }
+
+ stun6_1 := tailcfg.Endpoint{
+ Addr: netip.MustParseAddrPort("[2a09:8280:1::1111]:12345"),
+ Type: tailcfg.EndpointSTUN,
+ }
+ stun6_2 := tailcfg.Endpoint{
+ Addr: netip.MustParseAddrPort("[2a09:8280:1::2222]:12345"),
+ Type: tailcfg.EndpointSTUN,
+ }
+
+ start := time.Unix(1681503440, 0)
+
+ steps := []struct {
+ name string
+ now time.Time
+ eps []tailcfg.Endpoint
+ want []tailcfg.Endpoint
+ }{
+ {
+ name: "initial endpoints",
+ now: start,
+ eps: []tailcfg.Endpoint{local, stun4_1, stun6_1},
+ want: []tailcfg.Endpoint{local, stun4_1, stun6_1},
+ },
+ {
+ name: "no change",
+ now: start.Add(1 * time.Minute),
+ eps: []tailcfg.Endpoint{local, stun4_1, stun6_1},
+ want: []tailcfg.Endpoint{local, stun4_1, stun6_1},
+ },
+ {
+ name: "missing stun4",
+ now: start.Add(2 * time.Minute),
+ eps: []tailcfg.Endpoint{local, stun6_1},
+ want: []tailcfg.Endpoint{local, stun4_1, stun6_1},
+ },
+ {
+ name: "missing stun6",
+ now: start.Add(3 * time.Minute),
+ eps: []tailcfg.Endpoint{local, stun4_1},
+ want: []tailcfg.Endpoint{local, stun4_1, stun6_1},
+ },
+ {
+ name: "multiple STUN addresses within timeout",
+ now: start.Add(4 * time.Minute),
+ eps: []tailcfg.Endpoint{local, stun4_2, stun6_2},
+ want: []tailcfg.Endpoint{local, stun4_1, stun4_2, stun6_1, stun6_2},
+ },
+ {
+ name: "endpoint extended",
+ now: start.Add(3*time.Minute + endpointTrackerLifetime - 1),
+ eps: []tailcfg.Endpoint{local},
+ want: []tailcfg.Endpoint{
+ local, stun4_2, stun6_2,
+ // stun4_1 had its lifetime extended by the
+ // "missing stun6" test above to that start
+ // time plus the lifetime, while stun6 should
+ // have expired a minute sooner. It should thus
+ // be in this returned list.
+ stun4_1,
+ },
+ },
+ {
+ name: "after timeout",
+ now: start.Add(4*time.Minute + endpointTrackerLifetime + 1),
+ eps: []tailcfg.Endpoint{local, stun4_2, stun6_2},
+ want: []tailcfg.Endpoint{local, stun4_2, stun6_2},
+ },
+ {
+ name: "after timeout still caches",
+ now: start.Add(4*time.Minute + endpointTrackerLifetime + time.Minute),
+ eps: []tailcfg.Endpoint{local},
+ want: []tailcfg.Endpoint{local, stun4_2, stun6_2},
+ },
+ }
+
+ var et endpointTracker
+ for _, tt := range steps {
+ t.Logf("STEP: %s", tt.name)
+
+ got := et.update(tt.now, tt.eps)
+
+ // Sort both arrays for comparison
+ slices.SortFunc(got, func(a, b tailcfg.Endpoint) int {
+ return strings.Compare(a.Addr.String(), b.Addr.String())
+ })
+ slices.SortFunc(tt.want, func(a, b tailcfg.Endpoint) int {
+ return strings.Compare(a.Addr.String(), b.Addr.String())
+ })
+
+ if !reflect.DeepEqual(got, tt.want) {
+ t.Errorf("endpoints mismatch\ngot: %+v\nwant: %+v", got, tt.want)
+ }
+ }
+}
+
+func TestEndpointTrackerMaxNum(t *testing.T) {
+ start := time.Unix(1681503440, 0)
+
+ var allEndpoints []tailcfg.Endpoint // all created endpoints
+ mkEp := func(i int) tailcfg.Endpoint {
+ ep := tailcfg.Endpoint{
+ Addr: netip.AddrPortFrom(netip.MustParseAddr("1.2.3.4"), uint16(i)),
+ Type: tailcfg.EndpointSTUN,
+ }
+ allEndpoints = append(allEndpoints, ep)
+ return ep
+ }
+
+ var et endpointTracker
+
+ // Add more endpoints to the list than our limit
+ for i := 0; i <= endpointTrackerMaxPerAddr; i++ {
+ et.update(start.Add(time.Duration(i)*time.Second), []tailcfg.Endpoint{mkEp(10000 + i)})
+ }
+
+ // Now add two more, slightly later
+ got := et.update(start.Add(1*time.Minute), []tailcfg.Endpoint{
+ mkEp(10100),
+ mkEp(10101),
+ })
+
+ // We expect to get the last N endpoints per our per-Addr limit, since
+ // all of the endpoints have the same netip.Addr. The first endpoint(s)
+ // that we added were dropped because we had more than the limit for
+ // this Addr.
+ want := allEndpoints[len(allEndpoints)-endpointTrackerMaxPerAddr:]
+
+ compareEndpoints := func(got, want []tailcfg.Endpoint) {
+ t.Helper()
+ slices.SortFunc(want, func(a, b tailcfg.Endpoint) int {
+ return strings.Compare(a.Addr.String(), b.Addr.String())
+ })
+ slices.SortFunc(got, func(a, b tailcfg.Endpoint) int {
+ return strings.Compare(a.Addr.String(), b.Addr.String())
+ })
+ if !reflect.DeepEqual(got, want) {
+ t.Errorf("endpoints mismatch\ngot: %+v\nwant: %+v", got, want)
+ }
+ }
+ compareEndpoints(got, want)
+
+ // However, if we have more than our limit of endpoints passed in to
+ // the endpointTracker, we will return all of them (even if they're for
+ // the same address).
+ var inputEps []tailcfg.Endpoint
+ for i := 0; i < endpointTrackerMaxPerAddr+5; i++ {
+ inputEps = append(inputEps, tailcfg.Endpoint{
+ Addr: netip.AddrPortFrom(netip.MustParseAddr("1.2.3.4"), 10200+uint16(i)),
+ Type: tailcfg.EndpointSTUN,
+ })
+ }
+
+ want = inputEps
+ got = et.update(start.Add(2*time.Minute), inputEps)
+ compareEndpoints(got, want)
+}
diff --git a/wgengine/magicsock/magicsock.go b/wgengine/magicsock/magicsock.go
index e552b8826..3f63547ae 100644
--- a/wgengine/magicsock/magicsock.go
+++ b/wgengine/magicsock/magicsock.go
@@ -53,7 +53,6 @@ import (
"tailscale.com/util/clientmetric"
"tailscale.com/util/mak"
"tailscale.com/util/ringbuffer"
- "tailscale.com/util/set"
"tailscale.com/util/uniq"
"tailscale.com/version"
"tailscale.com/wgengine/capture"
@@ -2594,11 +2593,6 @@ const (
// STUN-derived endpoint valid for. UDP NAT mappings typically
// expire at 30 seconds, so this is a few seconds shy of that.
endpointsFreshEnoughDuration = 27 * time.Second
-
- // endpointTrackerLifetime is how long we continue advertising an
- // endpoint after we last see it. This is intentionally chosen to be
- // slightly longer than a full netcheck period.
- endpointTrackerLifetime = 5*time.Minute + 10*time.Second
)
// Constants that are variable for testing.
@@ -2683,79 +2677,6 @@ type discoInfo struct {
lastPingTime time.Time
}
-type endpointTrackerEntry struct {
- endpoint tailcfg.Endpoint
- until time.Time
-}
-
-type endpointTracker struct {
- mu sync.Mutex
- cache map[netip.AddrPort]endpointTrackerEntry
-}
-
-func (et *endpointTracker) update(now time.Time, eps []tailcfg.Endpoint) (epsPlusCached []tailcfg.Endpoint) {
- epsPlusCached = eps
-
- var inputEps set.Slice[netip.AddrPort]
- for _, ep := range eps {
- inputEps.Add(ep.Addr)
- }
-
- et.mu.Lock()
- defer et.mu.Unlock()
-
- // Add entries to the return array that aren't already there.
- for k, ep := range et.cache {
- // If the endpoint was in the input list, or has expired, skip it.
- if inputEps.Contains(k) {
- continue
- } else if now.After(ep.until) {
- continue
- }
-
- // We haven't seen this endpoint; add to the return array
- epsPlusCached = append(epsPlusCached, ep.endpoint)
- }
-
- // Add entries from the original input array into the cache, and/or
- // extend the lifetime of entries that are already in the cache.
- until := now.Add(endpointTrackerLifetime)
- for _, ep := range eps {
- et.addLocked(now, ep, until)
- }
-
- // Remove everything that has now expired.
- et.removeExpiredLocked(now)
- return epsPlusCached
-}
-
-// add will store the provided endpoint(s) in the cache for a fixed period of
-// time, and remove any entries in the cache that have expired.
-//
-// et.mu must be held.
-func (et *endpointTracker) addLocked(now time.Time, ep tailcfg.Endpoint, until time.Time) {
- // If we already have an entry for this endpoint, update the timeout on
- // it; otherwise, add it.
- entry, found := et.cache[ep.Addr]
- if found {
- entry.until = until
- } else {
- entry = endpointTrackerEntry{ep, until}
- }
- mak.Set(&et.cache, ep.Addr, entry)
-}
-
-// removeExpired will remove all expired entries from the cache
-//
-// et.mu must be held
-func (et *endpointTracker) removeExpiredLocked(now time.Time) {
- for k, ep := range et.cache {
- if now.After(ep.until) {
- delete(et.cache, k)
- }
- }
-}
-
var (
metricNumPeers = clientmetric.NewGauge("magicsock_netmap_num_peers")
metricNumDERPConns = clientmetric.NewGauge("magicsock_num_derp_conns")
diff --git a/wgengine/magicsock/magicsock_test.go b/wgengine/magicsock/magicsock_test.go
index b6bfef107..4c4153bf3 100644
--- a/wgengine/magicsock/magicsock_test.go
+++ b/wgengine/magicsock/magicsock_test.go
@@ -18,7 +18,6 @@ import (
"net/http/httptest"
"net/netip"
"os"
- "reflect"
"runtime"
"strconv"
"strings"
@@ -33,7 +32,6 @@ import (
"github.com/tailscale/wireguard-go/tun/tuntest"
"go4.org/mem"
"golang.org/x/exp/maps"
- "golang.org/x/exp/slices"
"golang.org/x/net/icmp"
"golang.org/x/net/ipv4"
"golang.org/x/net/ipv6"
@@ -2341,116 +2339,6 @@ func TestIsWireGuardOnlyPeerWithMasquerade(t *testing.T) {
}
}
-func TestEndpointTracker(t *testing.T) {
- local := tailcfg.Endpoint{
- Addr: netip.MustParseAddrPort("192.168.1.1:12345"),
- Type: tailcfg.EndpointLocal,
- }
-
- stun4_1 := tailcfg.Endpoint{
- Addr: netip.MustParseAddrPort("1.2.3.4:12345"),
- Type: tailcfg.EndpointSTUN,
- }
- stun4_2 := tailcfg.Endpoint{
- Addr: netip.MustParseAddrPort("5.6.7.8:12345"),
- Type: tailcfg.EndpointSTUN,
- }
-
- stun6_1 := tailcfg.Endpoint{
- Addr: netip.MustParseAddrPort("[2a09:8280:1::1111]:12345"),
- Type: tailcfg.EndpointSTUN,
- }
- stun6_2 := tailcfg.Endpoint{
- Addr: netip.MustParseAddrPort("[2a09:8280:1::2222]:12345"),
- Type: tailcfg.EndpointSTUN,
- }
-
- start := time.Unix(1681503440, 0)
-
- steps := []struct {
- name string
- now time.Time
- eps []tailcfg.Endpoint
- want []tailcfg.Endpoint
- }{
- {
- name: "initial endpoints",
- now: start,
- eps: []tailcfg.Endpoint{local, stun4_1, stun6_1},
- want: []tailcfg.Endpoint{local, stun4_1, stun6_1},
- },
- {
- name: "no change",
- now: start.Add(1 * time.Minute),
- eps: []tailcfg.Endpoint{local, stun4_1, stun6_1},
- want: []tailcfg.Endpoint{local, stun4_1, stun6_1},
- },
- {
- name: "missing stun4",
- now: start.Add(2 * time.Minute),
- eps: []tailcfg.Endpoint{local, stun6_1},
- want: []tailcfg.Endpoint{local, stun4_1, stun6_1},
- },
- {
- name: "missing stun6",
- now: start.Add(3 * time.Minute),
- eps: []tailcfg.Endpoint{local, stun4_1},
- want: []tailcfg.Endpoint{local, stun4_1, stun6_1},
- },
- {
- name: "multiple STUN addresses within timeout",
- now: start.Add(4 * time.Minute),
- eps: []tailcfg.Endpoint{local, stun4_2, stun6_2},
- want: []tailcfg.Endpoint{local, stun4_1, stun4_2, stun6_1, stun6_2},
- },
- {
- name: "endpoint extended",
- now: start.Add(3*time.Minute + endpointTrackerLifetime - 1),
- eps: []tailcfg.Endpoint{local},
- want: []tailcfg.Endpoint{
- local, stun4_2, stun6_2,
- // stun4_1 had its lifetime extended by the
- // "missing stun6" test above to that start
- // time plus the lifetime, while stun6 should
- // have expired a minute sooner. It should thus
- // be in this returned list.
- stun4_1,
- },
- },
- {
- name: "after timeout",
- now: start.Add(4*time.Minute + endpointTrackerLifetime + 1),
- eps: []tailcfg.Endpoint{local, stun4_2, stun6_2},
- want: []tailcfg.Endpoint{local, stun4_2, stun6_2},
- },
- {
- name: "after timeout still caches",
- now: start.Add(4*time.Minute + endpointTrackerLifetime + time.Minute),
- eps: []tailcfg.Endpoint{local},
- want: []tailcfg.Endpoint{local, stun4_2, stun6_2},
- },
- }
-
- var et endpointTracker
- for _, tt := range steps {
- t.Logf("STEP: %s", tt.name)
-
- got := et.update(tt.now, tt.eps)
-
- // Sort both arrays for comparison
- slices.SortFunc(got, func(a, b tailcfg.Endpoint) int {
- return strings.Compare(a.Addr.String(), b.Addr.String())
- })
- slices.SortFunc(tt.want, func(a, b tailcfg.Endpoint) int {
- return strings.Compare(a.Addr.String(), b.Addr.String())
- })
-
- if !reflect.DeepEqual(got, tt.want) {
- t.Errorf("endpoints mismatch\ngot: %+v\nwant: %+v", got, tt.want)
- }
- }
-}
-
// applyNetworkMap is a test helper that sets the network map and
// configures WG.
func applyNetworkMap(t *testing.T, m *magicStack, nm *netmap.NetworkMap) {