summaryrefslogtreecommitdiffhomepage
path: root/net
diff options
context:
space:
mode:
authorNick Khyl <nickk@tailscale.com>2024-12-05 13:16:48 -0600
committerNick Khyl <nickk@tailscale.com>2024-12-05 13:16:48 -0600
commit0267fe83b200f1702a2fa0a395442c02a053fadb (patch)
tree63654c55225eeb834de59a5a0bc8d19033c6145b /net
parent87546a5edf6b6503a87eeb2d666baba57398a066 (diff)
downloadtailscale-1.78.0.tar.xz
tailscale-1.78.0.zip
VERSION.txt: this is v1.78.0v1.78.0
Signed-off-by: Nick Khyl <nickk@tailscale.com>
Diffstat (limited to 'net')
-rw-r--r--net/art/art_test.go40
-rw-r--r--net/art/table.go1282
-rw-r--r--net/dns/debian_resolvconf.go368
-rw-r--r--net/dns/direct_notlinux.go20
-rw-r--r--net/dns/flush_default.go20
-rw-r--r--net/dns/ini.go60
-rw-r--r--net/dns/ini_test.go76
-rw-r--r--net/dns/noop.go34
-rw-r--r--net/dns/resolvconf-workaround.sh124
-rw-r--r--net/dns/resolvconf.go60
-rw-r--r--net/dns/resolvconffile/resolvconffile.go248
-rw-r--r--net/dns/resolvconfpath_default.go22
-rw-r--r--net/dns/resolvconfpath_gokrazy.go22
-rw-r--r--net/dns/resolver/doh_test.go198
-rw-r--r--net/dns/resolver/macios_ext.go52
-rw-r--r--net/dns/resolver/tsdns_server_test.go666
-rw-r--r--net/dns/utf.go110
-rw-r--r--net/dns/utf_test.go48
-rw-r--r--net/dnscache/dnscache_test.go484
-rw-r--r--net/dnscache/messagecache_test.go582
-rw-r--r--net/dnsfallback/update-dns-fallbacks.go90
-rw-r--r--net/memnet/conn.go228
-rw-r--r--net/memnet/conn_test.go42
-rw-r--r--net/memnet/listener.go200
-rw-r--r--net/memnet/listener_test.go66
-rw-r--r--net/memnet/memnet.go16
-rw-r--r--net/memnet/pipe.go488
-rw-r--r--net/memnet/pipe_test.go234
-rw-r--r--net/netaddr/netaddr.go98
-rw-r--r--net/neterror/neterror.go164
-rw-r--r--net/neterror/neterror_linux.go52
-rw-r--r--net/neterror/neterror_linux_test.go108
-rw-r--r--net/neterror/neterror_windows.go32
-rw-r--r--net/netkernelconf/netkernelconf.go10
-rw-r--r--net/netknob/netknob.go58
-rw-r--r--net/netmon/netmon_darwin_test.go54
-rw-r--r--net/netmon/netmon_freebsd.go112
-rw-r--r--net/netmon/netmon_linux.go580
-rw-r--r--net/netmon/netmon_polling.go42
-rw-r--r--net/netmon/polling.go172
-rw-r--r--net/netns/netns_android.go150
-rw-r--r--net/netns/netns_default.go44
-rw-r--r--net/netns/netns_linux_test.go28
-rw-r--r--net/netns/netns_test.go156
-rw-r--r--net/netns/socks.go38
-rw-r--r--net/netstat/netstat.go70
-rw-r--r--net/netstat/netstat_noimpl.go28
-rw-r--r--net/netstat/netstat_test.go42
-rw-r--r--net/packet/doc.go30
-rw-r--r--net/packet/header.go132
-rw-r--r--net/packet/icmp.go56
-rw-r--r--net/packet/icmp6_test.go158
-rw-r--r--net/packet/ip4.go232
-rw-r--r--net/packet/ip6.go152
-rw-r--r--net/packet/tsmp_test.go146
-rw-r--r--net/packet/udp4.go116
-rw-r--r--net/packet/udp6.go108
-rw-r--r--net/ping/ping.go686
-rw-r--r--net/ping/ping_test.go700
-rw-r--r--net/portmapper/pcp_test.go124
-rw-r--r--net/proxymux/mux.go288
-rw-r--r--net/routetable/routetable_darwin.go72
-rw-r--r--net/routetable/routetable_freebsd.go56
-rw-r--r--net/routetable/routetable_other.go34
-rw-r--r--net/sockstats/sockstats.go242
-rw-r--r--net/sockstats/sockstats_noop.go76
-rw-r--r--net/sockstats/sockstats_tsgo_darwin.go60
-rw-r--r--net/speedtest/speedtest.go174
-rw-r--r--net/speedtest/speedtest_client.go82
-rw-r--r--net/speedtest/speedtest_server.go292
-rw-r--r--net/speedtest/speedtest_test.go166
-rw-r--r--net/stun/stun.go624
-rw-r--r--net/stun/stun_fuzzer.go24
-rw-r--r--net/tcpinfo/tcpinfo.go102
-rw-r--r--net/tcpinfo/tcpinfo_darwin.go66
-rw-r--r--net/tcpinfo/tcpinfo_linux.go66
-rw-r--r--net/tcpinfo/tcpinfo_other.go30
-rw-r--r--net/tlsdial/deps_test.go16
-rw-r--r--net/tsdial/dnsmap_test.go250
-rw-r--r--net/tsdial/dohclient.go200
-rw-r--r--net/tsdial/dohclient_test.go62
-rw-r--r--net/tshttpproxy/mksyscall.go22
-rw-r--r--net/tshttpproxy/tshttpproxy_linux.go48
-rw-r--r--net/tshttpproxy/tshttpproxy_synology_test.go752
-rw-r--r--net/tshttpproxy/tshttpproxy_windows.go552
-rw-r--r--net/tstun/fake.go116
-rw-r--r--net/tstun/ifstatus_noop.go36
-rw-r--r--net/tstun/ifstatus_windows.go218
-rw-r--r--net/tstun/linkattrs_linux.go126
-rw-r--r--net/tstun/linkattrs_notlinux.go24
-rw-r--r--net/tstun/mtu.go322
-rw-r--r--net/tstun/mtu_test.go198
-rw-r--r--net/tstun/tun_linux.go206
-rw-r--r--net/tstun/tun_macos.go50
-rw-r--r--net/tstun/tun_notwindows.go24
95 files changed, 8117 insertions, 8117 deletions
diff --git a/net/art/art_test.go b/net/art/art_test.go
index daf8553ca..e3a427107 100644
--- a/net/art/art_test.go
+++ b/net/art/art_test.go
@@ -1,20 +1,20 @@
-// Copyright (c) Tailscale Inc & AUTHORS
-// SPDX-License-Identifier: BSD-3-Clause
-
-package art
-
-import (
- "os"
- "testing"
-
- "tailscale.com/util/cibuild"
-)
-
-func TestMain(m *testing.M) {
- if cibuild.On() {
- // Skip CI on GitHub for now
- // TODO: https://github.com/tailscale/tailscale/issues/7866
- os.Exit(0)
- }
- os.Exit(m.Run())
-}
+// Copyright (c) Tailscale Inc & AUTHORS
+// SPDX-License-Identifier: BSD-3-Clause
+
+package art
+
+import (
+ "os"
+ "testing"
+
+ "tailscale.com/util/cibuild"
+)
+
+func TestMain(m *testing.M) {
+ if cibuild.On() {
+ // Skip CI on GitHub for now
+ // TODO: https://github.com/tailscale/tailscale/issues/7866
+ os.Exit(0)
+ }
+ os.Exit(m.Run())
+}
diff --git a/net/art/table.go b/net/art/table.go
index fa3975778..2e130d82f 100644
--- a/net/art/table.go
+++ b/net/art/table.go
@@ -1,641 +1,641 @@
-// Copyright (c) Tailscale Inc & AUTHORS
-// SPDX-License-Identifier: BSD-3-Clause
-
-// Package art provides a routing table that implements the Allotment Routing
-// Table (ART) algorithm by Donald Knuth, as described in the paper by Yoichi
-// Hariguchi.
-//
-// ART outperforms the traditional radix tree implementations for route lookups,
-// insertions, and deletions.
-//
-// For more information, see Yoichi Hariguchi's paper:
-// https://cseweb.ucsd.edu//~varghese/TEACH/cs228/artlookup.pdf
-package art
-
-import (
- "bytes"
- "encoding/binary"
- "fmt"
- "io"
- "math/bits"
- "net/netip"
- "strings"
- "sync"
-)
-
-const (
- debugInsert = false
- debugDelete = false
-)
-
-// Table is an IPv4 and IPv6 routing table.
-type Table[T any] struct {
- v4 strideTable[T]
- v6 strideTable[T]
- initOnce sync.Once
-}
-
-func (t *Table[T]) init() {
- t.initOnce.Do(func() {
- t.v4.prefix = netip.PrefixFrom(netip.IPv4Unspecified(), 0)
- t.v6.prefix = netip.PrefixFrom(netip.IPv6Unspecified(), 0)
- })
-}
-
-func (t *Table[T]) tableForAddr(addr netip.Addr) *strideTable[T] {
- if addr.Is6() {
- return &t.v6
- }
- return &t.v4
-}
-
-// Get does a route lookup for addr and returns the associated value, or nil if
-// no route matched.
-func (t *Table[T]) Get(addr netip.Addr) (ret T, ok bool) {
- t.init()
-
- // Ideally we would use addr.AsSlice here, but AsSlice is just
- // barely complex enough that it can't be inlined, and that in
- // turn causes the slice to escape to the heap. Using As16 and
- // manual slicing here helps the compiler keep Get alloc-free.
- st := t.tableForAddr(addr)
- rawAddr := addr.As16()
- bs := rawAddr[:]
- if addr.Is4() {
- bs = bs[12:]
- }
-
- i := 0
- // With path compression, we might skip over some address bits while walking
- // to a strideTable leaf. This means the leaf answer we find might not be
- // correct, because path compression took us down the wrong subtree. When
- // that happens, we have to backtrack and figure out which most specific
- // route further up the tree is relevant to addr, and return that.
- //
- // So, as we walk down the stride tables, each time we find a non-nil route
- // result, we have to remember it and the associated strideTable prefix.
- //
- // We could also deal with this edge case of path compression by checking
- // the strideTable prefix on each table as we descend, but that means we
- // have to pay N prefix.Contains checks on every route lookup (where N is
- // the number of strideTables in the path), rather than only paying M prefix
- // comparisons in the edge case (where M is the number of strideTables in
- // the path with a non-nil route of their own).
- const maxDepth = 16
- type prefixAndRoute struct {
- prefix netip.Prefix
- route T
- }
- strideMatch := make([]prefixAndRoute, 0, maxDepth)
-findLeaf:
- for {
- rt, rtOK, child := st.getValAndChild(bs[i])
- if rtOK {
- // This strideTable contains a route that may be relevant to our
- // search, remember it.
- strideMatch = append(strideMatch, prefixAndRoute{st.prefix, rt})
- }
- if child == nil {
- // No sub-routes further down, the last thing we recorded
- // in strideRoutes is tentatively the result, barring
- // misdirection from path compression.
- break findLeaf
- }
- st = child
- // Path compression means we may be skipping over some intermediate
- // tables. We have to skip forward to whatever depth st now references.
- i = st.prefix.Bits() / 8
- }
-
- // Walk backwards through the hits we recorded in strideRoutes and
- // stridePrefixes, returning the first one whose subtree matches addr.
- //
- // In the common case where path compression did not mislead us, we'll
- // return on the first loop iteration because the last route we recorded was
- // the correct most-specific route.
- for i := len(strideMatch) - 1; i >= 0; i-- {
- if m := strideMatch[i]; m.prefix.Contains(addr) {
- return m.route, true
- }
- }
-
- // We either found no route hits at all (both previous loops terminated
- // immediately), or we went on a wild goose chase down a compressed path for
- // the wrong prefix, and also found no usable routes on the way back up to
- // the root. This is a miss.
- return ret, false
-}
-
-// Insert adds pfx to the table, with value val.
-// If pfx is already present in the table, its value is set to val.
-func (t *Table[T]) Insert(pfx netip.Prefix, val T) {
- t.init()
-
- // The standard library doesn't enforce normalized prefixes (where
- // the non-prefix bits are all zero). These algorithms require
- // normalized prefixes, so do it upfront.
- pfx = pfx.Masked()
-
- if debugInsert {
- defer func() {
- fmt.Printf("%s", t.debugSummary())
- }()
- fmt.Printf("\ninsert: start pfx=%s\n", pfx)
- }
-
- st := t.tableForAddr(pfx.Addr())
-
- // This algorithm is full of off-by-one headaches that boil down
- // to the fact that pfx.Bits() has (2^n)+1 values, rather than
- // just 2^n. For example, an IPv4 prefix length can be 0 through
- // 32, which is 33 values.
- //
- // This extra possible value creates a lot of problems as we do
- // bits and bytes math to traverse strideTables below. So, we
- // treat the default route 0/0 specially here, that way the rest
- // of the logic goes back to having 2^n values to reason about,
- // which can be done in a nice and regular fashion with no edge
- // cases.
- if pfx.Bits() == 0 {
- if debugInsert {
- fmt.Printf("insert: default route\n")
- }
- st.insert(0, 0, val)
- return
- }
-
- // No matter what we do as we traverse strideTables, our final
- // action will be to insert the last 1-8 bits of pfx into a
- // strideTable somewhere.
- //
- // We calculate upfront the byte position of the end of the
- // prefix; the number of bits within that byte that contain prefix
- // data; and the prefix of the strideTable into which we'll
- // eventually insert.
- //
- // We need this in a couple different branches of the code below,
- // and because the possible values are 1-indexed (1 through 32 for
- // ipv4, 1 through 128 for ipv6), the math is very slightly
- // unusual to account for the off-by-one indexing. Do it once up
- // here, with this large comment, rather than reproduce the subtle
- // math in multiple places further down.
- finalByteIdx := (pfx.Bits() - 1) / 8
- finalBits := pfx.Bits() - (finalByteIdx * 8)
- finalStridePrefix, err := pfx.Addr().Prefix(finalByteIdx * 8)
- if err != nil {
- panic(fmt.Sprintf("invalid prefix requested: %s/%d", pfx.Addr(), finalByteIdx*8))
- }
- if debugInsert {
- fmt.Printf("insert: finalByteIdx=%d finalBits=%d finalStridePrefix=%s\n", finalByteIdx, finalBits, finalStridePrefix)
- }
-
- // The strideTable we want to insert into is potentially at the
- // end of a chain of strideTables, each one encoding 8 bits of the
- // prefix.
- //
- // We're expecting to walk down a path of tables, although with
- // prefix compression we may end up skipping some links in the
- // chain, or taking wrong turns and having to course correct.
- //
- // As we walk down the tree, byteIdx is the byte of bs we're
- // currently examining to choose our next step, and numBits is the
- // number of bits that remain in pfx, starting with the byte at
- // byteIdx inclusive.
- bs := pfx.Addr().AsSlice()
- byteIdx := 0
- numBits := pfx.Bits()
- for {
- if debugInsert {
- fmt.Printf("insert: loop byteIdx=%d numBits=%d st.prefix=%s\n", byteIdx, numBits, st.prefix)
- }
- if numBits <= 8 {
- if debugInsert {
- fmt.Printf("insert: existing leaf st.prefix=%s addr=%d/%d\n", st.prefix, bs[finalByteIdx], finalBits)
- }
- // We've reached the end of the prefix, whichever
- // strideTable we're looking at now is the place where we
- // need to insert.
- st.insert(bs[finalByteIdx], finalBits, val)
- return
- }
-
- // Otherwise, we need to go down at least one more level of
- // strideTables. With prefix compression, each level of
- // descent can have one of three outcomes: we find a place
- // where prefix compression is possible; a place where prefix
- // compression made us take a "wrong turn"; or a point along
- // our intended path that we have to keep following.
- child, created := st.getOrCreateChild(bs[byteIdx])
- switch {
- case created:
- // The subtree we need for pfx doesn't exist yet. The rest
- // of the path, if we were to create it, will consist of a
- // bunch of strideTables with a single child each. We can
- // use path compression to elide those intermediates, and
- // jump straight to the final strideTable that hosts this
- // prefix.
- child.prefix = finalStridePrefix
- child.insert(bs[finalByteIdx], finalBits, val)
- if debugInsert {
- fmt.Printf("insert: new leaf st.prefix=%s child.prefix=%s addr=%d/%d\n", st.prefix, child.prefix, bs[finalByteIdx], finalBits)
- }
- return
- case !prefixStrictlyContains(child.prefix, pfx):
- // child already exists, but its prefix does not contain
- // our destination. This means that the path between st
- // and child was compressed by a previous insertion, and
- // somewhere in the (implicit) compressed path we took a
- // wrong turn, into the wrong part of st's subtree.
- //
- // This is okay, because pfx and child.prefix must have a
- // common ancestor node somewhere between st and child. We
- // can figure out what node that is, and materialize it.
- //
- // Once we've done that, we can immediately complete the
- // remainder of the insertion in one of two ways, without
- // further traversal. See a little further down for what
- // those are.
- if debugInsert {
- fmt.Printf("insert: wrong turn, pfx=%s child.prefix=%s\n", pfx, child.prefix)
- }
- intermediatePrefix, addrOfExisting, addrOfNew := computePrefixSplit(child.prefix, pfx)
- intermediate := &strideTable[T]{prefix: intermediatePrefix} // TODO: make this whole thing be st.AddIntermediate or something?
- st.setChild(bs[byteIdx], intermediate)
- intermediate.setChild(addrOfExisting, child)
-
- if debugInsert {
- fmt.Printf("insert: new intermediate st.prefix=%s intermediate.prefix=%s child.prefix=%s\n", st.prefix, intermediate.prefix, child.prefix)
- }
-
- // Now, we have a chain of st -> intermediate -> child.
- //
- // pfx either lives in a different child of intermediate,
- // or in intermediate itself. For example, if we created
- // the intermediate 1.2.0.0/16, pfx=1.2.3.4/32 would have
- // to go into a new child of intermediate, but
- // pfx=1.2.0.0/18 would go into intermediate directly.
- if remain := pfx.Bits() - intermediate.prefix.Bits(); remain <= 8 {
- // pfx lives in intermediate.
- if debugInsert {
- fmt.Printf("insert: into intermediate intermediate.prefix=%s addr=%d/%d\n", intermediate.prefix, bs[finalByteIdx], finalBits)
- }
- intermediate.insert(bs[finalByteIdx], finalBits, val)
- } else {
- // pfx lives in a different child subtree of
- // intermediate. By definition this subtree doesn't
- // exist at all, otherwise we'd never have entered
- // this entire "wrong turn" codepath in the first
- // place.
- //
- // This means we can apply prefix compression as we
- // create this new child, and we're done.
- st, created = intermediate.getOrCreateChild(addrOfNew)
- if !created {
- panic("new child path unexpectedly exists during path decompression")
- }
- st.prefix = finalStridePrefix
- st.insert(bs[finalByteIdx], finalBits, val)
- if debugInsert {
- fmt.Printf("insert: new child st.prefix=%s addr=%d/%d\n", st.prefix, bs[finalByteIdx], finalBits)
- }
- }
-
- return
- default:
- // An expected child table exists along pfx's
- // path. Continue traversing downwards.
- st = child
- byteIdx = child.prefix.Bits() / 8
- numBits = pfx.Bits() - child.prefix.Bits()
- if debugInsert {
- fmt.Printf("insert: descend st.prefix=%s\n", st.prefix)
- }
- }
- }
-}
-
-// Delete removes pfx from the table, if it is present.
-func (t *Table[T]) Delete(pfx netip.Prefix) {
- t.init()
-
- // The standard library doesn't enforce normalized prefixes (where
- // the non-prefix bits are all zero). These algorithms require
- // normalized prefixes, so do it upfront.
- pfx = pfx.Masked()
-
- if debugDelete {
- defer func() {
- fmt.Printf("%s", t.debugSummary())
- }()
- fmt.Printf("\ndelete: start pfx=%s table:\n%s", pfx, t.debugSummary())
- }
-
- st := t.tableForAddr(pfx.Addr())
-
- // This algorithm is full of off-by-one headaches, just like
- // Insert. See the comment in Insert for more details. Bottom
- // line: we handle the default route as a special case, and that
- // simplifies the rest of the code slightly.
- if pfx.Bits() == 0 {
- if debugDelete {
- fmt.Printf("delete: default route\n")
- }
- st.delete(0, 0)
- return
- }
-
- // Deletion may drive the refcount of some strideTables down to
- // zero. We need to clean up these dangling tables, so we have to
- // keep track of which tables we touch on the way down, and which
- // strideEntry index each child is registered in.
- //
- // Note that the strideIndex and strideTables entries are off-by-one.
- // The child table pointer is recorded at i+1, but it is referenced by a
- // particular index in the parent table, at index i.
- //
- // In other words: entry number strideIndexes[0] in
- // strideTables[0] is the same pointer as strideTables[1].
- //
- // This results in some slightly odd array accesses further down
- // in this code, because in a single loop iteration we have to
- // write to strideTables[N] and strideIndexes[N-1].
- strideIdx := 0
- strideTables := [16]*strideTable[T]{st}
- strideIndexes := [15]uint8{}
-
- // Similar to Insert, navigate down the tree of strideTables,
- // looking for the one that houses this prefix. This part is
- // easier than with insertion, since we can bail if the path ends
- // early or takes an unexpected detour. However, unlike
- // insertion, there's a whole post-deletion cleanup phase later
- // on.
- //
- // As we walk down the tree, byteIdx is the byte of bs we're
- // currently examining to choose our next step, and numBits is the
- // number of bits that remain in pfx, starting with the byte at
- // byteIdx inclusive.
- bs := pfx.Addr().AsSlice()
- byteIdx := 0
- numBits := pfx.Bits()
- for numBits > 8 {
- if debugDelete {
- fmt.Printf("delete: loop byteIdx=%d numBits=%d st.prefix=%s\n", byteIdx, numBits, st.prefix)
- }
- child := st.getChild(bs[byteIdx])
- if child == nil {
- // Prefix can't exist in the table, because one of the
- // necessary strideTables doesn't exist.
- if debugDelete {
- fmt.Printf("delete: missing necessary child pfx=%s\n", pfx)
- }
- return
- }
- strideIndexes[strideIdx] = bs[byteIdx]
- strideTables[strideIdx+1] = child
- strideIdx++
-
- // Path compression means byteIdx can jump forwards
- // unpredictably. Recompute the next byte to look at from the
- // child we just found.
- byteIdx = child.prefix.Bits() / 8
- numBits = pfx.Bits() - child.prefix.Bits()
- st = child
-
- if debugDelete {
- fmt.Printf("delete: descend st.prefix=%s\n", st.prefix)
- }
- }
-
- // We reached a leaf stride table that seems to be in the right
- // spot. But path compression might have led us to the wrong
- // table.
- if !prefixStrictlyContains(st.prefix, pfx) {
- // Wrong table, the requested prefix can't exist since its
- // path led us to the wrong place.
- if debugDelete {
- fmt.Printf("delete: wrong leaf table pfx=%s\n", pfx)
- }
- return
- }
- if debugDelete {
- fmt.Printf("delete: delete from st.prefix=%s addr=%d/%d\n", st.prefix, bs[byteIdx], numBits)
- }
- if routeExisted := st.delete(bs[byteIdx], numBits); !routeExisted {
- // We're in the right strideTable, but pfx wasn't in
- // it. Refcounts haven't changed, so we can skip cleanup.
- if debugDelete {
- fmt.Printf("delete: prefix not present pfx=%s\n", pfx)
- }
- return
- }
-
- // st.delete reduced st's refcount by one. This table may now be
- // reclaimable, and depending on how we can reclaim it, the parent
- // tables may also need to be reclaimed. This loop ends as soon as
- // an iteration takes no action, or takes an action that doesn't
- // alter the parent table's refcounts.
- //
- // We start our walk back at strideTables[strideIdx], which
- // contains st.
- for strideIdx > 0 {
- cur := strideTables[strideIdx]
- if debugDelete {
- fmt.Printf("delete: GC? strideIdx=%d st.prefix=%s\n", strideIdx, cur.prefix)
- }
- if cur.routeRefs > 0 {
- // the strideTable has other route entries, it cannot be
- // deleted or compacted.
- if debugDelete {
- fmt.Printf("delete: has other routes st.prefix=%s\n", cur.prefix)
- }
- return
- }
- switch cur.childRefs {
- case 0:
- // no routeRefs and no childRefs, this table can be
- // deleted. This will alter the parent table's refcount,
- // so we'll have to look at it as well (in the next loop
- // iteration).
- if debugDelete {
- fmt.Printf("delete: remove st.prefix=%s\n", cur.prefix)
- }
- strideTables[strideIdx-1].deleteChild(strideIndexes[strideIdx-1])
- strideIdx--
- case 1:
- // This table has no routes, and a single child. Compact
- // this table out of existence by making the parent point
- // directly at the one child. This does not affect the
- // parent's refcounts, so the parent can't be eligible for
- // deletion or compaction, and we can stop.
- child := strideTables[strideIdx].findFirstChild() // only 1 child exists, by definition
- parent := strideTables[strideIdx-1]
- if debugDelete {
- fmt.Printf("delete: compact parent.prefix=%s st.prefix=%s child.prefix=%s\n", parent.prefix, cur.prefix, child.prefix)
- }
- strideTables[strideIdx-1].setChild(strideIndexes[strideIdx-1], child)
- return
- default:
- // This table has two or more children, so it's acting as a "fork in
- // the road" between two prefix subtrees. It cannot be deleted, and
- // thus no further cleanups are possible.
- if debugDelete {
- fmt.Printf("delete: fork table st.prefix=%s\n", cur.prefix)
- }
- return
- }
- }
-}
-
-// debugSummary prints the tree of allocated strideTables in t, with each
-// strideTable's refcount.
-func (t *Table[T]) debugSummary() string {
- t.init()
- var ret bytes.Buffer
- fmt.Fprintf(&ret, "v4: ")
- strideSummary(&ret, &t.v4, 4)
- fmt.Fprintf(&ret, "v6: ")
- strideSummary(&ret, &t.v6, 4)
- return ret.String()
-}
-
-func strideSummary[T any](w io.Writer, st *strideTable[T], indent int) {
- fmt.Fprintf(w, "%s: %d routes, %d children\n", st.prefix, st.routeRefs, st.childRefs)
- indent += 4
- st.treeDebugStringRec(w, 1, indent)
- for addr, child := range st.children {
- if child == nil {
- continue
- }
- fmt.Fprintf(w, "%s%d/8 (%02x/8): ", strings.Repeat(" ", indent), addr, addr)
- strideSummary(w, child, indent)
- }
-}
-
-// prefixStrictlyContains reports whether child is a prefix within
-// parent, but not parent itself.
-func prefixStrictlyContains(parent, child netip.Prefix) bool {
- return parent.Overlaps(child) && parent.Bits() < child.Bits()
-}
-
-// computePrefixSplit returns the smallest common prefix that contains
-// both a and b. lastCommon is 8-bit aligned, with aStride and bStride
-// indicating the value of the 8-bit stride immediately following
-// lastCommon.
-//
-// computePrefixSplit is used in constructing an intermediate
-// strideTable when a new prefix needs to be inserted in a compressed
-// table. It can be read as: given that a is already in the table, and
-// b is being inserted, what is the prefix of the new intermediate
-// strideTable that needs to be created, and at what addresses in that
-// new strideTable should a and b's subsequent strideTables be
-// attached?
-//
-// Note as a special case, this can be called with a==b. An example of
-// when this happens:
-// - We want to insert the prefix 1.2.0.0/16
-// - A strideTable exists for 1.2.0.0/16, because another child
-// prefix already exists (e.g. 1.2.3.4/32)
-// - The 1.0.0.0/8 strideTable does not exist, because path
-// compression removed it.
-//
-// In this scenario, the caller of computePrefixSplit ends up making a
-// "wrong turn" while traversing strideTables: it was looking for the
-// 1.0.0.0/8 table, but ended up at the 1.2.0.0/16 table. When this
-// happens, it will invoke computePrefixSplit(1.2.0.0/16, 1.2.0.0/16),
-// and we return 1.0.0.0/8 as the missing intermediate.
-func computePrefixSplit(a, b netip.Prefix) (lastCommon netip.Prefix, aStride, bStride uint8) {
- a = a.Masked()
- b = b.Masked()
- if a.Bits() == 0 || b.Bits() == 0 {
- panic("computePrefixSplit called with a default route")
- }
- if a.Addr().Is4() != b.Addr().Is4() {
- panic("computePrefixSplit called with mismatched address families")
- }
-
- minPrefixLen := a.Bits()
- if b.Bits() < minPrefixLen {
- minPrefixLen = b.Bits()
- }
-
- commonBits := commonBits(a.Addr(), b.Addr(), minPrefixLen)
- // We want to know how many 8-bit strides are shared between a and
- // b. Naively, this would be commonBits/8, but this introduces an
- // off-by-one error. This is due to the way our ART stores
- // prefixes whose length falls exactly on a stride boundary.
- //
- // Consider 192.168.1.0/24 and 192.168.0.0/16. commonBits
- // correctly reports that these prefixes have their first 16 bits
- // in common. However, in the ART they only share 1 common stride:
- // they both use the 192.0.0.0/8 strideTable, but 192.168.0.0/16
- // is stored as 168/8 within that table, and not as 0/0 in the
- // 192.168.0.0/16 table.
- //
- // So, when commonBits matches the length of one of the inputs and
- // falls on a boundary between strides, the strideTable one
- // further up from commonBits/8 is the one we need to create,
- // which means we have to adjust the stride count down by one.
- if commonBits == minPrefixLen {
- commonBits--
- }
- commonStrides := commonBits / 8
- lastCommon, err := a.Addr().Prefix(commonStrides * 8)
- if err != nil {
- panic(fmt.Sprintf("computePrefixSplit constructing common prefix: %v", err))
- }
- if a.Addr().Is4() {
- aStride = a.Addr().As4()[commonStrides]
- bStride = b.Addr().As4()[commonStrides]
- } else {
- aStride = a.Addr().As16()[commonStrides]
- bStride = b.Addr().As16()[commonStrides]
- }
- return lastCommon, aStride, bStride
-}
-
-// commonBits returns the number of common leading bits of a and b.
-// If the number of common bits exceeds maxBits, it returns maxBits
-// instead.
-func commonBits(a, b netip.Addr, maxBits int) int {
- if a.Is4() != b.Is4() {
- panic("commonStrides called with mismatched address families")
- }
- var common int
- // The following implements an old bit-twiddling trick to compute
- // the number of common leading bits: if you XOR two numbers
- // together, equal bits become 0 and unequal bits become 1. You
- // can then count the number of leading zeros (which is a single
- // instruction on modern CPUs) to get the answer.
- //
- // This code is a little more complex than just XOR + count
- // leading zeros, because IPv4 and IPv6 are different sizes, and
- // for IPv6 we have to do the math in two 64-bit chunks because Go
- // lacks a uint128 type.
- if a.Is4() {
- aNum, bNum := ipv4AsUint(a), ipv4AsUint(b)
- common = bits.LeadingZeros32(aNum ^ bNum)
- } else {
- aNumHi, aNumLo := ipv6AsUint(a)
- bNumHi, bNumLo := ipv6AsUint(b)
- common = bits.LeadingZeros64(aNumHi ^ bNumHi)
- if common == 64 {
- common += bits.LeadingZeros64(aNumLo ^ bNumLo)
- }
- }
- if common > maxBits {
- common = maxBits
- }
- return common
-}
-
-// ipv4AsUint returns ip as a uint32.
-func ipv4AsUint(ip netip.Addr) uint32 {
- bs := ip.As4()
- return binary.BigEndian.Uint32(bs[:])
-}
-
-// ipv6AsUint returns ip as a pair of uint64s.
-func ipv6AsUint(ip netip.Addr) (uint64, uint64) {
- bs := ip.As16()
- return binary.BigEndian.Uint64(bs[:8]), binary.BigEndian.Uint64(bs[8:])
-}
+// Copyright (c) Tailscale Inc & AUTHORS
+// SPDX-License-Identifier: BSD-3-Clause
+
+// Package art provides a routing table that implements the Allotment Routing
+// Table (ART) algorithm by Donald Knuth, as described in the paper by Yoichi
+// Hariguchi.
+//
+// ART outperforms the traditional radix tree implementations for route lookups,
+// insertions, and deletions.
+//
+// For more information, see Yoichi Hariguchi's paper:
+// https://cseweb.ucsd.edu//~varghese/TEACH/cs228/artlookup.pdf
+package art
+
+import (
+ "bytes"
+ "encoding/binary"
+ "fmt"
+ "io"
+ "math/bits"
+ "net/netip"
+ "strings"
+ "sync"
+)
+
+const (
+ debugInsert = false
+ debugDelete = false
+)
+
+// Table is an IPv4 and IPv6 routing table.
+type Table[T any] struct {
+ v4 strideTable[T]
+ v6 strideTable[T]
+ initOnce sync.Once
+}
+
+func (t *Table[T]) init() {
+ t.initOnce.Do(func() {
+ t.v4.prefix = netip.PrefixFrom(netip.IPv4Unspecified(), 0)
+ t.v6.prefix = netip.PrefixFrom(netip.IPv6Unspecified(), 0)
+ })
+}
+
+func (t *Table[T]) tableForAddr(addr netip.Addr) *strideTable[T] {
+ if addr.Is6() {
+ return &t.v6
+ }
+ return &t.v4
+}
+
+// Get does a route lookup for addr and returns the associated value, or nil if
+// no route matched.
+func (t *Table[T]) Get(addr netip.Addr) (ret T, ok bool) {
+ t.init()
+
+ // Ideally we would use addr.AsSlice here, but AsSlice is just
+ // barely complex enough that it can't be inlined, and that in
+ // turn causes the slice to escape to the heap. Using As16 and
+ // manual slicing here helps the compiler keep Get alloc-free.
+ st := t.tableForAddr(addr)
+ rawAddr := addr.As16()
+ bs := rawAddr[:]
+ if addr.Is4() {
+ bs = bs[12:]
+ }
+
+ i := 0
+ // With path compression, we might skip over some address bits while walking
+ // to a strideTable leaf. This means the leaf answer we find might not be
+ // correct, because path compression took us down the wrong subtree. When
+ // that happens, we have to backtrack and figure out which most specific
+ // route further up the tree is relevant to addr, and return that.
+ //
+ // So, as we walk down the stride tables, each time we find a non-nil route
+ // result, we have to remember it and the associated strideTable prefix.
+ //
+ // We could also deal with this edge case of path compression by checking
+ // the strideTable prefix on each table as we descend, but that means we
+ // have to pay N prefix.Contains checks on every route lookup (where N is
+ // the number of strideTables in the path), rather than only paying M prefix
+ // comparisons in the edge case (where M is the number of strideTables in
+ // the path with a non-nil route of their own).
+ const maxDepth = 16
+ type prefixAndRoute struct {
+ prefix netip.Prefix
+ route T
+ }
+ strideMatch := make([]prefixAndRoute, 0, maxDepth)
+findLeaf:
+ for {
+ rt, rtOK, child := st.getValAndChild(bs[i])
+ if rtOK {
+ // This strideTable contains a route that may be relevant to our
+ // search, remember it.
+ strideMatch = append(strideMatch, prefixAndRoute{st.prefix, rt})
+ }
+ if child == nil {
+ // No sub-routes further down, the last thing we recorded
+ // in strideRoutes is tentatively the result, barring
+ // misdirection from path compression.
+ break findLeaf
+ }
+ st = child
+ // Path compression means we may be skipping over some intermediate
+ // tables. We have to skip forward to whatever depth st now references.
+ i = st.prefix.Bits() / 8
+ }
+
+ // Walk backwards through the hits we recorded in strideRoutes and
+ // stridePrefixes, returning the first one whose subtree matches addr.
+ //
+ // In the common case where path compression did not mislead us, we'll
+ // return on the first loop iteration because the last route we recorded was
+ // the correct most-specific route.
+ for i := len(strideMatch) - 1; i >= 0; i-- {
+ if m := strideMatch[i]; m.prefix.Contains(addr) {
+ return m.route, true
+ }
+ }
+
+ // We either found no route hits at all (both previous loops terminated
+ // immediately), or we went on a wild goose chase down a compressed path for
+ // the wrong prefix, and also found no usable routes on the way back up to
+ // the root. This is a miss.
+ return ret, false
+}
+
+// Insert adds pfx to the table, with value val.
+// If pfx is already present in the table, its value is set to val.
+func (t *Table[T]) Insert(pfx netip.Prefix, val T) {
+ t.init()
+
+ // The standard library doesn't enforce normalized prefixes (where
+ // the non-prefix bits are all zero). These algorithms require
+ // normalized prefixes, so do it upfront.
+ pfx = pfx.Masked()
+
+ if debugInsert {
+ defer func() {
+ fmt.Printf("%s", t.debugSummary())
+ }()
+ fmt.Printf("\ninsert: start pfx=%s\n", pfx)
+ }
+
+ st := t.tableForAddr(pfx.Addr())
+
+ // This algorithm is full of off-by-one headaches that boil down
+ // to the fact that pfx.Bits() has (2^n)+1 values, rather than
+ // just 2^n. For example, an IPv4 prefix length can be 0 through
+ // 32, which is 33 values.
+ //
+ // This extra possible value creates a lot of problems as we do
+ // bits and bytes math to traverse strideTables below. So, we
+ // treat the default route 0/0 specially here, that way the rest
+ // of the logic goes back to having 2^n values to reason about,
+ // which can be done in a nice and regular fashion with no edge
+ // cases.
+ if pfx.Bits() == 0 {
+ if debugInsert {
+ fmt.Printf("insert: default route\n")
+ }
+ st.insert(0, 0, val)
+ return
+ }
+
+ // No matter what we do as we traverse strideTables, our final
+ // action will be to insert the last 1-8 bits of pfx into a
+ // strideTable somewhere.
+ //
+ // We calculate upfront the byte position of the end of the
+ // prefix; the number of bits within that byte that contain prefix
+ // data; and the prefix of the strideTable into which we'll
+ // eventually insert.
+ //
+ // We need this in a couple different branches of the code below,
+ // and because the possible values are 1-indexed (1 through 32 for
+ // ipv4, 1 through 128 for ipv6), the math is very slightly
+ // unusual to account for the off-by-one indexing. Do it once up
+ // here, with this large comment, rather than reproduce the subtle
+ // math in multiple places further down.
+ finalByteIdx := (pfx.Bits() - 1) / 8
+ finalBits := pfx.Bits() - (finalByteIdx * 8)
+ finalStridePrefix, err := pfx.Addr().Prefix(finalByteIdx * 8)
+ if err != nil {
+ panic(fmt.Sprintf("invalid prefix requested: %s/%d", pfx.Addr(), finalByteIdx*8))
+ }
+ if debugInsert {
+ fmt.Printf("insert: finalByteIdx=%d finalBits=%d finalStridePrefix=%s\n", finalByteIdx, finalBits, finalStridePrefix)
+ }
+
+ // The strideTable we want to insert into is potentially at the
+ // end of a chain of strideTables, each one encoding 8 bits of the
+ // prefix.
+ //
+ // We're expecting to walk down a path of tables, although with
+ // prefix compression we may end up skipping some links in the
+ // chain, or taking wrong turns and having to course correct.
+ //
+ // As we walk down the tree, byteIdx is the byte of bs we're
+ // currently examining to choose our next step, and numBits is the
+ // number of bits that remain in pfx, starting with the byte at
+ // byteIdx inclusive.
+ bs := pfx.Addr().AsSlice()
+ byteIdx := 0
+ numBits := pfx.Bits()
+ for {
+ if debugInsert {
+ fmt.Printf("insert: loop byteIdx=%d numBits=%d st.prefix=%s\n", byteIdx, numBits, st.prefix)
+ }
+ if numBits <= 8 {
+ if debugInsert {
+ fmt.Printf("insert: existing leaf st.prefix=%s addr=%d/%d\n", st.prefix, bs[finalByteIdx], finalBits)
+ }
+ // We've reached the end of the prefix, whichever
+ // strideTable we're looking at now is the place where we
+ // need to insert.
+ st.insert(bs[finalByteIdx], finalBits, val)
+ return
+ }
+
+ // Otherwise, we need to go down at least one more level of
+ // strideTables. With prefix compression, each level of
+ // descent can have one of three outcomes: we find a place
+ // where prefix compression is possible; a place where prefix
+ // compression made us take a "wrong turn"; or a point along
+ // our intended path that we have to keep following.
+ child, created := st.getOrCreateChild(bs[byteIdx])
+ switch {
+ case created:
+ // The subtree we need for pfx doesn't exist yet. The rest
+ // of the path, if we were to create it, will consist of a
+ // bunch of strideTables with a single child each. We can
+ // use path compression to elide those intermediates, and
+ // jump straight to the final strideTable that hosts this
+ // prefix.
+ child.prefix = finalStridePrefix
+ child.insert(bs[finalByteIdx], finalBits, val)
+ if debugInsert {
+ fmt.Printf("insert: new leaf st.prefix=%s child.prefix=%s addr=%d/%d\n", st.prefix, child.prefix, bs[finalByteIdx], finalBits)
+ }
+ return
+ case !prefixStrictlyContains(child.prefix, pfx):
+ // child already exists, but its prefix does not contain
+ // our destination. This means that the path between st
+ // and child was compressed by a previous insertion, and
+ // somewhere in the (implicit) compressed path we took a
+ // wrong turn, into the wrong part of st's subtree.
+ //
+ // This is okay, because pfx and child.prefix must have a
+ // common ancestor node somewhere between st and child. We
+ // can figure out what node that is, and materialize it.
+ //
+ // Once we've done that, we can immediately complete the
+ // remainder of the insertion in one of two ways, without
+ // further traversal. See a little further down for what
+ // those are.
+ if debugInsert {
+ fmt.Printf("insert: wrong turn, pfx=%s child.prefix=%s\n", pfx, child.prefix)
+ }
+ intermediatePrefix, addrOfExisting, addrOfNew := computePrefixSplit(child.prefix, pfx)
+ intermediate := &strideTable[T]{prefix: intermediatePrefix} // TODO: make this whole thing be st.AddIntermediate or something?
+ st.setChild(bs[byteIdx], intermediate)
+ intermediate.setChild(addrOfExisting, child)
+
+ if debugInsert {
+ fmt.Printf("insert: new intermediate st.prefix=%s intermediate.prefix=%s child.prefix=%s\n", st.prefix, intermediate.prefix, child.prefix)
+ }
+
+ // Now, we have a chain of st -> intermediate -> child.
+ //
+ // pfx either lives in a different child of intermediate,
+ // or in intermediate itself. For example, if we created
+ // the intermediate 1.2.0.0/16, pfx=1.2.3.4/32 would have
+ // to go into a new child of intermediate, but
+ // pfx=1.2.0.0/18 would go into intermediate directly.
+ if remain := pfx.Bits() - intermediate.prefix.Bits(); remain <= 8 {
+ // pfx lives in intermediate.
+ if debugInsert {
+ fmt.Printf("insert: into intermediate intermediate.prefix=%s addr=%d/%d\n", intermediate.prefix, bs[finalByteIdx], finalBits)
+ }
+ intermediate.insert(bs[finalByteIdx], finalBits, val)
+ } else {
+ // pfx lives in a different child subtree of
+ // intermediate. By definition this subtree doesn't
+ // exist at all, otherwise we'd never have entered
+ // this entire "wrong turn" codepath in the first
+ // place.
+ //
+ // This means we can apply prefix compression as we
+ // create this new child, and we're done.
+ st, created = intermediate.getOrCreateChild(addrOfNew)
+ if !created {
+ panic("new child path unexpectedly exists during path decompression")
+ }
+ st.prefix = finalStridePrefix
+ st.insert(bs[finalByteIdx], finalBits, val)
+ if debugInsert {
+ fmt.Printf("insert: new child st.prefix=%s addr=%d/%d\n", st.prefix, bs[finalByteIdx], finalBits)
+ }
+ }
+
+ return
+ default:
+ // An expected child table exists along pfx's
+ // path. Continue traversing downwards.
+ st = child
+ byteIdx = child.prefix.Bits() / 8
+ numBits = pfx.Bits() - child.prefix.Bits()
+ if debugInsert {
+ fmt.Printf("insert: descend st.prefix=%s\n", st.prefix)
+ }
+ }
+ }
+}
+
+// Delete removes pfx from the table, if it is present.
+func (t *Table[T]) Delete(pfx netip.Prefix) {
+ t.init()
+
+ // The standard library doesn't enforce normalized prefixes (where
+ // the non-prefix bits are all zero). These algorithms require
+ // normalized prefixes, so do it upfront.
+ pfx = pfx.Masked()
+
+ if debugDelete {
+ defer func() {
+ fmt.Printf("%s", t.debugSummary())
+ }()
+ fmt.Printf("\ndelete: start pfx=%s table:\n%s", pfx, t.debugSummary())
+ }
+
+ st := t.tableForAddr(pfx.Addr())
+
+ // This algorithm is full of off-by-one headaches, just like
+ // Insert. See the comment in Insert for more details. Bottom
+ // line: we handle the default route as a special case, and that
+ // simplifies the rest of the code slightly.
+ if pfx.Bits() == 0 {
+ if debugDelete {
+ fmt.Printf("delete: default route\n")
+ }
+ st.delete(0, 0)
+ return
+ }
+
+ // Deletion may drive the refcount of some strideTables down to
+ // zero. We need to clean up these dangling tables, so we have to
+ // keep track of which tables we touch on the way down, and which
+ // strideEntry index each child is registered in.
+ //
+ // Note that the strideIndex and strideTables entries are off-by-one.
+ // The child table pointer is recorded at i+1, but it is referenced by a
+ // particular index in the parent table, at index i.
+ //
+ // In other words: entry number strideIndexes[0] in
+ // strideTables[0] is the same pointer as strideTables[1].
+ //
+ // This results in some slightly odd array accesses further down
+ // in this code, because in a single loop iteration we have to
+ // write to strideTables[N] and strideIndexes[N-1].
+ strideIdx := 0
+ strideTables := [16]*strideTable[T]{st}
+ strideIndexes := [15]uint8{}
+
+ // Similar to Insert, navigate down the tree of strideTables,
+ // looking for the one that houses this prefix. This part is
+ // easier than with insertion, since we can bail if the path ends
+ // early or takes an unexpected detour. However, unlike
+ // insertion, there's a whole post-deletion cleanup phase later
+ // on.
+ //
+ // As we walk down the tree, byteIdx is the byte of bs we're
+ // currently examining to choose our next step, and numBits is the
+ // number of bits that remain in pfx, starting with the byte at
+ // byteIdx inclusive.
+ bs := pfx.Addr().AsSlice()
+ byteIdx := 0
+ numBits := pfx.Bits()
+ for numBits > 8 {
+ if debugDelete {
+ fmt.Printf("delete: loop byteIdx=%d numBits=%d st.prefix=%s\n", byteIdx, numBits, st.prefix)
+ }
+ child := st.getChild(bs[byteIdx])
+ if child == nil {
+ // Prefix can't exist in the table, because one of the
+ // necessary strideTables doesn't exist.
+ if debugDelete {
+ fmt.Printf("delete: missing necessary child pfx=%s\n", pfx)
+ }
+ return
+ }
+ strideIndexes[strideIdx] = bs[byteIdx]
+ strideTables[strideIdx+1] = child
+ strideIdx++
+
+ // Path compression means byteIdx can jump forwards
+ // unpredictably. Recompute the next byte to look at from the
+ // child we just found.
+ byteIdx = child.prefix.Bits() / 8
+ numBits = pfx.Bits() - child.prefix.Bits()
+ st = child
+
+ if debugDelete {
+ fmt.Printf("delete: descend st.prefix=%s\n", st.prefix)
+ }
+ }
+
+ // We reached a leaf stride table that seems to be in the right
+ // spot. But path compression might have led us to the wrong
+ // table.
+ if !prefixStrictlyContains(st.prefix, pfx) {
+ // Wrong table, the requested prefix can't exist since its
+ // path led us to the wrong place.
+ if debugDelete {
+ fmt.Printf("delete: wrong leaf table pfx=%s\n", pfx)
+ }
+ return
+ }
+ if debugDelete {
+ fmt.Printf("delete: delete from st.prefix=%s addr=%d/%d\n", st.prefix, bs[byteIdx], numBits)
+ }
+ if routeExisted := st.delete(bs[byteIdx], numBits); !routeExisted {
+ // We're in the right strideTable, but pfx wasn't in
+ // it. Refcounts haven't changed, so we can skip cleanup.
+ if debugDelete {
+ fmt.Printf("delete: prefix not present pfx=%s\n", pfx)
+ }
+ return
+ }
+
+ // st.delete reduced st's refcount by one. This table may now be
+ // reclaimable, and depending on how we can reclaim it, the parent
+ // tables may also need to be reclaimed. This loop ends as soon as
+ // an iteration takes no action, or takes an action that doesn't
+ // alter the parent table's refcounts.
+ //
+ // We start our walk back at strideTables[strideIdx], which
+ // contains st.
+ for strideIdx > 0 {
+ cur := strideTables[strideIdx]
+ if debugDelete {
+ fmt.Printf("delete: GC? strideIdx=%d st.prefix=%s\n", strideIdx, cur.prefix)
+ }
+ if cur.routeRefs > 0 {
+ // the strideTable has other route entries, it cannot be
+ // deleted or compacted.
+ if debugDelete {
+ fmt.Printf("delete: has other routes st.prefix=%s\n", cur.prefix)
+ }
+ return
+ }
+ switch cur.childRefs {
+ case 0:
+ // no routeRefs and no childRefs, this table can be
+ // deleted. This will alter the parent table's refcount,
+ // so we'll have to look at it as well (in the next loop
+ // iteration).
+ if debugDelete {
+ fmt.Printf("delete: remove st.prefix=%s\n", cur.prefix)
+ }
+ strideTables[strideIdx-1].deleteChild(strideIndexes[strideIdx-1])
+ strideIdx--
+ case 1:
+ // This table has no routes, and a single child. Compact
+ // this table out of existence by making the parent point
+ // directly at the one child. This does not affect the
+ // parent's refcounts, so the parent can't be eligible for
+ // deletion or compaction, and we can stop.
+ child := strideTables[strideIdx].findFirstChild() // only 1 child exists, by definition
+ parent := strideTables[strideIdx-1]
+ if debugDelete {
+ fmt.Printf("delete: compact parent.prefix=%s st.prefix=%s child.prefix=%s\n", parent.prefix, cur.prefix, child.prefix)
+ }
+ strideTables[strideIdx-1].setChild(strideIndexes[strideIdx-1], child)
+ return
+ default:
+ // This table has two or more children, so it's acting as a "fork in
+ // the road" between two prefix subtrees. It cannot be deleted, and
+ // thus no further cleanups are possible.
+ if debugDelete {
+ fmt.Printf("delete: fork table st.prefix=%s\n", cur.prefix)
+ }
+ return
+ }
+ }
+}
+
+// debugSummary prints the tree of allocated strideTables in t, with each
+// strideTable's refcount.
+func (t *Table[T]) debugSummary() string {
+ t.init()
+ var ret bytes.Buffer
+ fmt.Fprintf(&ret, "v4: ")
+ strideSummary(&ret, &t.v4, 4)
+ fmt.Fprintf(&ret, "v6: ")
+ strideSummary(&ret, &t.v6, 4)
+ return ret.String()
+}
+
+func strideSummary[T any](w io.Writer, st *strideTable[T], indent int) {
+ fmt.Fprintf(w, "%s: %d routes, %d children\n", st.prefix, st.routeRefs, st.childRefs)
+ indent += 4
+ st.treeDebugStringRec(w, 1, indent)
+ for addr, child := range st.children {
+ if child == nil {
+ continue
+ }
+ fmt.Fprintf(w, "%s%d/8 (%02x/8): ", strings.Repeat(" ", indent), addr, addr)
+ strideSummary(w, child, indent)
+ }
+}
+
+// prefixStrictlyContains reports whether child is a prefix within
+// parent, but not parent itself.
+func prefixStrictlyContains(parent, child netip.Prefix) bool {
+ return parent.Overlaps(child) && parent.Bits() < child.Bits()
+}
+
+// computePrefixSplit returns the smallest common prefix that contains
+// both a and b. lastCommon is 8-bit aligned, with aStride and bStride
+// indicating the value of the 8-bit stride immediately following
+// lastCommon.
+//
+// computePrefixSplit is used in constructing an intermediate
+// strideTable when a new prefix needs to be inserted in a compressed
+// table. It can be read as: given that a is already in the table, and
+// b is being inserted, what is the prefix of the new intermediate
+// strideTable that needs to be created, and at what addresses in that
+// new strideTable should a and b's subsequent strideTables be
+// attached?
+//
+// Note as a special case, this can be called with a==b. An example of
+// when this happens:
+// - We want to insert the prefix 1.2.0.0/16
+// - A strideTable exists for 1.2.0.0/16, because another child
+// prefix already exists (e.g. 1.2.3.4/32)
+// - The 1.0.0.0/8 strideTable does not exist, because path
+// compression removed it.
+//
+// In this scenario, the caller of computePrefixSplit ends up making a
+// "wrong turn" while traversing strideTables: it was looking for the
+// 1.0.0.0/8 table, but ended up at the 1.2.0.0/16 table. When this
+// happens, it will invoke computePrefixSplit(1.2.0.0/16, 1.2.0.0/16),
+// and we return 1.0.0.0/8 as the missing intermediate.
+func computePrefixSplit(a, b netip.Prefix) (lastCommon netip.Prefix, aStride, bStride uint8) {
+ a = a.Masked()
+ b = b.Masked()
+ if a.Bits() == 0 || b.Bits() == 0 {
+ panic("computePrefixSplit called with a default route")
+ }
+ if a.Addr().Is4() != b.Addr().Is4() {
+ panic("computePrefixSplit called with mismatched address families")
+ }
+
+ minPrefixLen := a.Bits()
+ if b.Bits() < minPrefixLen {
+ minPrefixLen = b.Bits()
+ }
+
+ commonBits := commonBits(a.Addr(), b.Addr(), minPrefixLen)
+ // We want to know how many 8-bit strides are shared between a and
+ // b. Naively, this would be commonBits/8, but this introduces an
+ // off-by-one error. This is due to the way our ART stores
+ // prefixes whose length falls exactly on a stride boundary.
+ //
+ // Consider 192.168.1.0/24 and 192.168.0.0/16. commonBits
+ // correctly reports that these prefixes have their first 16 bits
+ // in common. However, in the ART they only share 1 common stride:
+ // they both use the 192.0.0.0/8 strideTable, but 192.168.0.0/16
+ // is stored as 168/8 within that table, and not as 0/0 in the
+ // 192.168.0.0/16 table.
+ //
+ // So, when commonBits matches the length of one of the inputs and
+ // falls on a boundary between strides, the strideTable one
+ // further up from commonBits/8 is the one we need to create,
+ // which means we have to adjust the stride count down by one.
+ if commonBits == minPrefixLen {
+ commonBits--
+ }
+ commonStrides := commonBits / 8
+ lastCommon, err := a.Addr().Prefix(commonStrides * 8)
+ if err != nil {
+ panic(fmt.Sprintf("computePrefixSplit constructing common prefix: %v", err))
+ }
+ if a.Addr().Is4() {
+ aStride = a.Addr().As4()[commonStrides]
+ bStride = b.Addr().As4()[commonStrides]
+ } else {
+ aStride = a.Addr().As16()[commonStrides]
+ bStride = b.Addr().As16()[commonStrides]
+ }
+ return lastCommon, aStride, bStride
+}
+
+// commonBits returns the number of common leading bits of a and b.
+// If the number of common bits exceeds maxBits, it returns maxBits
+// instead.
+func commonBits(a, b netip.Addr, maxBits int) int {
+ if a.Is4() != b.Is4() {
+ panic("commonStrides called with mismatched address families")
+ }
+ var common int
+ // The following implements an old bit-twiddling trick to compute
+ // the number of common leading bits: if you XOR two numbers
+ // together, equal bits become 0 and unequal bits become 1. You
+ // can then count the number of leading zeros (which is a single
+ // instruction on modern CPUs) to get the answer.
+ //
+ // This code is a little more complex than just XOR + count
+ // leading zeros, because IPv4 and IPv6 are different sizes, and
+ // for IPv6 we have to do the math in two 64-bit chunks because Go
+ // lacks a uint128 type.
+ if a.Is4() {
+ aNum, bNum := ipv4AsUint(a), ipv4AsUint(b)
+ common = bits.LeadingZeros32(aNum ^ bNum)
+ } else {
+ aNumHi, aNumLo := ipv6AsUint(a)
+ bNumHi, bNumLo := ipv6AsUint(b)
+ common = bits.LeadingZeros64(aNumHi ^ bNumHi)
+ if common == 64 {
+ common += bits.LeadingZeros64(aNumLo ^ bNumLo)
+ }
+ }
+ if common > maxBits {
+ common = maxBits
+ }
+ return common
+}
+
+// ipv4AsUint returns ip as a uint32.
+func ipv4AsUint(ip netip.Addr) uint32 {
+ bs := ip.As4()
+ return binary.BigEndian.Uint32(bs[:])
+}
+
+// ipv6AsUint returns ip as a pair of uint64s.
+func ipv6AsUint(ip netip.Addr) (uint64, uint64) {
+ bs := ip.As16()
+ return binary.BigEndian.Uint64(bs[:8]), binary.BigEndian.Uint64(bs[8:])
+}
diff --git a/net/dns/debian_resolvconf.go b/net/dns/debian_resolvconf.go
index 3ffc796e0..2a1fb18de 100644
--- a/net/dns/debian_resolvconf.go
+++ b/net/dns/debian_resolvconf.go
@@ -1,184 +1,184 @@
-// Copyright (c) Tailscale Inc & AUTHORS
-// SPDX-License-Identifier: BSD-3-Clause
-
-//go:build linux || freebsd || openbsd
-
-package dns
-
-import (
- "bufio"
- "bytes"
- _ "embed"
- "fmt"
- "os"
- "os/exec"
- "path/filepath"
-
- "tailscale.com/atomicfile"
- "tailscale.com/types/logger"
-)
-
-//go:embed resolvconf-workaround.sh
-var workaroundScript []byte
-
-// resolvconfConfigName is the name of the config submitted to
-// resolvconf.
-// The name starts with 'tun' in order to match the hardcoded
-// interface order in debian resolvconf, which will place this
-// configuration ahead of regular network links. In theory, this
-// doesn't matter because we then fix things up to ensure our config
-// is the only one in use, but in case that fails, this will make our
-// configuration slightly preferred.
-// The 'inet' suffix has no specific meaning, but conventionally
-// resolvconf implementations encourage adding a suffix roughly
-// indicating where the config came from, and "inet" is the "none of
-// the above" value (rather than, say, "ppp" or "dhcp").
-const resolvconfConfigName = "tun-tailscale.inet"
-
-// resolvconfLibcHookPath is the directory containing libc update
-// scripts, which are run by Debian resolvconf when /etc/resolv.conf
-// has been updated.
-const resolvconfLibcHookPath = "/etc/resolvconf/update-libc.d"
-
-// resolvconfHookPath is the name of the libc hook script we install
-// to force Tailscale's DNS config to take effect.
-var resolvconfHookPath = filepath.Join(resolvconfLibcHookPath, "tailscale")
-
-// resolvconfManager manages DNS configuration using the Debian
-// implementation of the `resolvconf` program, written by Thomas Hood.
-type resolvconfManager struct {
- logf logger.Logf
- listRecordsPath string
- interfacesDir string
- scriptInstalled bool // libc update script has been installed
-}
-
-func newDebianResolvconfManager(logf logger.Logf) (*resolvconfManager, error) {
- ret := &resolvconfManager{
- logf: logf,
- listRecordsPath: "/lib/resolvconf/list-records",
- interfacesDir: "/etc/resolvconf/run/interface", // panic fallback if nothing seems to work
- }
-
- if _, err := os.Stat(ret.listRecordsPath); os.IsNotExist(err) {
- // This might be a Debian system from before the big /usr
- // merge, try /usr instead.
- ret.listRecordsPath = "/usr" + ret.listRecordsPath
- }
- // The runtime directory is currently (2020-04) canonically
- // /etc/resolvconf/run, but the manpage is making noise about
- // switching to /run/resolvconf and dropping the /etc path. So,
- // let's probe the possible directories and use the first one
- // that works.
- for _, path := range []string{
- "/etc/resolvconf/run/interface",
- "/run/resolvconf/interface",
- "/var/run/resolvconf/interface",
- } {
- if _, err := os.Stat(path); err == nil {
- ret.interfacesDir = path
- break
- }
- }
- if ret.interfacesDir == "" {
- // None of the paths seem to work, use the canonical location
- // that the current manpage says to use.
- ret.interfacesDir = "/etc/resolvconf/run/interfaces"
- }
-
- return ret, nil
-}
-
-func (m *resolvconfManager) deleteTailscaleConfig() error {
- cmd := exec.Command("resolvconf", "-d", resolvconfConfigName)
- out, err := cmd.CombinedOutput()
- if err != nil {
- return fmt.Errorf("running %s: %s", cmd, out)
- }
- return nil
-}
-
-func (m *resolvconfManager) SetDNS(config OSConfig) error {
- if !m.scriptInstalled {
- m.logf("injecting resolvconf workaround script")
- if err := os.MkdirAll(resolvconfLibcHookPath, 0755); err != nil {
- return err
- }
- if err := atomicfile.WriteFile(resolvconfHookPath, workaroundScript, 0755); err != nil {
- return err
- }
- m.scriptInstalled = true
- }
-
- if config.IsZero() {
- if err := m.deleteTailscaleConfig(); err != nil {
- return err
- }
- } else {
- stdin := new(bytes.Buffer)
- writeResolvConf(stdin, config.Nameservers, config.SearchDomains) // dns_direct.go
-
- // This resolvconf implementation doesn't support exclusive
- // mode or interface priorities, so it will end up blending
- // our configuration with other sources. However, this will
- // get fixed up by the script we injected above.
- cmd := exec.Command("resolvconf", "-a", resolvconfConfigName)
- cmd.Stdin = stdin
- out, err := cmd.CombinedOutput()
- if err != nil {
- return fmt.Errorf("running %s: %s", cmd, out)
- }
- }
-
- return nil
-}
-
-func (m *resolvconfManager) SupportsSplitDNS() bool {
- return false
-}
-
-func (m *resolvconfManager) GetBaseConfig() (OSConfig, error) {
- var bs bytes.Buffer
-
- cmd := exec.Command(m.listRecordsPath)
- // list-records assumes it's being run with CWD set to the
- // interfaces runtime dir, and returns nonsense otherwise.
- cmd.Dir = m.interfacesDir
- cmd.Stdout = &bs
- if err := cmd.Run(); err != nil {
- return OSConfig{}, err
- }
-
- var conf bytes.Buffer
- sc := bufio.NewScanner(&bs)
- for sc.Scan() {
- if sc.Text() == resolvconfConfigName {
- continue
- }
- bs, err := os.ReadFile(filepath.Join(m.interfacesDir, sc.Text()))
- if err != nil {
- if os.IsNotExist(err) {
- // Probably raced with a deletion, that's okay.
- continue
- }
- return OSConfig{}, err
- }
- conf.Write(bs)
- conf.WriteByte('\n')
- }
-
- return readResolv(&conf)
-}
-
-func (m *resolvconfManager) Close() error {
- if err := m.deleteTailscaleConfig(); err != nil {
- return err
- }
-
- if m.scriptInstalled {
- m.logf("removing resolvconf workaround script")
- os.Remove(resolvconfHookPath) // Best-effort
- }
-
- return nil
-}
+// Copyright (c) Tailscale Inc & AUTHORS
+// SPDX-License-Identifier: BSD-3-Clause
+
+//go:build linux || freebsd || openbsd
+
+package dns
+
+import (
+ "bufio"
+ "bytes"
+ _ "embed"
+ "fmt"
+ "os"
+ "os/exec"
+ "path/filepath"
+
+ "tailscale.com/atomicfile"
+ "tailscale.com/types/logger"
+)
+
+//go:embed resolvconf-workaround.sh
+var workaroundScript []byte
+
+// resolvconfConfigName is the name of the config submitted to
+// resolvconf.
+// The name starts with 'tun' in order to match the hardcoded
+// interface order in debian resolvconf, which will place this
+// configuration ahead of regular network links. In theory, this
+// doesn't matter because we then fix things up to ensure our config
+// is the only one in use, but in case that fails, this will make our
+// configuration slightly preferred.
+// The 'inet' suffix has no specific meaning, but conventionally
+// resolvconf implementations encourage adding a suffix roughly
+// indicating where the config came from, and "inet" is the "none of
+// the above" value (rather than, say, "ppp" or "dhcp").
+const resolvconfConfigName = "tun-tailscale.inet"
+
+// resolvconfLibcHookPath is the directory containing libc update
+// scripts, which are run by Debian resolvconf when /etc/resolv.conf
+// has been updated.
+const resolvconfLibcHookPath = "/etc/resolvconf/update-libc.d"
+
+// resolvconfHookPath is the name of the libc hook script we install
+// to force Tailscale's DNS config to take effect.
+var resolvconfHookPath = filepath.Join(resolvconfLibcHookPath, "tailscale")
+
+// resolvconfManager manages DNS configuration using the Debian
+// implementation of the `resolvconf` program, written by Thomas Hood.
+type resolvconfManager struct {
+ logf logger.Logf
+ listRecordsPath string
+ interfacesDir string
+ scriptInstalled bool // libc update script has been installed
+}
+
+func newDebianResolvconfManager(logf logger.Logf) (*resolvconfManager, error) {
+ ret := &resolvconfManager{
+ logf: logf,
+ listRecordsPath: "/lib/resolvconf/list-records",
+ interfacesDir: "/etc/resolvconf/run/interface", // panic fallback if nothing seems to work
+ }
+
+ if _, err := os.Stat(ret.listRecordsPath); os.IsNotExist(err) {
+ // This might be a Debian system from before the big /usr
+ // merge, try /usr instead.
+ ret.listRecordsPath = "/usr" + ret.listRecordsPath
+ }
+ // The runtime directory is currently (2020-04) canonically
+ // /etc/resolvconf/run, but the manpage is making noise about
+ // switching to /run/resolvconf and dropping the /etc path. So,
+ // let's probe the possible directories and use the first one
+ // that works.
+ for _, path := range []string{
+ "/etc/resolvconf/run/interface",
+ "/run/resolvconf/interface",
+ "/var/run/resolvconf/interface",
+ } {
+ if _, err := os.Stat(path); err == nil {
+ ret.interfacesDir = path
+ break
+ }
+ }
+ if ret.interfacesDir == "" {
+ // None of the paths seem to work, use the canonical location
+ // that the current manpage says to use.
+ ret.interfacesDir = "/etc/resolvconf/run/interfaces"
+ }
+
+ return ret, nil
+}
+
+func (m *resolvconfManager) deleteTailscaleConfig() error {
+ cmd := exec.Command("resolvconf", "-d", resolvconfConfigName)
+ out, err := cmd.CombinedOutput()
+ if err != nil {
+ return fmt.Errorf("running %s: %s", cmd, out)
+ }
+ return nil
+}
+
+func (m *resolvconfManager) SetDNS(config OSConfig) error {
+ if !m.scriptInstalled {
+ m.logf("injecting resolvconf workaround script")
+ if err := os.MkdirAll(resolvconfLibcHookPath, 0755); err != nil {
+ return err
+ }
+ if err := atomicfile.WriteFile(resolvconfHookPath, workaroundScript, 0755); err != nil {
+ return err
+ }
+ m.scriptInstalled = true
+ }
+
+ if config.IsZero() {
+ if err := m.deleteTailscaleConfig(); err != nil {
+ return err
+ }
+ } else {
+ stdin := new(bytes.Buffer)
+ writeResolvConf(stdin, config.Nameservers, config.SearchDomains) // dns_direct.go
+
+ // This resolvconf implementation doesn't support exclusive
+ // mode or interface priorities, so it will end up blending
+ // our configuration with other sources. However, this will
+ // get fixed up by the script we injected above.
+ cmd := exec.Command("resolvconf", "-a", resolvconfConfigName)
+ cmd.Stdin = stdin
+ out, err := cmd.CombinedOutput()
+ if err != nil {
+ return fmt.Errorf("running %s: %s", cmd, out)
+ }
+ }
+
+ return nil
+}
+
+func (m *resolvconfManager) SupportsSplitDNS() bool {
+ return false
+}
+
+func (m *resolvconfManager) GetBaseConfig() (OSConfig, error) {
+ var bs bytes.Buffer
+
+ cmd := exec.Command(m.listRecordsPath)
+ // list-records assumes it's being run with CWD set to the
+ // interfaces runtime dir, and returns nonsense otherwise.
+ cmd.Dir = m.interfacesDir
+ cmd.Stdout = &bs
+ if err := cmd.Run(); err != nil {
+ return OSConfig{}, err
+ }
+
+ var conf bytes.Buffer
+ sc := bufio.NewScanner(&bs)
+ for sc.Scan() {
+ if sc.Text() == resolvconfConfigName {
+ continue
+ }
+ bs, err := os.ReadFile(filepath.Join(m.interfacesDir, sc.Text()))
+ if err != nil {
+ if os.IsNotExist(err) {
+ // Probably raced with a deletion, that's okay.
+ continue
+ }
+ return OSConfig{}, err
+ }
+ conf.Write(bs)
+ conf.WriteByte('\n')
+ }
+
+ return readResolv(&conf)
+}
+
+func (m *resolvconfManager) Close() error {
+ if err := m.deleteTailscaleConfig(); err != nil {
+ return err
+ }
+
+ if m.scriptInstalled {
+ m.logf("removing resolvconf workaround script")
+ os.Remove(resolvconfHookPath) // Best-effort
+ }
+
+ return nil
+}
diff --git a/net/dns/direct_notlinux.go b/net/dns/direct_notlinux.go
index c221ca1be..5bd8093d6 100644
--- a/net/dns/direct_notlinux.go
+++ b/net/dns/direct_notlinux.go
@@ -1,10 +1,10 @@
-// Copyright (c) Tailscale Inc & AUTHORS
-// SPDX-License-Identifier: BSD-3-Clause
-
-//go:build !linux
-
-package dns
-
-func (m *directManager) runFileWatcher() {
- // Not implemented on other platforms. Maybe it could resort to polling.
-}
+// Copyright (c) Tailscale Inc & AUTHORS
+// SPDX-License-Identifier: BSD-3-Clause
+
+//go:build !linux
+
+package dns
+
+func (m *directManager) runFileWatcher() {
+ // Not implemented on other platforms. Maybe it could resort to polling.
+}
diff --git a/net/dns/flush_default.go b/net/dns/flush_default.go
index eb6d9da41..73e446389 100644
--- a/net/dns/flush_default.go
+++ b/net/dns/flush_default.go
@@ -1,10 +1,10 @@
-// Copyright (c) Tailscale Inc & AUTHORS
-// SPDX-License-Identifier: BSD-3-Clause
-
-//go:build !windows
-
-package dns
-
-func flushCaches() error {
- return nil
-}
+// Copyright (c) Tailscale Inc & AUTHORS
+// SPDX-License-Identifier: BSD-3-Clause
+
+//go:build !windows
+
+package dns
+
+func flushCaches() error {
+ return nil
+}
diff --git a/net/dns/ini.go b/net/dns/ini.go
index 1e47d606e..deec04019 100644
--- a/net/dns/ini.go
+++ b/net/dns/ini.go
@@ -1,30 +1,30 @@
-// Copyright (c) Tailscale Inc & AUTHORS
-// SPDX-License-Identifier: BSD-3-Clause
-
-//go:build windows
-
-package dns
-
-import (
- "regexp"
- "strings"
-)
-
-// parseIni parses a basic .ini file, used for wsl.conf.
-func parseIni(data string) map[string]map[string]string {
- sectionRE := regexp.MustCompile(`^\[([^]]+)\]`)
- kvRE := regexp.MustCompile(`^\s*(\w+)\s*=\s*([^#]*)`)
-
- ini := map[string]map[string]string{}
- var section string
- for _, line := range strings.Split(data, "\n") {
- if res := sectionRE.FindStringSubmatch(line); len(res) > 1 {
- section = res[1]
- ini[section] = map[string]string{}
- } else if res := kvRE.FindStringSubmatch(line); len(res) > 2 {
- k, v := strings.TrimSpace(res[1]), strings.TrimSpace(res[2])
- ini[section][k] = v
- }
- }
- return ini
-}
+// Copyright (c) Tailscale Inc & AUTHORS
+// SPDX-License-Identifier: BSD-3-Clause
+
+//go:build windows
+
+package dns
+
+import (
+ "regexp"
+ "strings"
+)
+
+// parseIni parses a basic .ini file, used for wsl.conf.
+func parseIni(data string) map[string]map[string]string {
+ sectionRE := regexp.MustCompile(`^\[([^]]+)\]`)
+ kvRE := regexp.MustCompile(`^\s*(\w+)\s*=\s*([^#]*)`)
+
+ ini := map[string]map[string]string{}
+ var section string
+ for _, line := range strings.Split(data, "\n") {
+ if res := sectionRE.FindStringSubmatch(line); len(res) > 1 {
+ section = res[1]
+ ini[section] = map[string]string{}
+ } else if res := kvRE.FindStringSubmatch(line); len(res) > 2 {
+ k, v := strings.TrimSpace(res[1]), strings.TrimSpace(res[2])
+ ini[section][k] = v
+ }
+ }
+ return ini
+}
diff --git a/net/dns/ini_test.go b/net/dns/ini_test.go
index 3afe7009c..0e9eaa672 100644
--- a/net/dns/ini_test.go
+++ b/net/dns/ini_test.go
@@ -1,38 +1,38 @@
-// Copyright (c) Tailscale Inc & AUTHORS
-// SPDX-License-Identifier: BSD-3-Clause
-
-//go:build windows
-
-package dns
-
-import (
- "reflect"
- "testing"
-)
-
-func TestParseIni(t *testing.T) {
- var tests = []struct {
- src string
- want map[string]map[string]string
- }{
- {
- src: `# appended wsl.conf file
-[automount]
- enabled = true
- root=/mnt/
-# added by tailscale
-[network] # trailing comment
-generateResolvConf = false # trailing comment`,
- want: map[string]map[string]string{
- "automount": {"enabled": "true", "root": "/mnt/"},
- "network": {"generateResolvConf": "false"},
- },
- },
- }
- for _, test := range tests {
- got := parseIni(test.src)
- if !reflect.DeepEqual(got, test.want) {
- t.Errorf("for:\n%s\ngot: %v\nwant: %v", test.src, got, test.want)
- }
- }
-}
+// Copyright (c) Tailscale Inc & AUTHORS
+// SPDX-License-Identifier: BSD-3-Clause
+
+//go:build windows
+
+package dns
+
+import (
+ "reflect"
+ "testing"
+)
+
+func TestParseIni(t *testing.T) {
+ var tests = []struct {
+ src string
+ want map[string]map[string]string
+ }{
+ {
+ src: `# appended wsl.conf file
+[automount]
+ enabled = true
+ root=/mnt/
+# added by tailscale
+[network] # trailing comment
+generateResolvConf = false # trailing comment`,
+ want: map[string]map[string]string{
+ "automount": {"enabled": "true", "root": "/mnt/"},
+ "network": {"generateResolvConf": "false"},
+ },
+ },
+ }
+ for _, test := range tests {
+ got := parseIni(test.src)
+ if !reflect.DeepEqual(got, test.want) {
+ t.Errorf("for:\n%s\ngot: %v\nwant: %v", test.src, got, test.want)
+ }
+ }
+}
diff --git a/net/dns/noop.go b/net/dns/noop.go
index 9466b57a0..c90162668 100644
--- a/net/dns/noop.go
+++ b/net/dns/noop.go
@@ -1,17 +1,17 @@
-// Copyright (c) Tailscale Inc & AUTHORS
-// SPDX-License-Identifier: BSD-3-Clause
-
-package dns
-
-type noopManager struct{}
-
-func (m noopManager) SetDNS(OSConfig) error { return nil }
-func (m noopManager) SupportsSplitDNS() bool { return false }
-func (m noopManager) Close() error { return nil }
-func (m noopManager) GetBaseConfig() (OSConfig, error) {
- return OSConfig{}, ErrGetBaseConfigNotSupported
-}
-
-func NewNoopManager() (noopManager, error) {
- return noopManager{}, nil
-}
+// Copyright (c) Tailscale Inc & AUTHORS
+// SPDX-License-Identifier: BSD-3-Clause
+
+package dns
+
+type noopManager struct{}
+
+func (m noopManager) SetDNS(OSConfig) error { return nil }
+func (m noopManager) SupportsSplitDNS() bool { return false }
+func (m noopManager) Close() error { return nil }
+func (m noopManager) GetBaseConfig() (OSConfig, error) {
+ return OSConfig{}, ErrGetBaseConfigNotSupported
+}
+
+func NewNoopManager() (noopManager, error) {
+ return noopManager{}, nil
+}
diff --git a/net/dns/resolvconf-workaround.sh b/net/dns/resolvconf-workaround.sh
index aec6708a0..254b3949b 100644
--- a/net/dns/resolvconf-workaround.sh
+++ b/net/dns/resolvconf-workaround.sh
@@ -1,62 +1,62 @@
-#!/bin/sh
-# Copyright (c) Tailscale Inc & AUTHORS
-# SPDX-License-Identifier: BSD-3-Clause
-#
-# This script is a workaround for a vpn-unfriendly behavior of the
-# original resolvconf by Thomas Hood. Unlike the `openresolv`
-# implementation (whose binary is also called resolvconf,
-# confusingly), the original resolvconf lacks a way to specify
-# "exclusive mode" for a provider configuration. In practice, this
-# means that if Tailscale wants to install a DNS configuration, that
-# config will get "blended" with the configs from other sources,
-# rather than override those other sources.
-#
-# This script gets installed at /etc/resolvconf/update-libc.d, which
-# is a directory of hook scripts that get run after resolvconf's libc
-# helper has finished rewriting /etc/resolv.conf. It's meant to notify
-# consumers of resolv.conf of a new configuration.
-#
-# Instead, we use that hook mechanism to reach into resolvconf's
-# stuff, and rewrite the libc-generated resolv.conf to exclusively
-# contain Tailscale's configuration - effectively implementing
-# exclusive mode ourselves in post-production.
-
-set -e
-
-if [ -n "$TAILSCALE_RESOLVCONF_HOOK_LOOP" ]; then
- # Hook script being invoked by itself, skip.
- exit 0
-fi
-
-if [ ! -f tun-tailscale.inet ]; then
- # Tailscale isn't trying to manage DNS, do nothing.
- exit 0
-fi
-
-if ! grep resolvconf /etc/resolv.conf >/dev/null; then
- # resolvconf isn't managing /etc/resolv.conf, do nothing.
- exit 0
-fi
-
-# Write out a modified /etc/resolv.conf containing just our config.
-(
- if [ -f /etc/resolvconf/resolv.conf.d/head ]; then
- cat /etc/resolvconf/resolv.conf.d/head
- fi
- echo "# Tailscale workaround applied to set exclusive DNS configuration."
- cat tun-tailscale.inet
- if [ -f /etc/resolvconf/resolv.conf.d/base ]; then
- # Keep options and sortlist, discard other base things since
- # they're the things we're trying to override.
- grep -e 'sortlist ' -e 'options ' /etc/resolvconf/resolv.conf.d/base || true
- fi
- if [ -f /etc/resolvconf/resolv.conf.d/tail ]; then
- cat /etc/resolvconf/resolv.conf.d/tail
- fi
-) >/etc/resolv.conf
-
-if [ -d /etc/resolvconf/update-libc.d ] ; then
- # Re-notify libc watchers that we've changed resolv.conf again.
- export TAILSCALE_RESOLVCONF_HOOK_LOOP=1
- exec run-parts /etc/resolvconf/update-libc.d
-fi
+#!/bin/sh
+# Copyright (c) Tailscale Inc & AUTHORS
+# SPDX-License-Identifier: BSD-3-Clause
+#
+# This script is a workaround for a vpn-unfriendly behavior of the
+# original resolvconf by Thomas Hood. Unlike the `openresolv`
+# implementation (whose binary is also called resolvconf,
+# confusingly), the original resolvconf lacks a way to specify
+# "exclusive mode" for a provider configuration. In practice, this
+# means that if Tailscale wants to install a DNS configuration, that
+# config will get "blended" with the configs from other sources,
+# rather than override those other sources.
+#
+# This script gets installed at /etc/resolvconf/update-libc.d, which
+# is a directory of hook scripts that get run after resolvconf's libc
+# helper has finished rewriting /etc/resolv.conf. It's meant to notify
+# consumers of resolv.conf of a new configuration.
+#
+# Instead, we use that hook mechanism to reach into resolvconf's
+# stuff, and rewrite the libc-generated resolv.conf to exclusively
+# contain Tailscale's configuration - effectively implementing
+# exclusive mode ourselves in post-production.
+
+set -e
+
+if [ -n "$TAILSCALE_RESOLVCONF_HOOK_LOOP" ]; then
+ # Hook script being invoked by itself, skip.
+ exit 0
+fi
+
+if [ ! -f tun-tailscale.inet ]; then
+ # Tailscale isn't trying to manage DNS, do nothing.
+ exit 0
+fi
+
+if ! grep resolvconf /etc/resolv.conf >/dev/null; then
+ # resolvconf isn't managing /etc/resolv.conf, do nothing.
+ exit 0
+fi
+
+# Write out a modified /etc/resolv.conf containing just our config.
+(
+ if [ -f /etc/resolvconf/resolv.conf.d/head ]; then
+ cat /etc/resolvconf/resolv.conf.d/head
+ fi
+ echo "# Tailscale workaround applied to set exclusive DNS configuration."
+ cat tun-tailscale.inet
+ if [ -f /etc/resolvconf/resolv.conf.d/base ]; then
+ # Keep options and sortlist, discard other base things since
+ # they're the things we're trying to override.
+ grep -e 'sortlist ' -e 'options ' /etc/resolvconf/resolv.conf.d/base || true
+ fi
+ if [ -f /etc/resolvconf/resolv.conf.d/tail ]; then
+ cat /etc/resolvconf/resolv.conf.d/tail
+ fi
+) >/etc/resolv.conf
+
+if [ -d /etc/resolvconf/update-libc.d ] ; then
+ # Re-notify libc watchers that we've changed resolv.conf again.
+ export TAILSCALE_RESOLVCONF_HOOK_LOOP=1
+ exec run-parts /etc/resolvconf/update-libc.d
+fi
diff --git a/net/dns/resolvconf.go b/net/dns/resolvconf.go
index ca584ffcc..9e2a41c4a 100644
--- a/net/dns/resolvconf.go
+++ b/net/dns/resolvconf.go
@@ -1,30 +1,30 @@
-// Copyright (c) Tailscale Inc & AUTHORS
-// SPDX-License-Identifier: BSD-3-Clause
-
-//go:build linux || freebsd || openbsd
-
-package dns
-
-import (
- "bytes"
- "os/exec"
-)
-
-func resolvconfStyle() string {
- if _, err := exec.LookPath("resolvconf"); err != nil {
- return ""
- }
- output, err := exec.Command("resolvconf", "--version").CombinedOutput()
- if err != nil {
- // Debian resolvconf doesn't understand --version, and
- // exits with a specific error code.
- if exitErr, ok := err.(*exec.ExitError); ok && exitErr.ExitCode() == 99 {
- return "debian"
- }
- }
- if bytes.HasPrefix(output, []byte("Debian resolvconf")) {
- return "debian"
- }
- // Treat everything else as openresolv, by far the more popular implementation.
- return "openresolv"
-}
+// Copyright (c) Tailscale Inc & AUTHORS
+// SPDX-License-Identifier: BSD-3-Clause
+
+//go:build linux || freebsd || openbsd
+
+package dns
+
+import (
+ "bytes"
+ "os/exec"
+)
+
+func resolvconfStyle() string {
+ if _, err := exec.LookPath("resolvconf"); err != nil {
+ return ""
+ }
+ output, err := exec.Command("resolvconf", "--version").CombinedOutput()
+ if err != nil {
+ // Debian resolvconf doesn't understand --version, and
+ // exits with a specific error code.
+ if exitErr, ok := err.(*exec.ExitError); ok && exitErr.ExitCode() == 99 {
+ return "debian"
+ }
+ }
+ if bytes.HasPrefix(output, []byte("Debian resolvconf")) {
+ return "debian"
+ }
+ // Treat everything else as openresolv, by far the more popular implementation.
+ return "openresolv"
+}
diff --git a/net/dns/resolvconffile/resolvconffile.go b/net/dns/resolvconffile/resolvconffile.go
index 753000f6d..66c1600d8 100644
--- a/net/dns/resolvconffile/resolvconffile.go
+++ b/net/dns/resolvconffile/resolvconffile.go
@@ -1,124 +1,124 @@
-// Copyright (c) Tailscale Inc & AUTHORS
-// SPDX-License-Identifier: BSD-3-Clause
-
-// Package resolvconffile parses & serializes /etc/resolv.conf-style files.
-//
-// It's a leaf package so both net/dns and net/dns/resolver can depend
-// on it and we can unify a handful of implementations.
-//
-// The package is verbosely named to disambiguate it from resolvconf
-// the daemon, which Tailscale also supports.
-package resolvconffile
-
-import (
- "bufio"
- "bytes"
- "fmt"
- "io"
- "net/netip"
- "os"
- "strings"
-
- "tailscale.com/util/dnsname"
-)
-
-// Path is the canonical location of resolv.conf.
-const Path = "/etc/resolv.conf"
-
-// Config represents a resolv.conf(5) file.
-type Config struct {
- // Nameservers are the IP addresses of the nameservers to use.
- Nameservers []netip.Addr
-
- // SearchDomains are the domain suffixes to use when expanding
- // single-label name queries. SearchDomains is additive to
- // whatever non-Tailscale search domains the OS has.
- SearchDomains []dnsname.FQDN
-}
-
-// Write writes c to w. It does so in one Write call.
-func (c *Config) Write(w io.Writer) error {
- buf := new(bytes.Buffer)
- io.WriteString(buf, "# resolv.conf(5) file generated by tailscale\n")
- io.WriteString(buf, "# For more info, see https://tailscale.com/s/resolvconf-overwrite\n")
- io.WriteString(buf, "# DO NOT EDIT THIS FILE BY HAND -- CHANGES WILL BE OVERWRITTEN\n\n")
- for _, ns := range c.Nameservers {
- io.WriteString(buf, "nameserver ")
- io.WriteString(buf, ns.String())
- io.WriteString(buf, "\n")
- }
- if len(c.SearchDomains) > 0 {
- io.WriteString(buf, "search")
- for _, domain := range c.SearchDomains {
- io.WriteString(buf, " ")
- io.WriteString(buf, domain.WithoutTrailingDot())
- }
- io.WriteString(buf, "\n")
- }
- _, err := w.Write(buf.Bytes())
- return err
-}
-
-// Parse parses a resolv.conf file from r.
-func Parse(r io.Reader) (*Config, error) {
- config := new(Config)
- scanner := bufio.NewScanner(r)
- for scanner.Scan() {
- line := scanner.Text()
- line, _, _ = strings.Cut(line, "#") // remove any comments
- line = strings.TrimSpace(line)
-
- if s, ok := strings.CutPrefix(line, "nameserver"); ok {
- nameserver := strings.TrimSpace(s)
- if len(nameserver) == len(s) {
- return nil, fmt.Errorf("missing space after \"nameserver\" in %q", line)
- }
- ip, err := netip.ParseAddr(nameserver)
- if err != nil {
- return nil, err
- }
- config.Nameservers = append(config.Nameservers, ip)
- continue
- }
-
- if s, ok := strings.CutPrefix(line, "search"); ok {
- domains := strings.TrimSpace(s)
- if len(domains) == len(s) {
- // No leading space?!
- return nil, fmt.Errorf("missing space after \"search\" in %q", line)
- }
- for len(domains) > 0 {
- domain := domains
- i := strings.IndexAny(domain, " \t")
- if i != -1 {
- domain = domain[:i]
- domains = strings.TrimSpace(domains[i+1:])
- } else {
- domains = ""
- }
- fqdn, err := dnsname.ToFQDN(domain)
- if err != nil {
- return nil, fmt.Errorf("parsing search domain %q in %q: %w", domain, line, err)
- }
- config.SearchDomains = append(config.SearchDomains, fqdn)
- }
- }
- }
- return config, nil
-}
-
-// ParseFile parses the named resolv.conf file.
-func ParseFile(name string) (*Config, error) {
- fi, err := os.Stat(name)
- if err != nil {
- return nil, err
- }
- if n := fi.Size(); n > 10<<10 {
- return nil, fmt.Errorf("unexpectedly large %q file: %d bytes", name, n)
- }
- all, err := os.ReadFile(name)
- if err != nil {
- return nil, err
- }
- return Parse(bytes.NewReader(all))
-}
+// Copyright (c) Tailscale Inc & AUTHORS
+// SPDX-License-Identifier: BSD-3-Clause
+
+// Package resolvconffile parses & serializes /etc/resolv.conf-style files.
+//
+// It's a leaf package so both net/dns and net/dns/resolver can depend
+// on it and we can unify a handful of implementations.
+//
+// The package is verbosely named to disambiguate it from resolvconf
+// the daemon, which Tailscale also supports.
+package resolvconffile
+
+import (
+ "bufio"
+ "bytes"
+ "fmt"
+ "io"
+ "net/netip"
+ "os"
+ "strings"
+
+ "tailscale.com/util/dnsname"
+)
+
+// Path is the canonical location of resolv.conf.
+const Path = "/etc/resolv.conf"
+
+// Config represents a resolv.conf(5) file.
+type Config struct {
+ // Nameservers are the IP addresses of the nameservers to use.
+ Nameservers []netip.Addr
+
+ // SearchDomains are the domain suffixes to use when expanding
+ // single-label name queries. SearchDomains is additive to
+ // whatever non-Tailscale search domains the OS has.
+ SearchDomains []dnsname.FQDN
+}
+
+// Write writes c to w. It does so in one Write call.
+func (c *Config) Write(w io.Writer) error {
+ buf := new(bytes.Buffer)
+ io.WriteString(buf, "# resolv.conf(5) file generated by tailscale\n")
+ io.WriteString(buf, "# For more info, see https://tailscale.com/s/resolvconf-overwrite\n")
+ io.WriteString(buf, "# DO NOT EDIT THIS FILE BY HAND -- CHANGES WILL BE OVERWRITTEN\n\n")
+ for _, ns := range c.Nameservers {
+ io.WriteString(buf, "nameserver ")
+ io.WriteString(buf, ns.String())
+ io.WriteString(buf, "\n")
+ }
+ if len(c.SearchDomains) > 0 {
+ io.WriteString(buf, "search")
+ for _, domain := range c.SearchDomains {
+ io.WriteString(buf, " ")
+ io.WriteString(buf, domain.WithoutTrailingDot())
+ }
+ io.WriteString(buf, "\n")
+ }
+ _, err := w.Write(buf.Bytes())
+ return err
+}
+
+// Parse parses a resolv.conf file from r.
+func Parse(r io.Reader) (*Config, error) {
+ config := new(Config)
+ scanner := bufio.NewScanner(r)
+ for scanner.Scan() {
+ line := scanner.Text()
+ line, _, _ = strings.Cut(line, "#") // remove any comments
+ line = strings.TrimSpace(line)
+
+ if s, ok := strings.CutPrefix(line, "nameserver"); ok {
+ nameserver := strings.TrimSpace(s)
+ if len(nameserver) == len(s) {
+ return nil, fmt.Errorf("missing space after \"nameserver\" in %q", line)
+ }
+ ip, err := netip.ParseAddr(nameserver)
+ if err != nil {
+ return nil, err
+ }
+ config.Nameservers = append(config.Nameservers, ip)
+ continue
+ }
+
+ if s, ok := strings.CutPrefix(line, "search"); ok {
+ domains := strings.TrimSpace(s)
+ if len(domains) == len(s) {
+ // No leading space?!
+ return nil, fmt.Errorf("missing space after \"search\" in %q", line)
+ }
+ for len(domains) > 0 {
+ domain := domains
+ i := strings.IndexAny(domain, " \t")
+ if i != -1 {
+ domain = domain[:i]
+ domains = strings.TrimSpace(domains[i+1:])
+ } else {
+ domains = ""
+ }
+ fqdn, err := dnsname.ToFQDN(domain)
+ if err != nil {
+ return nil, fmt.Errorf("parsing search domain %q in %q: %w", domain, line, err)
+ }
+ config.SearchDomains = append(config.SearchDomains, fqdn)
+ }
+ }
+ }
+ return config, nil
+}
+
+// ParseFile parses the named resolv.conf file.
+func ParseFile(name string) (*Config, error) {
+ fi, err := os.Stat(name)
+ if err != nil {
+ return nil, err
+ }
+ if n := fi.Size(); n > 10<<10 {
+ return nil, fmt.Errorf("unexpectedly large %q file: %d bytes", name, n)
+ }
+ all, err := os.ReadFile(name)
+ if err != nil {
+ return nil, err
+ }
+ return Parse(bytes.NewReader(all))
+}
diff --git a/net/dns/resolvconfpath_default.go b/net/dns/resolvconfpath_default.go
index 57e82c4c7..02f24a0cf 100644
--- a/net/dns/resolvconfpath_default.go
+++ b/net/dns/resolvconfpath_default.go
@@ -1,11 +1,11 @@
-// Copyright (c) Tailscale Inc & AUTHORS
-// SPDX-License-Identifier: BSD-3-Clause
-
-//go:build !gokrazy
-
-package dns
-
-const (
- resolvConf = "/etc/resolv.conf"
- backupConf = "/etc/resolv.pre-tailscale-backup.conf"
-)
+// Copyright (c) Tailscale Inc & AUTHORS
+// SPDX-License-Identifier: BSD-3-Clause
+
+//go:build !gokrazy
+
+package dns
+
+const (
+ resolvConf = "/etc/resolv.conf"
+ backupConf = "/etc/resolv.pre-tailscale-backup.conf"
+)
diff --git a/net/dns/resolvconfpath_gokrazy.go b/net/dns/resolvconfpath_gokrazy.go
index f0759b0e3..6315596d2 100644
--- a/net/dns/resolvconfpath_gokrazy.go
+++ b/net/dns/resolvconfpath_gokrazy.go
@@ -1,11 +1,11 @@
-// Copyright (c) Tailscale Inc & AUTHORS
-// SPDX-License-Identifier: BSD-3-Clause
-
-//go:build gokrazy
-
-package dns
-
-const (
- resolvConf = "/tmp/resolv.conf"
- backupConf = "/tmp/resolv.pre-tailscale-backup.conf"
-)
+// Copyright (c) Tailscale Inc & AUTHORS
+// SPDX-License-Identifier: BSD-3-Clause
+
+//go:build gokrazy
+
+package dns
+
+const (
+ resolvConf = "/tmp/resolv.conf"
+ backupConf = "/tmp/resolv.pre-tailscale-backup.conf"
+)
diff --git a/net/dns/resolver/doh_test.go b/net/dns/resolver/doh_test.go
index a9c284761..d9ef970c2 100644
--- a/net/dns/resolver/doh_test.go
+++ b/net/dns/resolver/doh_test.go
@@ -1,99 +1,99 @@
-// Copyright (c) Tailscale Inc & AUTHORS
-// SPDX-License-Identifier: BSD-3-Clause
-
-package resolver
-
-import (
- "context"
- "flag"
- "net/http"
- "testing"
-
- "golang.org/x/net/dns/dnsmessage"
- "tailscale.com/net/dns/publicdns"
-)
-
-var testDoH = flag.Bool("test-doh", false, "do real DoH tests against the network")
-
-const someDNSID = 123 // something non-zero as a test; in violation of spec's SHOULD of 0
-
-func someDNSQuestion(t testing.TB) []byte {
- b := dnsmessage.NewBuilder(nil, dnsmessage.Header{
- OpCode: 0, // query
- RecursionDesired: true,
- ID: someDNSID,
- })
- b.StartQuestions() // err
- b.Question(dnsmessage.Question{
- Name: dnsmessage.MustNewName("tailscale.com."),
- Type: dnsmessage.TypeA,
- Class: dnsmessage.ClassINET,
- })
- msg, err := b.Finish()
- if err != nil {
- t.Fatal(err)
- }
- return msg
-}
-
-func TestDoH(t *testing.T) {
- if !*testDoH {
- t.Skip("skipping manual test without --test-doh flag")
- }
- prefixes := publicdns.KnownDoHPrefixes()
- if len(prefixes) == 0 {
- t.Fatal("no known DoH")
- }
-
- f := &forwarder{}
-
- for _, urlBase := range prefixes {
- t.Run(urlBase, func(t *testing.T) {
- c, ok := f.getKnownDoHClientForProvider(urlBase)
- if !ok {
- t.Fatal("expected DoH")
- }
- res, err := f.sendDoH(context.Background(), urlBase, c, someDNSQuestion(t))
- if err != nil {
- t.Fatal(err)
- }
- c.Transport.(*http.Transport).CloseIdleConnections()
-
- var p dnsmessage.Parser
- h, err := p.Start(res)
- if err != nil {
- t.Fatal(err)
- }
- if h.ID != someDNSID {
- t.Errorf("response DNS ID = %v; want %v", h.ID, someDNSID)
- }
-
- p.SkipAllQuestions()
- aa, err := p.AllAnswers()
- if err != nil {
- t.Fatal(err)
- }
- if len(aa) == 0 {
- t.Fatal("no answers")
- }
- for _, r := range aa {
- t.Logf("got: %v", r.GoString())
- }
- })
- }
-}
-
-func TestDoHV6Fallback(t *testing.T) {
- for _, base := range publicdns.KnownDoHPrefixes() {
- for _, ip := range publicdns.DoHIPsOfBase(base) {
- if ip.Is4() {
- ip6, ok := publicdns.DoHV6(base)
- if !ok {
- t.Errorf("no v6 DoH known for %v", ip)
- } else if !ip6.Is6() {
- t.Errorf("dohV6(%q) returned non-v6 address %v", base, ip6)
- }
- }
- }
- }
-}
+// Copyright (c) Tailscale Inc & AUTHORS
+// SPDX-License-Identifier: BSD-3-Clause
+
+package resolver
+
+import (
+ "context"
+ "flag"
+ "net/http"
+ "testing"
+
+ "golang.org/x/net/dns/dnsmessage"
+ "tailscale.com/net/dns/publicdns"
+)
+
+var testDoH = flag.Bool("test-doh", false, "do real DoH tests against the network")
+
+const someDNSID = 123 // something non-zero as a test; in violation of spec's SHOULD of 0
+
+func someDNSQuestion(t testing.TB) []byte {
+ b := dnsmessage.NewBuilder(nil, dnsmessage.Header{
+ OpCode: 0, // query
+ RecursionDesired: true,
+ ID: someDNSID,
+ })
+ b.StartQuestions() // err
+ b.Question(dnsmessage.Question{
+ Name: dnsmessage.MustNewName("tailscale.com."),
+ Type: dnsmessage.TypeA,
+ Class: dnsmessage.ClassINET,
+ })
+ msg, err := b.Finish()
+ if err != nil {
+ t.Fatal(err)
+ }
+ return msg
+}
+
+func TestDoH(t *testing.T) {
+ if !*testDoH {
+ t.Skip("skipping manual test without --test-doh flag")
+ }
+ prefixes := publicdns.KnownDoHPrefixes()
+ if len(prefixes) == 0 {
+ t.Fatal("no known DoH")
+ }
+
+ f := &forwarder{}
+
+ for _, urlBase := range prefixes {
+ t.Run(urlBase, func(t *testing.T) {
+ c, ok := f.getKnownDoHClientForProvider(urlBase)
+ if !ok {
+ t.Fatal("expected DoH")
+ }
+ res, err := f.sendDoH(context.Background(), urlBase, c, someDNSQuestion(t))
+ if err != nil {
+ t.Fatal(err)
+ }
+ c.Transport.(*http.Transport).CloseIdleConnections()
+
+ var p dnsmessage.Parser
+ h, err := p.Start(res)
+ if err != nil {
+ t.Fatal(err)
+ }
+ if h.ID != someDNSID {
+ t.Errorf("response DNS ID = %v; want %v", h.ID, someDNSID)
+ }
+
+ p.SkipAllQuestions()
+ aa, err := p.AllAnswers()
+ if err != nil {
+ t.Fatal(err)
+ }
+ if len(aa) == 0 {
+ t.Fatal("no answers")
+ }
+ for _, r := range aa {
+ t.Logf("got: %v", r.GoString())
+ }
+ })
+ }
+}
+
+func TestDoHV6Fallback(t *testing.T) {
+ for _, base := range publicdns.KnownDoHPrefixes() {
+ for _, ip := range publicdns.DoHIPsOfBase(base) {
+ if ip.Is4() {
+ ip6, ok := publicdns.DoHV6(base)
+ if !ok {
+ t.Errorf("no v6 DoH known for %v", ip)
+ } else if !ip6.Is6() {
+ t.Errorf("dohV6(%q) returned non-v6 address %v", base, ip6)
+ }
+ }
+ }
+ }
+}
diff --git a/net/dns/resolver/macios_ext.go b/net/dns/resolver/macios_ext.go
index e3f979c19..37cccc7f0 100644
--- a/net/dns/resolver/macios_ext.go
+++ b/net/dns/resolver/macios_ext.go
@@ -1,26 +1,26 @@
-// Copyright (c) Tailscale Inc & AUTHORS
-// SPDX-License-Identifier: BSD-3-Clause
-
-//go:build ts_macext && (darwin || ios)
-
-package resolver
-
-import (
- "errors"
- "net"
-
- "tailscale.com/net/netmon"
- "tailscale.com/net/netns"
-)
-
-func init() {
- initListenConfig = initListenConfigNetworkExtension
-}
-
-func initListenConfigNetworkExtension(nc *net.ListenConfig, netMon *netmon.Monitor, tunName string) error {
- nif, ok := netMon.InterfaceState().Interface[tunName]
- if !ok {
- return errors.New("utun not found")
- }
- return netns.SetListenConfigInterfaceIndex(nc, nif.Interface.Index)
-}
+// Copyright (c) Tailscale Inc & AUTHORS
+// SPDX-License-Identifier: BSD-3-Clause
+
+//go:build ts_macext && (darwin || ios)
+
+package resolver
+
+import (
+ "errors"
+ "net"
+
+ "tailscale.com/net/netmon"
+ "tailscale.com/net/netns"
+)
+
+func init() {
+ initListenConfig = initListenConfigNetworkExtension
+}
+
+func initListenConfigNetworkExtension(nc *net.ListenConfig, netMon *netmon.Monitor, tunName string) error {
+ nif, ok := netMon.InterfaceState().Interface[tunName]
+ if !ok {
+ return errors.New("utun not found")
+ }
+ return netns.SetListenConfigInterfaceIndex(nc, nif.Interface.Index)
+}
diff --git a/net/dns/resolver/tsdns_server_test.go b/net/dns/resolver/tsdns_server_test.go
index 82fd3bebf..be47cdfbc 100644
--- a/net/dns/resolver/tsdns_server_test.go
+++ b/net/dns/resolver/tsdns_server_test.go
@@ -1,333 +1,333 @@
-// Copyright (c) Tailscale Inc & AUTHORS
-// SPDX-License-Identifier: BSD-3-Clause
-
-package resolver
-
-import (
- "fmt"
- "net"
- "net/netip"
- "strings"
- "testing"
-
- "github.com/miekg/dns"
-)
-
-// This file exists to isolate the test infrastructure
-// that depends on github.com/miekg/dns
-// from the rest, which only depends on dnsmessage.
-
-// resolveToIP returns a handler function which responds
-// to queries of type A it receives with an A record containing ipv4,
-// to queries of type AAAA with an AAAA record containing ipv6,
-// to queries of type NS with an NS record containing name.
-func resolveToIP(ipv4, ipv6 netip.Addr, ns string) dns.HandlerFunc {
- return func(w dns.ResponseWriter, req *dns.Msg) {
- m := new(dns.Msg)
- m.SetReply(req)
-
- if len(req.Question) != 1 {
- panic("not a single-question request")
- }
- question := req.Question[0]
-
- var ans dns.RR
- switch question.Qtype {
- case dns.TypeA:
- ans = &dns.A{
- Hdr: dns.RR_Header{
- Name: question.Name,
- Rrtype: dns.TypeA,
- Class: dns.ClassINET,
- },
- A: ipv4.AsSlice(),
- }
- case dns.TypeAAAA:
- ans = &dns.AAAA{
- Hdr: dns.RR_Header{
- Name: question.Name,
- Rrtype: dns.TypeAAAA,
- Class: dns.ClassINET,
- },
- AAAA: ipv6.AsSlice(),
- }
- case dns.TypeNS:
- ans = &dns.NS{
- Hdr: dns.RR_Header{
- Name: question.Name,
- Rrtype: dns.TypeNS,
- Class: dns.ClassINET,
- },
- Ns: ns,
- }
- }
-
- m.Answer = append(m.Answer, ans)
- w.WriteMsg(m)
- }
-}
-
-// resolveToIPLowercase returns a handler function which canonicalizes responses
-// by lowercasing the question and answer names, and responds
-// to queries of type A it receives with an A record containing ipv4,
-// to queries of type AAAA with an AAAA record containing ipv6,
-// to queries of type NS with an NS record containing name.
-func resolveToIPLowercase(ipv4, ipv6 netip.Addr, ns string) dns.HandlerFunc {
- return func(w dns.ResponseWriter, req *dns.Msg) {
- m := new(dns.Msg)
- m.SetReply(req)
-
- if len(req.Question) != 1 {
- panic("not a single-question request")
- }
- m.Question[0].Name = strings.ToLower(m.Question[0].Name)
- question := req.Question[0]
-
- var ans dns.RR
- switch question.Qtype {
- case dns.TypeA:
- ans = &dns.A{
- Hdr: dns.RR_Header{
- Name: question.Name,
- Rrtype: dns.TypeA,
- Class: dns.ClassINET,
- },
- A: ipv4.AsSlice(),
- }
- case dns.TypeAAAA:
- ans = &dns.AAAA{
- Hdr: dns.RR_Header{
- Name: question.Name,
- Rrtype: dns.TypeAAAA,
- Class: dns.ClassINET,
- },
- AAAA: ipv6.AsSlice(),
- }
- case dns.TypeNS:
- ans = &dns.NS{
- Hdr: dns.RR_Header{
- Name: question.Name,
- Rrtype: dns.TypeNS,
- Class: dns.ClassINET,
- },
- Ns: ns,
- }
- }
-
- m.Answer = append(m.Answer, ans)
- w.WriteMsg(m)
- }
-}
-
-// resolveToTXT returns a handler function which responds to queries of type TXT
-// it receives with the strings in txts.
-func resolveToTXT(txts []string, ednsMaxSize uint16) dns.HandlerFunc {
- return func(w dns.ResponseWriter, req *dns.Msg) {
- m := new(dns.Msg)
- m.SetReply(req)
-
- if len(req.Question) != 1 {
- panic("not a single-question request")
- }
- question := req.Question[0]
-
- if question.Qtype != dns.TypeTXT {
- w.WriteMsg(m)
- return
- }
-
- ans := &dns.TXT{
- Hdr: dns.RR_Header{
- Name: question.Name,
- Rrtype: dns.TypeTXT,
- Class: dns.ClassINET,
- },
- Txt: txts,
- }
-
- m.Answer = append(m.Answer, ans)
-
- queryInfo := &dns.TXT{
- Hdr: dns.RR_Header{
- Name: "query-info.test.",
- Rrtype: dns.TypeTXT,
- Class: dns.ClassINET,
- },
- }
-
- if edns := req.IsEdns0(); edns == nil {
- queryInfo.Txt = []string{"EDNS=false"}
- } else {
- queryInfo.Txt = []string{"EDNS=true", fmt.Sprintf("maxSize=%v", edns.UDPSize())}
- }
-
- m.Extra = append(m.Extra, queryInfo)
-
- if ednsMaxSize > 0 {
- m.SetEdns0(ednsMaxSize, false)
- }
-
- if err := w.WriteMsg(m); err != nil {
- panic(err)
- }
- }
-}
-
-var resolveToNXDOMAIN = dns.HandlerFunc(func(w dns.ResponseWriter, req *dns.Msg) {
- m := new(dns.Msg)
- m.SetRcode(req, dns.RcodeNameError)
- w.WriteMsg(m)
-})
-
-// weirdoGoCNAMEHandler returns a DNS handler that satisfies
-// Go's weird Resolver.LookupCNAME (read its godoc carefully!).
-//
-// This doesn't even return a CNAME record, because that's not
-// what Go looks for.
-func weirdoGoCNAMEHandler(target string) dns.HandlerFunc {
- return func(w dns.ResponseWriter, req *dns.Msg) {
- m := new(dns.Msg)
- m.SetReply(req)
- question := req.Question[0]
-
- switch question.Qtype {
- case dns.TypeA:
- m.Answer = append(m.Answer, &dns.CNAME{
- Hdr: dns.RR_Header{
- Name: target,
- Rrtype: dns.TypeCNAME,
- Class: dns.ClassINET,
- Ttl: 600,
- },
- Target: target,
- })
- case dns.TypeAAAA:
- m.Answer = append(m.Answer, &dns.AAAA{
- Hdr: dns.RR_Header{
- Name: target,
- Rrtype: dns.TypeAAAA,
- Class: dns.ClassINET,
- Ttl: 600,
- },
- AAAA: net.ParseIP("1::2"),
- })
- }
- w.WriteMsg(m)
- }
-}
-
-// dnsHandler returns a handler that replies with the answers/options
-// provided.
-//
-// Types supported: netip.Addr.
-func dnsHandler(answers ...any) dns.HandlerFunc {
- return func(w dns.ResponseWriter, req *dns.Msg) {
- m := new(dns.Msg)
- m.SetReply(req)
- if len(req.Question) != 1 {
- panic("not a single-question request")
- }
- m.RecursionAvailable = true // to stop net package's errLameReferral on empty replies
-
- question := req.Question[0]
- for _, a := range answers {
- switch a := a.(type) {
- default:
- panic(fmt.Sprintf("unsupported dnsHandler arg %T", a))
- case netip.Addr:
- ip := a
- if ip.Is4() {
- m.Answer = append(m.Answer, &dns.A{
- Hdr: dns.RR_Header{
- Name: question.Name,
- Rrtype: dns.TypeA,
- Class: dns.ClassINET,
- },
- A: ip.AsSlice(),
- })
- } else if ip.Is6() {
- m.Answer = append(m.Answer, &dns.AAAA{
- Hdr: dns.RR_Header{
- Name: question.Name,
- Rrtype: dns.TypeAAAA,
- Class: dns.ClassINET,
- },
- AAAA: ip.AsSlice(),
- })
- }
- case dns.PTR:
- ptr := a
- ptr.Hdr = dns.RR_Header{
- Name: question.Name,
- Rrtype: dns.TypePTR,
- Class: dns.ClassINET,
- }
- m.Answer = append(m.Answer, &ptr)
- case dns.CNAME:
- c := a
- c.Hdr = dns.RR_Header{
- Name: question.Name,
- Rrtype: dns.TypeCNAME,
- Class: dns.ClassINET,
- Ttl: 600,
- }
- m.Answer = append(m.Answer, &c)
- case dns.TXT:
- txt := a
- txt.Hdr = dns.RR_Header{
- Name: question.Name,
- Rrtype: dns.TypeTXT,
- Class: dns.ClassINET,
- }
- m.Answer = append(m.Answer, &txt)
- case dns.SRV:
- srv := a
- srv.Hdr = dns.RR_Header{
- Name: question.Name,
- Rrtype: dns.TypeSRV,
- Class: dns.ClassINET,
- }
- m.Answer = append(m.Answer, &srv)
- case dns.NS:
- rr := a
- rr.Hdr = dns.RR_Header{
- Name: question.Name,
- Rrtype: dns.TypeNS,
- Class: dns.ClassINET,
- }
- m.Answer = append(m.Answer, &rr)
- }
- }
- w.WriteMsg(m)
- }
-}
-
-func serveDNS(tb testing.TB, addr string, records ...any) *dns.Server {
- if len(records)%2 != 0 {
- panic("must have an even number of record values")
- }
- mux := dns.NewServeMux()
- for i := 0; i < len(records); i += 2 {
- name := records[i].(string)
- handler := records[i+1].(dns.Handler)
- mux.Handle(name, handler)
- }
- waitch := make(chan struct{})
- server := &dns.Server{
- Addr: addr,
- Net: "udp",
- Handler: mux,
- NotifyStartedFunc: func() { close(waitch) },
- ReusePort: true,
- }
-
- go func() {
- err := server.ListenAndServe()
- if err != nil {
- panic(fmt.Sprintf("ListenAndServe(%q): %v", addr, err))
- }
- }()
-
- <-waitch
- return server
-}
+// Copyright (c) Tailscale Inc & AUTHORS
+// SPDX-License-Identifier: BSD-3-Clause
+
+package resolver
+
+import (
+ "fmt"
+ "net"
+ "net/netip"
+ "strings"
+ "testing"
+
+ "github.com/miekg/dns"
+)
+
+// This file exists to isolate the test infrastructure
+// that depends on github.com/miekg/dns
+// from the rest, which only depends on dnsmessage.
+
+// resolveToIP returns a handler function which responds
+// to queries of type A it receives with an A record containing ipv4,
+// to queries of type AAAA with an AAAA record containing ipv6,
+// to queries of type NS with an NS record containing name.
+func resolveToIP(ipv4, ipv6 netip.Addr, ns string) dns.HandlerFunc {
+ return func(w dns.ResponseWriter, req *dns.Msg) {
+ m := new(dns.Msg)
+ m.SetReply(req)
+
+ if len(req.Question) != 1 {
+ panic("not a single-question request")
+ }
+ question := req.Question[0]
+
+ var ans dns.RR
+ switch question.Qtype {
+ case dns.TypeA:
+ ans = &dns.A{
+ Hdr: dns.RR_Header{
+ Name: question.Name,
+ Rrtype: dns.TypeA,
+ Class: dns.ClassINET,
+ },
+ A: ipv4.AsSlice(),
+ }
+ case dns.TypeAAAA:
+ ans = &dns.AAAA{
+ Hdr: dns.RR_Header{
+ Name: question.Name,
+ Rrtype: dns.TypeAAAA,
+ Class: dns.ClassINET,
+ },
+ AAAA: ipv6.AsSlice(),
+ }
+ case dns.TypeNS:
+ ans = &dns.NS{
+ Hdr: dns.RR_Header{
+ Name: question.Name,
+ Rrtype: dns.TypeNS,
+ Class: dns.ClassINET,
+ },
+ Ns: ns,
+ }
+ }
+
+ m.Answer = append(m.Answer, ans)
+ w.WriteMsg(m)
+ }
+}
+
+// resolveToIPLowercase returns a handler function which canonicalizes responses
+// by lowercasing the question and answer names, and responds
+// to queries of type A it receives with an A record containing ipv4,
+// to queries of type AAAA with an AAAA record containing ipv6,
+// to queries of type NS with an NS record containing name.
+func resolveToIPLowercase(ipv4, ipv6 netip.Addr, ns string) dns.HandlerFunc {
+ return func(w dns.ResponseWriter, req *dns.Msg) {
+ m := new(dns.Msg)
+ m.SetReply(req)
+
+ if len(req.Question) != 1 {
+ panic("not a single-question request")
+ }
+ m.Question[0].Name = strings.ToLower(m.Question[0].Name)
+ question := req.Question[0]
+
+ var ans dns.RR
+ switch question.Qtype {
+ case dns.TypeA:
+ ans = &dns.A{
+ Hdr: dns.RR_Header{
+ Name: question.Name,
+ Rrtype: dns.TypeA,
+ Class: dns.ClassINET,
+ },
+ A: ipv4.AsSlice(),
+ }
+ case dns.TypeAAAA:
+ ans = &dns.AAAA{
+ Hdr: dns.RR_Header{
+ Name: question.Name,
+ Rrtype: dns.TypeAAAA,
+ Class: dns.ClassINET,
+ },
+ AAAA: ipv6.AsSlice(),
+ }
+ case dns.TypeNS:
+ ans = &dns.NS{
+ Hdr: dns.RR_Header{
+ Name: question.Name,
+ Rrtype: dns.TypeNS,
+ Class: dns.ClassINET,
+ },
+ Ns: ns,
+ }
+ }
+
+ m.Answer = append(m.Answer, ans)
+ w.WriteMsg(m)
+ }
+}
+
+// resolveToTXT returns a handler function which responds to queries of type TXT
+// it receives with the strings in txts.
+func resolveToTXT(txts []string, ednsMaxSize uint16) dns.HandlerFunc {
+ return func(w dns.ResponseWriter, req *dns.Msg) {
+ m := new(dns.Msg)
+ m.SetReply(req)
+
+ if len(req.Question) != 1 {
+ panic("not a single-question request")
+ }
+ question := req.Question[0]
+
+ if question.Qtype != dns.TypeTXT {
+ w.WriteMsg(m)
+ return
+ }
+
+ ans := &dns.TXT{
+ Hdr: dns.RR_Header{
+ Name: question.Name,
+ Rrtype: dns.TypeTXT,
+ Class: dns.ClassINET,
+ },
+ Txt: txts,
+ }
+
+ m.Answer = append(m.Answer, ans)
+
+ queryInfo := &dns.TXT{
+ Hdr: dns.RR_Header{
+ Name: "query-info.test.",
+ Rrtype: dns.TypeTXT,
+ Class: dns.ClassINET,
+ },
+ }
+
+ if edns := req.IsEdns0(); edns == nil {
+ queryInfo.Txt = []string{"EDNS=false"}
+ } else {
+ queryInfo.Txt = []string{"EDNS=true", fmt.Sprintf("maxSize=%v", edns.UDPSize())}
+ }
+
+ m.Extra = append(m.Extra, queryInfo)
+
+ if ednsMaxSize > 0 {
+ m.SetEdns0(ednsMaxSize, false)
+ }
+
+ if err := w.WriteMsg(m); err != nil {
+ panic(err)
+ }
+ }
+}
+
+var resolveToNXDOMAIN = dns.HandlerFunc(func(w dns.ResponseWriter, req *dns.Msg) {
+ m := new(dns.Msg)
+ m.SetRcode(req, dns.RcodeNameError)
+ w.WriteMsg(m)
+})
+
+// weirdoGoCNAMEHandler returns a DNS handler that satisfies
+// Go's weird Resolver.LookupCNAME (read its godoc carefully!).
+//
+// This doesn't even return a CNAME record, because that's not
+// what Go looks for.
+func weirdoGoCNAMEHandler(target string) dns.HandlerFunc {
+ return func(w dns.ResponseWriter, req *dns.Msg) {
+ m := new(dns.Msg)
+ m.SetReply(req)
+ question := req.Question[0]
+
+ switch question.Qtype {
+ case dns.TypeA:
+ m.Answer = append(m.Answer, &dns.CNAME{
+ Hdr: dns.RR_Header{
+ Name: target,
+ Rrtype: dns.TypeCNAME,
+ Class: dns.ClassINET,
+ Ttl: 600,
+ },
+ Target: target,
+ })
+ case dns.TypeAAAA:
+ m.Answer = append(m.Answer, &dns.AAAA{
+ Hdr: dns.RR_Header{
+ Name: target,
+ Rrtype: dns.TypeAAAA,
+ Class: dns.ClassINET,
+ Ttl: 600,
+ },
+ AAAA: net.ParseIP("1::2"),
+ })
+ }
+ w.WriteMsg(m)
+ }
+}
+
+// dnsHandler returns a handler that replies with the answers/options
+// provided.
+//
+// Types supported: netip.Addr.
+func dnsHandler(answers ...any) dns.HandlerFunc {
+ return func(w dns.ResponseWriter, req *dns.Msg) {
+ m := new(dns.Msg)
+ m.SetReply(req)
+ if len(req.Question) != 1 {
+ panic("not a single-question request")
+ }
+ m.RecursionAvailable = true // to stop net package's errLameReferral on empty replies
+
+ question := req.Question[0]
+ for _, a := range answers {
+ switch a := a.(type) {
+ default:
+ panic(fmt.Sprintf("unsupported dnsHandler arg %T", a))
+ case netip.Addr:
+ ip := a
+ if ip.Is4() {
+ m.Answer = append(m.Answer, &dns.A{
+ Hdr: dns.RR_Header{
+ Name: question.Name,
+ Rrtype: dns.TypeA,
+ Class: dns.ClassINET,
+ },
+ A: ip.AsSlice(),
+ })
+ } else if ip.Is6() {
+ m.Answer = append(m.Answer, &dns.AAAA{
+ Hdr: dns.RR_Header{
+ Name: question.Name,
+ Rrtype: dns.TypeAAAA,
+ Class: dns.ClassINET,
+ },
+ AAAA: ip.AsSlice(),
+ })
+ }
+ case dns.PTR:
+ ptr := a
+ ptr.Hdr = dns.RR_Header{
+ Name: question.Name,
+ Rrtype: dns.TypePTR,
+ Class: dns.ClassINET,
+ }
+ m.Answer = append(m.Answer, &ptr)
+ case dns.CNAME:
+ c := a
+ c.Hdr = dns.RR_Header{
+ Name: question.Name,
+ Rrtype: dns.TypeCNAME,
+ Class: dns.ClassINET,
+ Ttl: 600,
+ }
+ m.Answer = append(m.Answer, &c)
+ case dns.TXT:
+ txt := a
+ txt.Hdr = dns.RR_Header{
+ Name: question.Name,
+ Rrtype: dns.TypeTXT,
+ Class: dns.ClassINET,
+ }
+ m.Answer = append(m.Answer, &txt)
+ case dns.SRV:
+ srv := a
+ srv.Hdr = dns.RR_Header{
+ Name: question.Name,
+ Rrtype: dns.TypeSRV,
+ Class: dns.ClassINET,
+ }
+ m.Answer = append(m.Answer, &srv)
+ case dns.NS:
+ rr := a
+ rr.Hdr = dns.RR_Header{
+ Name: question.Name,
+ Rrtype: dns.TypeNS,
+ Class: dns.ClassINET,
+ }
+ m.Answer = append(m.Answer, &rr)
+ }
+ }
+ w.WriteMsg(m)
+ }
+}
+
+func serveDNS(tb testing.TB, addr string, records ...any) *dns.Server {
+ if len(records)%2 != 0 {
+ panic("must have an even number of record values")
+ }
+ mux := dns.NewServeMux()
+ for i := 0; i < len(records); i += 2 {
+ name := records[i].(string)
+ handler := records[i+1].(dns.Handler)
+ mux.Handle(name, handler)
+ }
+ waitch := make(chan struct{})
+ server := &dns.Server{
+ Addr: addr,
+ Net: "udp",
+ Handler: mux,
+ NotifyStartedFunc: func() { close(waitch) },
+ ReusePort: true,
+ }
+
+ go func() {
+ err := server.ListenAndServe()
+ if err != nil {
+ panic(fmt.Sprintf("ListenAndServe(%q): %v", addr, err))
+ }
+ }()
+
+ <-waitch
+ return server
+}
diff --git a/net/dns/utf.go b/net/dns/utf.go
index 0c1db69ac..267829c05 100644
--- a/net/dns/utf.go
+++ b/net/dns/utf.go
@@ -1,55 +1,55 @@
-// Copyright (c) Tailscale Inc & AUTHORS
-// SPDX-License-Identifier: BSD-3-Clause
-
-package dns
-
-// This code is only used in Windows builds, but is in an
-// OS-independent file so tests can run all the time.
-
-import (
- "bytes"
- "encoding/binary"
- "unicode/utf16"
-)
-
-// maybeUnUTF16 tries to detect whether bs contains UTF-16, and if so
-// translates it to regular UTF-8.
-//
-// Some of wsl.exe's output get printed as UTF-16, which breaks a
-// bunch of things. Try to detect this by looking for a zero byte in
-// the first few bytes of output (which will appear if any of those
-// codepoints are basic ASCII - very likely). From that we can infer
-// that UTF-16 is being printed, and the byte order in use, and we
-// decode that back to UTF-8.
-//
-// https://github.com/microsoft/WSL/issues/4607
-func maybeUnUTF16(bs []byte) []byte {
- if len(bs)%2 != 0 {
- // Can't be complete UTF-16.
- return bs
- }
- checkLen := 20
- if len(bs) < checkLen {
- checkLen = len(bs)
- }
- zeroOff := bytes.IndexByte(bs[:checkLen], 0)
- if zeroOff == -1 {
- return bs
- }
-
- // We assume wsl.exe is trying to print an ASCII codepoint,
- // meaning the zero byte is in the upper 8 bits of the
- // codepoint. That means we can use the zero's byte offset to
- // work out if we're seeing little-endian or big-endian
- // UTF-16.
- var endian binary.ByteOrder = binary.LittleEndian
- if zeroOff%2 == 0 {
- endian = binary.BigEndian
- }
-
- var u16 []uint16
- for i := 0; i < len(bs); i += 2 {
- u16 = append(u16, endian.Uint16(bs[i:]))
- }
- return []byte(string(utf16.Decode(u16)))
-}
+// Copyright (c) Tailscale Inc & AUTHORS
+// SPDX-License-Identifier: BSD-3-Clause
+
+package dns
+
+// This code is only used in Windows builds, but is in an
+// OS-independent file so tests can run all the time.
+
+import (
+ "bytes"
+ "encoding/binary"
+ "unicode/utf16"
+)
+
+// maybeUnUTF16 tries to detect whether bs contains UTF-16, and if so
+// translates it to regular UTF-8.
+//
+// Some of wsl.exe's output get printed as UTF-16, which breaks a
+// bunch of things. Try to detect this by looking for a zero byte in
+// the first few bytes of output (which will appear if any of those
+// codepoints are basic ASCII - very likely). From that we can infer
+// that UTF-16 is being printed, and the byte order in use, and we
+// decode that back to UTF-8.
+//
+// https://github.com/microsoft/WSL/issues/4607
+func maybeUnUTF16(bs []byte) []byte {
+ if len(bs)%2 != 0 {
+ // Can't be complete UTF-16.
+ return bs
+ }
+ checkLen := 20
+ if len(bs) < checkLen {
+ checkLen = len(bs)
+ }
+ zeroOff := bytes.IndexByte(bs[:checkLen], 0)
+ if zeroOff == -1 {
+ return bs
+ }
+
+ // We assume wsl.exe is trying to print an ASCII codepoint,
+ // meaning the zero byte is in the upper 8 bits of the
+ // codepoint. That means we can use the zero's byte offset to
+ // work out if we're seeing little-endian or big-endian
+ // UTF-16.
+ var endian binary.ByteOrder = binary.LittleEndian
+ if zeroOff%2 == 0 {
+ endian = binary.BigEndian
+ }
+
+ var u16 []uint16
+ for i := 0; i < len(bs); i += 2 {
+ u16 = append(u16, endian.Uint16(bs[i:]))
+ }
+ return []byte(string(utf16.Decode(u16)))
+}
diff --git a/net/dns/utf_test.go b/net/dns/utf_test.go
index b5fd37262..fcf593497 100644
--- a/net/dns/utf_test.go
+++ b/net/dns/utf_test.go
@@ -1,24 +1,24 @@
-// Copyright (c) Tailscale Inc & AUTHORS
-// SPDX-License-Identifier: BSD-3-Clause
-
-package dns
-
-import "testing"
-
-func TestMaybeUnUTF16(t *testing.T) {
- tests := []struct {
- in string
- want string
- }{
- {"abc", "abc"}, // UTF-8
- {"a\x00b\x00c\x00", "abc"}, // UTF-16-LE
- {"\x00a\x00b\x00c", "abc"}, // UTF-16-BE
- }
-
- for _, test := range tests {
- got := string(maybeUnUTF16([]byte(test.in)))
- if got != test.want {
- t.Errorf("maybeUnUTF16(%q) = %q, want %q", test.in, got, test.want)
- }
- }
-}
+// Copyright (c) Tailscale Inc & AUTHORS
+// SPDX-License-Identifier: BSD-3-Clause
+
+package dns
+
+import "testing"
+
+func TestMaybeUnUTF16(t *testing.T) {
+ tests := []struct {
+ in string
+ want string
+ }{
+ {"abc", "abc"}, // UTF-8
+ {"a\x00b\x00c\x00", "abc"}, // UTF-16-LE
+ {"\x00a\x00b\x00c", "abc"}, // UTF-16-BE
+ }
+
+ for _, test := range tests {
+ got := string(maybeUnUTF16([]byte(test.in)))
+ if got != test.want {
+ t.Errorf("maybeUnUTF16(%q) = %q, want %q", test.in, got, test.want)
+ }
+ }
+}
diff --git a/net/dnscache/dnscache_test.go b/net/dnscache/dnscache_test.go
index ef4249b74..6a4b96931 100644
--- a/net/dnscache/dnscache_test.go
+++ b/net/dnscache/dnscache_test.go
@@ -1,242 +1,242 @@
-// Copyright (c) Tailscale Inc & AUTHORS
-// SPDX-License-Identifier: BSD-3-Clause
-
-package dnscache
-
-import (
- "context"
- "errors"
- "flag"
- "fmt"
- "net"
- "net/netip"
- "reflect"
- "testing"
- "time"
-
- "tailscale.com/tstest"
-)
-
-var dialTest = flag.String("dial-test", "", "if non-empty, addr:port to test dial")
-
-func TestDialer(t *testing.T) {
- if *dialTest == "" {
- t.Skip("skipping; --dial-test is blank")
- }
- r := &Resolver{Logf: t.Logf}
- var std net.Dialer
- dialer := Dialer(std.DialContext, r)
- t0 := time.Now()
- ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
- defer cancel()
- c, err := dialer(ctx, "tcp", *dialTest)
- if err != nil {
- t.Fatal(err)
- }
- t.Logf("dialed in %v", time.Since(t0))
- c.Close()
-}
-
-func TestDialCall_DNSWasTrustworthy(t *testing.T) {
- type step struct {
- ip netip.Addr // IP we pretended to dial
- err error // the dial error or nil for success
- }
- mustIP := netip.MustParseAddr
- errFail := errors.New("some connect failure")
- tests := []struct {
- name string
- steps []step
- want bool
- }{
- {
- name: "no-info",
- want: false,
- },
- {
- name: "previous-dial",
- steps: []step{
- {mustIP("2003::1"), nil},
- {mustIP("2003::1"), errFail},
- },
- want: true,
- },
- {
- name: "no-previous-dial",
- steps: []step{
- {mustIP("2003::1"), errFail},
- },
- want: false,
- },
- }
- for _, tt := range tests {
- t.Run(tt.name, func(t *testing.T) {
- d := &dialer{
- pastConnect: map[netip.Addr]time.Time{},
- }
- dc := &dialCall{
- d: d,
- }
- for _, st := range tt.steps {
- dc.noteDialResult(st.ip, st.err)
- }
- got := dc.dnsWasTrustworthy()
- if got != tt.want {
- t.Errorf("got %v; want %v", got, tt.want)
- }
- })
- }
-}
-
-func TestDialCall_uniqueIPs(t *testing.T) {
- dc := &dialCall{}
- mustIP := netip.MustParseAddr
- errFail := errors.New("some connect failure")
- dc.noteDialResult(mustIP("2003::1"), errFail)
- dc.noteDialResult(mustIP("2003::2"), errFail)
- got := dc.uniqueIPs([]netip.Addr{
- mustIP("2003::1"),
- mustIP("2003::2"),
- mustIP("2003::2"),
- mustIP("2003::3"),
- mustIP("2003::3"),
- mustIP("2003::4"),
- mustIP("2003::4"),
- })
- want := []netip.Addr{
- mustIP("2003::3"),
- mustIP("2003::4"),
- }
- if !reflect.DeepEqual(got, want) {
- t.Errorf("got %v; want %v", got, want)
- }
-}
-
-func TestResolverAllHostStaticResult(t *testing.T) {
- r := &Resolver{
- Logf: t.Logf,
- SingleHost: "foo.bar",
- SingleHostStaticResult: []netip.Addr{
- netip.MustParseAddr("2001:4860:4860::8888"),
- netip.MustParseAddr("2001:4860:4860::8844"),
- netip.MustParseAddr("8.8.8.8"),
- netip.MustParseAddr("8.8.4.4"),
- },
- }
- ip4, ip6, allIPs, err := r.LookupIP(context.Background(), "foo.bar")
- if err != nil {
- t.Fatal(err)
- }
- if got, want := ip4.String(), "8.8.8.8"; got != want {
- t.Errorf("ip4 got %q; want %q", got, want)
- }
- if got, want := ip6.String(), "2001:4860:4860::8888"; got != want {
- t.Errorf("ip4 got %q; want %q", got, want)
- }
- if got, want := fmt.Sprintf("%q", allIPs), `["2001:4860:4860::8888" "2001:4860:4860::8844" "8.8.8.8" "8.8.4.4"]`; got != want {
- t.Errorf("allIPs got %q; want %q", got, want)
- }
-
- _, _, _, err = r.LookupIP(context.Background(), "bad")
- if got, want := fmt.Sprint(err), `dnscache: unexpected hostname "bad" doesn't match expected "foo.bar"`; got != want {
- t.Errorf("bad dial error got %q; want %q", got, want)
- }
-}
-
-func TestShouldTryBootstrap(t *testing.T) {
- tstest.Replace(t, &debug, func() bool { return true })
-
- type step struct {
- ip netip.Addr // IP we pretended to dial
- err error // the dial error or nil for success
- }
-
- canceled, cancel := context.WithCancel(context.Background())
- cancel()
-
- deadlineExceeded, cancel := context.WithTimeout(context.Background(), 0)
- defer cancel()
-
- ctx := context.Background()
- errFailed := errors.New("some failure")
-
- cacheWithFallback := &Resolver{
- Logf: t.Logf,
- LookupIPFallback: func(_ context.Context, _ string) ([]netip.Addr, error) {
- panic("unimplemented")
- },
- }
- cacheNoFallback := &Resolver{Logf: t.Logf}
-
- testCases := []struct {
- name string
- steps []step
- ctx context.Context
- err error
- noFallback bool
- want bool
- }{
- {
- name: "no-error",
- ctx: ctx,
- err: nil,
- want: false,
- },
- {
- name: "canceled",
- ctx: canceled,
- err: errFailed,
- want: false,
- },
- {
- name: "deadline-exceeded",
- ctx: deadlineExceeded,
- err: errFailed,
- want: false,
- },
- {
- name: "no-fallback",
- ctx: ctx,
- err: errFailed,
- noFallback: true,
- want: false,
- },
- {
- name: "dns-was-trustworthy",
- ctx: ctx,
- err: errFailed,
- steps: []step{
- {netip.MustParseAddr("2003::1"), nil},
- {netip.MustParseAddr("2003::1"), errFailed},
- },
- want: false,
- },
- {
- name: "should-bootstrap",
- ctx: ctx,
- err: errFailed,
- want: true,
- },
- }
-
- for _, tt := range testCases {
- t.Run(tt.name, func(t *testing.T) {
- d := &dialer{
- pastConnect: map[netip.Addr]time.Time{},
- }
- if tt.noFallback {
- d.dnsCache = cacheNoFallback
- } else {
- d.dnsCache = cacheWithFallback
- }
- dc := &dialCall{d: d}
- for _, st := range tt.steps {
- dc.noteDialResult(st.ip, st.err)
- }
- got := d.shouldTryBootstrap(tt.ctx, tt.err, dc)
- if got != tt.want {
- t.Errorf("got %v; want %v", got, tt.want)
- }
- })
- }
-}
+// Copyright (c) Tailscale Inc & AUTHORS
+// SPDX-License-Identifier: BSD-3-Clause
+
+package dnscache
+
+import (
+ "context"
+ "errors"
+ "flag"
+ "fmt"
+ "net"
+ "net/netip"
+ "reflect"
+ "testing"
+ "time"
+
+ "tailscale.com/tstest"
+)
+
+var dialTest = flag.String("dial-test", "", "if non-empty, addr:port to test dial")
+
+func TestDialer(t *testing.T) {
+ if *dialTest == "" {
+ t.Skip("skipping; --dial-test is blank")
+ }
+ r := &Resolver{Logf: t.Logf}
+ var std net.Dialer
+ dialer := Dialer(std.DialContext, r)
+ t0 := time.Now()
+ ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
+ defer cancel()
+ c, err := dialer(ctx, "tcp", *dialTest)
+ if err != nil {
+ t.Fatal(err)
+ }
+ t.Logf("dialed in %v", time.Since(t0))
+ c.Close()
+}
+
+func TestDialCall_DNSWasTrustworthy(t *testing.T) {
+ type step struct {
+ ip netip.Addr // IP we pretended to dial
+ err error // the dial error or nil for success
+ }
+ mustIP := netip.MustParseAddr
+ errFail := errors.New("some connect failure")
+ tests := []struct {
+ name string
+ steps []step
+ want bool
+ }{
+ {
+ name: "no-info",
+ want: false,
+ },
+ {
+ name: "previous-dial",
+ steps: []step{
+ {mustIP("2003::1"), nil},
+ {mustIP("2003::1"), errFail},
+ },
+ want: true,
+ },
+ {
+ name: "no-previous-dial",
+ steps: []step{
+ {mustIP("2003::1"), errFail},
+ },
+ want: false,
+ },
+ }
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ d := &dialer{
+ pastConnect: map[netip.Addr]time.Time{},
+ }
+ dc := &dialCall{
+ d: d,
+ }
+ for _, st := range tt.steps {
+ dc.noteDialResult(st.ip, st.err)
+ }
+ got := dc.dnsWasTrustworthy()
+ if got != tt.want {
+ t.Errorf("got %v; want %v", got, tt.want)
+ }
+ })
+ }
+}
+
+func TestDialCall_uniqueIPs(t *testing.T) {
+ dc := &dialCall{}
+ mustIP := netip.MustParseAddr
+ errFail := errors.New("some connect failure")
+ dc.noteDialResult(mustIP("2003::1"), errFail)
+ dc.noteDialResult(mustIP("2003::2"), errFail)
+ got := dc.uniqueIPs([]netip.Addr{
+ mustIP("2003::1"),
+ mustIP("2003::2"),
+ mustIP("2003::2"),
+ mustIP("2003::3"),
+ mustIP("2003::3"),
+ mustIP("2003::4"),
+ mustIP("2003::4"),
+ })
+ want := []netip.Addr{
+ mustIP("2003::3"),
+ mustIP("2003::4"),
+ }
+ if !reflect.DeepEqual(got, want) {
+ t.Errorf("got %v; want %v", got, want)
+ }
+}
+
+func TestResolverAllHostStaticResult(t *testing.T) {
+ r := &Resolver{
+ Logf: t.Logf,
+ SingleHost: "foo.bar",
+ SingleHostStaticResult: []netip.Addr{
+ netip.MustParseAddr("2001:4860:4860::8888"),
+ netip.MustParseAddr("2001:4860:4860::8844"),
+ netip.MustParseAddr("8.8.8.8"),
+ netip.MustParseAddr("8.8.4.4"),
+ },
+ }
+ ip4, ip6, allIPs, err := r.LookupIP(context.Background(), "foo.bar")
+ if err != nil {
+ t.Fatal(err)
+ }
+ if got, want := ip4.String(), "8.8.8.8"; got != want {
+ t.Errorf("ip4 got %q; want %q", got, want)
+ }
+ if got, want := ip6.String(), "2001:4860:4860::8888"; got != want {
+ t.Errorf("ip4 got %q; want %q", got, want)
+ }
+ if got, want := fmt.Sprintf("%q", allIPs), `["2001:4860:4860::8888" "2001:4860:4860::8844" "8.8.8.8" "8.8.4.4"]`; got != want {
+ t.Errorf("allIPs got %q; want %q", got, want)
+ }
+
+ _, _, _, err = r.LookupIP(context.Background(), "bad")
+ if got, want := fmt.Sprint(err), `dnscache: unexpected hostname "bad" doesn't match expected "foo.bar"`; got != want {
+ t.Errorf("bad dial error got %q; want %q", got, want)
+ }
+}
+
+func TestShouldTryBootstrap(t *testing.T) {
+ tstest.Replace(t, &debug, func() bool { return true })
+
+ type step struct {
+ ip netip.Addr // IP we pretended to dial
+ err error // the dial error or nil for success
+ }
+
+ canceled, cancel := context.WithCancel(context.Background())
+ cancel()
+
+ deadlineExceeded, cancel := context.WithTimeout(context.Background(), 0)
+ defer cancel()
+
+ ctx := context.Background()
+ errFailed := errors.New("some failure")
+
+ cacheWithFallback := &Resolver{
+ Logf: t.Logf,
+ LookupIPFallback: func(_ context.Context, _ string) ([]netip.Addr, error) {
+ panic("unimplemented")
+ },
+ }
+ cacheNoFallback := &Resolver{Logf: t.Logf}
+
+ testCases := []struct {
+ name string
+ steps []step
+ ctx context.Context
+ err error
+ noFallback bool
+ want bool
+ }{
+ {
+ name: "no-error",
+ ctx: ctx,
+ err: nil,
+ want: false,
+ },
+ {
+ name: "canceled",
+ ctx: canceled,
+ err: errFailed,
+ want: false,
+ },
+ {
+ name: "deadline-exceeded",
+ ctx: deadlineExceeded,
+ err: errFailed,
+ want: false,
+ },
+ {
+ name: "no-fallback",
+ ctx: ctx,
+ err: errFailed,
+ noFallback: true,
+ want: false,
+ },
+ {
+ name: "dns-was-trustworthy",
+ ctx: ctx,
+ err: errFailed,
+ steps: []step{
+ {netip.MustParseAddr("2003::1"), nil},
+ {netip.MustParseAddr("2003::1"), errFailed},
+ },
+ want: false,
+ },
+ {
+ name: "should-bootstrap",
+ ctx: ctx,
+ err: errFailed,
+ want: true,
+ },
+ }
+
+ for _, tt := range testCases {
+ t.Run(tt.name, func(t *testing.T) {
+ d := &dialer{
+ pastConnect: map[netip.Addr]time.Time{},
+ }
+ if tt.noFallback {
+ d.dnsCache = cacheNoFallback
+ } else {
+ d.dnsCache = cacheWithFallback
+ }
+ dc := &dialCall{d: d}
+ for _, st := range tt.steps {
+ dc.noteDialResult(st.ip, st.err)
+ }
+ got := d.shouldTryBootstrap(tt.ctx, tt.err, dc)
+ if got != tt.want {
+ t.Errorf("got %v; want %v", got, tt.want)
+ }
+ })
+ }
+}
diff --git a/net/dnscache/messagecache_test.go b/net/dnscache/messagecache_test.go
index 41fc33448..18af32459 100644
--- a/net/dnscache/messagecache_test.go
+++ b/net/dnscache/messagecache_test.go
@@ -1,291 +1,291 @@
-// Copyright (c) Tailscale Inc & AUTHORS
-// SPDX-License-Identifier: BSD-3-Clause
-
-package dnscache
-
-import (
- "bytes"
- "context"
- "errors"
- "fmt"
- "net"
- "runtime"
- "testing"
- "time"
-
- "golang.org/x/net/dns/dnsmessage"
- "tailscale.com/tstest"
-)
-
-func TestMessageCache(t *testing.T) {
- clock := tstest.NewClock(tstest.ClockOpts{
- Start: time.Date(1987, 11, 1, 0, 0, 0, 0, time.UTC),
- })
- mc := &MessageCache{Clock: clock.Now}
- mc.SetMaxCacheSize(2)
- clock.Advance(time.Second)
-
- var out bytes.Buffer
- if err := mc.ReplyFromCache(&out, makeQ(1, "foo.com.")); err != ErrCacheMiss {
- t.Fatalf("unexpected error: %v", err)
- }
-
- if err := mc.AddCacheEntry(
- makeQ(2, "foo.com."),
- makeRes(2, "FOO.COM.", ttlOpt(10),
- &dnsmessage.AResource{A: [4]byte{127, 0, 0, 1}},
- &dnsmessage.AResource{A: [4]byte{127, 0, 0, 2}})); err != nil {
- t.Fatal(err)
- }
-
- // Expect cache hit, with 10 seconds remaining.
- out.Reset()
- if err := mc.ReplyFromCache(&out, makeQ(3, "foo.com.")); err != nil {
- t.Fatalf("expected cache hit; got: %v", err)
- }
- if p := mustParseResponse(t, out.Bytes()); p.TxID != 3 {
- t.Errorf("TxID = %v; want %v", p.TxID, 3)
- } else if p.TTL != 10 {
- t.Errorf("TTL = %v; want 10", p.TTL)
- }
-
- // One second elapses, expect a cache hit, with 9 seconds
- // remaining.
- clock.Advance(time.Second)
- out.Reset()
- if err := mc.ReplyFromCache(&out, makeQ(4, "foo.com.")); err != nil {
- t.Fatalf("expected cache hit; got: %v", err)
- }
- if p := mustParseResponse(t, out.Bytes()); p.TxID != 4 {
- t.Errorf("TxID = %v; want %v", p.TxID, 4)
- } else if p.TTL != 9 {
- t.Errorf("TTL = %v; want 9", p.TTL)
- }
-
- // Expect cache miss on MX record.
- if err := mc.ReplyFromCache(&out, makeQ(4, "foo.com.", dnsmessage.TypeMX)); err != ErrCacheMiss {
- t.Fatalf("expected cache miss on MX; got: %v", err)
- }
- // Expect cache miss on CHAOS class.
- if err := mc.ReplyFromCache(&out, makeQ(4, "foo.com.", dnsmessage.ClassCHAOS)); err != ErrCacheMiss {
- t.Fatalf("expected cache miss on CHAOS; got: %v", err)
- }
-
- // Ten seconds elapses; expect a cache miss.
- clock.Advance(10 * time.Second)
- if err := mc.ReplyFromCache(&out, makeQ(5, "foo.com.")); err != ErrCacheMiss {
- t.Fatalf("expected cache miss, got: %v", err)
- }
-}
-
-type parsedMeta struct {
- TxID uint16
- TTL uint32
-}
-
-func mustParseResponse(t testing.TB, r []byte) (ret parsedMeta) {
- t.Helper()
- var p dnsmessage.Parser
- h, err := p.Start(r)
- if err != nil {
- t.Fatal(err)
- }
- ret.TxID = h.ID
- qq, err := p.AllQuestions()
- if err != nil {
- t.Fatalf("AllQuestions: %v", err)
- }
- if len(qq) != 1 {
- t.Fatalf("num questions = %v; want 1", len(qq))
- }
- aa, err := p.AllAnswers()
- if err != nil {
- t.Fatalf("AllAnswers: %v", err)
- }
- for _, r := range aa {
- if ret.TTL == 0 {
- ret.TTL = r.Header.TTL
- }
- if ret.TTL != r.Header.TTL {
- t.Fatal("mixed TTLs")
- }
- }
- return ret
-}
-
-type responseOpt bool
-
-type ttlOpt uint32
-
-func makeQ(txID uint16, name string, opt ...any) []byte {
- opt = append(opt, responseOpt(false))
- return makeDNSPkt(txID, name, opt...)
-}
-
-func makeRes(txID uint16, name string, opt ...any) []byte {
- opt = append(opt, responseOpt(true))
- return makeDNSPkt(txID, name, opt...)
-}
-
-func makeDNSPkt(txID uint16, name string, opt ...any) []byte {
- typ := dnsmessage.TypeA
- class := dnsmessage.ClassINET
- var response bool
- var answers []dnsmessage.ResourceBody
- var ttl uint32 = 1 // one second by default
- for _, o := range opt {
- switch o := o.(type) {
- case dnsmessage.Type:
- typ = o
- case dnsmessage.Class:
- class = o
- case responseOpt:
- response = bool(o)
- case dnsmessage.ResourceBody:
- answers = append(answers, o)
- case ttlOpt:
- ttl = uint32(o)
- default:
- panic(fmt.Sprintf("unknown opt type %T", o))
- }
- }
- qname := dnsmessage.MustNewName(name)
- msg := dnsmessage.Message{
- Header: dnsmessage.Header{ID: txID, Response: response},
- Questions: []dnsmessage.Question{
- {
- Name: qname,
- Type: typ,
- Class: class,
- },
- },
- }
- for _, rb := range answers {
- msg.Answers = append(msg.Answers, dnsmessage.Resource{
- Header: dnsmessage.ResourceHeader{
- Name: qname,
- Type: typ,
- Class: class,
- TTL: ttl,
- },
- Body: rb,
- })
- }
- buf, err := msg.Pack()
- if err != nil {
- panic(err)
- }
- return buf
-}
-
-func TestASCIILowerName(t *testing.T) {
- n := asciiLowerName(dnsmessage.MustNewName("Foo.COM."))
- if got, want := n.String(), "foo.com."; got != want {
- t.Errorf("got = %q; want %q", got, want)
- }
-}
-
-func TestGetDNSQueryCacheKey(t *testing.T) {
- tests := []struct {
- name string
- pkt []byte
- want msgQ
- txID uint16
- anyTX bool
- }{
- {
- name: "empty",
- },
- {
- name: "a",
- pkt: makeQ(123, "foo.com."),
- want: msgQ{"foo.com.", dnsmessage.TypeA},
- txID: 123,
- },
- {
- name: "aaaa",
- pkt: makeQ(6, "foo.com.", dnsmessage.TypeAAAA),
- want: msgQ{"foo.com.", dnsmessage.TypeAAAA},
- txID: 6,
- },
- {
- name: "normalize_case",
- pkt: makeQ(123, "FoO.CoM."),
- want: msgQ{"foo.com.", dnsmessage.TypeA},
- txID: 123,
- },
- {
- name: "ignore_response",
- pkt: makeRes(123, "foo.com."),
- },
- {
- name: "ignore_question_with_answers",
- pkt: makeQ(2, "foo.com.", &dnsmessage.AResource{A: [4]byte{127, 0, 0, 1}}),
- },
- {
- name: "whatever_go_generates", // in case Go's net package grows functionality we don't handle
- pkt: getGoNetPacketDNSQuery("from-go.foo."),
- want: msgQ{"from-go.foo.", dnsmessage.TypeA},
- anyTX: true,
- },
- }
- for _, tt := range tests {
- t.Run(tt.name, func(t *testing.T) {
- got, gotTX, ok := getDNSQueryCacheKey(tt.pkt)
- if !ok {
- if tt.txID == 0 && got == (msgQ{}) {
- return
- }
- t.Fatal("failed")
- }
- if got != tt.want {
- t.Errorf("got %+v, want %+v", got, tt.want)
- }
- if gotTX != tt.txID && !tt.anyTX {
- t.Errorf("got tx %v, want %v", gotTX, tt.txID)
- }
- })
- }
-}
-
-func getGoNetPacketDNSQuery(name string) []byte {
- if runtime.GOOS == "windows" {
- // On Windows, Go's net.Resolver doesn't use the DNS client.
- // See https://github.com/golang/go/issues/33097 which
- // was approved but not yet implemented.
- // For now just pretend it's implemented to make this test
- // pass on Windows with complicated the caller.
- return makeQ(123, name)
- }
- res := make(chan []byte, 1)
- r := &net.Resolver{
- PreferGo: true,
- Dial: func(ctx context.Context, network, address string) (net.Conn, error) {
- return goResolverConn(res), nil
- },
- }
- r.LookupIP(context.Background(), "ip4", name)
- return <-res
-}
-
-type goResolverConn chan<- []byte
-
-func (goResolverConn) Close() error { return nil }
-func (goResolverConn) LocalAddr() net.Addr { return todoAddr{} }
-func (goResolverConn) RemoteAddr() net.Addr { return todoAddr{} }
-func (goResolverConn) SetDeadline(t time.Time) error { return nil }
-func (goResolverConn) SetReadDeadline(t time.Time) error { return nil }
-func (goResolverConn) SetWriteDeadline(t time.Time) error { return nil }
-func (goResolverConn) Read([]byte) (int, error) { return 0, errors.New("boom") }
-func (c goResolverConn) Write(p []byte) (int, error) {
- select {
- case c <- p[2:]: // skip 2 byte length for TCP mode DNS query
- default:
- }
- return 0, errors.New("boom")
-}
-
-type todoAddr struct{}
-
-func (todoAddr) Network() string { return "unused" }
-func (todoAddr) String() string { return "unused-todoAddr" }
+// Copyright (c) Tailscale Inc & AUTHORS
+// SPDX-License-Identifier: BSD-3-Clause
+
+package dnscache
+
+import (
+ "bytes"
+ "context"
+ "errors"
+ "fmt"
+ "net"
+ "runtime"
+ "testing"
+ "time"
+
+ "golang.org/x/net/dns/dnsmessage"
+ "tailscale.com/tstest"
+)
+
+func TestMessageCache(t *testing.T) {
+ clock := tstest.NewClock(tstest.ClockOpts{
+ Start: time.Date(1987, 11, 1, 0, 0, 0, 0, time.UTC),
+ })
+ mc := &MessageCache{Clock: clock.Now}
+ mc.SetMaxCacheSize(2)
+ clock.Advance(time.Second)
+
+ var out bytes.Buffer
+ if err := mc.ReplyFromCache(&out, makeQ(1, "foo.com.")); err != ErrCacheMiss {
+ t.Fatalf("unexpected error: %v", err)
+ }
+
+ if err := mc.AddCacheEntry(
+ makeQ(2, "foo.com."),
+ makeRes(2, "FOO.COM.", ttlOpt(10),
+ &dnsmessage.AResource{A: [4]byte{127, 0, 0, 1}},
+ &dnsmessage.AResource{A: [4]byte{127, 0, 0, 2}})); err != nil {
+ t.Fatal(err)
+ }
+
+ // Expect cache hit, with 10 seconds remaining.
+ out.Reset()
+ if err := mc.ReplyFromCache(&out, makeQ(3, "foo.com.")); err != nil {
+ t.Fatalf("expected cache hit; got: %v", err)
+ }
+ if p := mustParseResponse(t, out.Bytes()); p.TxID != 3 {
+ t.Errorf("TxID = %v; want %v", p.TxID, 3)
+ } else if p.TTL != 10 {
+ t.Errorf("TTL = %v; want 10", p.TTL)
+ }
+
+ // One second elapses, expect a cache hit, with 9 seconds
+ // remaining.
+ clock.Advance(time.Second)
+ out.Reset()
+ if err := mc.ReplyFromCache(&out, makeQ(4, "foo.com.")); err != nil {
+ t.Fatalf("expected cache hit; got: %v", err)
+ }
+ if p := mustParseResponse(t, out.Bytes()); p.TxID != 4 {
+ t.Errorf("TxID = %v; want %v", p.TxID, 4)
+ } else if p.TTL != 9 {
+ t.Errorf("TTL = %v; want 9", p.TTL)
+ }
+
+ // Expect cache miss on MX record.
+ if err := mc.ReplyFromCache(&out, makeQ(4, "foo.com.", dnsmessage.TypeMX)); err != ErrCacheMiss {
+ t.Fatalf("expected cache miss on MX; got: %v", err)
+ }
+ // Expect cache miss on CHAOS class.
+ if err := mc.ReplyFromCache(&out, makeQ(4, "foo.com.", dnsmessage.ClassCHAOS)); err != ErrCacheMiss {
+ t.Fatalf("expected cache miss on CHAOS; got: %v", err)
+ }
+
+ // Ten seconds elapses; expect a cache miss.
+ clock.Advance(10 * time.Second)
+ if err := mc.ReplyFromCache(&out, makeQ(5, "foo.com.")); err != ErrCacheMiss {
+ t.Fatalf("expected cache miss, got: %v", err)
+ }
+}
+
+type parsedMeta struct {
+ TxID uint16
+ TTL uint32
+}
+
+func mustParseResponse(t testing.TB, r []byte) (ret parsedMeta) {
+ t.Helper()
+ var p dnsmessage.Parser
+ h, err := p.Start(r)
+ if err != nil {
+ t.Fatal(err)
+ }
+ ret.TxID = h.ID
+ qq, err := p.AllQuestions()
+ if err != nil {
+ t.Fatalf("AllQuestions: %v", err)
+ }
+ if len(qq) != 1 {
+ t.Fatalf("num questions = %v; want 1", len(qq))
+ }
+ aa, err := p.AllAnswers()
+ if err != nil {
+ t.Fatalf("AllAnswers: %v", err)
+ }
+ for _, r := range aa {
+ if ret.TTL == 0 {
+ ret.TTL = r.Header.TTL
+ }
+ if ret.TTL != r.Header.TTL {
+ t.Fatal("mixed TTLs")
+ }
+ }
+ return ret
+}
+
+type responseOpt bool
+
+type ttlOpt uint32
+
+func makeQ(txID uint16, name string, opt ...any) []byte {
+ opt = append(opt, responseOpt(false))
+ return makeDNSPkt(txID, name, opt...)
+}
+
+func makeRes(txID uint16, name string, opt ...any) []byte {
+ opt = append(opt, responseOpt(true))
+ return makeDNSPkt(txID, name, opt...)
+}
+
+func makeDNSPkt(txID uint16, name string, opt ...any) []byte {
+ typ := dnsmessage.TypeA
+ class := dnsmessage.ClassINET
+ var response bool
+ var answers []dnsmessage.ResourceBody
+ var ttl uint32 = 1 // one second by default
+ for _, o := range opt {
+ switch o := o.(type) {
+ case dnsmessage.Type:
+ typ = o
+ case dnsmessage.Class:
+ class = o
+ case responseOpt:
+ response = bool(o)
+ case dnsmessage.ResourceBody:
+ answers = append(answers, o)
+ case ttlOpt:
+ ttl = uint32(o)
+ default:
+ panic(fmt.Sprintf("unknown opt type %T", o))
+ }
+ }
+ qname := dnsmessage.MustNewName(name)
+ msg := dnsmessage.Message{
+ Header: dnsmessage.Header{ID: txID, Response: response},
+ Questions: []dnsmessage.Question{
+ {
+ Name: qname,
+ Type: typ,
+ Class: class,
+ },
+ },
+ }
+ for _, rb := range answers {
+ msg.Answers = append(msg.Answers, dnsmessage.Resource{
+ Header: dnsmessage.ResourceHeader{
+ Name: qname,
+ Type: typ,
+ Class: class,
+ TTL: ttl,
+ },
+ Body: rb,
+ })
+ }
+ buf, err := msg.Pack()
+ if err != nil {
+ panic(err)
+ }
+ return buf
+}
+
+func TestASCIILowerName(t *testing.T) {
+ n := asciiLowerName(dnsmessage.MustNewName("Foo.COM."))
+ if got, want := n.String(), "foo.com."; got != want {
+ t.Errorf("got = %q; want %q", got, want)
+ }
+}
+
+func TestGetDNSQueryCacheKey(t *testing.T) {
+ tests := []struct {
+ name string
+ pkt []byte
+ want msgQ
+ txID uint16
+ anyTX bool
+ }{
+ {
+ name: "empty",
+ },
+ {
+ name: "a",
+ pkt: makeQ(123, "foo.com."),
+ want: msgQ{"foo.com.", dnsmessage.TypeA},
+ txID: 123,
+ },
+ {
+ name: "aaaa",
+ pkt: makeQ(6, "foo.com.", dnsmessage.TypeAAAA),
+ want: msgQ{"foo.com.", dnsmessage.TypeAAAA},
+ txID: 6,
+ },
+ {
+ name: "normalize_case",
+ pkt: makeQ(123, "FoO.CoM."),
+ want: msgQ{"foo.com.", dnsmessage.TypeA},
+ txID: 123,
+ },
+ {
+ name: "ignore_response",
+ pkt: makeRes(123, "foo.com."),
+ },
+ {
+ name: "ignore_question_with_answers",
+ pkt: makeQ(2, "foo.com.", &dnsmessage.AResource{A: [4]byte{127, 0, 0, 1}}),
+ },
+ {
+ name: "whatever_go_generates", // in case Go's net package grows functionality we don't handle
+ pkt: getGoNetPacketDNSQuery("from-go.foo."),
+ want: msgQ{"from-go.foo.", dnsmessage.TypeA},
+ anyTX: true,
+ },
+ }
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ got, gotTX, ok := getDNSQueryCacheKey(tt.pkt)
+ if !ok {
+ if tt.txID == 0 && got == (msgQ{}) {
+ return
+ }
+ t.Fatal("failed")
+ }
+ if got != tt.want {
+ t.Errorf("got %+v, want %+v", got, tt.want)
+ }
+ if gotTX != tt.txID && !tt.anyTX {
+ t.Errorf("got tx %v, want %v", gotTX, tt.txID)
+ }
+ })
+ }
+}
+
+func getGoNetPacketDNSQuery(name string) []byte {
+ if runtime.GOOS == "windows" {
+ // On Windows, Go's net.Resolver doesn't use the DNS client.
+ // See https://github.com/golang/go/issues/33097 which
+ // was approved but not yet implemented.
+ // For now just pretend it's implemented to make this test
+ // pass on Windows with complicated the caller.
+ return makeQ(123, name)
+ }
+ res := make(chan []byte, 1)
+ r := &net.Resolver{
+ PreferGo: true,
+ Dial: func(ctx context.Context, network, address string) (net.Conn, error) {
+ return goResolverConn(res), nil
+ },
+ }
+ r.LookupIP(context.Background(), "ip4", name)
+ return <-res
+}
+
+type goResolverConn chan<- []byte
+
+func (goResolverConn) Close() error { return nil }
+func (goResolverConn) LocalAddr() net.Addr { return todoAddr{} }
+func (goResolverConn) RemoteAddr() net.Addr { return todoAddr{} }
+func (goResolverConn) SetDeadline(t time.Time) error { return nil }
+func (goResolverConn) SetReadDeadline(t time.Time) error { return nil }
+func (goResolverConn) SetWriteDeadline(t time.Time) error { return nil }
+func (goResolverConn) Read([]byte) (int, error) { return 0, errors.New("boom") }
+func (c goResolverConn) Write(p []byte) (int, error) {
+ select {
+ case c <- p[2:]: // skip 2 byte length for TCP mode DNS query
+ default:
+ }
+ return 0, errors.New("boom")
+}
+
+type todoAddr struct{}
+
+func (todoAddr) Network() string { return "unused" }
+func (todoAddr) String() string { return "unused-todoAddr" }
diff --git a/net/dnsfallback/update-dns-fallbacks.go b/net/dnsfallback/update-dns-fallbacks.go
index 384e77e10..ebbfc2ad1 100644
--- a/net/dnsfallback/update-dns-fallbacks.go
+++ b/net/dnsfallback/update-dns-fallbacks.go
@@ -1,45 +1,45 @@
-// Copyright (c) Tailscale Inc & AUTHORS
-// SPDX-License-Identifier: BSD-3-Clause
-
-//go:build ignore
-
-package main
-
-import (
- "encoding/json"
- "fmt"
- "log"
- "net/http"
- "os"
-
- "tailscale.com/tailcfg"
-)
-
-func main() {
- res, err := http.Get("https://login.tailscale.com/derpmap/default")
- if err != nil {
- log.Fatal(err)
- }
- if res.StatusCode != 200 {
- res.Write(os.Stderr)
- os.Exit(1)
- }
- dm := new(tailcfg.DERPMap)
- if err := json.NewDecoder(res.Body).Decode(dm); err != nil {
- log.Fatal(err)
- }
- for rid, r := range dm.Regions {
- // Names misleading to check into git, as this is a
- // static snapshot and doesn't reflect the live DERP
- // map.
- r.RegionCode = fmt.Sprintf("r%d", rid)
- r.RegionName = r.RegionCode
- }
- out, err := json.MarshalIndent(dm, "", "\t")
- if err != nil {
- log.Fatal(err)
- }
- if err := os.WriteFile("dns-fallback-servers.json", out, 0644); err != nil {
- log.Fatal(err)
- }
-}
+// Copyright (c) Tailscale Inc & AUTHORS
+// SPDX-License-Identifier: BSD-3-Clause
+
+//go:build ignore
+
+package main
+
+import (
+ "encoding/json"
+ "fmt"
+ "log"
+ "net/http"
+ "os"
+
+ "tailscale.com/tailcfg"
+)
+
+func main() {
+ res, err := http.Get("https://login.tailscale.com/derpmap/default")
+ if err != nil {
+ log.Fatal(err)
+ }
+ if res.StatusCode != 200 {
+ res.Write(os.Stderr)
+ os.Exit(1)
+ }
+ dm := new(tailcfg.DERPMap)
+ if err := json.NewDecoder(res.Body).Decode(dm); err != nil {
+ log.Fatal(err)
+ }
+ for rid, r := range dm.Regions {
+ // Names misleading to check into git, as this is a
+ // static snapshot and doesn't reflect the live DERP
+ // map.
+ r.RegionCode = fmt.Sprintf("r%d", rid)
+ r.RegionName = r.RegionCode
+ }
+ out, err := json.MarshalIndent(dm, "", "\t")
+ if err != nil {
+ log.Fatal(err)
+ }
+ if err := os.WriteFile("dns-fallback-servers.json", out, 0644); err != nil {
+ log.Fatal(err)
+ }
+}
diff --git a/net/memnet/conn.go b/net/memnet/conn.go
index a9e1fd399..f599612d9 100644
--- a/net/memnet/conn.go
+++ b/net/memnet/conn.go
@@ -1,114 +1,114 @@
-// Copyright (c) Tailscale Inc & AUTHORS
-// SPDX-License-Identifier: BSD-3-Clause
-
-package memnet
-
-import (
- "net"
- "net/netip"
- "time"
-)
-
-// NetworkName is the network name returned by [net.Addr.Network]
-// for [net.Conn.LocalAddr] and [net.Conn.RemoteAddr] from the [Conn] type.
-const NetworkName = "mem"
-
-// Conn is a net.Conn that can additionally have its reads and writes blocked and unblocked.
-type Conn interface {
- net.Conn
-
- // SetReadBlock blocks or unblocks the Read method of this Conn.
- // It reports an error if the existing value matches the new value,
- // or if the Conn has been Closed.
- SetReadBlock(bool) error
-
- // SetWriteBlock blocks or unblocks the Write method of this Conn.
- // It reports an error if the existing value matches the new value,
- // or if the Conn has been Closed.
- SetWriteBlock(bool) error
-}
-
-// NewConn creates a pair of Conns that are wired together by pipes.
-func NewConn(name string, maxBuf int) (Conn, Conn) {
- r := NewPipe(name+"|0", maxBuf)
- w := NewPipe(name+"|1", maxBuf)
-
- return &connHalf{r: r, w: w}, &connHalf{r: w, w: r}
-}
-
-// NewTCPConn creates a pair of Conns that are wired together by pipes.
-func NewTCPConn(src, dst netip.AddrPort, maxBuf int) (local Conn, remote Conn) {
- r := NewPipe(src.String(), maxBuf)
- w := NewPipe(dst.String(), maxBuf)
-
- lAddr := net.TCPAddrFromAddrPort(src)
- rAddr := net.TCPAddrFromAddrPort(dst)
-
- return &connHalf{r: r, w: w, remote: rAddr, local: lAddr}, &connHalf{r: w, w: r, remote: lAddr, local: rAddr}
-}
-
-type connAddr string
-
-func (a connAddr) Network() string { return NetworkName }
-func (a connAddr) String() string { return string(a) }
-
-type connHalf struct {
- local, remote net.Addr
- r, w *Pipe
-}
-
-func (c *connHalf) LocalAddr() net.Addr {
- if c.local != nil {
- return c.local
- }
- return connAddr(c.r.name)
-}
-
-func (c *connHalf) RemoteAddr() net.Addr {
- if c.remote != nil {
- return c.remote
- }
- return connAddr(c.w.name)
-}
-
-func (c *connHalf) Read(b []byte) (n int, err error) {
- return c.r.Read(b)
-}
-func (c *connHalf) Write(b []byte) (n int, err error) {
- return c.w.Write(b)
-}
-
-func (c *connHalf) Close() error {
- if err := c.w.Close(); err != nil {
- return err
- }
- return c.r.Close()
-}
-
-func (c *connHalf) SetDeadline(t time.Time) error {
- err1 := c.SetReadDeadline(t)
- err2 := c.SetWriteDeadline(t)
- if err1 != nil {
- return err1
- }
- return err2
-}
-func (c *connHalf) SetReadDeadline(t time.Time) error {
- return c.r.SetReadDeadline(t)
-}
-func (c *connHalf) SetWriteDeadline(t time.Time) error {
- return c.w.SetWriteDeadline(t)
-}
-
-func (c *connHalf) SetReadBlock(b bool) error {
- if b {
- return c.r.Block()
- }
- return c.r.Unblock()
-}
-func (c *connHalf) SetWriteBlock(b bool) error {
- if b {
- return c.w.Block()
- }
- return c.w.Unblock()
-}
+// Copyright (c) Tailscale Inc & AUTHORS
+// SPDX-License-Identifier: BSD-3-Clause
+
+package memnet
+
+import (
+ "net"
+ "net/netip"
+ "time"
+)
+
+// NetworkName is the network name returned by [net.Addr.Network]
+// for [net.Conn.LocalAddr] and [net.Conn.RemoteAddr] from the [Conn] type.
+const NetworkName = "mem"
+
+// Conn is a net.Conn that can additionally have its reads and writes blocked and unblocked.
+type Conn interface {
+ net.Conn
+
+ // SetReadBlock blocks or unblocks the Read method of this Conn.
+ // It reports an error if the existing value matches the new value,
+ // or if the Conn has been Closed.
+ SetReadBlock(bool) error
+
+ // SetWriteBlock blocks or unblocks the Write method of this Conn.
+ // It reports an error if the existing value matches the new value,
+ // or if the Conn has been Closed.
+ SetWriteBlock(bool) error
+}
+
+// NewConn creates a pair of Conns that are wired together by pipes.
+func NewConn(name string, maxBuf int) (Conn, Conn) {
+ r := NewPipe(name+"|0", maxBuf)
+ w := NewPipe(name+"|1", maxBuf)
+
+ return &connHalf{r: r, w: w}, &connHalf{r: w, w: r}
+}
+
+// NewTCPConn creates a pair of Conns that are wired together by pipes.
+func NewTCPConn(src, dst netip.AddrPort, maxBuf int) (local Conn, remote Conn) {
+ r := NewPipe(src.String(), maxBuf)
+ w := NewPipe(dst.String(), maxBuf)
+
+ lAddr := net.TCPAddrFromAddrPort(src)
+ rAddr := net.TCPAddrFromAddrPort(dst)
+
+ return &connHalf{r: r, w: w, remote: rAddr, local: lAddr}, &connHalf{r: w, w: r, remote: lAddr, local: rAddr}
+}
+
+type connAddr string
+
+func (a connAddr) Network() string { return NetworkName }
+func (a connAddr) String() string { return string(a) }
+
+type connHalf struct {
+ local, remote net.Addr
+ r, w *Pipe
+}
+
+func (c *connHalf) LocalAddr() net.Addr {
+ if c.local != nil {
+ return c.local
+ }
+ return connAddr(c.r.name)
+}
+
+func (c *connHalf) RemoteAddr() net.Addr {
+ if c.remote != nil {
+ return c.remote
+ }
+ return connAddr(c.w.name)
+}
+
+func (c *connHalf) Read(b []byte) (n int, err error) {
+ return c.r.Read(b)
+}
+func (c *connHalf) Write(b []byte) (n int, err error) {
+ return c.w.Write(b)
+}
+
+func (c *connHalf) Close() error {
+ if err := c.w.Close(); err != nil {
+ return err
+ }
+ return c.r.Close()
+}
+
+func (c *connHalf) SetDeadline(t time.Time) error {
+ err1 := c.SetReadDeadline(t)
+ err2 := c.SetWriteDeadline(t)
+ if err1 != nil {
+ return err1
+ }
+ return err2
+}
+func (c *connHalf) SetReadDeadline(t time.Time) error {
+ return c.r.SetReadDeadline(t)
+}
+func (c *connHalf) SetWriteDeadline(t time.Time) error {
+ return c.w.SetWriteDeadline(t)
+}
+
+func (c *connHalf) SetReadBlock(b bool) error {
+ if b {
+ return c.r.Block()
+ }
+ return c.r.Unblock()
+}
+func (c *connHalf) SetWriteBlock(b bool) error {
+ if b {
+ return c.w.Block()
+ }
+ return c.w.Unblock()
+}
diff --git a/net/memnet/conn_test.go b/net/memnet/conn_test.go
index 743ce5248..3eec80bc6 100644
--- a/net/memnet/conn_test.go
+++ b/net/memnet/conn_test.go
@@ -1,21 +1,21 @@
-// Copyright (c) Tailscale Inc & AUTHORS
-// SPDX-License-Identifier: BSD-3-Clause
-
-package memnet
-
-import (
- "net"
- "testing"
-
- "golang.org/x/net/nettest"
-)
-
-func TestConn(t *testing.T) {
- nettest.TestConn(t, func() (c1 net.Conn, c2 net.Conn, stop func(), err error) {
- c1, c2 = NewConn("test", bufferSize)
- return c1, c2, func() {
- c1.Close()
- c2.Close()
- }, nil
- })
-}
+// Copyright (c) Tailscale Inc & AUTHORS
+// SPDX-License-Identifier: BSD-3-Clause
+
+package memnet
+
+import (
+ "net"
+ "testing"
+
+ "golang.org/x/net/nettest"
+)
+
+func TestConn(t *testing.T) {
+ nettest.TestConn(t, func() (c1 net.Conn, c2 net.Conn, stop func(), err error) {
+ c1, c2 = NewConn("test", bufferSize)
+ return c1, c2, func() {
+ c1.Close()
+ c2.Close()
+ }, nil
+ })
+}
diff --git a/net/memnet/listener.go b/net/memnet/listener.go
index d84a2e443..d1364d790 100644
--- a/net/memnet/listener.go
+++ b/net/memnet/listener.go
@@ -1,100 +1,100 @@
-// Copyright (c) Tailscale Inc & AUTHORS
-// SPDX-License-Identifier: BSD-3-Clause
-
-package memnet
-
-import (
- "context"
- "net"
- "strings"
- "sync"
-)
-
-const (
- bufferSize = 256 * 1024
-)
-
-// Listener is a net.Listener using NewConn to create pairs of network
-// connections connected in memory using a buffered pipe. It also provides a
-// Dial method to establish new connections.
-type Listener struct {
- addr connAddr
- ch chan Conn
- closeOnce sync.Once
- closed chan struct{}
-
- // NewConn, if non-nil, is called to create a new pair of connections
- // when dialing. If nil, NewConn is used.
- NewConn func(network, addr string, maxBuf int) (Conn, Conn)
-}
-
-// Listen returns a new Listener for the provided address.
-func Listen(addr string) *Listener {
- return &Listener{
- addr: connAddr(addr),
- ch: make(chan Conn),
- closed: make(chan struct{}),
- }
-}
-
-// Addr implements net.Listener.Addr.
-func (l *Listener) Addr() net.Addr {
- return l.addr
-}
-
-// Close closes the pipe listener.
-func (l *Listener) Close() error {
- l.closeOnce.Do(func() {
- close(l.closed)
- })
- return nil
-}
-
-// Accept blocks until a new connection is available or the listener is closed.
-func (l *Listener) Accept() (net.Conn, error) {
- select {
- case c := <-l.ch:
- return c, nil
- case <-l.closed:
- return nil, net.ErrClosed
- }
-}
-
-// Dial connects to the listener using the provided context.
-// The provided Context must be non-nil. If the context expires before the
-// connection is complete, an error is returned. Once successfully connected
-// any expiration of the context will not affect the connection.
-func (l *Listener) Dial(ctx context.Context, network, addr string) (_ net.Conn, err error) {
- if !strings.HasSuffix(network, "tcp") {
- return nil, net.UnknownNetworkError(network)
- }
- if connAddr(addr) != l.addr {
- return nil, &net.AddrError{
- Err: "invalid address",
- Addr: addr,
- }
- }
-
- newConn := l.NewConn
- if newConn == nil {
- newConn = func(network, addr string, maxBuf int) (Conn, Conn) {
- return NewConn(addr, maxBuf)
- }
- }
- c, s := newConn(network, addr, bufferSize)
- defer func() {
- if err != nil {
- c.Close()
- s.Close()
- }
- }()
-
- select {
- case <-ctx.Done():
- return nil, ctx.Err()
- case <-l.closed:
- return nil, net.ErrClosed
- case l.ch <- s:
- return c, nil
- }
-}
+// Copyright (c) Tailscale Inc & AUTHORS
+// SPDX-License-Identifier: BSD-3-Clause
+
+package memnet
+
+import (
+ "context"
+ "net"
+ "strings"
+ "sync"
+)
+
+const (
+ bufferSize = 256 * 1024
+)
+
+// Listener is a net.Listener using NewConn to create pairs of network
+// connections connected in memory using a buffered pipe. It also provides a
+// Dial method to establish new connections.
+type Listener struct {
+ addr connAddr
+ ch chan Conn
+ closeOnce sync.Once
+ closed chan struct{}
+
+ // NewConn, if non-nil, is called to create a new pair of connections
+ // when dialing. If nil, NewConn is used.
+ NewConn func(network, addr string, maxBuf int) (Conn, Conn)
+}
+
+// Listen returns a new Listener for the provided address.
+func Listen(addr string) *Listener {
+ return &Listener{
+ addr: connAddr(addr),
+ ch: make(chan Conn),
+ closed: make(chan struct{}),
+ }
+}
+
+// Addr implements net.Listener.Addr.
+func (l *Listener) Addr() net.Addr {
+ return l.addr
+}
+
+// Close closes the pipe listener.
+func (l *Listener) Close() error {
+ l.closeOnce.Do(func() {
+ close(l.closed)
+ })
+ return nil
+}
+
+// Accept blocks until a new connection is available or the listener is closed.
+func (l *Listener) Accept() (net.Conn, error) {
+ select {
+ case c := <-l.ch:
+ return c, nil
+ case <-l.closed:
+ return nil, net.ErrClosed
+ }
+}
+
+// Dial connects to the listener using the provided context.
+// The provided Context must be non-nil. If the context expires before the
+// connection is complete, an error is returned. Once successfully connected
+// any expiration of the context will not affect the connection.
+func (l *Listener) Dial(ctx context.Context, network, addr string) (_ net.Conn, err error) {
+ if !strings.HasSuffix(network, "tcp") {
+ return nil, net.UnknownNetworkError(network)
+ }
+ if connAddr(addr) != l.addr {
+ return nil, &net.AddrError{
+ Err: "invalid address",
+ Addr: addr,
+ }
+ }
+
+ newConn := l.NewConn
+ if newConn == nil {
+ newConn = func(network, addr string, maxBuf int) (Conn, Conn) {
+ return NewConn(addr, maxBuf)
+ }
+ }
+ c, s := newConn(network, addr, bufferSize)
+ defer func() {
+ if err != nil {
+ c.Close()
+ s.Close()
+ }
+ }()
+
+ select {
+ case <-ctx.Done():
+ return nil, ctx.Err()
+ case <-l.closed:
+ return nil, net.ErrClosed
+ case l.ch <- s:
+ return c, nil
+ }
+}
diff --git a/net/memnet/listener_test.go b/net/memnet/listener_test.go
index 73b67841a..989d5e9e4 100644
--- a/net/memnet/listener_test.go
+++ b/net/memnet/listener_test.go
@@ -1,33 +1,33 @@
-// Copyright (c) Tailscale Inc & AUTHORS
-// SPDX-License-Identifier: BSD-3-Clause
-
-package memnet
-
-import (
- "context"
- "testing"
-)
-
-func TestListener(t *testing.T) {
- l := Listen("srv.local")
- defer l.Close()
- go func() {
- c, err := l.Accept()
- if err != nil {
- t.Error(err)
- return
- }
- defer c.Close()
- }()
-
- if c, err := l.Dial(context.Background(), "tcp", "invalid"); err == nil {
- c.Close()
- t.Fatalf("dial to invalid address succeeded")
- }
- c, err := l.Dial(context.Background(), "tcp", "srv.local")
- if err != nil {
- t.Fatalf("dial failed: %v", err)
- return
- }
- c.Close()
-}
+// Copyright (c) Tailscale Inc & AUTHORS
+// SPDX-License-Identifier: BSD-3-Clause
+
+package memnet
+
+import (
+ "context"
+ "testing"
+)
+
+func TestListener(t *testing.T) {
+ l := Listen("srv.local")
+ defer l.Close()
+ go func() {
+ c, err := l.Accept()
+ if err != nil {
+ t.Error(err)
+ return
+ }
+ defer c.Close()
+ }()
+
+ if c, err := l.Dial(context.Background(), "tcp", "invalid"); err == nil {
+ c.Close()
+ t.Fatalf("dial to invalid address succeeded")
+ }
+ c, err := l.Dial(context.Background(), "tcp", "srv.local")
+ if err != nil {
+ t.Fatalf("dial failed: %v", err)
+ return
+ }
+ c.Close()
+}
diff --git a/net/memnet/memnet.go b/net/memnet/memnet.go
index c8799bc17..2fc13b4b2 100644
--- a/net/memnet/memnet.go
+++ b/net/memnet/memnet.go
@@ -1,8 +1,8 @@
-// Copyright (c) Tailscale Inc & AUTHORS
-// SPDX-License-Identifier: BSD-3-Clause
-
-// Package memnet implements an in-memory network implementation.
-// It is useful for dialing and listening on in-memory addresses
-// in tests and other situations where you don't want to use the
-// network.
-package memnet
+// Copyright (c) Tailscale Inc & AUTHORS
+// SPDX-License-Identifier: BSD-3-Clause
+
+// Package memnet implements an in-memory network implementation.
+// It is useful for dialing and listening on in-memory addresses
+// in tests and other situations where you don't want to use the
+// network.
+package memnet
diff --git a/net/memnet/pipe.go b/net/memnet/pipe.go
index 471635083..51bee1090 100644
--- a/net/memnet/pipe.go
+++ b/net/memnet/pipe.go
@@ -1,244 +1,244 @@
-// Copyright (c) Tailscale Inc & AUTHORS
-// SPDX-License-Identifier: BSD-3-Clause
-
-package memnet
-
-import (
- "bytes"
- "context"
- "fmt"
- "io"
- "log"
- "net"
- "os"
- "sync"
- "time"
-)
-
-const debugPipe = false
-
-// Pipe implements an in-memory FIFO with timeouts.
-type Pipe struct {
- name string
- maxBuf int
- mu sync.Mutex
- cnd *sync.Cond
-
- blocked bool
- closed bool
- buf bytes.Buffer
- readTimeout time.Time
- writeTimeout time.Time
- cancelReadTimer func()
- cancelWriteTimer func()
-}
-
-// NewPipe creates a Pipe with a buffer size fixed at maxBuf.
-func NewPipe(name string, maxBuf int) *Pipe {
- p := &Pipe{
- name: name,
- maxBuf: maxBuf,
- }
- p.cnd = sync.NewCond(&p.mu)
- return p
-}
-
-// readOrBlock attempts to read from the buffer, if the buffer is empty and
-// the connection hasn't been closed it will block until there is a change.
-func (p *Pipe) readOrBlock(b []byte) (int, error) {
- p.mu.Lock()
- defer p.mu.Unlock()
- if !p.readTimeout.IsZero() && !time.Now().Before(p.readTimeout) {
- return 0, os.ErrDeadlineExceeded
- }
- if p.blocked {
- p.cnd.Wait()
- return 0, nil
- }
-
- n, err := p.buf.Read(b)
- // err will either be nil or io.EOF.
- if err == io.EOF {
- if p.closed {
- return n, err
- }
- // Wait for something to change.
- p.cnd.Wait()
- }
- return n, nil
-}
-
-// Read implements io.Reader.
-// Once the buffer is drained (i.e. after Close), subsequent calls will
-// return io.EOF.
-func (p *Pipe) Read(b []byte) (n int, err error) {
- if debugPipe {
- orig := b
- defer func() {
- log.Printf("Pipe(%q).Read(%q) n=%d, err=%v", p.name, string(orig[:n]), n, err)
- }()
- }
- for n == 0 {
- n2, err := p.readOrBlock(b)
- if err != nil {
- return n2, err
- }
- n += n2
- }
- p.cnd.Signal()
- return n, nil
-}
-
-// writeOrBlock attempts to write to the buffer, if the buffer is full it will
-// block until there is a change.
-func (p *Pipe) writeOrBlock(b []byte) (int, error) {
- p.mu.Lock()
- defer p.mu.Unlock()
- if p.closed {
- return 0, net.ErrClosed
- }
- if !p.writeTimeout.IsZero() && !time.Now().Before(p.writeTimeout) {
- return 0, os.ErrDeadlineExceeded
- }
- if p.blocked {
- p.cnd.Wait()
- return 0, nil
- }
-
- // Optimistically we want to write the entire slice.
- n := len(b)
- if limit := p.maxBuf - p.buf.Len(); limit < n {
- // However, we don't have enough capacity to write everything.
- n = limit
- }
- if n == 0 {
- // Wait for something to change.
- p.cnd.Wait()
- return 0, nil
- }
-
- p.buf.Write(b[:n])
- p.cnd.Signal()
- return n, nil
-}
-
-// Write implements io.Writer.
-func (p *Pipe) Write(b []byte) (n int, err error) {
- if debugPipe {
- orig := b
- defer func() {
- log.Printf("Pipe(%q).Write(%q) n=%d, err=%v", p.name, string(orig), n, err)
- }()
- }
- for len(b) > 0 {
- n2, err := p.writeOrBlock(b)
- if err != nil {
- return n + n2, err
- }
- n += n2
- b = b[n2:]
- }
- return n, nil
-}
-
-// Close closes the pipe.
-func (p *Pipe) Close() error {
- p.mu.Lock()
- defer p.mu.Unlock()
- p.closed = true
- p.blocked = false
- if p.cancelWriteTimer != nil {
- p.cancelWriteTimer()
- p.cancelWriteTimer = nil
- }
- if p.cancelReadTimer != nil {
- p.cancelReadTimer()
- p.cancelReadTimer = nil
- }
- p.cnd.Broadcast()
-
- return nil
-}
-
-func (p *Pipe) deadlineTimer(t time.Time) func() {
- if t.IsZero() {
- return nil
- }
- if t.Before(time.Now()) {
- p.cnd.Broadcast()
- return nil
- }
- ctx, cancel := context.WithDeadline(context.Background(), t)
- go func() {
- <-ctx.Done()
- if ctx.Err() == context.DeadlineExceeded {
- p.cnd.Broadcast()
- }
- }()
- return cancel
-}
-
-// SetReadDeadline sets the deadline for future Read calls.
-func (p *Pipe) SetReadDeadline(t time.Time) error {
- p.mu.Lock()
- defer p.mu.Unlock()
- p.readTimeout = t
- // If we already have a deadline, cancel it and create a new one.
- if p.cancelReadTimer != nil {
- p.cancelReadTimer()
- p.cancelReadTimer = nil
- }
- p.cancelReadTimer = p.deadlineTimer(t)
- return nil
-}
-
-// SetWriteDeadline sets the deadline for future Write calls.
-func (p *Pipe) SetWriteDeadline(t time.Time) error {
- p.mu.Lock()
- defer p.mu.Unlock()
- p.writeTimeout = t
- // If we already have a deadline, cancel it and create a new one.
- if p.cancelWriteTimer != nil {
- p.cancelWriteTimer()
- p.cancelWriteTimer = nil
- }
- p.cancelWriteTimer = p.deadlineTimer(t)
- return nil
-}
-
-// Block will cause all calls to Read and Write to block until they either
-// timeout, are unblocked or the pipe is closed.
-func (p *Pipe) Block() error {
- p.mu.Lock()
- defer p.mu.Unlock()
- closed := p.closed
- blocked := p.blocked
- p.blocked = true
-
- if closed {
- return fmt.Errorf("memnet.Pipe(%q).Block: closed", p.name)
- }
- if blocked {
- return fmt.Errorf("memnet.Pipe(%q).Block: already blocked", p.name)
- }
- p.cnd.Broadcast()
- return nil
-}
-
-// Unblock will cause all blocked Read/Write calls to continue execution.
-func (p *Pipe) Unblock() error {
- p.mu.Lock()
- defer p.mu.Unlock()
- closed := p.closed
- blocked := p.blocked
- p.blocked = false
-
- if closed {
- return fmt.Errorf("memnet.Pipe(%q).Block: closed", p.name)
- }
- if !blocked {
- return fmt.Errorf("memnet.Pipe(%q).Block: already unblocked", p.name)
- }
- p.cnd.Broadcast()
- return nil
-}
+// Copyright (c) Tailscale Inc & AUTHORS
+// SPDX-License-Identifier: BSD-3-Clause
+
+package memnet
+
+import (
+ "bytes"
+ "context"
+ "fmt"
+ "io"
+ "log"
+ "net"
+ "os"
+ "sync"
+ "time"
+)
+
+const debugPipe = false
+
+// Pipe implements an in-memory FIFO with timeouts.
+type Pipe struct {
+ name string
+ maxBuf int
+ mu sync.Mutex
+ cnd *sync.Cond
+
+ blocked bool
+ closed bool
+ buf bytes.Buffer
+ readTimeout time.Time
+ writeTimeout time.Time
+ cancelReadTimer func()
+ cancelWriteTimer func()
+}
+
+// NewPipe creates a Pipe with a buffer size fixed at maxBuf.
+func NewPipe(name string, maxBuf int) *Pipe {
+ p := &Pipe{
+ name: name,
+ maxBuf: maxBuf,
+ }
+ p.cnd = sync.NewCond(&p.mu)
+ return p
+}
+
+// readOrBlock attempts to read from the buffer, if the buffer is empty and
+// the connection hasn't been closed it will block until there is a change.
+func (p *Pipe) readOrBlock(b []byte) (int, error) {
+ p.mu.Lock()
+ defer p.mu.Unlock()
+ if !p.readTimeout.IsZero() && !time.Now().Before(p.readTimeout) {
+ return 0, os.ErrDeadlineExceeded
+ }
+ if p.blocked {
+ p.cnd.Wait()
+ return 0, nil
+ }
+
+ n, err := p.buf.Read(b)
+ // err will either be nil or io.EOF.
+ if err == io.EOF {
+ if p.closed {
+ return n, err
+ }
+ // Wait for something to change.
+ p.cnd.Wait()
+ }
+ return n, nil
+}
+
+// Read implements io.Reader.
+// Once the buffer is drained (i.e. after Close), subsequent calls will
+// return io.EOF.
+func (p *Pipe) Read(b []byte) (n int, err error) {
+ if debugPipe {
+ orig := b
+ defer func() {
+ log.Printf("Pipe(%q).Read(%q) n=%d, err=%v", p.name, string(orig[:n]), n, err)
+ }()
+ }
+ for n == 0 {
+ n2, err := p.readOrBlock(b)
+ if err != nil {
+ return n2, err
+ }
+ n += n2
+ }
+ p.cnd.Signal()
+ return n, nil
+}
+
+// writeOrBlock attempts to write to the buffer, if the buffer is full it will
+// block until there is a change.
+func (p *Pipe) writeOrBlock(b []byte) (int, error) {
+ p.mu.Lock()
+ defer p.mu.Unlock()
+ if p.closed {
+ return 0, net.ErrClosed
+ }
+ if !p.writeTimeout.IsZero() && !time.Now().Before(p.writeTimeout) {
+ return 0, os.ErrDeadlineExceeded
+ }
+ if p.blocked {
+ p.cnd.Wait()
+ return 0, nil
+ }
+
+ // Optimistically we want to write the entire slice.
+ n := len(b)
+ if limit := p.maxBuf - p.buf.Len(); limit < n {
+ // However, we don't have enough capacity to write everything.
+ n = limit
+ }
+ if n == 0 {
+ // Wait for something to change.
+ p.cnd.Wait()
+ return 0, nil
+ }
+
+ p.buf.Write(b[:n])
+ p.cnd.Signal()
+ return n, nil
+}
+
+// Write implements io.Writer.
+func (p *Pipe) Write(b []byte) (n int, err error) {
+ if debugPipe {
+ orig := b
+ defer func() {
+ log.Printf("Pipe(%q).Write(%q) n=%d, err=%v", p.name, string(orig), n, err)
+ }()
+ }
+ for len(b) > 0 {
+ n2, err := p.writeOrBlock(b)
+ if err != nil {
+ return n + n2, err
+ }
+ n += n2
+ b = b[n2:]
+ }
+ return n, nil
+}
+
+// Close closes the pipe.
+func (p *Pipe) Close() error {
+ p.mu.Lock()
+ defer p.mu.Unlock()
+ p.closed = true
+ p.blocked = false
+ if p.cancelWriteTimer != nil {
+ p.cancelWriteTimer()
+ p.cancelWriteTimer = nil
+ }
+ if p.cancelReadTimer != nil {
+ p.cancelReadTimer()
+ p.cancelReadTimer = nil
+ }
+ p.cnd.Broadcast()
+
+ return nil
+}
+
+func (p *Pipe) deadlineTimer(t time.Time) func() {
+ if t.IsZero() {
+ return nil
+ }
+ if t.Before(time.Now()) {
+ p.cnd.Broadcast()
+ return nil
+ }
+ ctx, cancel := context.WithDeadline(context.Background(), t)
+ go func() {
+ <-ctx.Done()
+ if ctx.Err() == context.DeadlineExceeded {
+ p.cnd.Broadcast()
+ }
+ }()
+ return cancel
+}
+
+// SetReadDeadline sets the deadline for future Read calls.
+func (p *Pipe) SetReadDeadline(t time.Time) error {
+ p.mu.Lock()
+ defer p.mu.Unlock()
+ p.readTimeout = t
+ // If we already have a deadline, cancel it and create a new one.
+ if p.cancelReadTimer != nil {
+ p.cancelReadTimer()
+ p.cancelReadTimer = nil
+ }
+ p.cancelReadTimer = p.deadlineTimer(t)
+ return nil
+}
+
+// SetWriteDeadline sets the deadline for future Write calls.
+func (p *Pipe) SetWriteDeadline(t time.Time) error {
+ p.mu.Lock()
+ defer p.mu.Unlock()
+ p.writeTimeout = t
+ // If we already have a deadline, cancel it and create a new one.
+ if p.cancelWriteTimer != nil {
+ p.cancelWriteTimer()
+ p.cancelWriteTimer = nil
+ }
+ p.cancelWriteTimer = p.deadlineTimer(t)
+ return nil
+}
+
+// Block will cause all calls to Read and Write to block until they either
+// timeout, are unblocked or the pipe is closed.
+func (p *Pipe) Block() error {
+ p.mu.Lock()
+ defer p.mu.Unlock()
+ closed := p.closed
+ blocked := p.blocked
+ p.blocked = true
+
+ if closed {
+ return fmt.Errorf("memnet.Pipe(%q).Block: closed", p.name)
+ }
+ if blocked {
+ return fmt.Errorf("memnet.Pipe(%q).Block: already blocked", p.name)
+ }
+ p.cnd.Broadcast()
+ return nil
+}
+
+// Unblock will cause all blocked Read/Write calls to continue execution.
+func (p *Pipe) Unblock() error {
+ p.mu.Lock()
+ defer p.mu.Unlock()
+ closed := p.closed
+ blocked := p.blocked
+ p.blocked = false
+
+ if closed {
+ return fmt.Errorf("memnet.Pipe(%q).Block: closed", p.name)
+ }
+ if !blocked {
+ return fmt.Errorf("memnet.Pipe(%q).Block: already unblocked", p.name)
+ }
+ p.cnd.Broadcast()
+ return nil
+}
diff --git a/net/memnet/pipe_test.go b/net/memnet/pipe_test.go
index a86d65388..b3775cf7f 100644
--- a/net/memnet/pipe_test.go
+++ b/net/memnet/pipe_test.go
@@ -1,117 +1,117 @@
-// Copyright (c) Tailscale Inc & AUTHORS
-// SPDX-License-Identifier: BSD-3-Clause
-
-package memnet
-
-import (
- "errors"
- "fmt"
- "os"
- "testing"
- "time"
-)
-
-func TestPipeHello(t *testing.T) {
- p := NewPipe("p1", 1<<16)
- msg := "Hello, World!"
- if n, err := p.Write([]byte(msg)); err != nil {
- t.Fatal(err)
- } else if n != len(msg) {
- t.Errorf("p.Write(%q) n=%d, want %d", msg, n, len(msg))
- }
- b := make([]byte, len(msg))
- if n, err := p.Read(b); err != nil {
- t.Fatal(err)
- } else if n != len(b) {
- t.Errorf("p.Read(%q) n=%d, want %d", string(b[:n]), n, len(b))
- }
- if got := string(b); got != msg {
- t.Errorf("p.Read: %q, want %q", got, msg)
- }
-}
-
-func TestPipeTimeout(t *testing.T) {
- t.Run("write", func(t *testing.T) {
- p := NewPipe("p1", 1<<16)
- p.SetWriteDeadline(time.Now().Add(-1 * time.Second))
- n, err := p.Write([]byte{'h'})
- if !errors.Is(err, os.ErrDeadlineExceeded) {
- t.Errorf("missing write timeout got err: %v", err)
- }
- if n != 0 {
- t.Errorf("n=%d on timeout", n)
- }
- })
- t.Run("read", func(t *testing.T) {
- p := NewPipe("p1", 1<<16)
- p.Write([]byte{'h'})
-
- p.SetReadDeadline(time.Now().Add(-1 * time.Second))
- b := make([]byte, 1)
- n, err := p.Read(b)
- if !errors.Is(err, os.ErrDeadlineExceeded) {
- t.Errorf("missing read timeout got err: %v", err)
- }
- if n != 0 {
- t.Errorf("n=%d on timeout", n)
- }
- })
- t.Run("block-write", func(t *testing.T) {
- p := NewPipe("p1", 1<<16)
- p.SetWriteDeadline(time.Now().Add(10 * time.Millisecond))
- if err := p.Block(); err != nil {
- t.Fatal(err)
- }
- if _, err := p.Write([]byte{'h'}); !errors.Is(err, os.ErrDeadlineExceeded) {
- t.Fatalf("want write timeout got: %v", err)
- }
- })
- t.Run("block-read", func(t *testing.T) {
- p := NewPipe("p1", 1<<16)
- p.Write([]byte{'h', 'i'})
- p.SetReadDeadline(time.Now().Add(10 * time.Millisecond))
- b := make([]byte, 1)
- if err := p.Block(); err != nil {
- t.Fatal(err)
- }
- if _, err := p.Read(b); !errors.Is(err, os.ErrDeadlineExceeded) {
- t.Fatalf("want read timeout got: %v", err)
- }
- })
-}
-
-func TestLimit(t *testing.T) {
- p := NewPipe("p1", 1)
- errCh := make(chan error)
- go func() {
- n, err := p.Write([]byte{'a', 'b', 'c'})
- if err != nil {
- errCh <- err
- } else if n != 3 {
- errCh <- fmt.Errorf("p.Write n=%d, want 3", n)
- } else {
- errCh <- nil
- }
- }()
- b := make([]byte, 3)
-
- if n, err := p.Read(b); err != nil {
- t.Fatal(err)
- } else if n != 1 {
- t.Errorf("Read(%q): n=%d want 1", string(b), n)
- }
- if n, err := p.Read(b); err != nil {
- t.Fatal(err)
- } else if n != 1 {
- t.Errorf("Read(%q): n=%d want 1", string(b), n)
- }
- if n, err := p.Read(b); err != nil {
- t.Fatal(err)
- } else if n != 1 {
- t.Errorf("Read(%q): n=%d want 1", string(b), n)
- }
-
- if err := <-errCh; err != nil {
- t.Error(err)
- }
-}
+// Copyright (c) Tailscale Inc & AUTHORS
+// SPDX-License-Identifier: BSD-3-Clause
+
+package memnet
+
+import (
+ "errors"
+ "fmt"
+ "os"
+ "testing"
+ "time"
+)
+
+func TestPipeHello(t *testing.T) {
+ p := NewPipe("p1", 1<<16)
+ msg := "Hello, World!"
+ if n, err := p.Write([]byte(msg)); err != nil {
+ t.Fatal(err)
+ } else if n != len(msg) {
+ t.Errorf("p.Write(%q) n=%d, want %d", msg, n, len(msg))
+ }
+ b := make([]byte, len(msg))
+ if n, err := p.Read(b); err != nil {
+ t.Fatal(err)
+ } else if n != len(b) {
+ t.Errorf("p.Read(%q) n=%d, want %d", string(b[:n]), n, len(b))
+ }
+ if got := string(b); got != msg {
+ t.Errorf("p.Read: %q, want %q", got, msg)
+ }
+}
+
+func TestPipeTimeout(t *testing.T) {
+ t.Run("write", func(t *testing.T) {
+ p := NewPipe("p1", 1<<16)
+ p.SetWriteDeadline(time.Now().Add(-1 * time.Second))
+ n, err := p.Write([]byte{'h'})
+ if !errors.Is(err, os.ErrDeadlineExceeded) {
+ t.Errorf("missing write timeout got err: %v", err)
+ }
+ if n != 0 {
+ t.Errorf("n=%d on timeout", n)
+ }
+ })
+ t.Run("read", func(t *testing.T) {
+ p := NewPipe("p1", 1<<16)
+ p.Write([]byte{'h'})
+
+ p.SetReadDeadline(time.Now().Add(-1 * time.Second))
+ b := make([]byte, 1)
+ n, err := p.Read(b)
+ if !errors.Is(err, os.ErrDeadlineExceeded) {
+ t.Errorf("missing read timeout got err: %v", err)
+ }
+ if n != 0 {
+ t.Errorf("n=%d on timeout", n)
+ }
+ })
+ t.Run("block-write", func(t *testing.T) {
+ p := NewPipe("p1", 1<<16)
+ p.SetWriteDeadline(time.Now().Add(10 * time.Millisecond))
+ if err := p.Block(); err != nil {
+ t.Fatal(err)
+ }
+ if _, err := p.Write([]byte{'h'}); !errors.Is(err, os.ErrDeadlineExceeded) {
+ t.Fatalf("want write timeout got: %v", err)
+ }
+ })
+ t.Run("block-read", func(t *testing.T) {
+ p := NewPipe("p1", 1<<16)
+ p.Write([]byte{'h', 'i'})
+ p.SetReadDeadline(time.Now().Add(10 * time.Millisecond))
+ b := make([]byte, 1)
+ if err := p.Block(); err != nil {
+ t.Fatal(err)
+ }
+ if _, err := p.Read(b); !errors.Is(err, os.ErrDeadlineExceeded) {
+ t.Fatalf("want read timeout got: %v", err)
+ }
+ })
+}
+
+func TestLimit(t *testing.T) {
+ p := NewPipe("p1", 1)
+ errCh := make(chan error)
+ go func() {
+ n, err := p.Write([]byte{'a', 'b', 'c'})
+ if err != nil {
+ errCh <- err
+ } else if n != 3 {
+ errCh <- fmt.Errorf("p.Write n=%d, want 3", n)
+ } else {
+ errCh <- nil
+ }
+ }()
+ b := make([]byte, 3)
+
+ if n, err := p.Read(b); err != nil {
+ t.Fatal(err)
+ } else if n != 1 {
+ t.Errorf("Read(%q): n=%d want 1", string(b), n)
+ }
+ if n, err := p.Read(b); err != nil {
+ t.Fatal(err)
+ } else if n != 1 {
+ t.Errorf("Read(%q): n=%d want 1", string(b), n)
+ }
+ if n, err := p.Read(b); err != nil {
+ t.Fatal(err)
+ } else if n != 1 {
+ t.Errorf("Read(%q): n=%d want 1", string(b), n)
+ }
+
+ if err := <-errCh; err != nil {
+ t.Error(err)
+ }
+}
diff --git a/net/netaddr/netaddr.go b/net/netaddr/netaddr.go
index 1ab6c053a..6f85a52b7 100644
--- a/net/netaddr/netaddr.go
+++ b/net/netaddr/netaddr.go
@@ -1,49 +1,49 @@
-// Copyright (c) Tailscale Inc & AUTHORS
-// SPDX-License-Identifier: BSD-3-Clause
-
-// Package netaddr is a transitional package while we finish migrating from inet.af/netaddr
-// to Go 1.18's net/netip.
-//
-// TODO(bradfitz): delete this package eventually. Tracking bug is
-// https://github.com/tailscale/tailscale/issues/5162
-package netaddr
-
-import (
- "net"
- "net/netip"
-)
-
-// IPv4 returns the IP of the IPv4 address a.b.c.d.
-func IPv4(a, b, c, d uint8) netip.Addr {
- return netip.AddrFrom4([4]byte{a, b, c, d})
-}
-
-// Unmap returns the provided AddrPort with its Addr IP component Unmap'ed.
-//
-// See https://github.com/golang/go/issues/53607#issuecomment-1203466984
-func Unmap(ap netip.AddrPort) netip.AddrPort {
- return netip.AddrPortFrom(ap.Addr().Unmap(), ap.Port())
-}
-
-// FromStdIPNet returns an IPPrefix from the standard library's IPNet type.
-// If std is invalid, ok is false.
-func FromStdIPNet(std *net.IPNet) (prefix netip.Prefix, ok bool) {
- ip, ok := netip.AddrFromSlice(std.IP)
- if !ok {
- return netip.Prefix{}, false
- }
- ip = ip.Unmap()
-
- if l := len(std.Mask); l != net.IPv4len && l != net.IPv6len {
- // Invalid mask.
- return netip.Prefix{}, false
- }
-
- ones, bits := std.Mask.Size()
- if ones == 0 && bits == 0 {
- // IPPrefix does not support non-contiguous masks.
- return netip.Prefix{}, false
- }
-
- return netip.PrefixFrom(ip, ones), true
-}
+// Copyright (c) Tailscale Inc & AUTHORS
+// SPDX-License-Identifier: BSD-3-Clause
+
+// Package netaddr is a transitional package while we finish migrating from inet.af/netaddr
+// to Go 1.18's net/netip.
+//
+// TODO(bradfitz): delete this package eventually. Tracking bug is
+// https://github.com/tailscale/tailscale/issues/5162
+package netaddr
+
+import (
+ "net"
+ "net/netip"
+)
+
+// IPv4 returns the IP of the IPv4 address a.b.c.d.
+func IPv4(a, b, c, d uint8) netip.Addr {
+ return netip.AddrFrom4([4]byte{a, b, c, d})
+}
+
+// Unmap returns the provided AddrPort with its Addr IP component Unmap'ed.
+//
+// See https://github.com/golang/go/issues/53607#issuecomment-1203466984
+func Unmap(ap netip.AddrPort) netip.AddrPort {
+ return netip.AddrPortFrom(ap.Addr().Unmap(), ap.Port())
+}
+
+// FromStdIPNet returns an IPPrefix from the standard library's IPNet type.
+// If std is invalid, ok is false.
+func FromStdIPNet(std *net.IPNet) (prefix netip.Prefix, ok bool) {
+ ip, ok := netip.AddrFromSlice(std.IP)
+ if !ok {
+ return netip.Prefix{}, false
+ }
+ ip = ip.Unmap()
+
+ if l := len(std.Mask); l != net.IPv4len && l != net.IPv6len {
+ // Invalid mask.
+ return netip.Prefix{}, false
+ }
+
+ ones, bits := std.Mask.Size()
+ if ones == 0 && bits == 0 {
+ // IPPrefix does not support non-contiguous masks.
+ return netip.Prefix{}, false
+ }
+
+ return netip.PrefixFrom(ip, ones), true
+}
diff --git a/net/neterror/neterror.go b/net/neterror/neterror.go
index e2387440d..f570b8930 100644
--- a/net/neterror/neterror.go
+++ b/net/neterror/neterror.go
@@ -1,82 +1,82 @@
-// Copyright (c) Tailscale Inc & AUTHORS
-// SPDX-License-Identifier: BSD-3-Clause
-
-// Package neterror classifies network errors.
-package neterror
-
-import (
- "errors"
- "fmt"
- "runtime"
- "syscall"
-)
-
-var errEPERM error = syscall.EPERM // box it into interface just once
-
-// TreatAsLostUDP reports whether err is an error from a UDP send
-// operation that should be treated as a UDP packet that just got
-// lost.
-//
-// Notably, on Linux this reports true for EPERM errors (from outbound
-// firewall blocks) which aren't really send errors; they're just
-// sends that are never going to make it because the local OS blocked
-// it.
-func TreatAsLostUDP(err error) bool {
- if err == nil {
- return false
- }
- switch runtime.GOOS {
- case "linux":
- // Linux, while not documented in the man page,
- // returns EPERM when there's an OUTPUT rule with -j
- // DROP or -j REJECT. We use this very specific
- // Linux+EPERM check rather than something super broad
- // like net.Error.Temporary which could be anything.
- //
- // For now we only do this on Linux, as such outgoing
- // firewall violations mapping to syscall errors
- // hasn't yet been observed on other OSes.
- return errors.Is(err, errEPERM)
- }
- return false
-}
-
-var packetWasTruncated func(error) bool // non-nil on Windows at least
-
-// PacketWasTruncated reports whether err indicates truncation but the RecvFrom
-// that generated err was otherwise successful. On Windows, Go's UDP RecvFrom
-// calls WSARecvFrom which returns the WSAEMSGSIZE error code when the received
-// datagram is larger than the provided buffer. When that happens, both a valid
-// size and an error are returned (as per the partial fix for golang/go#14074).
-// If the WSAEMSGSIZE error is returned, then we ignore the error to get
-// semantics similar to the POSIX operating systems. One caveat is that it
-// appears that the source address is not returned when WSAEMSGSIZE occurs, but
-// we do not currently look at the source address.
-func PacketWasTruncated(err error) bool {
- if packetWasTruncated == nil {
- return false
- }
- return packetWasTruncated(err)
-}
-
-var shouldDisableUDPGSO func(error) bool // non-nil on Linux
-
-func ShouldDisableUDPGSO(err error) bool {
- if shouldDisableUDPGSO == nil {
- return false
- }
- return shouldDisableUDPGSO(err)
-}
-
-type ErrUDPGSODisabled struct {
- OnLaddr string
- RetryErr error
-}
-
-func (e ErrUDPGSODisabled) Error() string {
- return fmt.Sprintf("disabled UDP GSO on %s, NIC(s) may not support checksum offload", e.OnLaddr)
-}
-
-func (e ErrUDPGSODisabled) Unwrap() error {
- return e.RetryErr
-}
+// Copyright (c) Tailscale Inc & AUTHORS
+// SPDX-License-Identifier: BSD-3-Clause
+
+// Package neterror classifies network errors.
+package neterror
+
+import (
+ "errors"
+ "fmt"
+ "runtime"
+ "syscall"
+)
+
+var errEPERM error = syscall.EPERM // box it into interface just once
+
+// TreatAsLostUDP reports whether err is an error from a UDP send
+// operation that should be treated as a UDP packet that just got
+// lost.
+//
+// Notably, on Linux this reports true for EPERM errors (from outbound
+// firewall blocks) which aren't really send errors; they're just
+// sends that are never going to make it because the local OS blocked
+// it.
+func TreatAsLostUDP(err error) bool {
+ if err == nil {
+ return false
+ }
+ switch runtime.GOOS {
+ case "linux":
+ // Linux, while not documented in the man page,
+ // returns EPERM when there's an OUTPUT rule with -j
+ // DROP or -j REJECT. We use this very specific
+ // Linux+EPERM check rather than something super broad
+ // like net.Error.Temporary which could be anything.
+ //
+ // For now we only do this on Linux, as such outgoing
+ // firewall violations mapping to syscall errors
+ // hasn't yet been observed on other OSes.
+ return errors.Is(err, errEPERM)
+ }
+ return false
+}
+
+var packetWasTruncated func(error) bool // non-nil on Windows at least
+
+// PacketWasTruncated reports whether err indicates truncation but the RecvFrom
+// that generated err was otherwise successful. On Windows, Go's UDP RecvFrom
+// calls WSARecvFrom which returns the WSAEMSGSIZE error code when the received
+// datagram is larger than the provided buffer. When that happens, both a valid
+// size and an error are returned (as per the partial fix for golang/go#14074).
+// If the WSAEMSGSIZE error is returned, then we ignore the error to get
+// semantics similar to the POSIX operating systems. One caveat is that it
+// appears that the source address is not returned when WSAEMSGSIZE occurs, but
+// we do not currently look at the source address.
+func PacketWasTruncated(err error) bool {
+ if packetWasTruncated == nil {
+ return false
+ }
+ return packetWasTruncated(err)
+}
+
+var shouldDisableUDPGSO func(error) bool // non-nil on Linux
+
+func ShouldDisableUDPGSO(err error) bool {
+ if shouldDisableUDPGSO == nil {
+ return false
+ }
+ return shouldDisableUDPGSO(err)
+}
+
+type ErrUDPGSODisabled struct {
+ OnLaddr string
+ RetryErr error
+}
+
+func (e ErrUDPGSODisabled) Error() string {
+ return fmt.Sprintf("disabled UDP GSO on %s, NIC(s) may not support checksum offload", e.OnLaddr)
+}
+
+func (e ErrUDPGSODisabled) Unwrap() error {
+ return e.RetryErr
+}
diff --git a/net/neterror/neterror_linux.go b/net/neterror/neterror_linux.go
index 857367fe8..3f402dd30 100644
--- a/net/neterror/neterror_linux.go
+++ b/net/neterror/neterror_linux.go
@@ -1,26 +1,26 @@
-// Copyright (c) Tailscale Inc & AUTHORS
-// SPDX-License-Identifier: BSD-3-Clause
-
-package neterror
-
-import (
- "errors"
- "os"
-
- "golang.org/x/sys/unix"
-)
-
-func init() {
- shouldDisableUDPGSO = func(err error) bool {
- var serr *os.SyscallError
- if errors.As(err, &serr) {
- // EIO is returned by udp_send_skb() if the device driver does not
- // have tx checksumming enabled, which is a hard requirement of
- // UDP_SEGMENT. See:
- // https://git.kernel.org/pub/scm/docs/man-pages/man-pages.git/tree/man7/udp.7?id=806eabd74910447f21005160e90957bde4db0183#n228
- // https://git.kernel.org/pub/scm/linux/kernel/git/torvalds/linux.git/tree/net/ipv4/udp.c?h=v6.2&id=c9c3395d5e3dcc6daee66c6908354d47bf98cb0c#n942
- return serr.Err == unix.EIO
- }
- return false
- }
-}
+// Copyright (c) Tailscale Inc & AUTHORS
+// SPDX-License-Identifier: BSD-3-Clause
+
+package neterror
+
+import (
+ "errors"
+ "os"
+
+ "golang.org/x/sys/unix"
+)
+
+func init() {
+ shouldDisableUDPGSO = func(err error) bool {
+ var serr *os.SyscallError
+ if errors.As(err, &serr) {
+ // EIO is returned by udp_send_skb() if the device driver does not
+ // have tx checksumming enabled, which is a hard requirement of
+ // UDP_SEGMENT. See:
+ // https://git.kernel.org/pub/scm/docs/man-pages/man-pages.git/tree/man7/udp.7?id=806eabd74910447f21005160e90957bde4db0183#n228
+ // https://git.kernel.org/pub/scm/linux/kernel/git/torvalds/linux.git/tree/net/ipv4/udp.c?h=v6.2&id=c9c3395d5e3dcc6daee66c6908354d47bf98cb0c#n942
+ return serr.Err == unix.EIO
+ }
+ return false
+ }
+}
diff --git a/net/neterror/neterror_linux_test.go b/net/neterror/neterror_linux_test.go
index 5b9906074..1d600d6b6 100644
--- a/net/neterror/neterror_linux_test.go
+++ b/net/neterror/neterror_linux_test.go
@@ -1,54 +1,54 @@
-// Copyright (c) Tailscale Inc & AUTHORS
-// SPDX-License-Identifier: BSD-3-Clause
-
-package neterror
-
-import (
- "errors"
- "net"
- "os"
- "syscall"
- "testing"
-)
-
-func TestTreatAsLostUDP(t *testing.T) {
- tests := []struct {
- name string
- err error
- want bool
- }{
- {"nil", nil, false},
- {"non-nil", errors.New("foo"), false},
- {"eperm", syscall.EPERM, true},
- {
- name: "operror",
- err: &net.OpError{
- Op: "write",
- Err: &os.SyscallError{
- Syscall: "sendto",
- Err: syscall.EPERM,
- },
- },
- want: true,
- },
- {
- name: "host_unreach",
- err: &net.OpError{
- Op: "write",
- Err: &os.SyscallError{
- Syscall: "sendto",
- Err: syscall.EHOSTUNREACH,
- },
- },
- want: false,
- },
- }
- for _, tt := range tests {
- t.Run(tt.name, func(t *testing.T) {
- if got := TreatAsLostUDP(tt.err); got != tt.want {
- t.Errorf("got = %v; want %v", got, tt.want)
- }
- })
- }
-
-}
+// Copyright (c) Tailscale Inc & AUTHORS
+// SPDX-License-Identifier: BSD-3-Clause
+
+package neterror
+
+import (
+ "errors"
+ "net"
+ "os"
+ "syscall"
+ "testing"
+)
+
+func TestTreatAsLostUDP(t *testing.T) {
+ tests := []struct {
+ name string
+ err error
+ want bool
+ }{
+ {"nil", nil, false},
+ {"non-nil", errors.New("foo"), false},
+ {"eperm", syscall.EPERM, true},
+ {
+ name: "operror",
+ err: &net.OpError{
+ Op: "write",
+ Err: &os.SyscallError{
+ Syscall: "sendto",
+ Err: syscall.EPERM,
+ },
+ },
+ want: true,
+ },
+ {
+ name: "host_unreach",
+ err: &net.OpError{
+ Op: "write",
+ Err: &os.SyscallError{
+ Syscall: "sendto",
+ Err: syscall.EHOSTUNREACH,
+ },
+ },
+ want: false,
+ },
+ }
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ if got := TreatAsLostUDP(tt.err); got != tt.want {
+ t.Errorf("got = %v; want %v", got, tt.want)
+ }
+ })
+ }
+
+}
diff --git a/net/neterror/neterror_windows.go b/net/neterror/neterror_windows.go
index bf112f5ed..c293ec4a9 100644
--- a/net/neterror/neterror_windows.go
+++ b/net/neterror/neterror_windows.go
@@ -1,16 +1,16 @@
-// Copyright (c) Tailscale Inc & AUTHORS
-// SPDX-License-Identifier: BSD-3-Clause
-
-package neterror
-
-import (
- "errors"
-
- "golang.org/x/sys/windows"
-)
-
-func init() {
- packetWasTruncated = func(err error) bool {
- return errors.Is(err, windows.WSAEMSGSIZE)
- }
-}
+// Copyright (c) Tailscale Inc & AUTHORS
+// SPDX-License-Identifier: BSD-3-Clause
+
+package neterror
+
+import (
+ "errors"
+
+ "golang.org/x/sys/windows"
+)
+
+func init() {
+ packetWasTruncated = func(err error) bool {
+ return errors.Is(err, windows.WSAEMSGSIZE)
+ }
+}
diff --git a/net/netkernelconf/netkernelconf.go b/net/netkernelconf/netkernelconf.go
index 3ea502b37..23ec9c5b6 100644
--- a/net/netkernelconf/netkernelconf.go
+++ b/net/netkernelconf/netkernelconf.go
@@ -1,5 +1,5 @@
-// Copyright (c) Tailscale Inc & AUTHORS
-// SPDX-License-Identifier: BSD-3-Clause
-
-// Package netkernelconf contains code for checking kernel netdev config.
-package netkernelconf
+// Copyright (c) Tailscale Inc & AUTHORS
+// SPDX-License-Identifier: BSD-3-Clause
+
+// Package netkernelconf contains code for checking kernel netdev config.
+package netkernelconf
diff --git a/net/netknob/netknob.go b/net/netknob/netknob.go
index 53171f424..0b271fc95 100644
--- a/net/netknob/netknob.go
+++ b/net/netknob/netknob.go
@@ -1,29 +1,29 @@
-// Copyright (c) Tailscale Inc & AUTHORS
-// SPDX-License-Identifier: BSD-3-Clause
-
-// Package netknob has Tailscale network knobs.
-package netknob
-
-import (
- "runtime"
- "time"
-)
-
-// PlatformTCPKeepAlive returns the default net.Dialer.KeepAlive
-// value for the current runtime.GOOS.
-func PlatformTCPKeepAlive() time.Duration {
- switch runtime.GOOS {
- case "ios", "android":
- // Disable TCP keep-alives on mobile platforms.
- // See https://github.com/golang/go/issues/48622.
- //
- // TODO(bradfitz): in 1.17.x, try disabling TCP
- // keep-alives on for all platforms.
- return -1
- }
-
- // Otherwise, default to 30 seconds, which is mostly what we
- // used to do. In some places we used the zero value, which Go
- // defaults to 15 seconds. But 30 seconds is fine.
- return 30 * time.Second
-}
+// Copyright (c) Tailscale Inc & AUTHORS
+// SPDX-License-Identifier: BSD-3-Clause
+
+// Package netknob has Tailscale network knobs.
+package netknob
+
+import (
+ "runtime"
+ "time"
+)
+
+// PlatformTCPKeepAlive returns the default net.Dialer.KeepAlive
+// value for the current runtime.GOOS.
+func PlatformTCPKeepAlive() time.Duration {
+ switch runtime.GOOS {
+ case "ios", "android":
+ // Disable TCP keep-alives on mobile platforms.
+ // See https://github.com/golang/go/issues/48622.
+ //
+ // TODO(bradfitz): in 1.17.x, try disabling TCP
+ // keep-alives on for all platforms.
+ return -1
+ }
+
+ // Otherwise, default to 30 seconds, which is mostly what we
+ // used to do. In some places we used the zero value, which Go
+ // defaults to 15 seconds. But 30 seconds is fine.
+ return 30 * time.Second
+}
diff --git a/net/netmon/netmon_darwin_test.go b/net/netmon/netmon_darwin_test.go
index 84c67cf6f..77a212683 100644
--- a/net/netmon/netmon_darwin_test.go
+++ b/net/netmon/netmon_darwin_test.go
@@ -1,27 +1,27 @@
-// Copyright (c) Tailscale Inc & AUTHORS
-// SPDX-License-Identifier: BSD-3-Clause
-
-package netmon
-
-import (
- "encoding/hex"
- "strings"
- "testing"
-
- "golang.org/x/net/route"
-)
-
-func TestIssue1416RIB(t *testing.T) {
- const ribHex = `32 00 05 10 30 00 00 00 00 00 00 00 04 00 00 00 14 12 04 00 06 03 06 00 65 6e 30 ac 87 a3 19 7f 82 00 00 00 0e 12 00 00 00 00 06 00 91 e0 f0 01 00 00`
- rtmMsg, err := hex.DecodeString(strings.ReplaceAll(ribHex, " ", ""))
- if err != nil {
- t.Fatal(err)
- }
- msgs, err := route.ParseRIB(route.RIBTypeRoute, rtmMsg)
- if err != nil {
- t.Logf("ParseRIB: %v", err)
- t.Skip("skipping on known failure; see https://github.com/tailscale/tailscale/issues/1416")
- t.Fatal(err)
- }
- t.Logf("Got: %#v", msgs)
-}
+// Copyright (c) Tailscale Inc & AUTHORS
+// SPDX-License-Identifier: BSD-3-Clause
+
+package netmon
+
+import (
+ "encoding/hex"
+ "strings"
+ "testing"
+
+ "golang.org/x/net/route"
+)
+
+func TestIssue1416RIB(t *testing.T) {
+ const ribHex = `32 00 05 10 30 00 00 00 00 00 00 00 04 00 00 00 14 12 04 00 06 03 06 00 65 6e 30 ac 87 a3 19 7f 82 00 00 00 0e 12 00 00 00 00 06 00 91 e0 f0 01 00 00`
+ rtmMsg, err := hex.DecodeString(strings.ReplaceAll(ribHex, " ", ""))
+ if err != nil {
+ t.Fatal(err)
+ }
+ msgs, err := route.ParseRIB(route.RIBTypeRoute, rtmMsg)
+ if err != nil {
+ t.Logf("ParseRIB: %v", err)
+ t.Skip("skipping on known failure; see https://github.com/tailscale/tailscale/issues/1416")
+ t.Fatal(err)
+ }
+ t.Logf("Got: %#v", msgs)
+}
diff --git a/net/netmon/netmon_freebsd.go b/net/netmon/netmon_freebsd.go
index 30480a1d3..724f964c9 100644
--- a/net/netmon/netmon_freebsd.go
+++ b/net/netmon/netmon_freebsd.go
@@ -1,56 +1,56 @@
-// Copyright (c) Tailscale Inc & AUTHORS
-// SPDX-License-Identifier: BSD-3-Clause
-
-package netmon
-
-import (
- "bufio"
- "fmt"
- "net"
- "strings"
-
- "tailscale.com/types/logger"
-)
-
-// unspecifiedMessage is a minimal message implementation that should not
-// be ignored. In general, OS-specific implementations should use better
-// types and avoid this if they can.
-type unspecifiedMessage struct{}
-
-func (unspecifiedMessage) ignore() bool { return false }
-
-// devdConn implements osMon using devd(8).
-type devdConn struct {
- conn net.Conn
-}
-
-func newOSMon(logf logger.Logf, m *Monitor) (osMon, error) {
- conn, err := net.Dial("unixpacket", "/var/run/devd.seqpacket.pipe")
- if err != nil {
- logf("devd dial error: %v, falling back to polling method", err)
- return newPollingMon(logf, m)
- }
- return &devdConn{conn}, nil
-}
-
-func (c *devdConn) IsInterestingInterface(iface string) bool { return true }
-
-func (c *devdConn) Close() error {
- return c.conn.Close()
-}
-
-func (c *devdConn) Receive() (message, error) {
- for {
- msg, err := bufio.NewReader(c.conn).ReadString('\n')
- if err != nil {
- return nil, fmt.Errorf("reading devd socket: %v", err)
- }
- // Only return messages related to the network subsystem.
- if !strings.Contains(msg, "system=IFNET") {
- continue
- }
- // TODO: this is where the devd-specific message would
- // get converted into a "standard" event message and returned.
- return unspecifiedMessage{}, nil
- }
-}
+// Copyright (c) Tailscale Inc & AUTHORS
+// SPDX-License-Identifier: BSD-3-Clause
+
+package netmon
+
+import (
+ "bufio"
+ "fmt"
+ "net"
+ "strings"
+
+ "tailscale.com/types/logger"
+)
+
+// unspecifiedMessage is a minimal message implementation that should not
+// be ignored. In general, OS-specific implementations should use better
+// types and avoid this if they can.
+type unspecifiedMessage struct{}
+
+func (unspecifiedMessage) ignore() bool { return false }
+
+// devdConn implements osMon using devd(8).
+type devdConn struct {
+ conn net.Conn
+}
+
+func newOSMon(logf logger.Logf, m *Monitor) (osMon, error) {
+ conn, err := net.Dial("unixpacket", "/var/run/devd.seqpacket.pipe")
+ if err != nil {
+ logf("devd dial error: %v, falling back to polling method", err)
+ return newPollingMon(logf, m)
+ }
+ return &devdConn{conn}, nil
+}
+
+func (c *devdConn) IsInterestingInterface(iface string) bool { return true }
+
+func (c *devdConn) Close() error {
+ return c.conn.Close()
+}
+
+func (c *devdConn) Receive() (message, error) {
+ for {
+ msg, err := bufio.NewReader(c.conn).ReadString('\n')
+ if err != nil {
+ return nil, fmt.Errorf("reading devd socket: %v", err)
+ }
+ // Only return messages related to the network subsystem.
+ if !strings.Contains(msg, "system=IFNET") {
+ continue
+ }
+ // TODO: this is where the devd-specific message would
+ // get converted into a "standard" event message and returned.
+ return unspecifiedMessage{}, nil
+ }
+}
diff --git a/net/netmon/netmon_linux.go b/net/netmon/netmon_linux.go
index dd23dd342..888afa92d 100644
--- a/net/netmon/netmon_linux.go
+++ b/net/netmon/netmon_linux.go
@@ -1,290 +1,290 @@
-// Copyright (c) Tailscale Inc & AUTHORS
-// SPDX-License-Identifier: BSD-3-Clause
-
-//go:build !android
-
-package netmon
-
-import (
- "net"
- "net/netip"
- "time"
-
- "github.com/jsimonetti/rtnetlink"
- "github.com/mdlayher/netlink"
- "golang.org/x/sys/unix"
- "tailscale.com/envknob"
- "tailscale.com/net/tsaddr"
- "tailscale.com/types/logger"
-)
-
-var debugNetlinkMessages = envknob.RegisterBool("TS_DEBUG_NETLINK")
-
-// unspecifiedMessage is a minimal message implementation that should not
-// be ignored. In general, OS-specific implementations should use better
-// types and avoid this if they can.
-type unspecifiedMessage struct{}
-
-func (unspecifiedMessage) ignore() bool { return false }
-
-// nlConn wraps a *netlink.Conn and returns a monitor.Message
-// instead of a netlink.Message. Currently, messages are discarded,
-// but down the line, when messages trigger different logic depending
-// on the type of event, this provides the capability of handling
-// each architecture-specific message in a generic fashion.
-type nlConn struct {
- logf logger.Logf
- conn *netlink.Conn
- buffered []netlink.Message
-
- // addrCache maps interface indices to a set of addresses, and is
- // used to suppress duplicate RTM_NEWADDR messages. It is populated
- // by RTM_NEWADDR messages and de-populated by RTM_DELADDR. See
- // issue #4282.
- addrCache map[uint32]map[netip.Addr]bool
-}
-
-func newOSMon(logf logger.Logf, m *Monitor) (osMon, error) {
- conn, err := netlink.Dial(unix.NETLINK_ROUTE, &netlink.Config{
- // Routes get us most of the events of interest, but we need
- // address as well to cover things like DHCP deciding to give
- // us a new address upon renewal - routing wouldn't change,
- // but all reachability would.
- Groups: unix.RTMGRP_IPV4_IFADDR | unix.RTMGRP_IPV6_IFADDR |
- unix.RTMGRP_IPV4_ROUTE | unix.RTMGRP_IPV6_ROUTE |
- unix.RTMGRP_IPV4_RULE, // no IPV6_RULE in x/sys/unix
- })
- if err != nil {
- // Google Cloud Run does not implement NETLINK_ROUTE RTMGRP support
- logf("monitor_linux: AF_NETLINK RTMGRP failed, falling back to polling")
- return newPollingMon(logf, m)
- }
- return &nlConn{logf: logf, conn: conn, addrCache: make(map[uint32]map[netip.Addr]bool)}, nil
-}
-
-func (c *nlConn) IsInterestingInterface(iface string) bool { return true }
-
-func (c *nlConn) Close() error { return c.conn.Close() }
-
-func (c *nlConn) Receive() (message, error) {
- if len(c.buffered) == 0 {
- var err error
- c.buffered, err = c.conn.Receive()
- if err != nil {
- return nil, err
- }
- if len(c.buffered) == 0 {
- // Unexpected. Not seen in wild, but sleep defensively.
- time.Sleep(time.Second)
- return ignoreMessage{}, nil
- }
- }
- msg := c.buffered[0]
- c.buffered = c.buffered[1:]
-
- // See https://github.com/torvalds/linux/blob/master/include/uapi/linux/rtnetlink.h
- // And https://man7.org/linux/man-pages/man7/rtnetlink.7.html
- switch msg.Header.Type {
- case unix.RTM_NEWADDR, unix.RTM_DELADDR:
- var rmsg rtnetlink.AddressMessage
- if err := rmsg.UnmarshalBinary(msg.Data); err != nil {
- c.logf("failed to parse type %v: %v", msg.Header.Type, err)
- return unspecifiedMessage{}, nil
- }
-
- nip := netaddrIP(rmsg.Attributes.Address)
-
- if debugNetlinkMessages() {
- typ := "RTM_NEWADDR"
- if msg.Header.Type == unix.RTM_DELADDR {
- typ = "RTM_DELADDR"
- }
-
- // label attributes are seemingly only populated for IPv4 addresses in the wild.
- label := rmsg.Attributes.Label
- if label == "" {
- itf, err := net.InterfaceByIndex(int(rmsg.Index))
- if err == nil {
- label = itf.Name
- }
- }
-
- c.logf("%s: %s(%d) %s / %s", typ, label, rmsg.Index, rmsg.Attributes.Address, rmsg.Attributes.Local)
- }
-
- addrs := c.addrCache[rmsg.Index]
-
- // Ignore duplicate RTM_NEWADDR messages using c.addrCache to
- // detect them. See nlConn.addrcache and issue #4282.
- if msg.Header.Type == unix.RTM_NEWADDR {
- if addrs == nil {
- addrs = make(map[netip.Addr]bool)
- c.addrCache[rmsg.Index] = addrs
- }
-
- if addrs[nip] {
- if debugNetlinkMessages() {
- c.logf("ignored duplicate RTM_NEWADDR for %s", nip)
- }
- return ignoreMessage{}, nil
- }
-
- addrs[nip] = true
- } else { // msg.Header.Type == unix.RTM_DELADDR
- if addrs != nil {
- delete(addrs, nip)
- }
-
- if len(addrs) == 0 {
- delete(c.addrCache, rmsg.Index)
- }
- }
-
- nam := &newAddrMessage{
- IfIndex: rmsg.Index,
- Addr: nip,
- Delete: msg.Header.Type == unix.RTM_DELADDR,
- }
- if debugNetlinkMessages() {
- c.logf("%+v", nam)
- }
- return nam, nil
- case unix.RTM_NEWROUTE, unix.RTM_DELROUTE:
- typeStr := "RTM_NEWROUTE"
- if msg.Header.Type == unix.RTM_DELROUTE {
- typeStr = "RTM_DELROUTE"
- }
- var rmsg rtnetlink.RouteMessage
- if err := rmsg.UnmarshalBinary(msg.Data); err != nil {
- c.logf("%s: failed to parse: %v", typeStr, err)
- return unspecifiedMessage{}, nil
- }
- src := netaddrIPPrefix(rmsg.Attributes.Src, rmsg.SrcLength)
- dst := netaddrIPPrefix(rmsg.Attributes.Dst, rmsg.DstLength)
- gw := netaddrIP(rmsg.Attributes.Gateway)
-
- if msg.Header.Type == unix.RTM_NEWROUTE &&
- (rmsg.Attributes.Table == 255 || rmsg.Attributes.Table == 254) &&
- (dst.Addr().IsMulticast() || dst.Addr().IsLinkLocalUnicast()) {
-
- if debugNetlinkMessages() {
- c.logf("%s ignored", typeStr)
- }
-
- // Normal Linux route changes on new interface coming up; don't log or react.
- return ignoreMessage{}, nil
- }
-
- if rmsg.Table == tsTable && dst.IsSingleIP() {
- // Don't log. Spammy and normal to see a bunch of these on start-up,
- // which we make ourselves.
- } else if tsaddr.IsTailscaleIP(dst.Addr()) {
- // Verbose only.
- c.logf("%s: [v1] src=%v, dst=%v, gw=%v, outif=%v, table=%v", typeStr,
- condNetAddrPrefix(src), condNetAddrPrefix(dst), condNetAddrIP(gw),
- rmsg.Attributes.OutIface, rmsg.Attributes.Table)
- } else {
- c.logf("%s: src=%v, dst=%v, gw=%v, outif=%v, table=%v", typeStr,
- condNetAddrPrefix(src), condNetAddrPrefix(dst), condNetAddrIP(gw),
- rmsg.Attributes.OutIface, rmsg.Attributes.Table)
- }
- if msg.Header.Type == unix.RTM_DELROUTE {
- // Just logging it for now.
- // (Debugging https://github.com/tailscale/tailscale/issues/643)
- return unspecifiedMessage{}, nil
- }
-
- nrm := &newRouteMessage{
- Table: rmsg.Table,
- Src: src,
- Dst: dst,
- Gateway: gw,
- }
- if debugNetlinkMessages() {
- c.logf("%+v", nrm)
- }
- return nrm, nil
- case unix.RTM_NEWRULE:
- // Probably ourselves adding it.
- return ignoreMessage{}, nil
- case unix.RTM_DELRULE:
- // For https://github.com/tailscale/tailscale/issues/1591 where
- // systemd-networkd deletes our rules.
- var rmsg rtnetlink.RouteMessage
- err := rmsg.UnmarshalBinary(msg.Data)
- if err != nil {
- c.logf("ip rule deleted; failed to parse netlink message: %v", err)
- } else {
- c.logf("ip rule deleted: %+v", rmsg)
- // On `ip -4 rule del pref 5210 table main`, logs:
- // monitor: ip rule deleted: {Family:2 DstLength:0 SrcLength:0 Tos:0 Table:254 Protocol:0 Scope:0 Type:1 Flags:0 Attributes:{Dst:<nil> Src:<nil> Gateway:<nil> OutIface:0 Priority:5210 Table:254 Mark:4294967295 Expires:<nil> Metrics:<nil> Multipath:[]}}
- }
- rdm := ipRuleDeletedMessage{
- table: rmsg.Table,
- priority: rmsg.Attributes.Priority,
- }
- if debugNetlinkMessages() {
- c.logf("%+v", rdm)
- }
- return rdm, nil
- case unix.RTM_NEWLINK, unix.RTM_DELLINK:
- // This is an unhandled message, but don't print an error.
- // See https://github.com/tailscale/tailscale/issues/6806
- return unspecifiedMessage{}, nil
- default:
- c.logf("unhandled netlink msg type %+v, %q", msg.Header, msg.Data)
- return unspecifiedMessage{}, nil
- }
-}
-
-func netaddrIP(std net.IP) netip.Addr {
- ip, _ := netip.AddrFromSlice(std)
- return ip.Unmap()
-}
-
-func netaddrIPPrefix(std net.IP, bits uint8) netip.Prefix {
- ip, _ := netip.AddrFromSlice(std)
- return netip.PrefixFrom(ip.Unmap(), int(bits))
-}
-
-func condNetAddrPrefix(ipp netip.Prefix) string {
- if !ipp.Addr().IsValid() {
- return ""
- }
- return ipp.String()
-}
-
-func condNetAddrIP(ip netip.Addr) string {
- if !ip.IsValid() {
- return ""
- }
- return ip.String()
-}
-
-// newRouteMessage is a message for a new route being added.
-type newRouteMessage struct {
- Src, Dst netip.Prefix
- Gateway netip.Addr
- Table uint8
-}
-
-const tsTable = 52
-
-func (m *newRouteMessage) ignore() bool {
- return m.Table == tsTable || tsaddr.IsTailscaleIP(m.Dst.Addr())
-}
-
-// newAddrMessage is a message for a new address being added.
-type newAddrMessage struct {
- Delete bool
- Addr netip.Addr
- IfIndex uint32 // interface index
-}
-
-func (m *newAddrMessage) ignore() bool {
- return tsaddr.IsTailscaleIP(m.Addr)
-}
-
-type ignoreMessage struct{}
-
-func (ignoreMessage) ignore() bool { return true }
+// Copyright (c) Tailscale Inc & AUTHORS
+// SPDX-License-Identifier: BSD-3-Clause
+
+//go:build !android
+
+package netmon
+
+import (
+ "net"
+ "net/netip"
+ "time"
+
+ "github.com/jsimonetti/rtnetlink"
+ "github.com/mdlayher/netlink"
+ "golang.org/x/sys/unix"
+ "tailscale.com/envknob"
+ "tailscale.com/net/tsaddr"
+ "tailscale.com/types/logger"
+)
+
+var debugNetlinkMessages = envknob.RegisterBool("TS_DEBUG_NETLINK")
+
+// unspecifiedMessage is a minimal message implementation that should not
+// be ignored. In general, OS-specific implementations should use better
+// types and avoid this if they can.
+type unspecifiedMessage struct{}
+
+func (unspecifiedMessage) ignore() bool { return false }
+
+// nlConn wraps a *netlink.Conn and returns a monitor.Message
+// instead of a netlink.Message. Currently, messages are discarded,
+// but down the line, when messages trigger different logic depending
+// on the type of event, this provides the capability of handling
+// each architecture-specific message in a generic fashion.
+type nlConn struct {
+ logf logger.Logf
+ conn *netlink.Conn
+ buffered []netlink.Message
+
+ // addrCache maps interface indices to a set of addresses, and is
+ // used to suppress duplicate RTM_NEWADDR messages. It is populated
+ // by RTM_NEWADDR messages and de-populated by RTM_DELADDR. See
+ // issue #4282.
+ addrCache map[uint32]map[netip.Addr]bool
+}
+
+func newOSMon(logf logger.Logf, m *Monitor) (osMon, error) {
+ conn, err := netlink.Dial(unix.NETLINK_ROUTE, &netlink.Config{
+ // Routes get us most of the events of interest, but we need
+ // address as well to cover things like DHCP deciding to give
+ // us a new address upon renewal - routing wouldn't change,
+ // but all reachability would.
+ Groups: unix.RTMGRP_IPV4_IFADDR | unix.RTMGRP_IPV6_IFADDR |
+ unix.RTMGRP_IPV4_ROUTE | unix.RTMGRP_IPV6_ROUTE |
+ unix.RTMGRP_IPV4_RULE, // no IPV6_RULE in x/sys/unix
+ })
+ if err != nil {
+ // Google Cloud Run does not implement NETLINK_ROUTE RTMGRP support
+ logf("monitor_linux: AF_NETLINK RTMGRP failed, falling back to polling")
+ return newPollingMon(logf, m)
+ }
+ return &nlConn{logf: logf, conn: conn, addrCache: make(map[uint32]map[netip.Addr]bool)}, nil
+}
+
+func (c *nlConn) IsInterestingInterface(iface string) bool { return true }
+
+func (c *nlConn) Close() error { return c.conn.Close() }
+
+func (c *nlConn) Receive() (message, error) {
+ if len(c.buffered) == 0 {
+ var err error
+ c.buffered, err = c.conn.Receive()
+ if err != nil {
+ return nil, err
+ }
+ if len(c.buffered) == 0 {
+ // Unexpected. Not seen in wild, but sleep defensively.
+ time.Sleep(time.Second)
+ return ignoreMessage{}, nil
+ }
+ }
+ msg := c.buffered[0]
+ c.buffered = c.buffered[1:]
+
+ // See https://github.com/torvalds/linux/blob/master/include/uapi/linux/rtnetlink.h
+ // And https://man7.org/linux/man-pages/man7/rtnetlink.7.html
+ switch msg.Header.Type {
+ case unix.RTM_NEWADDR, unix.RTM_DELADDR:
+ var rmsg rtnetlink.AddressMessage
+ if err := rmsg.UnmarshalBinary(msg.Data); err != nil {
+ c.logf("failed to parse type %v: %v", msg.Header.Type, err)
+ return unspecifiedMessage{}, nil
+ }
+
+ nip := netaddrIP(rmsg.Attributes.Address)
+
+ if debugNetlinkMessages() {
+ typ := "RTM_NEWADDR"
+ if msg.Header.Type == unix.RTM_DELADDR {
+ typ = "RTM_DELADDR"
+ }
+
+ // label attributes are seemingly only populated for IPv4 addresses in the wild.
+ label := rmsg.Attributes.Label
+ if label == "" {
+ itf, err := net.InterfaceByIndex(int(rmsg.Index))
+ if err == nil {
+ label = itf.Name
+ }
+ }
+
+ c.logf("%s: %s(%d) %s / %s", typ, label, rmsg.Index, rmsg.Attributes.Address, rmsg.Attributes.Local)
+ }
+
+ addrs := c.addrCache[rmsg.Index]
+
+ // Ignore duplicate RTM_NEWADDR messages using c.addrCache to
+ // detect them. See nlConn.addrcache and issue #4282.
+ if msg.Header.Type == unix.RTM_NEWADDR {
+ if addrs == nil {
+ addrs = make(map[netip.Addr]bool)
+ c.addrCache[rmsg.Index] = addrs
+ }
+
+ if addrs[nip] {
+ if debugNetlinkMessages() {
+ c.logf("ignored duplicate RTM_NEWADDR for %s", nip)
+ }
+ return ignoreMessage{}, nil
+ }
+
+ addrs[nip] = true
+ } else { // msg.Header.Type == unix.RTM_DELADDR
+ if addrs != nil {
+ delete(addrs, nip)
+ }
+
+ if len(addrs) == 0 {
+ delete(c.addrCache, rmsg.Index)
+ }
+ }
+
+ nam := &newAddrMessage{
+ IfIndex: rmsg.Index,
+ Addr: nip,
+ Delete: msg.Header.Type == unix.RTM_DELADDR,
+ }
+ if debugNetlinkMessages() {
+ c.logf("%+v", nam)
+ }
+ return nam, nil
+ case unix.RTM_NEWROUTE, unix.RTM_DELROUTE:
+ typeStr := "RTM_NEWROUTE"
+ if msg.Header.Type == unix.RTM_DELROUTE {
+ typeStr = "RTM_DELROUTE"
+ }
+ var rmsg rtnetlink.RouteMessage
+ if err := rmsg.UnmarshalBinary(msg.Data); err != nil {
+ c.logf("%s: failed to parse: %v", typeStr, err)
+ return unspecifiedMessage{}, nil
+ }
+ src := netaddrIPPrefix(rmsg.Attributes.Src, rmsg.SrcLength)
+ dst := netaddrIPPrefix(rmsg.Attributes.Dst, rmsg.DstLength)
+ gw := netaddrIP(rmsg.Attributes.Gateway)
+
+ if msg.Header.Type == unix.RTM_NEWROUTE &&
+ (rmsg.Attributes.Table == 255 || rmsg.Attributes.Table == 254) &&
+ (dst.Addr().IsMulticast() || dst.Addr().IsLinkLocalUnicast()) {
+
+ if debugNetlinkMessages() {
+ c.logf("%s ignored", typeStr)
+ }
+
+ // Normal Linux route changes on new interface coming up; don't log or react.
+ return ignoreMessage{}, nil
+ }
+
+ if rmsg.Table == tsTable && dst.IsSingleIP() {
+ // Don't log. Spammy and normal to see a bunch of these on start-up,
+ // which we make ourselves.
+ } else if tsaddr.IsTailscaleIP(dst.Addr()) {
+ // Verbose only.
+ c.logf("%s: [v1] src=%v, dst=%v, gw=%v, outif=%v, table=%v", typeStr,
+ condNetAddrPrefix(src), condNetAddrPrefix(dst), condNetAddrIP(gw),
+ rmsg.Attributes.OutIface, rmsg.Attributes.Table)
+ } else {
+ c.logf("%s: src=%v, dst=%v, gw=%v, outif=%v, table=%v", typeStr,
+ condNetAddrPrefix(src), condNetAddrPrefix(dst), condNetAddrIP(gw),
+ rmsg.Attributes.OutIface, rmsg.Attributes.Table)
+ }
+ if msg.Header.Type == unix.RTM_DELROUTE {
+ // Just logging it for now.
+ // (Debugging https://github.com/tailscale/tailscale/issues/643)
+ return unspecifiedMessage{}, nil
+ }
+
+ nrm := &newRouteMessage{
+ Table: rmsg.Table,
+ Src: src,
+ Dst: dst,
+ Gateway: gw,
+ }
+ if debugNetlinkMessages() {
+ c.logf("%+v", nrm)
+ }
+ return nrm, nil
+ case unix.RTM_NEWRULE:
+ // Probably ourselves adding it.
+ return ignoreMessage{}, nil
+ case unix.RTM_DELRULE:
+ // For https://github.com/tailscale/tailscale/issues/1591 where
+ // systemd-networkd deletes our rules.
+ var rmsg rtnetlink.RouteMessage
+ err := rmsg.UnmarshalBinary(msg.Data)
+ if err != nil {
+ c.logf("ip rule deleted; failed to parse netlink message: %v", err)
+ } else {
+ c.logf("ip rule deleted: %+v", rmsg)
+ // On `ip -4 rule del pref 5210 table main`, logs:
+ // monitor: ip rule deleted: {Family:2 DstLength:0 SrcLength:0 Tos:0 Table:254 Protocol:0 Scope:0 Type:1 Flags:0 Attributes:{Dst:<nil> Src:<nil> Gateway:<nil> OutIface:0 Priority:5210 Table:254 Mark:4294967295 Expires:<nil> Metrics:<nil> Multipath:[]}}
+ }
+ rdm := ipRuleDeletedMessage{
+ table: rmsg.Table,
+ priority: rmsg.Attributes.Priority,
+ }
+ if debugNetlinkMessages() {
+ c.logf("%+v", rdm)
+ }
+ return rdm, nil
+ case unix.RTM_NEWLINK, unix.RTM_DELLINK:
+ // This is an unhandled message, but don't print an error.
+ // See https://github.com/tailscale/tailscale/issues/6806
+ return unspecifiedMessage{}, nil
+ default:
+ c.logf("unhandled netlink msg type %+v, %q", msg.Header, msg.Data)
+ return unspecifiedMessage{}, nil
+ }
+}
+
+func netaddrIP(std net.IP) netip.Addr {
+ ip, _ := netip.AddrFromSlice(std)
+ return ip.Unmap()
+}
+
+func netaddrIPPrefix(std net.IP, bits uint8) netip.Prefix {
+ ip, _ := netip.AddrFromSlice(std)
+ return netip.PrefixFrom(ip.Unmap(), int(bits))
+}
+
+func condNetAddrPrefix(ipp netip.Prefix) string {
+ if !ipp.Addr().IsValid() {
+ return ""
+ }
+ return ipp.String()
+}
+
+func condNetAddrIP(ip netip.Addr) string {
+ if !ip.IsValid() {
+ return ""
+ }
+ return ip.String()
+}
+
+// newRouteMessage is a message for a new route being added.
+type newRouteMessage struct {
+ Src, Dst netip.Prefix
+ Gateway netip.Addr
+ Table uint8
+}
+
+const tsTable = 52
+
+func (m *newRouteMessage) ignore() bool {
+ return m.Table == tsTable || tsaddr.IsTailscaleIP(m.Dst.Addr())
+}
+
+// newAddrMessage is a message for a new address being added.
+type newAddrMessage struct {
+ Delete bool
+ Addr netip.Addr
+ IfIndex uint32 // interface index
+}
+
+func (m *newAddrMessage) ignore() bool {
+ return tsaddr.IsTailscaleIP(m.Addr)
+}
+
+type ignoreMessage struct{}
+
+func (ignoreMessage) ignore() bool { return true }
diff --git a/net/netmon/netmon_polling.go b/net/netmon/netmon_polling.go
index 3d6f94731..1ce4a51de 100644
--- a/net/netmon/netmon_polling.go
+++ b/net/netmon/netmon_polling.go
@@ -1,21 +1,21 @@
-// Copyright (c) Tailscale Inc & AUTHORS
-// SPDX-License-Identifier: BSD-3-Clause
-
-//go:build (!linux && !freebsd && !windows && !darwin) || android
-
-package netmon
-
-import (
- "tailscale.com/types/logger"
-)
-
-func newOSMon(logf logger.Logf, m *Monitor) (osMon, error) {
- return newPollingMon(logf, m)
-}
-
-// unspecifiedMessage is a minimal message implementation that should not
-// be ignored. In general, OS-specific implementations should use better
-// types and avoid this if they can.
-type unspecifiedMessage struct{}
-
-func (unspecifiedMessage) ignore() bool { return false }
+// Copyright (c) Tailscale Inc & AUTHORS
+// SPDX-License-Identifier: BSD-3-Clause
+
+//go:build (!linux && !freebsd && !windows && !darwin) || android
+
+package netmon
+
+import (
+ "tailscale.com/types/logger"
+)
+
+func newOSMon(logf logger.Logf, m *Monitor) (osMon, error) {
+ return newPollingMon(logf, m)
+}
+
+// unspecifiedMessage is a minimal message implementation that should not
+// be ignored. In general, OS-specific implementations should use better
+// types and avoid this if they can.
+type unspecifiedMessage struct{}
+
+func (unspecifiedMessage) ignore() bool { return false }
diff --git a/net/netmon/polling.go b/net/netmon/polling.go
index ce1618ed6..bb7210b94 100644
--- a/net/netmon/polling.go
+++ b/net/netmon/polling.go
@@ -1,86 +1,86 @@
-// Copyright (c) Tailscale Inc & AUTHORS
-// SPDX-License-Identifier: BSD-3-Clause
-
-//go:build !windows && !darwin
-
-package netmon
-
-import (
- "bytes"
- "errors"
- "os"
- "runtime"
- "sync"
- "time"
-
- "tailscale.com/types/logger"
-)
-
-func newPollingMon(logf logger.Logf, m *Monitor) (osMon, error) {
- return &pollingMon{
- logf: logf,
- m: m,
- stop: make(chan struct{}),
- }, nil
-}
-
-// pollingMon is a bad but portable implementation of the link monitor
-// that works by polling the interface state every 10 seconds, in lieu
-// of anything to subscribe to.
-type pollingMon struct {
- logf logger.Logf
- m *Monitor
-
- closeOnce sync.Once
- stop chan struct{}
-}
-
-func (pm *pollingMon) IsInterestingInterface(iface string) bool {
- return true
-}
-
-func (pm *pollingMon) Close() error {
- pm.closeOnce.Do(func() {
- close(pm.stop)
- })
- return nil
-}
-
-func (pm *pollingMon) isCloudRun() bool {
- // https: //cloud.google.com/run/docs/reference/container-contract#env-vars
- if os.Getenv("K_REVISION") == "" || os.Getenv("K_CONFIGURATION") == "" ||
- os.Getenv("K_SERVICE") == "" || os.Getenv("PORT") == "" {
- return false
- }
- vers, err := os.ReadFile("/proc/version")
- if err != nil {
- pm.logf("Failed to read /proc/version: %v", err)
- return false
- }
- return string(bytes.TrimSpace(vers)) == "Linux version 4.4.0 #1 SMP Sun Jan 10 15:06:54 PST 2016"
-}
-
-func (pm *pollingMon) Receive() (message, error) {
- d := 10 * time.Second
- if runtime.GOOS == "android" {
- // We'll have Android notify the link monitor to wake up earlier,
- // so this can go very slowly there, to save battery.
- // https://github.com/tailscale/tailscale/issues/1427
- d = 10 * time.Minute
- } else if pm.isCloudRun() {
- // Cloud Run routes never change at runtime. the containers are killed within
- // 15 minutes by default, set the interval long enough to be effectively infinite.
- pm.logf("monitor polling: Cloud Run detected, reduce polling interval to 24h")
- d = 24 * time.Hour
- }
- timer := time.NewTimer(d)
- defer timer.Stop()
- for {
- select {
- case <-timer.C:
- return unspecifiedMessage{}, nil
- case <-pm.stop:
- return nil, errors.New("stopped")
- }
- }
-}
+// Copyright (c) Tailscale Inc & AUTHORS
+// SPDX-License-Identifier: BSD-3-Clause
+
+//go:build !windows && !darwin
+
+package netmon
+
+import (
+ "bytes"
+ "errors"
+ "os"
+ "runtime"
+ "sync"
+ "time"
+
+ "tailscale.com/types/logger"
+)
+
+func newPollingMon(logf logger.Logf, m *Monitor) (osMon, error) {
+ return &pollingMon{
+ logf: logf,
+ m: m,
+ stop: make(chan struct{}),
+ }, nil
+}
+
+// pollingMon is a bad but portable implementation of the link monitor
+// that works by polling the interface state every 10 seconds, in lieu
+// of anything to subscribe to.
+type pollingMon struct {
+ logf logger.Logf
+ m *Monitor
+
+ closeOnce sync.Once
+ stop chan struct{}
+}
+
+func (pm *pollingMon) IsInterestingInterface(iface string) bool {
+ return true
+}
+
+func (pm *pollingMon) Close() error {
+ pm.closeOnce.Do(func() {
+ close(pm.stop)
+ })
+ return nil
+}
+
+func (pm *pollingMon) isCloudRun() bool {
+ // https: //cloud.google.com/run/docs/reference/container-contract#env-vars
+ if os.Getenv("K_REVISION") == "" || os.Getenv("K_CONFIGURATION") == "" ||
+ os.Getenv("K_SERVICE") == "" || os.Getenv("PORT") == "" {
+ return false
+ }
+ vers, err := os.ReadFile("/proc/version")
+ if err != nil {
+ pm.logf("Failed to read /proc/version: %v", err)
+ return false
+ }
+ return string(bytes.TrimSpace(vers)) == "Linux version 4.4.0 #1 SMP Sun Jan 10 15:06:54 PST 2016"
+}
+
+func (pm *pollingMon) Receive() (message, error) {
+ d := 10 * time.Second
+ if runtime.GOOS == "android" {
+ // We'll have Android notify the link monitor to wake up earlier,
+ // so this can go very slowly there, to save battery.
+ // https://github.com/tailscale/tailscale/issues/1427
+ d = 10 * time.Minute
+ } else if pm.isCloudRun() {
+ // Cloud Run routes never change at runtime. the containers are killed within
+ // 15 minutes by default, set the interval long enough to be effectively infinite.
+ pm.logf("monitor polling: Cloud Run detected, reduce polling interval to 24h")
+ d = 24 * time.Hour
+ }
+ timer := time.NewTimer(d)
+ defer timer.Stop()
+ for {
+ select {
+ case <-timer.C:
+ return unspecifiedMessage{}, nil
+ case <-pm.stop:
+ return nil, errors.New("stopped")
+ }
+ }
+}
diff --git a/net/netns/netns_android.go b/net/netns/netns_android.go
index 162e5c79a..79084ff11 100644
--- a/net/netns/netns_android.go
+++ b/net/netns/netns_android.go
@@ -1,75 +1,75 @@
-// Copyright (c) Tailscale Inc & AUTHORS
-// SPDX-License-Identifier: BSD-3-Clause
-
-//go:build android
-
-package netns
-
-import (
- "fmt"
- "sync"
- "syscall"
-
- "tailscale.com/net/netmon"
- "tailscale.com/types/logger"
-)
-
-var (
- androidProtectFuncMu sync.Mutex
- androidProtectFunc func(fd int) error
-)
-
-// UseSocketMark reports whether SO_MARK is in use. Android does not use SO_MARK.
-func UseSocketMark() bool {
- return false
-}
-
-// SetAndroidProtectFunc register a func that Android provides that JNI calls into
-// https://developer.android.com/reference/android/net/VpnService#protect(int)
-// which is documented as:
-//
-// "Protect a socket from VPN connections. After protecting, data sent
-// through this socket will go directly to the underlying network, so
-// its traffic will not be forwarded through the VPN. This method is
-// useful if some connections need to be kept outside of VPN. For
-// example, a VPN tunnel should protect itself if its destination is
-// covered by VPN routes. Otherwise its outgoing packets will be sent
-// back to the VPN interface and cause an infinite loop. This method
-// will fail if the application is not prepared or is revoked."
-//
-// A nil func disables the use the hook.
-//
-// This indirection is necessary because this is the supported, stable
-// interface to use on Android, and doing the sockopts to set the
-// fwmark return errors on Android. The actual implementation of
-// VpnService.protect ends up doing an IPC to another process on
-// Android, asking for the fwmark to be set.
-func SetAndroidProtectFunc(f func(fd int) error) {
- androidProtectFuncMu.Lock()
- defer androidProtectFuncMu.Unlock()
- androidProtectFunc = f
-}
-
-func control(logger.Logf, *netmon.Monitor) func(network, address string, c syscall.RawConn) error {
- return controlC
-}
-
-// controlC marks c as necessary to dial in a separate network namespace.
-//
-// It's intentionally the same signature as net.Dialer.Control
-// and net.ListenConfig.Control.
-func controlC(network, address string, c syscall.RawConn) error {
- var sockErr error
- err := c.Control(func(fd uintptr) {
- androidProtectFuncMu.Lock()
- f := androidProtectFunc
- androidProtectFuncMu.Unlock()
- if f != nil {
- sockErr = f(int(fd))
- }
- })
- if err != nil {
- return fmt.Errorf("RawConn.Control on %T: %w", c, err)
- }
- return sockErr
-}
+// Copyright (c) Tailscale Inc & AUTHORS
+// SPDX-License-Identifier: BSD-3-Clause
+
+//go:build android
+
+package netns
+
+import (
+ "fmt"
+ "sync"
+ "syscall"
+
+ "tailscale.com/net/netmon"
+ "tailscale.com/types/logger"
+)
+
+var (
+ androidProtectFuncMu sync.Mutex
+ androidProtectFunc func(fd int) error
+)
+
+// UseSocketMark reports whether SO_MARK is in use. Android does not use SO_MARK.
+func UseSocketMark() bool {
+ return false
+}
+
+// SetAndroidProtectFunc register a func that Android provides that JNI calls into
+// https://developer.android.com/reference/android/net/VpnService#protect(int)
+// which is documented as:
+//
+// "Protect a socket from VPN connections. After protecting, data sent
+// through this socket will go directly to the underlying network, so
+// its traffic will not be forwarded through the VPN. This method is
+// useful if some connections need to be kept outside of VPN. For
+// example, a VPN tunnel should protect itself if its destination is
+// covered by VPN routes. Otherwise its outgoing packets will be sent
+// back to the VPN interface and cause an infinite loop. This method
+// will fail if the application is not prepared or is revoked."
+//
+// A nil func disables the use the hook.
+//
+// This indirection is necessary because this is the supported, stable
+// interface to use on Android, and doing the sockopts to set the
+// fwmark return errors on Android. The actual implementation of
+// VpnService.protect ends up doing an IPC to another process on
+// Android, asking for the fwmark to be set.
+func SetAndroidProtectFunc(f func(fd int) error) {
+ androidProtectFuncMu.Lock()
+ defer androidProtectFuncMu.Unlock()
+ androidProtectFunc = f
+}
+
+func control(logger.Logf, *netmon.Monitor) func(network, address string, c syscall.RawConn) error {
+ return controlC
+}
+
+// controlC marks c as necessary to dial in a separate network namespace.
+//
+// It's intentionally the same signature as net.Dialer.Control
+// and net.ListenConfig.Control.
+func controlC(network, address string, c syscall.RawConn) error {
+ var sockErr error
+ err := c.Control(func(fd uintptr) {
+ androidProtectFuncMu.Lock()
+ f := androidProtectFunc
+ androidProtectFuncMu.Unlock()
+ if f != nil {
+ sockErr = f(int(fd))
+ }
+ })
+ if err != nil {
+ return fmt.Errorf("RawConn.Control on %T: %w", c, err)
+ }
+ return sockErr
+}
diff --git a/net/netns/netns_default.go b/net/netns/netns_default.go
index 94f24d8fa..02db19e75 100644
--- a/net/netns/netns_default.go
+++ b/net/netns/netns_default.go
@@ -1,22 +1,22 @@
-// Copyright (c) Tailscale Inc & AUTHORS
-// SPDX-License-Identifier: BSD-3-Clause
-
-//go:build !linux && !windows && !darwin
-
-package netns
-
-import (
- "syscall"
-
- "tailscale.com/net/netmon"
- "tailscale.com/types/logger"
-)
-
-func control(logger.Logf, *netmon.Monitor) func(network, address string, c syscall.RawConn) error {
- return controlC
-}
-
-// controlC does nothing to c.
-func controlC(network, address string, c syscall.RawConn) error {
- return nil
-}
+// Copyright (c) Tailscale Inc & AUTHORS
+// SPDX-License-Identifier: BSD-3-Clause
+
+//go:build !linux && !windows && !darwin
+
+package netns
+
+import (
+ "syscall"
+
+ "tailscale.com/net/netmon"
+ "tailscale.com/types/logger"
+)
+
+func control(logger.Logf, *netmon.Monitor) func(network, address string, c syscall.RawConn) error {
+ return controlC
+}
+
+// controlC does nothing to c.
+func controlC(network, address string, c syscall.RawConn) error {
+ return nil
+}
diff --git a/net/netns/netns_linux_test.go b/net/netns/netns_linux_test.go
index a5000f37f..cc221bcb1 100644
--- a/net/netns/netns_linux_test.go
+++ b/net/netns/netns_linux_test.go
@@ -1,14 +1,14 @@
-// Copyright (c) Tailscale Inc & AUTHORS
-// SPDX-License-Identifier: BSD-3-Clause
-
-package netns
-
-import (
- "testing"
-)
-
-func TestSocketMarkWorks(t *testing.T) {
- _ = socketMarkWorks()
- // we cannot actually assert whether the test runner has SO_MARK available
- // or not, as we don't know. We're just checking that it doesn't panic.
-}
+// Copyright (c) Tailscale Inc & AUTHORS
+// SPDX-License-Identifier: BSD-3-Clause
+
+package netns
+
+import (
+ "testing"
+)
+
+func TestSocketMarkWorks(t *testing.T) {
+ _ = socketMarkWorks()
+ // we cannot actually assert whether the test runner has SO_MARK available
+ // or not, as we don't know. We're just checking that it doesn't panic.
+}
diff --git a/net/netns/netns_test.go b/net/netns/netns_test.go
index 82f919b94..1c6d699ac 100644
--- a/net/netns/netns_test.go
+++ b/net/netns/netns_test.go
@@ -1,78 +1,78 @@
-// Copyright (c) Tailscale Inc & AUTHORS
-// SPDX-License-Identifier: BSD-3-Clause
-
-// Package netns contains the common code for using the Go net package
-// in a logical "network namespace" to avoid routing loops where
-// Tailscale-created packets would otherwise loop back through
-// Tailscale routes.
-//
-// Despite the name netns, the exact mechanism used differs by
-// operating system, and perhaps even by version of the OS.
-//
-// The netns package also handles connecting via SOCKS proxies when
-// configured by the environment.
-package netns
-
-import (
- "flag"
- "testing"
-)
-
-var extNetwork = flag.Bool("use-external-network", false, "use the external network in tests")
-
-func TestDial(t *testing.T) {
- if !*extNetwork {
- t.Skip("skipping test without --use-external-network")
- }
- d := NewDialer(t.Logf, nil)
- c, err := d.Dial("tcp", "google.com:80")
- if err != nil {
- t.Fatal(err)
- }
- defer c.Close()
- t.Logf("got addr %v", c.RemoteAddr())
-
- c, err = d.Dial("tcp4", "google.com:80")
- if err != nil {
- t.Fatal(err)
- }
- defer c.Close()
- t.Logf("got addr %v", c.RemoteAddr())
-}
-
-func TestIsLocalhost(t *testing.T) {
- tests := []struct {
- name string
- host string
- want bool
- }{
- {"IPv4 loopback", "127.0.0.1", true},
- {"IPv4 !loopback", "192.168.0.1", false},
- {"IPv4 loopback with port", "127.0.0.1:1", true},
- {"IPv4 !loopback with port", "192.168.0.1:1", false},
- {"IPv4 unspecified", "0.0.0.0", false},
- {"IPv4 unspecified with port", "0.0.0.0:1", false},
- {"IPv6 loopback", "::1", true},
- {"IPv6 !loopback", "2001:4860:4860::8888", false},
- {"IPv6 loopback with port", "[::1]:1", true},
- {"IPv6 !loopback with port", "[2001:4860:4860::8888]:1", false},
- {"IPv6 unspecified", "::", false},
- {"IPv6 unspecified with port", "[::]:1", false},
- {"empty", "", false},
- {"hostname", "example.com", false},
- {"localhost", "localhost", true},
- {"localhost6", "localhost6", true},
- {"localhost with port", "localhost:1", true},
- {"localhost6 with port", "localhost6:1", true},
- {"ip6-localhost", "ip6-localhost", true},
- {"ip6-localhost with port", "ip6-localhost:1", true},
- {"ip6-loopback", "ip6-loopback", true},
- {"ip6-loopback with port", "ip6-loopback:1", true},
- }
-
- for _, test := range tests {
- if got := isLocalhost(test.host); got != test.want {
- t.Errorf("isLocalhost(%q) = %v, want %v", test.name, got, test.want)
- }
- }
-}
+// Copyright (c) Tailscale Inc & AUTHORS
+// SPDX-License-Identifier: BSD-3-Clause
+
+// Package netns contains the common code for using the Go net package
+// in a logical "network namespace" to avoid routing loops where
+// Tailscale-created packets would otherwise loop back through
+// Tailscale routes.
+//
+// Despite the name netns, the exact mechanism used differs by
+// operating system, and perhaps even by version of the OS.
+//
+// The netns package also handles connecting via SOCKS proxies when
+// configured by the environment.
+package netns
+
+import (
+ "flag"
+ "testing"
+)
+
+var extNetwork = flag.Bool("use-external-network", false, "use the external network in tests")
+
+func TestDial(t *testing.T) {
+ if !*extNetwork {
+ t.Skip("skipping test without --use-external-network")
+ }
+ d := NewDialer(t.Logf, nil)
+ c, err := d.Dial("tcp", "google.com:80")
+ if err != nil {
+ t.Fatal(err)
+ }
+ defer c.Close()
+ t.Logf("got addr %v", c.RemoteAddr())
+
+ c, err = d.Dial("tcp4", "google.com:80")
+ if err != nil {
+ t.Fatal(err)
+ }
+ defer c.Close()
+ t.Logf("got addr %v", c.RemoteAddr())
+}
+
+func TestIsLocalhost(t *testing.T) {
+ tests := []struct {
+ name string
+ host string
+ want bool
+ }{
+ {"IPv4 loopback", "127.0.0.1", true},
+ {"IPv4 !loopback", "192.168.0.1", false},
+ {"IPv4 loopback with port", "127.0.0.1:1", true},
+ {"IPv4 !loopback with port", "192.168.0.1:1", false},
+ {"IPv4 unspecified", "0.0.0.0", false},
+ {"IPv4 unspecified with port", "0.0.0.0:1", false},
+ {"IPv6 loopback", "::1", true},
+ {"IPv6 !loopback", "2001:4860:4860::8888", false},
+ {"IPv6 loopback with port", "[::1]:1", true},
+ {"IPv6 !loopback with port", "[2001:4860:4860::8888]:1", false},
+ {"IPv6 unspecified", "::", false},
+ {"IPv6 unspecified with port", "[::]:1", false},
+ {"empty", "", false},
+ {"hostname", "example.com", false},
+ {"localhost", "localhost", true},
+ {"localhost6", "localhost6", true},
+ {"localhost with port", "localhost:1", true},
+ {"localhost6 with port", "localhost6:1", true},
+ {"ip6-localhost", "ip6-localhost", true},
+ {"ip6-localhost with port", "ip6-localhost:1", true},
+ {"ip6-loopback", "ip6-loopback", true},
+ {"ip6-loopback with port", "ip6-loopback:1", true},
+ }
+
+ for _, test := range tests {
+ if got := isLocalhost(test.host); got != test.want {
+ t.Errorf("isLocalhost(%q) = %v, want %v", test.name, got, test.want)
+ }
+ }
+}
diff --git a/net/netns/socks.go b/net/netns/socks.go
index eea69d865..a3d10d3ae 100644
--- a/net/netns/socks.go
+++ b/net/netns/socks.go
@@ -1,19 +1,19 @@
-// Copyright (c) Tailscale Inc & AUTHORS
-// SPDX-License-Identifier: BSD-3-Clause
-
-//go:build !ios && !js
-
-package netns
-
-import "golang.org/x/net/proxy"
-
-func init() {
- wrapDialer = wrapSocks
-}
-
-func wrapSocks(d Dialer) Dialer {
- if cd, ok := proxy.FromEnvironmentUsing(d).(Dialer); ok {
- return cd
- }
- return d
-}
+// Copyright (c) Tailscale Inc & AUTHORS
+// SPDX-License-Identifier: BSD-3-Clause
+
+//go:build !ios && !js
+
+package netns
+
+import "golang.org/x/net/proxy"
+
+func init() {
+ wrapDialer = wrapSocks
+}
+
+func wrapSocks(d Dialer) Dialer {
+ if cd, ok := proxy.FromEnvironmentUsing(d).(Dialer); ok {
+ return cd
+ }
+ return d
+}
diff --git a/net/netstat/netstat.go b/net/netstat/netstat.go
index 53c7d7757..53121dc52 100644
--- a/net/netstat/netstat.go
+++ b/net/netstat/netstat.go
@@ -1,35 +1,35 @@
-// Copyright (c) Tailscale Inc & AUTHORS
-// SPDX-License-Identifier: BSD-3-Clause
-
-// Package netstat returns the local machine's network connection table.
-package netstat
-
-import (
- "errors"
- "net/netip"
- "runtime"
-)
-
-var ErrNotImplemented = errors.New("not implemented for GOOS=" + runtime.GOOS)
-
-type Entry struct {
- Local, Remote netip.AddrPort
- Pid int
- State string // TODO: type?
- OSMetadata OSMetadata
-}
-
-// Table contains local machine's TCP connection entries.
-//
-// Currently only TCP (IPv4 and IPv6) are included.
-type Table struct {
- Entries []Entry
-}
-
-// Get returns the connection table.
-//
-// It returns ErrNotImplemented if the table is not available for the
-// current operating system.
-func Get() (*Table, error) {
- return get()
-}
+// Copyright (c) Tailscale Inc & AUTHORS
+// SPDX-License-Identifier: BSD-3-Clause
+
+// Package netstat returns the local machine's network connection table.
+package netstat
+
+import (
+ "errors"
+ "net/netip"
+ "runtime"
+)
+
+var ErrNotImplemented = errors.New("not implemented for GOOS=" + runtime.GOOS)
+
+type Entry struct {
+ Local, Remote netip.AddrPort
+ Pid int
+ State string // TODO: type?
+ OSMetadata OSMetadata
+}
+
+// Table contains local machine's TCP connection entries.
+//
+// Currently only TCP (IPv4 and IPv6) are included.
+type Table struct {
+ Entries []Entry
+}
+
+// Get returns the connection table.
+//
+// It returns ErrNotImplemented if the table is not available for the
+// current operating system.
+func Get() (*Table, error) {
+ return get()
+}
diff --git a/net/netstat/netstat_noimpl.go b/net/netstat/netstat_noimpl.go
index e455c8ce9..608b1a617 100644
--- a/net/netstat/netstat_noimpl.go
+++ b/net/netstat/netstat_noimpl.go
@@ -1,14 +1,14 @@
-// Copyright (c) Tailscale Inc & AUTHORS
-// SPDX-License-Identifier: BSD-3-Clause
-
-//go:build !windows
-
-package netstat
-
-// OSMetadata includes any additional OS-specific information that may be
-// obtained during the retrieval of a given Entry.
-type OSMetadata struct{}
-
-func get() (*Table, error) {
- return nil, ErrNotImplemented
-}
+// Copyright (c) Tailscale Inc & AUTHORS
+// SPDX-License-Identifier: BSD-3-Clause
+
+//go:build !windows
+
+package netstat
+
+// OSMetadata includes any additional OS-specific information that may be
+// obtained during the retrieval of a given Entry.
+type OSMetadata struct{}
+
+func get() (*Table, error) {
+ return nil, ErrNotImplemented
+}
diff --git a/net/netstat/netstat_test.go b/net/netstat/netstat_test.go
index 38827df5e..74f4fcec0 100644
--- a/net/netstat/netstat_test.go
+++ b/net/netstat/netstat_test.go
@@ -1,21 +1,21 @@
-// Copyright (c) Tailscale Inc & AUTHORS
-// SPDX-License-Identifier: BSD-3-Clause
-
-package netstat
-
-import (
- "testing"
-)
-
-func TestGet(t *testing.T) {
- nt, err := Get()
- if err == ErrNotImplemented {
- t.Skip("TODO: not implemented")
- }
- if err != nil {
- t.Fatal(err)
- }
- for _, e := range nt.Entries {
- t.Logf("Entry: %+v", e)
- }
-}
+// Copyright (c) Tailscale Inc & AUTHORS
+// SPDX-License-Identifier: BSD-3-Clause
+
+package netstat
+
+import (
+ "testing"
+)
+
+func TestGet(t *testing.T) {
+ nt, err := Get()
+ if err == ErrNotImplemented {
+ t.Skip("TODO: not implemented")
+ }
+ if err != nil {
+ t.Fatal(err)
+ }
+ for _, e := range nt.Entries {
+ t.Logf("Entry: %+v", e)
+ }
+}
diff --git a/net/packet/doc.go b/net/packet/doc.go
index ce6c0c307..f3cb93db8 100644
--- a/net/packet/doc.go
+++ b/net/packet/doc.go
@@ -1,15 +1,15 @@
-// Copyright (c) Tailscale Inc & AUTHORS
-// SPDX-License-Identifier: BSD-3-Clause
-
-// Package packet contains packet parsing and marshaling utilities.
-//
-// Parsed provides allocation-free minimal packet header decoding, for
-// use in packet filtering. The other types in the package are for
-// constructing and marshaling packets into []bytes.
-//
-// To support allocation-free parsing, this package defines IPv4 and
-// IPv6 address types. You should prefer to use netaddr's types,
-// except where you absolutely need allocation-free IP handling
-// (i.e. in the tunnel datapath) and are willing to implement all
-// codepaths and data structures twice, once per IP family.
-package packet
+// Copyright (c) Tailscale Inc & AUTHORS
+// SPDX-License-Identifier: BSD-3-Clause
+
+// Package packet contains packet parsing and marshaling utilities.
+//
+// Parsed provides allocation-free minimal packet header decoding, for
+// use in packet filtering. The other types in the package are for
+// constructing and marshaling packets into []bytes.
+//
+// To support allocation-free parsing, this package defines IPv4 and
+// IPv6 address types. You should prefer to use netaddr's types,
+// except where you absolutely need allocation-free IP handling
+// (i.e. in the tunnel datapath) and are willing to implement all
+// codepaths and data structures twice, once per IP family.
+package packet
diff --git a/net/packet/header.go b/net/packet/header.go
index dbe84429a..0b1947c0a 100644
--- a/net/packet/header.go
+++ b/net/packet/header.go
@@ -1,66 +1,66 @@
-// Copyright (c) Tailscale Inc & AUTHORS
-// SPDX-License-Identifier: BSD-3-Clause
-
-package packet
-
-import (
- "errors"
- "math"
-)
-
-const tcpHeaderLength = 20
-const sctpHeaderLength = 12
-
-// maxPacketLength is the largest length that all headers support.
-// IPv4 headers using uint16 for this forces an upper bound of 64KB.
-const maxPacketLength = math.MaxUint16
-
-var (
- // errSmallBuffer is returned when Marshal receives a buffer
- // too small to contain the header to marshal.
- errSmallBuffer = errors.New("buffer too small")
- // errLargePacket is returned when Marshal receives a payload
- // larger than the maximum representable size in header
- // fields.
- errLargePacket = errors.New("packet too large")
-)
-
-// Header is a packet header capable of marshaling itself into a byte
-// buffer.
-type Header interface {
- // Len returns the length of the marshaled packet.
- Len() int
- // Marshal serializes the header into buf, which must be at
- // least Len() bytes long. Implementations of Marshal assume
- // that bytes after the first Len() are payload bytes for the
- // purpose of computing length and checksum fields. Marshal
- // implementations must not allocate memory.
- Marshal(buf []byte) error
-}
-
-// HeaderChecksummer is implemented by Header implementations that
-// need to do a checksum over their payloads.
-type HeaderChecksummer interface {
- Header
-
- // WriteCheck writes the correct checksum into buf, which should
- // be be the already-marshalled header and payload.
- WriteChecksum(buf []byte)
-}
-
-// Generate generates a new packet with the given Header and
-// payload. This function allocates memory, see Header.Marshal for an
-// allocation-free option.
-func Generate(h Header, payload []byte) []byte {
- hlen := h.Len()
- buf := make([]byte, hlen+len(payload))
-
- copy(buf[hlen:], payload)
- h.Marshal(buf)
-
- if hc, ok := h.(HeaderChecksummer); ok {
- hc.WriteChecksum(buf)
- }
-
- return buf
-}
+// Copyright (c) Tailscale Inc & AUTHORS
+// SPDX-License-Identifier: BSD-3-Clause
+
+package packet
+
+import (
+ "errors"
+ "math"
+)
+
+const tcpHeaderLength = 20
+const sctpHeaderLength = 12
+
+// maxPacketLength is the largest length that all headers support.
+// IPv4 headers using uint16 for this forces an upper bound of 64KB.
+const maxPacketLength = math.MaxUint16
+
+var (
+ // errSmallBuffer is returned when Marshal receives a buffer
+ // too small to contain the header to marshal.
+ errSmallBuffer = errors.New("buffer too small")
+ // errLargePacket is returned when Marshal receives a payload
+ // larger than the maximum representable size in header
+ // fields.
+ errLargePacket = errors.New("packet too large")
+)
+
+// Header is a packet header capable of marshaling itself into a byte
+// buffer.
+type Header interface {
+ // Len returns the length of the marshaled packet.
+ Len() int
+ // Marshal serializes the header into buf, which must be at
+ // least Len() bytes long. Implementations of Marshal assume
+ // that bytes after the first Len() are payload bytes for the
+ // purpose of computing length and checksum fields. Marshal
+ // implementations must not allocate memory.
+ Marshal(buf []byte) error
+}
+
+// HeaderChecksummer is implemented by Header implementations that
+// need to do a checksum over their payloads.
+type HeaderChecksummer interface {
+ Header
+
+ // WriteCheck writes the correct checksum into buf, which should
+ // be be the already-marshalled header and payload.
+ WriteChecksum(buf []byte)
+}
+
+// Generate generates a new packet with the given Header and
+// payload. This function allocates memory, see Header.Marshal for an
+// allocation-free option.
+func Generate(h Header, payload []byte) []byte {
+ hlen := h.Len()
+ buf := make([]byte, hlen+len(payload))
+
+ copy(buf[hlen:], payload)
+ h.Marshal(buf)
+
+ if hc, ok := h.(HeaderChecksummer); ok {
+ hc.WriteChecksum(buf)
+ }
+
+ return buf
+}
diff --git a/net/packet/icmp.go b/net/packet/icmp.go
index 89a7aaa32..7b86edd81 100644
--- a/net/packet/icmp.go
+++ b/net/packet/icmp.go
@@ -1,28 +1,28 @@
-// Copyright (c) Tailscale Inc & AUTHORS
-// SPDX-License-Identifier: BSD-3-Clause
-
-package packet
-
-import (
- crand "crypto/rand"
-
- "encoding/binary"
-)
-
-// ICMPEchoPayload generates a new random ID/Sequence pair, and returns a uint32
-// derived from them, along with the id, sequence and given payload in a buffer.
-// It returns an error if the random source could not be read.
-func ICMPEchoPayload(payload []byte) (idSeq uint32, buf []byte) {
- buf = make([]byte, len(payload)+4)
-
- // make a completely random id/sequence combo, which is very unlikely to
- // collide with a running ping sequence on the host system. Errors are
- // ignored, that would result in collisions, but errors reading from the
- // random device are rare, and will cause this process universe to soon end.
- crand.Read(buf[:4])
-
- idSeq = binary.LittleEndian.Uint32(buf)
- copy(buf[4:], payload)
-
- return
-}
+// Copyright (c) Tailscale Inc & AUTHORS
+// SPDX-License-Identifier: BSD-3-Clause
+
+package packet
+
+import (
+ crand "crypto/rand"
+
+ "encoding/binary"
+)
+
+// ICMPEchoPayload generates a new random ID/Sequence pair, and returns a uint32
+// derived from them, along with the id, sequence and given payload in a buffer.
+// It returns an error if the random source could not be read.
+func ICMPEchoPayload(payload []byte) (idSeq uint32, buf []byte) {
+ buf = make([]byte, len(payload)+4)
+
+ // make a completely random id/sequence combo, which is very unlikely to
+ // collide with a running ping sequence on the host system. Errors are
+ // ignored, that would result in collisions, but errors reading from the
+ // random device are rare, and will cause this process universe to soon end.
+ crand.Read(buf[:4])
+
+ idSeq = binary.LittleEndian.Uint32(buf)
+ copy(buf[4:], payload)
+
+ return
+}
diff --git a/net/packet/icmp6_test.go b/net/packet/icmp6_test.go
index f34883ca4..c2fab353a 100644
--- a/net/packet/icmp6_test.go
+++ b/net/packet/icmp6_test.go
@@ -1,79 +1,79 @@
-// Copyright (c) Tailscale Inc & AUTHORS
-// SPDX-License-Identifier: BSD-3-Clause
-
-package packet
-
-import (
- "net/netip"
- "testing"
-
- "tailscale.com/types/ipproto"
-)
-
-func TestICMPv6PingResponse(t *testing.T) {
- pingHdr := ICMP6Header{
- IP6Header: IP6Header{
- Src: netip.MustParseAddr("1::1"),
- Dst: netip.MustParseAddr("2::2"),
- IPProto: ipproto.ICMPv6,
- },
- Type: ICMP6EchoRequest,
- Code: ICMP6NoCode,
- }
-
- // echoReqLen is 2 bytes identifier + 2 bytes seq number.
- // https://datatracker.ietf.org/doc/html/rfc4443#section-4.1
- // Packet.IsEchoRequest verifies that these 4 bytes are present.
- const echoReqLen = 4
- buf := make([]byte, pingHdr.Len()+echoReqLen)
- if err := pingHdr.Marshal(buf); err != nil {
- t.Fatal(err)
- }
-
- var p Parsed
- p.Decode(buf)
- if !p.IsEchoRequest() {
- t.Fatalf("not an echo request, got: %+v", p)
- }
-
- pingHdr.ToResponse()
- buf = make([]byte, pingHdr.Len()+echoReqLen)
- if err := pingHdr.Marshal(buf); err != nil {
- t.Fatal(err)
- }
-
- p.Decode(buf)
- if p.IsEchoRequest() {
- t.Fatalf("unexpectedly still an echo request: %+v", p)
- }
- if !p.IsEchoResponse() {
- t.Fatalf("not an echo response: %+v", p)
- }
-}
-
-func TestICMPv6Checksum(t *testing.T) {
- const req = "\x60\x0f\x07\x00\x00\x10\x3a\x40\xfd\x7a\x11\x5c\xa1\xe0\xab\x12" +
- "\x48\x43\xcd\x96\x62\x7b\x65\x28\x26\x07\xf8\xb0\x40\x0a\x08\x07" +
- "\x00\x00\x00\x00\x00\x00\x20\x0e\x80\x00\x4a\x9a\x2e\xea\x00\x02" +
- "\x61\xb1\x9e\xad\x00\x06\x45\xaa"
- // The packet that we'd originally generated incorrectly, but with the checksum
- // bytes fixed per WireShark's correct calculation:
- const wantRes = "\x60\x00\xf8\xff\x00\x10\x3a\x40\x26\x07\xf8\xb0\x40\x0a\x08\x07" +
- "\x00\x00\x00\x00\x00\x00\x20\x0e\xfd\x7a\x11\x5c\xa1\xe0\xab\x12" +
- "\x48\x43\xcd\x96\x62\x7b\x65\x28\x81\x00\x49\x9a\x2e\xea\x00\x02" +
- "\x61\xb1\x9e\xad\x00\x06\x45\xaa"
-
- var p Parsed
- p.Decode([]byte(req))
- if !p.IsEchoRequest() {
- t.Fatalf("not an echo request, got: %+v", p)
- }
-
- h := p.ICMP6Header()
- h.ToResponse()
- pong := Generate(&h, p.Payload())
-
- if string(pong) != wantRes {
- t.Errorf("wrong packet\n\n got: %x\nwant: %x", pong, wantRes)
- }
-}
+// Copyright (c) Tailscale Inc & AUTHORS
+// SPDX-License-Identifier: BSD-3-Clause
+
+package packet
+
+import (
+ "net/netip"
+ "testing"
+
+ "tailscale.com/types/ipproto"
+)
+
+func TestICMPv6PingResponse(t *testing.T) {
+ pingHdr := ICMP6Header{
+ IP6Header: IP6Header{
+ Src: netip.MustParseAddr("1::1"),
+ Dst: netip.MustParseAddr("2::2"),
+ IPProto: ipproto.ICMPv6,
+ },
+ Type: ICMP6EchoRequest,
+ Code: ICMP6NoCode,
+ }
+
+ // echoReqLen is 2 bytes identifier + 2 bytes seq number.
+ // https://datatracker.ietf.org/doc/html/rfc4443#section-4.1
+ // Packet.IsEchoRequest verifies that these 4 bytes are present.
+ const echoReqLen = 4
+ buf := make([]byte, pingHdr.Len()+echoReqLen)
+ if err := pingHdr.Marshal(buf); err != nil {
+ t.Fatal(err)
+ }
+
+ var p Parsed
+ p.Decode(buf)
+ if !p.IsEchoRequest() {
+ t.Fatalf("not an echo request, got: %+v", p)
+ }
+
+ pingHdr.ToResponse()
+ buf = make([]byte, pingHdr.Len()+echoReqLen)
+ if err := pingHdr.Marshal(buf); err != nil {
+ t.Fatal(err)
+ }
+
+ p.Decode(buf)
+ if p.IsEchoRequest() {
+ t.Fatalf("unexpectedly still an echo request: %+v", p)
+ }
+ if !p.IsEchoResponse() {
+ t.Fatalf("not an echo response: %+v", p)
+ }
+}
+
+func TestICMPv6Checksum(t *testing.T) {
+ const req = "\x60\x0f\x07\x00\x00\x10\x3a\x40\xfd\x7a\x11\x5c\xa1\xe0\xab\x12" +
+ "\x48\x43\xcd\x96\x62\x7b\x65\x28\x26\x07\xf8\xb0\x40\x0a\x08\x07" +
+ "\x00\x00\x00\x00\x00\x00\x20\x0e\x80\x00\x4a\x9a\x2e\xea\x00\x02" +
+ "\x61\xb1\x9e\xad\x00\x06\x45\xaa"
+ // The packet that we'd originally generated incorrectly, but with the checksum
+ // bytes fixed per WireShark's correct calculation:
+ const wantRes = "\x60\x00\xf8\xff\x00\x10\x3a\x40\x26\x07\xf8\xb0\x40\x0a\x08\x07" +
+ "\x00\x00\x00\x00\x00\x00\x20\x0e\xfd\x7a\x11\x5c\xa1\xe0\xab\x12" +
+ "\x48\x43\xcd\x96\x62\x7b\x65\x28\x81\x00\x49\x9a\x2e\xea\x00\x02" +
+ "\x61\xb1\x9e\xad\x00\x06\x45\xaa"
+
+ var p Parsed
+ p.Decode([]byte(req))
+ if !p.IsEchoRequest() {
+ t.Fatalf("not an echo request, got: %+v", p)
+ }
+
+ h := p.ICMP6Header()
+ h.ToResponse()
+ pong := Generate(&h, p.Payload())
+
+ if string(pong) != wantRes {
+ t.Errorf("wrong packet\n\n got: %x\nwant: %x", pong, wantRes)
+ }
+}
diff --git a/net/packet/ip4.go b/net/packet/ip4.go
index 967a8dba7..596bc766d 100644
--- a/net/packet/ip4.go
+++ b/net/packet/ip4.go
@@ -1,116 +1,116 @@
-// Copyright (c) Tailscale Inc & AUTHORS
-// SPDX-License-Identifier: BSD-3-Clause
-
-package packet
-
-import (
- "encoding/binary"
- "errors"
- "net/netip"
-
- "tailscale.com/types/ipproto"
-)
-
-// ip4HeaderLength is the length of an IPv4 header with no IP options.
-const ip4HeaderLength = 20
-
-// IP4Header represents an IPv4 packet header.
-type IP4Header struct {
- IPProto ipproto.Proto
- IPID uint16
- Src netip.Addr
- Dst netip.Addr
-}
-
-// Len implements Header.
-func (h IP4Header) Len() int {
- return ip4HeaderLength
-}
-
-var errWrongFamily = errors.New("wrong address family for src/dst IP")
-
-// Marshal implements Header.
-func (h IP4Header) Marshal(buf []byte) error {
- if len(buf) < h.Len() {
- return errSmallBuffer
- }
- if len(buf) > maxPacketLength {
- return errLargePacket
- }
- if !h.Src.Is4() || !h.Dst.Is4() {
- return errWrongFamily
- }
-
- buf[0] = 0x40 | (byte(h.Len() >> 2)) // IPv4 + IHL
- buf[1] = 0x00 // DSCP + ECN
- binary.BigEndian.PutUint16(buf[2:4], uint16(len(buf))) // Total length
- binary.BigEndian.PutUint16(buf[4:6], h.IPID) // ID
- binary.BigEndian.PutUint16(buf[6:8], 0) // Flags + fragment offset
- buf[8] = 64 // TTL
- buf[9] = uint8(h.IPProto) // Inner protocol
- // Blank checksum. This is necessary even though we overwrite
- // it later, because the checksum computation runs over these
- // bytes and expects them to be zero.
- binary.BigEndian.PutUint16(buf[10:12], 0)
- src := h.Src.As4()
- dst := h.Dst.As4()
- copy(buf[12:16], src[:])
- copy(buf[16:20], dst[:])
-
- binary.BigEndian.PutUint16(buf[10:12], ip4Checksum(buf[0:20])) // Checksum
-
- return nil
-}
-
-// ToResponse implements Header.
-func (h *IP4Header) ToResponse() {
- h.Src, h.Dst = h.Dst, h.Src
- // Flip the bits in the IPID. If incoming IPIDs are distinct, so are these.
- h.IPID = ^h.IPID
-}
-
-// ip4Checksum computes an IPv4 checksum, as specified in
-// https://tools.ietf.org/html/rfc1071
-func ip4Checksum(b []byte) uint16 {
- var ac uint32
- i := 0
- n := len(b)
- for n >= 2 {
- ac += uint32(binary.BigEndian.Uint16(b[i : i+2]))
- n -= 2
- i += 2
- }
- if n == 1 {
- ac += uint32(b[i]) << 8
- }
- for (ac >> 16) > 0 {
- ac = (ac >> 16) + (ac & 0xffff)
- }
- return uint16(^ac)
-}
-
-// ip4PseudoHeaderOffset is the number of bytes by which the IPv4 UDP
-// pseudo-header is smaller than the real IPv4 header.
-const ip4PseudoHeaderOffset = 8
-
-// marshalPseudo serializes h into buf in the "pseudo-header" form
-// required when calculating UDP checksums. The pseudo-header starts
-// at buf[ip4PseudoHeaderOffset] so as to abut the following UDP
-// header, while leaving enough space in buf for a full IPv4 header.
-func (h IP4Header) marshalPseudo(buf []byte) error {
- if len(buf) < h.Len() {
- return errSmallBuffer
- }
- if len(buf) > maxPacketLength {
- return errLargePacket
- }
-
- length := len(buf) - h.Len()
- src, dst := h.Src.As4(), h.Dst.As4()
- copy(buf[8:12], src[:])
- copy(buf[12:16], dst[:])
- buf[16] = 0x0
- buf[17] = uint8(h.IPProto)
- binary.BigEndian.PutUint16(buf[18:20], uint16(length))
- return nil
-}
+// Copyright (c) Tailscale Inc & AUTHORS
+// SPDX-License-Identifier: BSD-3-Clause
+
+package packet
+
+import (
+ "encoding/binary"
+ "errors"
+ "net/netip"
+
+ "tailscale.com/types/ipproto"
+)
+
+// ip4HeaderLength is the length of an IPv4 header with no IP options.
+const ip4HeaderLength = 20
+
+// IP4Header represents an IPv4 packet header.
+type IP4Header struct {
+ IPProto ipproto.Proto
+ IPID uint16
+ Src netip.Addr
+ Dst netip.Addr
+}
+
+// Len implements Header.
+func (h IP4Header) Len() int {
+ return ip4HeaderLength
+}
+
+var errWrongFamily = errors.New("wrong address family for src/dst IP")
+
+// Marshal implements Header.
+func (h IP4Header) Marshal(buf []byte) error {
+ if len(buf) < h.Len() {
+ return errSmallBuffer
+ }
+ if len(buf) > maxPacketLength {
+ return errLargePacket
+ }
+ if !h.Src.Is4() || !h.Dst.Is4() {
+ return errWrongFamily
+ }
+
+ buf[0] = 0x40 | (byte(h.Len() >> 2)) // IPv4 + IHL
+ buf[1] = 0x00 // DSCP + ECN
+ binary.BigEndian.PutUint16(buf[2:4], uint16(len(buf))) // Total length
+ binary.BigEndian.PutUint16(buf[4:6], h.IPID) // ID
+ binary.BigEndian.PutUint16(buf[6:8], 0) // Flags + fragment offset
+ buf[8] = 64 // TTL
+ buf[9] = uint8(h.IPProto) // Inner protocol
+ // Blank checksum. This is necessary even though we overwrite
+ // it later, because the checksum computation runs over these
+ // bytes and expects them to be zero.
+ binary.BigEndian.PutUint16(buf[10:12], 0)
+ src := h.Src.As4()
+ dst := h.Dst.As4()
+ copy(buf[12:16], src[:])
+ copy(buf[16:20], dst[:])
+
+ binary.BigEndian.PutUint16(buf[10:12], ip4Checksum(buf[0:20])) // Checksum
+
+ return nil
+}
+
+// ToResponse implements Header.
+func (h *IP4Header) ToResponse() {
+ h.Src, h.Dst = h.Dst, h.Src
+ // Flip the bits in the IPID. If incoming IPIDs are distinct, so are these.
+ h.IPID = ^h.IPID
+}
+
+// ip4Checksum computes an IPv4 checksum, as specified in
+// https://tools.ietf.org/html/rfc1071
+func ip4Checksum(b []byte) uint16 {
+ var ac uint32
+ i := 0
+ n := len(b)
+ for n >= 2 {
+ ac += uint32(binary.BigEndian.Uint16(b[i : i+2]))
+ n -= 2
+ i += 2
+ }
+ if n == 1 {
+ ac += uint32(b[i]) << 8
+ }
+ for (ac >> 16) > 0 {
+ ac = (ac >> 16) + (ac & 0xffff)
+ }
+ return uint16(^ac)
+}
+
+// ip4PseudoHeaderOffset is the number of bytes by which the IPv4 UDP
+// pseudo-header is smaller than the real IPv4 header.
+const ip4PseudoHeaderOffset = 8
+
+// marshalPseudo serializes h into buf in the "pseudo-header" form
+// required when calculating UDP checksums. The pseudo-header starts
+// at buf[ip4PseudoHeaderOffset] so as to abut the following UDP
+// header, while leaving enough space in buf for a full IPv4 header.
+func (h IP4Header) marshalPseudo(buf []byte) error {
+ if len(buf) < h.Len() {
+ return errSmallBuffer
+ }
+ if len(buf) > maxPacketLength {
+ return errLargePacket
+ }
+
+ length := len(buf) - h.Len()
+ src, dst := h.Src.As4(), h.Dst.As4()
+ copy(buf[8:12], src[:])
+ copy(buf[12:16], dst[:])
+ buf[16] = 0x0
+ buf[17] = uint8(h.IPProto)
+ binary.BigEndian.PutUint16(buf[18:20], uint16(length))
+ return nil
+}
diff --git a/net/packet/ip6.go b/net/packet/ip6.go
index d26b9a161..cebc46c53 100644
--- a/net/packet/ip6.go
+++ b/net/packet/ip6.go
@@ -1,76 +1,76 @@
-// Copyright (c) Tailscale Inc & AUTHORS
-// SPDX-License-Identifier: BSD-3-Clause
-
-package packet
-
-import (
- "encoding/binary"
- "net/netip"
-
- "tailscale.com/types/ipproto"
-)
-
-// ip6HeaderLength is the length of an IPv6 header with no IP options.
-const ip6HeaderLength = 40
-
-// IP6Header represents an IPv6 packet header.
-type IP6Header struct {
- IPProto ipproto.Proto
- IPID uint32 // only lower 20 bits used
- Src netip.Addr
- Dst netip.Addr
-}
-
-// Len implements Header.
-func (h IP6Header) Len() int {
- return ip6HeaderLength
-}
-
-// Marshal implements Header.
-func (h IP6Header) Marshal(buf []byte) error {
- if len(buf) < h.Len() {
- return errSmallBuffer
- }
- if len(buf) > maxPacketLength {
- return errLargePacket
- }
-
- binary.BigEndian.PutUint32(buf[:4], h.IPID&0x000FFFFF)
- buf[0] = 0x60
- binary.BigEndian.PutUint16(buf[4:6], uint16(len(buf)-ip6HeaderLength)) // Total length
- buf[6] = uint8(h.IPProto) // Inner protocol
- buf[7] = 64 // TTL
- src, dst := h.Src.As16(), h.Dst.As16()
- copy(buf[8:24], src[:])
- copy(buf[24:40], dst[:])
-
- return nil
-}
-
-// ToResponse implements Header.
-func (h *IP6Header) ToResponse() {
- h.Src, h.Dst = h.Dst, h.Src
- // Flip the bits in the IPID. If incoming IPIDs are distinct, so are these.
- h.IPID = (^h.IPID) & 0x000FFFFF
-}
-
-// marshalPseudo serializes h into buf in the "pseudo-header" form
-// required when calculating UDP checksums.
-func (h IP6Header) marshalPseudo(buf []byte, proto ipproto.Proto) error {
- if len(buf) < h.Len() {
- return errSmallBuffer
- }
- if len(buf) > maxPacketLength {
- return errLargePacket
- }
-
- src, dst := h.Src.As16(), h.Dst.As16()
- copy(buf[:16], src[:])
- copy(buf[16:32], dst[:])
- binary.BigEndian.PutUint32(buf[32:36], uint32(len(buf)-h.Len()))
- buf[36] = 0
- buf[37] = 0
- buf[38] = 0
- buf[39] = byte(proto) // NextProto
- return nil
-}
+// Copyright (c) Tailscale Inc & AUTHORS
+// SPDX-License-Identifier: BSD-3-Clause
+
+package packet
+
+import (
+ "encoding/binary"
+ "net/netip"
+
+ "tailscale.com/types/ipproto"
+)
+
+// ip6HeaderLength is the length of an IPv6 header with no IP options.
+const ip6HeaderLength = 40
+
+// IP6Header represents an IPv6 packet header.
+type IP6Header struct {
+ IPProto ipproto.Proto
+ IPID uint32 // only lower 20 bits used
+ Src netip.Addr
+ Dst netip.Addr
+}
+
+// Len implements Header.
+func (h IP6Header) Len() int {
+ return ip6HeaderLength
+}
+
+// Marshal implements Header.
+func (h IP6Header) Marshal(buf []byte) error {
+ if len(buf) < h.Len() {
+ return errSmallBuffer
+ }
+ if len(buf) > maxPacketLength {
+ return errLargePacket
+ }
+
+ binary.BigEndian.PutUint32(buf[:4], h.IPID&0x000FFFFF)
+ buf[0] = 0x60
+ binary.BigEndian.PutUint16(buf[4:6], uint16(len(buf)-ip6HeaderLength)) // Total length
+ buf[6] = uint8(h.IPProto) // Inner protocol
+ buf[7] = 64 // TTL
+ src, dst := h.Src.As16(), h.Dst.As16()
+ copy(buf[8:24], src[:])
+ copy(buf[24:40], dst[:])
+
+ return nil
+}
+
+// ToResponse implements Header.
+func (h *IP6Header) ToResponse() {
+ h.Src, h.Dst = h.Dst, h.Src
+ // Flip the bits in the IPID. If incoming IPIDs are distinct, so are these.
+ h.IPID = (^h.IPID) & 0x000FFFFF
+}
+
+// marshalPseudo serializes h into buf in the "pseudo-header" form
+// required when calculating UDP checksums.
+func (h IP6Header) marshalPseudo(buf []byte, proto ipproto.Proto) error {
+ if len(buf) < h.Len() {
+ return errSmallBuffer
+ }
+ if len(buf) > maxPacketLength {
+ return errLargePacket
+ }
+
+ src, dst := h.Src.As16(), h.Dst.As16()
+ copy(buf[:16], src[:])
+ copy(buf[16:32], dst[:])
+ binary.BigEndian.PutUint32(buf[32:36], uint32(len(buf)-h.Len()))
+ buf[36] = 0
+ buf[37] = 0
+ buf[38] = 0
+ buf[39] = byte(proto) // NextProto
+ return nil
+}
diff --git a/net/packet/tsmp_test.go b/net/packet/tsmp_test.go
index e261e6a41..4ec24e1ea 100644
--- a/net/packet/tsmp_test.go
+++ b/net/packet/tsmp_test.go
@@ -1,73 +1,73 @@
-// Copyright (c) Tailscale Inc & AUTHORS
-// SPDX-License-Identifier: BSD-3-Clause
-
-package packet
-
-import (
- "net/netip"
- "testing"
-)
-
-func TestTailscaleRejectedHeader(t *testing.T) {
- tests := []struct {
- h TailscaleRejectedHeader
- wantStr string
- }{
- {
- h: TailscaleRejectedHeader{
- IPSrc: netip.MustParseAddr("5.5.5.5"),
- IPDst: netip.MustParseAddr("1.2.3.4"),
- Src: netip.MustParseAddrPort("1.2.3.4:567"),
- Dst: netip.MustParseAddrPort("5.5.5.5:443"),
- Proto: TCP,
- Reason: RejectedDueToACLs,
- },
- wantStr: "TSMP-reject-flow{TCP 1.2.3.4:567 > 5.5.5.5:443}: acl",
- },
- {
- h: TailscaleRejectedHeader{
- IPSrc: netip.MustParseAddr("2::2"),
- IPDst: netip.MustParseAddr("1::1"),
- Src: netip.MustParseAddrPort("[1::1]:567"),
- Dst: netip.MustParseAddrPort("[2::2]:443"),
- Proto: UDP,
- Reason: RejectedDueToShieldsUp,
- },
- wantStr: "TSMP-reject-flow{UDP [1::1]:567 > [2::2]:443}: shields",
- },
- {
- h: TailscaleRejectedHeader{
- IPSrc: netip.MustParseAddr("2::2"),
- IPDst: netip.MustParseAddr("1::1"),
- Src: netip.MustParseAddrPort("[1::1]:567"),
- Dst: netip.MustParseAddrPort("[2::2]:443"),
- Proto: UDP,
- Reason: RejectedDueToIPForwarding,
- MaybeBroken: true,
- },
- wantStr: "TSMP-reject-flow{UDP [1::1]:567 > [2::2]:443}: host-ip-forwarding-unavailable",
- },
- }
- for i, tt := range tests {
- gotStr := tt.h.String()
- if gotStr != tt.wantStr {
- t.Errorf("%v. String = %q; want %q", i, gotStr, tt.wantStr)
- continue
- }
- pkt := make([]byte, tt.h.Len())
- tt.h.Marshal(pkt)
-
- var p Parsed
- p.Decode(pkt)
- t.Logf("Parsed: %+v", p)
- t.Logf("Parsed: %s", p.String())
- back, ok := p.AsTailscaleRejectedHeader()
- if !ok {
- t.Errorf("%v. %q (%02x) didn't parse back", i, gotStr, pkt)
- continue
- }
- if back != tt.h {
- t.Errorf("%v. %q parsed back as %q", i, tt.h, back)
- }
- }
-}
+// Copyright (c) Tailscale Inc & AUTHORS
+// SPDX-License-Identifier: BSD-3-Clause
+
+package packet
+
+import (
+ "net/netip"
+ "testing"
+)
+
+func TestTailscaleRejectedHeader(t *testing.T) {
+ tests := []struct {
+ h TailscaleRejectedHeader
+ wantStr string
+ }{
+ {
+ h: TailscaleRejectedHeader{
+ IPSrc: netip.MustParseAddr("5.5.5.5"),
+ IPDst: netip.MustParseAddr("1.2.3.4"),
+ Src: netip.MustParseAddrPort("1.2.3.4:567"),
+ Dst: netip.MustParseAddrPort("5.5.5.5:443"),
+ Proto: TCP,
+ Reason: RejectedDueToACLs,
+ },
+ wantStr: "TSMP-reject-flow{TCP 1.2.3.4:567 > 5.5.5.5:443}: acl",
+ },
+ {
+ h: TailscaleRejectedHeader{
+ IPSrc: netip.MustParseAddr("2::2"),
+ IPDst: netip.MustParseAddr("1::1"),
+ Src: netip.MustParseAddrPort("[1::1]:567"),
+ Dst: netip.MustParseAddrPort("[2::2]:443"),
+ Proto: UDP,
+ Reason: RejectedDueToShieldsUp,
+ },
+ wantStr: "TSMP-reject-flow{UDP [1::1]:567 > [2::2]:443}: shields",
+ },
+ {
+ h: TailscaleRejectedHeader{
+ IPSrc: netip.MustParseAddr("2::2"),
+ IPDst: netip.MustParseAddr("1::1"),
+ Src: netip.MustParseAddrPort("[1::1]:567"),
+ Dst: netip.MustParseAddrPort("[2::2]:443"),
+ Proto: UDP,
+ Reason: RejectedDueToIPForwarding,
+ MaybeBroken: true,
+ },
+ wantStr: "TSMP-reject-flow{UDP [1::1]:567 > [2::2]:443}: host-ip-forwarding-unavailable",
+ },
+ }
+ for i, tt := range tests {
+ gotStr := tt.h.String()
+ if gotStr != tt.wantStr {
+ t.Errorf("%v. String = %q; want %q", i, gotStr, tt.wantStr)
+ continue
+ }
+ pkt := make([]byte, tt.h.Len())
+ tt.h.Marshal(pkt)
+
+ var p Parsed
+ p.Decode(pkt)
+ t.Logf("Parsed: %+v", p)
+ t.Logf("Parsed: %s", p.String())
+ back, ok := p.AsTailscaleRejectedHeader()
+ if !ok {
+ t.Errorf("%v. %q (%02x) didn't parse back", i, gotStr, pkt)
+ continue
+ }
+ if back != tt.h {
+ t.Errorf("%v. %q parsed back as %q", i, tt.h, back)
+ }
+ }
+}
diff --git a/net/packet/udp4.go b/net/packet/udp4.go
index 0d5bca73e..c8761baef 100644
--- a/net/packet/udp4.go
+++ b/net/packet/udp4.go
@@ -1,58 +1,58 @@
-// Copyright (c) Tailscale Inc & AUTHORS
-// SPDX-License-Identifier: BSD-3-Clause
-
-package packet
-
-import (
- "encoding/binary"
-
- "tailscale.com/types/ipproto"
-)
-
-// udpHeaderLength is the size of the UDP packet header, not including
-// the outer IP header.
-const udpHeaderLength = 8
-
-// UDP4Header is an IPv4+UDP header.
-type UDP4Header struct {
- IP4Header
- SrcPort uint16
- DstPort uint16
-}
-
-// Len implements Header.
-func (h UDP4Header) Len() int {
- return h.IP4Header.Len() + udpHeaderLength
-}
-
-// Marshal implements Header.
-func (h UDP4Header) Marshal(buf []byte) error {
- if len(buf) < h.Len() {
- return errSmallBuffer
- }
- if len(buf) > maxPacketLength {
- return errLargePacket
- }
- // The caller does not need to set this.
- h.IPProto = ipproto.UDP
-
- length := len(buf) - h.IP4Header.Len()
- binary.BigEndian.PutUint16(buf[20:22], h.SrcPort)
- binary.BigEndian.PutUint16(buf[22:24], h.DstPort)
- binary.BigEndian.PutUint16(buf[24:26], uint16(length))
- binary.BigEndian.PutUint16(buf[26:28], 0) // blank checksum
-
- // UDP checksum with IP pseudo header.
- h.IP4Header.marshalPseudo(buf)
- binary.BigEndian.PutUint16(buf[26:28], ip4Checksum(buf[ip4PseudoHeaderOffset:]))
-
- h.IP4Header.Marshal(buf)
-
- return nil
-}
-
-// ToResponse implements Header.
-func (h *UDP4Header) ToResponse() {
- h.SrcPort, h.DstPort = h.DstPort, h.SrcPort
- h.IP4Header.ToResponse()
-}
+// Copyright (c) Tailscale Inc & AUTHORS
+// SPDX-License-Identifier: BSD-3-Clause
+
+package packet
+
+import (
+ "encoding/binary"
+
+ "tailscale.com/types/ipproto"
+)
+
+// udpHeaderLength is the size of the UDP packet header, not including
+// the outer IP header.
+const udpHeaderLength = 8
+
+// UDP4Header is an IPv4+UDP header.
+type UDP4Header struct {
+ IP4Header
+ SrcPort uint16
+ DstPort uint16
+}
+
+// Len implements Header.
+func (h UDP4Header) Len() int {
+ return h.IP4Header.Len() + udpHeaderLength
+}
+
+// Marshal implements Header.
+func (h UDP4Header) Marshal(buf []byte) error {
+ if len(buf) < h.Len() {
+ return errSmallBuffer
+ }
+ if len(buf) > maxPacketLength {
+ return errLargePacket
+ }
+ // The caller does not need to set this.
+ h.IPProto = ipproto.UDP
+
+ length := len(buf) - h.IP4Header.Len()
+ binary.BigEndian.PutUint16(buf[20:22], h.SrcPort)
+ binary.BigEndian.PutUint16(buf[22:24], h.DstPort)
+ binary.BigEndian.PutUint16(buf[24:26], uint16(length))
+ binary.BigEndian.PutUint16(buf[26:28], 0) // blank checksum
+
+ // UDP checksum with IP pseudo header.
+ h.IP4Header.marshalPseudo(buf)
+ binary.BigEndian.PutUint16(buf[26:28], ip4Checksum(buf[ip4PseudoHeaderOffset:]))
+
+ h.IP4Header.Marshal(buf)
+
+ return nil
+}
+
+// ToResponse implements Header.
+func (h *UDP4Header) ToResponse() {
+ h.SrcPort, h.DstPort = h.DstPort, h.SrcPort
+ h.IP4Header.ToResponse()
+}
diff --git a/net/packet/udp6.go b/net/packet/udp6.go
index 10fdcb99e..c8634b508 100644
--- a/net/packet/udp6.go
+++ b/net/packet/udp6.go
@@ -1,54 +1,54 @@
-// Copyright (c) Tailscale Inc & AUTHORS
-// SPDX-License-Identifier: BSD-3-Clause
-
-package packet
-
-import (
- "encoding/binary"
-
- "tailscale.com/types/ipproto"
-)
-
-// UDP6Header is an IPv6+UDP header.
-type UDP6Header struct {
- IP6Header
- SrcPort uint16
- DstPort uint16
-}
-
-// Len implements Header.
-func (h UDP6Header) Len() int {
- return h.IP6Header.Len() + udpHeaderLength
-}
-
-// Marshal implements Header.
-func (h UDP6Header) Marshal(buf []byte) error {
- if len(buf) < h.Len() {
- return errSmallBuffer
- }
- if len(buf) > maxPacketLength {
- return errLargePacket
- }
- // The caller does not need to set this.
- h.IPProto = ipproto.UDP
-
- length := len(buf) - h.IP6Header.Len()
- binary.BigEndian.PutUint16(buf[40:42], h.SrcPort)
- binary.BigEndian.PutUint16(buf[42:44], h.DstPort)
- binary.BigEndian.PutUint16(buf[44:46], uint16(length))
- binary.BigEndian.PutUint16(buf[46:48], 0) // blank checksum
-
- // UDP checksum with IP pseudo header.
- h.IP6Header.marshalPseudo(buf, ipproto.UDP)
- binary.BigEndian.PutUint16(buf[46:48], ip4Checksum(buf[:]))
-
- h.IP6Header.Marshal(buf)
-
- return nil
-}
-
-// ToResponse implements Header.
-func (h *UDP6Header) ToResponse() {
- h.SrcPort, h.DstPort = h.DstPort, h.SrcPort
- h.IP6Header.ToResponse()
-}
+// Copyright (c) Tailscale Inc & AUTHORS
+// SPDX-License-Identifier: BSD-3-Clause
+
+package packet
+
+import (
+ "encoding/binary"
+
+ "tailscale.com/types/ipproto"
+)
+
+// UDP6Header is an IPv6+UDP header.
+type UDP6Header struct {
+ IP6Header
+ SrcPort uint16
+ DstPort uint16
+}
+
+// Len implements Header.
+func (h UDP6Header) Len() int {
+ return h.IP6Header.Len() + udpHeaderLength
+}
+
+// Marshal implements Header.
+func (h UDP6Header) Marshal(buf []byte) error {
+ if len(buf) < h.Len() {
+ return errSmallBuffer
+ }
+ if len(buf) > maxPacketLength {
+ return errLargePacket
+ }
+ // The caller does not need to set this.
+ h.IPProto = ipproto.UDP
+
+ length := len(buf) - h.IP6Header.Len()
+ binary.BigEndian.PutUint16(buf[40:42], h.SrcPort)
+ binary.BigEndian.PutUint16(buf[42:44], h.DstPort)
+ binary.BigEndian.PutUint16(buf[44:46], uint16(length))
+ binary.BigEndian.PutUint16(buf[46:48], 0) // blank checksum
+
+ // UDP checksum with IP pseudo header.
+ h.IP6Header.marshalPseudo(buf, ipproto.UDP)
+ binary.BigEndian.PutUint16(buf[46:48], ip4Checksum(buf[:]))
+
+ h.IP6Header.Marshal(buf)
+
+ return nil
+}
+
+// ToResponse implements Header.
+func (h *UDP6Header) ToResponse() {
+ h.SrcPort, h.DstPort = h.DstPort, h.SrcPort
+ h.IP6Header.ToResponse()
+}
diff --git a/net/ping/ping.go b/net/ping/ping.go
index 01f3dcf2c..f2093292a 100644
--- a/net/ping/ping.go
+++ b/net/ping/ping.go
@@ -1,343 +1,343 @@
-// Copyright (c) Tailscale Inc & AUTHORS
-// SPDX-License-Identifier: BSD-3-Clause
-
-// Package ping allows sending ICMP echo requests to a host in order to
-// determine network latency.
-package ping
-
-import (
- "bytes"
- "context"
- "crypto/rand"
- "encoding/binary"
- "fmt"
- "io"
- "log"
- "net"
- "net/netip"
- "sync"
- "sync/atomic"
- "time"
-
- "golang.org/x/net/icmp"
- "golang.org/x/net/ipv4"
- "golang.org/x/net/ipv6"
- "tailscale.com/types/logger"
- "tailscale.com/util/mak"
- "tailscale.com/util/multierr"
-)
-
-const (
- v4Type = "ip4:icmp"
- v6Type = "ip6:icmp"
-)
-
-type response struct {
- t time.Time
- err error
-}
-
-type outstanding struct {
- ch chan response
- data []byte
-}
-
-// PacketListener defines the interface required to listen to packages
-// on an address.
-type ListenPacketer interface {
- ListenPacket(ctx context.Context, typ string, addr string) (net.PacketConn, error)
-}
-
-// Pinger represents a set of ICMP echo requests to be sent at a single time.
-//
-// A new instance should be created for each concurrent set of ping requests;
-// this type should not be reused.
-type Pinger struct {
- lp ListenPacketer
-
- // closed guards against send incrementing the waitgroup concurrently with close.
- closed atomic.Bool
- Logf logger.Logf
- Verbose bool
- timeNow func() time.Time
- id uint16 // uint16 per RFC 792
- wg sync.WaitGroup
-
- // Following fields protected by mu
- mu sync.Mutex
- // conns is a map of "type" to net.PacketConn, type is either
- // "ip4:icmp" or "ip6:icmp"
- conns map[string]net.PacketConn
- seq uint16 // uint16 per RFC 792
- pings map[uint16]outstanding
-}
-
-// New creates a new Pinger. The Context provided will be used to create
-// network listeners, and to set an absolute deadline (if any) on the net.Conn
-func New(ctx context.Context, logf logger.Logf, lp ListenPacketer) *Pinger {
- var id [2]byte
- if _, err := io.ReadFull(rand.Reader, id[:]); err != nil {
- panic("net/ping: New:" + err.Error())
- }
-
- return &Pinger{
- lp: lp,
- Logf: logf,
- timeNow: time.Now,
- id: binary.LittleEndian.Uint16(id[:]),
- pings: make(map[uint16]outstanding),
- }
-}
-
-func (p *Pinger) mkconn(ctx context.Context, typ, addr string) (net.PacketConn, error) {
- if p.closed.Load() {
- return nil, net.ErrClosed
- }
-
- c, err := p.lp.ListenPacket(ctx, typ, addr)
- if err != nil {
- return nil, err
- }
-
- // Start by setting the deadline from the context; note that this
- // applies to all future I/O, so we only need to do it once.
- deadline, ok := ctx.Deadline()
- if ok {
- if err := c.SetReadDeadline(deadline); err != nil {
- return nil, err
- }
- }
-
- p.wg.Add(1)
- go p.run(ctx, c, typ)
-
- return c, err
-}
-
-// getConn creates or returns a conn matching typ which is ip4:icmp
-// or ip6:icmp.
-func (p *Pinger) getConn(ctx context.Context, typ string) (net.PacketConn, error) {
- p.mu.Lock()
- defer p.mu.Unlock()
- if c, ok := p.conns[typ]; ok {
- return c, nil
- }
-
- var addr = "0.0.0.0"
- if typ == v6Type {
- addr = "::"
- }
- c, err := p.mkconn(ctx, typ, addr)
- if err != nil {
- return nil, err
- }
- mak.Set(&p.conns, typ, c)
- return c, nil
-}
-
-func (p *Pinger) logf(format string, a ...any) {
- if p.Logf != nil {
- p.Logf(format, a...)
- } else {
- log.Printf(format, a...)
- }
-}
-
-func (p *Pinger) vlogf(format string, a ...any) {
- if p.Verbose {
- p.logf(format, a...)
- }
-}
-
-func (p *Pinger) Close() error {
- p.closed.Store(true)
-
- p.mu.Lock()
- conns := p.conns
- p.conns = nil
- p.mu.Unlock()
-
- var errors []error
- for _, c := range conns {
- if err := c.Close(); err != nil {
- errors = append(errors, err)
- }
- }
-
- p.wg.Wait()
- p.cleanupOutstanding()
-
- return multierr.New(errors...)
-}
-
-func (p *Pinger) run(ctx context.Context, conn net.PacketConn, typ string) {
- defer p.wg.Done()
- defer func() {
- conn.Close()
- p.mu.Lock()
- delete(p.conns, typ)
- p.mu.Unlock()
- }()
- buf := make([]byte, 1500)
-
-loop:
- for {
- select {
- case <-ctx.Done():
- break loop
- default:
- }
-
- n, _, err := conn.ReadFrom(buf)
- if err != nil {
- // Ignore temporary errors; everything else is fatal
- if netErr, ok := err.(net.Error); !ok || !netErr.Temporary() {
- break
- }
- continue
- }
-
- p.handleResponse(buf[:n], p.timeNow(), typ)
- }
-}
-
-func (p *Pinger) cleanupOutstanding() {
- // Complete outstanding requests
- p.mu.Lock()
- defer p.mu.Unlock()
- for _, o := range p.pings {
- o.ch <- response{err: net.ErrClosed}
- }
-}
-
-func (p *Pinger) handleResponse(buf []byte, now time.Time, typ string) {
- // We need to handle responding to both IPv4
- // and IPv6.
- var icmpType icmp.Type
- switch typ {
- case v4Type:
- icmpType = ipv4.ICMPTypeEchoReply
- case v6Type:
- icmpType = ipv6.ICMPTypeEchoReply
- default:
- p.vlogf("handleResponse: unknown icmp.Type")
- return
- }
-
- m, err := icmp.ParseMessage(icmpType.Protocol(), buf)
- if err != nil {
- p.vlogf("handleResponse: invalid packet: %v", err)
- return
- }
-
- if m.Type != icmpType {
- p.vlogf("handleResponse: wanted m.Type=%d; got %d", icmpType, m.Type)
- return
- }
-
- resp, ok := m.Body.(*icmp.Echo)
- if !ok || resp == nil {
- p.vlogf("handleResponse: wanted body=*icmp.Echo; got %v", m.Body)
- return
- }
-
- // We assume we sent this if the ID in the response is ours.
- if uint16(resp.ID) != p.id {
- p.vlogf("handleResponse: wanted ID=%d; got %d", p.id, resp.ID)
- return
- }
-
- // Search for existing running echo request
- var o outstanding
- p.mu.Lock()
- if o, ok = p.pings[uint16(resp.Seq)]; ok {
- // Ensure that the data matches before we delete from our map,
- // so a future correct packet will be handled correctly.
- if bytes.Equal(resp.Data, o.data) {
- delete(p.pings, uint16(resp.Seq))
- } else {
- p.vlogf("handleResponse: got response for Seq %d with mismatched data", resp.Seq)
- ok = false
- }
- } else {
- p.vlogf("handleResponse: got response for unknown Seq %d", resp.Seq)
- }
- p.mu.Unlock()
-
- if ok {
- o.ch <- response{t: now}
- }
-}
-
-// Send sends an ICMP Echo Request packet to the destination, waits for a
-// response, and returns the duration between when the request was sent and
-// when the reply was received.
-//
-// If provided, "data" is sent with the packet and is compared upon receiving a
-// reply.
-func (p *Pinger) Send(ctx context.Context, dest net.Addr, data []byte) (time.Duration, error) {
- // Use sequential sequence numbers on the assumption that we will not
- // wrap around when using a single Pinger instance
- p.mu.Lock()
- p.seq++
- seq := p.seq
- p.mu.Unlock()
-
- // Check whether the address is IPv4 or IPv6 to
- // determine the icmp.Type and conn to use.
- var conn net.PacketConn
- var icmpType icmp.Type = ipv4.ICMPTypeEcho
- ap, err := netip.ParseAddr(dest.String())
- if err != nil {
- return 0, err
- }
- if ap.Is6() {
- icmpType = ipv6.ICMPTypeEchoRequest
- conn, err = p.getConn(ctx, v6Type)
- } else {
- conn, err = p.getConn(ctx, v4Type)
- }
- if err != nil {
- return 0, err
- }
-
- m := icmp.Message{
- Type: icmpType,
- Code: 0,
- Body: &icmp.Echo{
- ID: int(p.id),
- Seq: int(seq),
- Data: data,
- },
- }
- b, err := m.Marshal(nil)
- if err != nil {
- return 0, err
- }
-
- // Register our response before sending since we could otherwise race a
- // quick reply.
- ch := make(chan response, 1)
- p.mu.Lock()
- p.pings[seq] = outstanding{ch: ch, data: data}
- p.mu.Unlock()
-
- start := p.timeNow()
- n, err := conn.WriteTo(b, dest)
- if err != nil {
- return 0, err
- } else if n != len(b) {
- return 0, fmt.Errorf("conn.WriteTo: got %v; want %v", n, len(b))
- }
-
- select {
- case resp := <-ch:
- if resp.err != nil {
- return 0, resp.err
- }
- return resp.t.Sub(start), nil
-
- case <-ctx.Done():
- return 0, ctx.Err()
- }
-}
+// Copyright (c) Tailscale Inc & AUTHORS
+// SPDX-License-Identifier: BSD-3-Clause
+
+// Package ping allows sending ICMP echo requests to a host in order to
+// determine network latency.
+package ping
+
+import (
+ "bytes"
+ "context"
+ "crypto/rand"
+ "encoding/binary"
+ "fmt"
+ "io"
+ "log"
+ "net"
+ "net/netip"
+ "sync"
+ "sync/atomic"
+ "time"
+
+ "golang.org/x/net/icmp"
+ "golang.org/x/net/ipv4"
+ "golang.org/x/net/ipv6"
+ "tailscale.com/types/logger"
+ "tailscale.com/util/mak"
+ "tailscale.com/util/multierr"
+)
+
+const (
+ v4Type = "ip4:icmp"
+ v6Type = "ip6:icmp"
+)
+
+type response struct {
+ t time.Time
+ err error
+}
+
+type outstanding struct {
+ ch chan response
+ data []byte
+}
+
+// PacketListener defines the interface required to listen to packages
+// on an address.
+type ListenPacketer interface {
+ ListenPacket(ctx context.Context, typ string, addr string) (net.PacketConn, error)
+}
+
+// Pinger represents a set of ICMP echo requests to be sent at a single time.
+//
+// A new instance should be created for each concurrent set of ping requests;
+// this type should not be reused.
+type Pinger struct {
+ lp ListenPacketer
+
+ // closed guards against send incrementing the waitgroup concurrently with close.
+ closed atomic.Bool
+ Logf logger.Logf
+ Verbose bool
+ timeNow func() time.Time
+ id uint16 // uint16 per RFC 792
+ wg sync.WaitGroup
+
+ // Following fields protected by mu
+ mu sync.Mutex
+ // conns is a map of "type" to net.PacketConn, type is either
+ // "ip4:icmp" or "ip6:icmp"
+ conns map[string]net.PacketConn
+ seq uint16 // uint16 per RFC 792
+ pings map[uint16]outstanding
+}
+
+// New creates a new Pinger. The Context provided will be used to create
+// network listeners, and to set an absolute deadline (if any) on the net.Conn
+func New(ctx context.Context, logf logger.Logf, lp ListenPacketer) *Pinger {
+ var id [2]byte
+ if _, err := io.ReadFull(rand.Reader, id[:]); err != nil {
+ panic("net/ping: New:" + err.Error())
+ }
+
+ return &Pinger{
+ lp: lp,
+ Logf: logf,
+ timeNow: time.Now,
+ id: binary.LittleEndian.Uint16(id[:]),
+ pings: make(map[uint16]outstanding),
+ }
+}
+
+func (p *Pinger) mkconn(ctx context.Context, typ, addr string) (net.PacketConn, error) {
+ if p.closed.Load() {
+ return nil, net.ErrClosed
+ }
+
+ c, err := p.lp.ListenPacket(ctx, typ, addr)
+ if err != nil {
+ return nil, err
+ }
+
+ // Start by setting the deadline from the context; note that this
+ // applies to all future I/O, so we only need to do it once.
+ deadline, ok := ctx.Deadline()
+ if ok {
+ if err := c.SetReadDeadline(deadline); err != nil {
+ return nil, err
+ }
+ }
+
+ p.wg.Add(1)
+ go p.run(ctx, c, typ)
+
+ return c, err
+}
+
+// getConn creates or returns a conn matching typ which is ip4:icmp
+// or ip6:icmp.
+func (p *Pinger) getConn(ctx context.Context, typ string) (net.PacketConn, error) {
+ p.mu.Lock()
+ defer p.mu.Unlock()
+ if c, ok := p.conns[typ]; ok {
+ return c, nil
+ }
+
+ var addr = "0.0.0.0"
+ if typ == v6Type {
+ addr = "::"
+ }
+ c, err := p.mkconn(ctx, typ, addr)
+ if err != nil {
+ return nil, err
+ }
+ mak.Set(&p.conns, typ, c)
+ return c, nil
+}
+
+func (p *Pinger) logf(format string, a ...any) {
+ if p.Logf != nil {
+ p.Logf(format, a...)
+ } else {
+ log.Printf(format, a...)
+ }
+}
+
+func (p *Pinger) vlogf(format string, a ...any) {
+ if p.Verbose {
+ p.logf(format, a...)
+ }
+}
+
+func (p *Pinger) Close() error {
+ p.closed.Store(true)
+
+ p.mu.Lock()
+ conns := p.conns
+ p.conns = nil
+ p.mu.Unlock()
+
+ var errors []error
+ for _, c := range conns {
+ if err := c.Close(); err != nil {
+ errors = append(errors, err)
+ }
+ }
+
+ p.wg.Wait()
+ p.cleanupOutstanding()
+
+ return multierr.New(errors...)
+}
+
+func (p *Pinger) run(ctx context.Context, conn net.PacketConn, typ string) {
+ defer p.wg.Done()
+ defer func() {
+ conn.Close()
+ p.mu.Lock()
+ delete(p.conns, typ)
+ p.mu.Unlock()
+ }()
+ buf := make([]byte, 1500)
+
+loop:
+ for {
+ select {
+ case <-ctx.Done():
+ break loop
+ default:
+ }
+
+ n, _, err := conn.ReadFrom(buf)
+ if err != nil {
+ // Ignore temporary errors; everything else is fatal
+ if netErr, ok := err.(net.Error); !ok || !netErr.Temporary() {
+ break
+ }
+ continue
+ }
+
+ p.handleResponse(buf[:n], p.timeNow(), typ)
+ }
+}
+
+func (p *Pinger) cleanupOutstanding() {
+ // Complete outstanding requests
+ p.mu.Lock()
+ defer p.mu.Unlock()
+ for _, o := range p.pings {
+ o.ch <- response{err: net.ErrClosed}
+ }
+}
+
+func (p *Pinger) handleResponse(buf []byte, now time.Time, typ string) {
+ // We need to handle responding to both IPv4
+ // and IPv6.
+ var icmpType icmp.Type
+ switch typ {
+ case v4Type:
+ icmpType = ipv4.ICMPTypeEchoReply
+ case v6Type:
+ icmpType = ipv6.ICMPTypeEchoReply
+ default:
+ p.vlogf("handleResponse: unknown icmp.Type")
+ return
+ }
+
+ m, err := icmp.ParseMessage(icmpType.Protocol(), buf)
+ if err != nil {
+ p.vlogf("handleResponse: invalid packet: %v", err)
+ return
+ }
+
+ if m.Type != icmpType {
+ p.vlogf("handleResponse: wanted m.Type=%d; got %d", icmpType, m.Type)
+ return
+ }
+
+ resp, ok := m.Body.(*icmp.Echo)
+ if !ok || resp == nil {
+ p.vlogf("handleResponse: wanted body=*icmp.Echo; got %v", m.Body)
+ return
+ }
+
+ // We assume we sent this if the ID in the response is ours.
+ if uint16(resp.ID) != p.id {
+ p.vlogf("handleResponse: wanted ID=%d; got %d", p.id, resp.ID)
+ return
+ }
+
+ // Search for existing running echo request
+ var o outstanding
+ p.mu.Lock()
+ if o, ok = p.pings[uint16(resp.Seq)]; ok {
+ // Ensure that the data matches before we delete from our map,
+ // so a future correct packet will be handled correctly.
+ if bytes.Equal(resp.Data, o.data) {
+ delete(p.pings, uint16(resp.Seq))
+ } else {
+ p.vlogf("handleResponse: got response for Seq %d with mismatched data", resp.Seq)
+ ok = false
+ }
+ } else {
+ p.vlogf("handleResponse: got response for unknown Seq %d", resp.Seq)
+ }
+ p.mu.Unlock()
+
+ if ok {
+ o.ch <- response{t: now}
+ }
+}
+
+// Send sends an ICMP Echo Request packet to the destination, waits for a
+// response, and returns the duration between when the request was sent and
+// when the reply was received.
+//
+// If provided, "data" is sent with the packet and is compared upon receiving a
+// reply.
+func (p *Pinger) Send(ctx context.Context, dest net.Addr, data []byte) (time.Duration, error) {
+ // Use sequential sequence numbers on the assumption that we will not
+ // wrap around when using a single Pinger instance
+ p.mu.Lock()
+ p.seq++
+ seq := p.seq
+ p.mu.Unlock()
+
+ // Check whether the address is IPv4 or IPv6 to
+ // determine the icmp.Type and conn to use.
+ var conn net.PacketConn
+ var icmpType icmp.Type = ipv4.ICMPTypeEcho
+ ap, err := netip.ParseAddr(dest.String())
+ if err != nil {
+ return 0, err
+ }
+ if ap.Is6() {
+ icmpType = ipv6.ICMPTypeEchoRequest
+ conn, err = p.getConn(ctx, v6Type)
+ } else {
+ conn, err = p.getConn(ctx, v4Type)
+ }
+ if err != nil {
+ return 0, err
+ }
+
+ m := icmp.Message{
+ Type: icmpType,
+ Code: 0,
+ Body: &icmp.Echo{
+ ID: int(p.id),
+ Seq: int(seq),
+ Data: data,
+ },
+ }
+ b, err := m.Marshal(nil)
+ if err != nil {
+ return 0, err
+ }
+
+ // Register our response before sending since we could otherwise race a
+ // quick reply.
+ ch := make(chan response, 1)
+ p.mu.Lock()
+ p.pings[seq] = outstanding{ch: ch, data: data}
+ p.mu.Unlock()
+
+ start := p.timeNow()
+ n, err := conn.WriteTo(b, dest)
+ if err != nil {
+ return 0, err
+ } else if n != len(b) {
+ return 0, fmt.Errorf("conn.WriteTo: got %v; want %v", n, len(b))
+ }
+
+ select {
+ case resp := <-ch:
+ if resp.err != nil {
+ return 0, resp.err
+ }
+ return resp.t.Sub(start), nil
+
+ case <-ctx.Done():
+ return 0, ctx.Err()
+ }
+}
diff --git a/net/ping/ping_test.go b/net/ping/ping_test.go
index bbedbcad8..5232f6ada 100644
--- a/net/ping/ping_test.go
+++ b/net/ping/ping_test.go
@@ -1,350 +1,350 @@
-// Copyright (c) Tailscale Inc & AUTHORS
-// SPDX-License-Identifier: BSD-3-Clause
-
-package ping
-
-import (
- "context"
- "errors"
- "fmt"
- "net"
- "testing"
- "time"
-
- "golang.org/x/net/icmp"
- "golang.org/x/net/ipv4"
- "golang.org/x/net/ipv6"
- "tailscale.com/tstest"
- "tailscale.com/util/mak"
-)
-
-var (
- localhost = &net.IPAddr{IP: net.IPv4(127, 0, 0, 1)}
-)
-
-func TestPinger(t *testing.T) {
- clock := &tstest.Clock{}
-
- ctx := context.Background()
- ctx, cancel := context.WithTimeout(ctx, 5*time.Second)
- defer cancel()
-
- p, closeP := mockPinger(t, clock)
- defer closeP()
-
- bodyData := []byte("data goes here")
-
- // Start a ping in the background
- r := make(chan time.Duration, 1)
- go func() {
- dur, err := p.Send(ctx, localhost, bodyData)
- if err != nil {
- t.Errorf("p.Send: %v", err)
- r <- 0
- } else {
- r <- dur
- }
- }()
-
- p.waitOutstanding(t, ctx, 1)
-
- // Fake a response from ourself
- fakeResponse := mustMarshal(t, &icmp.Message{
- Type: ipv4.ICMPTypeEchoReply,
- Code: ipv4.ICMPTypeEchoReply.Protocol(),
- Body: &icmp.Echo{
- ID: 1234,
- Seq: 1,
- Data: bodyData,
- },
- })
-
- const fakeDuration = 100 * time.Millisecond
- p.handleResponse(fakeResponse, clock.Now().Add(fakeDuration), v4Type)
-
- select {
- case dur := <-r:
- want := fakeDuration
- if dur != want {
- t.Errorf("wanted ping response time = %d; got %d", want, dur)
- }
- case <-ctx.Done():
- t.Fatal("did not get response by timeout")
- }
-}
-
-func TestV6Pinger(t *testing.T) {
- if c, err := net.ListenPacket("udp6", "::1"); err != nil {
- // skip test if we can't use IPv6.
- t.Skipf("IPv6 not supported: %s", err)
- } else {
- c.Close()
- }
-
- clock := &tstest.Clock{}
-
- ctx := context.Background()
- ctx, cancel := context.WithTimeout(ctx, 5*time.Second)
- defer cancel()
-
- p, closeP := mockPinger(t, clock)
- defer closeP()
-
- bodyData := []byte("data goes here")
-
- // Start a ping in the background
- r := make(chan time.Duration, 1)
- go func() {
- dur, err := p.Send(ctx, &net.IPAddr{IP: net.ParseIP("::")}, bodyData)
- if err != nil {
- t.Errorf("p.Send: %v", err)
- r <- 0
- } else {
- r <- dur
- }
- }()
-
- p.waitOutstanding(t, ctx, 1)
-
- // Fake a response from ourself
- fakeResponse := mustMarshal(t, &icmp.Message{
- Type: ipv6.ICMPTypeEchoReply,
- Code: ipv6.ICMPTypeEchoReply.Protocol(),
- Body: &icmp.Echo{
- ID: 1234,
- Seq: 1,
- Data: bodyData,
- },
- })
-
- const fakeDuration = 100 * time.Millisecond
- p.handleResponse(fakeResponse, clock.Now().Add(fakeDuration), v6Type)
-
- select {
- case dur := <-r:
- want := fakeDuration
- if dur != want {
- t.Errorf("wanted ping response time = %d; got %d", want, dur)
- }
- case <-ctx.Done():
- t.Fatal("did not get response by timeout")
- }
-}
-
-func TestPingerTimeout(t *testing.T) {
- ctx := context.Background()
- ctx, cancel := context.WithTimeout(ctx, 5*time.Second)
- defer cancel()
-
- clock := &tstest.Clock{}
- p, closeP := mockPinger(t, clock)
- defer closeP()
-
- // Send a ping in the background
- r := make(chan error, 1)
- go func() {
- _, err := p.Send(ctx, localhost, []byte("data goes here"))
- r <- err
- }()
-
- // Wait until we're blocking
- p.waitOutstanding(t, ctx, 1)
-
- // Close everything down
- p.cleanupOutstanding()
-
- // Should have got an error from the ping
- err := <-r
- if !errors.Is(err, net.ErrClosed) {
- t.Errorf("wanted errors.Is(err, net.ErrClosed); got=%v", err)
- }
-}
-
-func TestPingerMismatch(t *testing.T) {
- clock := &tstest.Clock{}
-
- ctx := context.Background()
- ctx, cancel := context.WithTimeout(ctx, 1*time.Second) // intentionally short
- defer cancel()
-
- p, closeP := mockPinger(t, clock)
- defer closeP()
-
- bodyData := []byte("data goes here")
-
- // Start a ping in the background
- r := make(chan time.Duration, 1)
- go func() {
- dur, err := p.Send(ctx, localhost, bodyData)
- if err != nil && !errors.Is(err, context.DeadlineExceeded) {
- t.Errorf("p.Send: %v", err)
- r <- 0
- } else {
- r <- dur
- }
- }()
-
- p.waitOutstanding(t, ctx, 1)
-
- // "Receive" a bunch of intentionally malformed packets that should not
- // result in the Send call above returning
- badPackets := []struct {
- name string
- pkt *icmp.Message
- }{
- {
- name: "wrong type",
- pkt: &icmp.Message{
- Type: ipv4.ICMPTypeDestinationUnreachable,
- Code: 0,
- Body: &icmp.DstUnreach{},
- },
- },
- {
- name: "wrong id",
- pkt: &icmp.Message{
- Type: ipv4.ICMPTypeEchoReply,
- Code: 0,
- Body: &icmp.Echo{
- ID: 9999,
- Seq: 1,
- Data: bodyData,
- },
- },
- },
- {
- name: "wrong seq",
- pkt: &icmp.Message{
- Type: ipv4.ICMPTypeEchoReply,
- Code: 0,
- Body: &icmp.Echo{
- ID: 1234,
- Seq: 5,
- Data: bodyData,
- },
- },
- },
- {
- name: "bad body",
- pkt: &icmp.Message{
- Type: ipv4.ICMPTypeEchoReply,
- Code: 0,
- Body: &icmp.Echo{
- ID: 1234,
- Seq: 1,
-
- // Intentionally missing first byte
- Data: bodyData[1:],
- },
- },
- },
- }
-
- const fakeDuration = 100 * time.Millisecond
- tm := clock.Now().Add(fakeDuration)
-
- for _, tt := range badPackets {
- fakeResponse := mustMarshal(t, tt.pkt)
- p.handleResponse(fakeResponse, tm, v4Type)
- }
-
- // Also "receive" a packet that does not unmarshal as an ICMP packet
- p.handleResponse([]byte("foo"), tm, v4Type)
-
- select {
- case <-r:
- t.Fatal("wanted timeout")
- case <-ctx.Done():
- t.Logf("test correctly timed out")
- }
-}
-
-// udpingPacketConn will convert potentially ICMP destination addrs to UDP
-// destination addrs in WriteTo so that a test that is intending to send ICMP
-// traffic will instead send UDP traffic, without the higher level Pinger being
-// aware of this difference.
-type udpingPacketConn struct {
- net.PacketConn
- // destPort will be configured by the test to be the peer expected to respond to a ping.
- destPort uint16
-}
-
-func (u *udpingPacketConn) WriteTo(body []byte, dest net.Addr) (int, error) {
- switch d := dest.(type) {
- case *net.IPAddr:
- udpAddr := &net.UDPAddr{
- IP: d.IP,
- Port: int(u.destPort),
- Zone: d.Zone,
- }
- return u.PacketConn.WriteTo(body, udpAddr)
- }
- return 0, fmt.Errorf("unimplemented udpingPacketConn for %T", dest)
-}
-
-func mockPinger(t *testing.T, clock *tstest.Clock) (*Pinger, func()) {
- p := New(context.Background(), t.Logf, nil)
- p.timeNow = clock.Now
- p.Verbose = true
- p.id = 1234
-
- // In tests, we use UDP so that we can test without being root; this
- // doesn't matter because we mock out the ICMP reply below to be a real
- // ICMP echo reply packet.
- conn4, err := net.ListenPacket("udp4", "127.0.0.1:0")
- if err != nil {
- t.Fatalf("net.ListenPacket: %v", err)
- }
-
- conn6, err := net.ListenPacket("udp6", "[::]:0")
- if err != nil {
- t.Fatalf("net.ListenPacket: %v", err)
- }
-
- conn4 = &udpingPacketConn{
- destPort: 12345,
- PacketConn: conn4,
- }
- conn6 = &udpingPacketConn{
- PacketConn: conn6,
- destPort: 12345,
- }
-
- mak.Set(&p.conns, v4Type, conn4)
- mak.Set(&p.conns, v6Type, conn6)
- done := func() {
- if err := p.Close(); err != nil {
- t.Errorf("error on close: %v", err)
- }
- }
- return p, done
-}
-
-func mustMarshal(t *testing.T, m *icmp.Message) []byte {
- t.Helper()
-
- b, err := m.Marshal(nil)
- if err != nil {
- t.Fatal(err)
- }
- return b
-}
-
-func (p *Pinger) waitOutstanding(t *testing.T, ctx context.Context, count int) {
- // This is a bit janky, but... we busy-loop to wait for the Send call
- // to write to our map so we know that a response will be handled.
- var haveMapEntry bool
- for !haveMapEntry {
- time.Sleep(10 * time.Millisecond)
- select {
- case <-ctx.Done():
- t.Error("no entry in ping map before timeout")
- return
- default:
- }
-
- p.mu.Lock()
- haveMapEntry = len(p.pings) == count
- p.mu.Unlock()
- }
-}
+// Copyright (c) Tailscale Inc & AUTHORS
+// SPDX-License-Identifier: BSD-3-Clause
+
+package ping
+
+import (
+ "context"
+ "errors"
+ "fmt"
+ "net"
+ "testing"
+ "time"
+
+ "golang.org/x/net/icmp"
+ "golang.org/x/net/ipv4"
+ "golang.org/x/net/ipv6"
+ "tailscale.com/tstest"
+ "tailscale.com/util/mak"
+)
+
+var (
+ localhost = &net.IPAddr{IP: net.IPv4(127, 0, 0, 1)}
+)
+
+func TestPinger(t *testing.T) {
+ clock := &tstest.Clock{}
+
+ ctx := context.Background()
+ ctx, cancel := context.WithTimeout(ctx, 5*time.Second)
+ defer cancel()
+
+ p, closeP := mockPinger(t, clock)
+ defer closeP()
+
+ bodyData := []byte("data goes here")
+
+ // Start a ping in the background
+ r := make(chan time.Duration, 1)
+ go func() {
+ dur, err := p.Send(ctx, localhost, bodyData)
+ if err != nil {
+ t.Errorf("p.Send: %v", err)
+ r <- 0
+ } else {
+ r <- dur
+ }
+ }()
+
+ p.waitOutstanding(t, ctx, 1)
+
+ // Fake a response from ourself
+ fakeResponse := mustMarshal(t, &icmp.Message{
+ Type: ipv4.ICMPTypeEchoReply,
+ Code: ipv4.ICMPTypeEchoReply.Protocol(),
+ Body: &icmp.Echo{
+ ID: 1234,
+ Seq: 1,
+ Data: bodyData,
+ },
+ })
+
+ const fakeDuration = 100 * time.Millisecond
+ p.handleResponse(fakeResponse, clock.Now().Add(fakeDuration), v4Type)
+
+ select {
+ case dur := <-r:
+ want := fakeDuration
+ if dur != want {
+ t.Errorf("wanted ping response time = %d; got %d", want, dur)
+ }
+ case <-ctx.Done():
+ t.Fatal("did not get response by timeout")
+ }
+}
+
+func TestV6Pinger(t *testing.T) {
+ if c, err := net.ListenPacket("udp6", "::1"); err != nil {
+ // skip test if we can't use IPv6.
+ t.Skipf("IPv6 not supported: %s", err)
+ } else {
+ c.Close()
+ }
+
+ clock := &tstest.Clock{}
+
+ ctx := context.Background()
+ ctx, cancel := context.WithTimeout(ctx, 5*time.Second)
+ defer cancel()
+
+ p, closeP := mockPinger(t, clock)
+ defer closeP()
+
+ bodyData := []byte("data goes here")
+
+ // Start a ping in the background
+ r := make(chan time.Duration, 1)
+ go func() {
+ dur, err := p.Send(ctx, &net.IPAddr{IP: net.ParseIP("::")}, bodyData)
+ if err != nil {
+ t.Errorf("p.Send: %v", err)
+ r <- 0
+ } else {
+ r <- dur
+ }
+ }()
+
+ p.waitOutstanding(t, ctx, 1)
+
+ // Fake a response from ourself
+ fakeResponse := mustMarshal(t, &icmp.Message{
+ Type: ipv6.ICMPTypeEchoReply,
+ Code: ipv6.ICMPTypeEchoReply.Protocol(),
+ Body: &icmp.Echo{
+ ID: 1234,
+ Seq: 1,
+ Data: bodyData,
+ },
+ })
+
+ const fakeDuration = 100 * time.Millisecond
+ p.handleResponse(fakeResponse, clock.Now().Add(fakeDuration), v6Type)
+
+ select {
+ case dur := <-r:
+ want := fakeDuration
+ if dur != want {
+ t.Errorf("wanted ping response time = %d; got %d", want, dur)
+ }
+ case <-ctx.Done():
+ t.Fatal("did not get response by timeout")
+ }
+}
+
+func TestPingerTimeout(t *testing.T) {
+ ctx := context.Background()
+ ctx, cancel := context.WithTimeout(ctx, 5*time.Second)
+ defer cancel()
+
+ clock := &tstest.Clock{}
+ p, closeP := mockPinger(t, clock)
+ defer closeP()
+
+ // Send a ping in the background
+ r := make(chan error, 1)
+ go func() {
+ _, err := p.Send(ctx, localhost, []byte("data goes here"))
+ r <- err
+ }()
+
+ // Wait until we're blocking
+ p.waitOutstanding(t, ctx, 1)
+
+ // Close everything down
+ p.cleanupOutstanding()
+
+ // Should have got an error from the ping
+ err := <-r
+ if !errors.Is(err, net.ErrClosed) {
+ t.Errorf("wanted errors.Is(err, net.ErrClosed); got=%v", err)
+ }
+}
+
+func TestPingerMismatch(t *testing.T) {
+ clock := &tstest.Clock{}
+
+ ctx := context.Background()
+ ctx, cancel := context.WithTimeout(ctx, 1*time.Second) // intentionally short
+ defer cancel()
+
+ p, closeP := mockPinger(t, clock)
+ defer closeP()
+
+ bodyData := []byte("data goes here")
+
+ // Start a ping in the background
+ r := make(chan time.Duration, 1)
+ go func() {
+ dur, err := p.Send(ctx, localhost, bodyData)
+ if err != nil && !errors.Is(err, context.DeadlineExceeded) {
+ t.Errorf("p.Send: %v", err)
+ r <- 0
+ } else {
+ r <- dur
+ }
+ }()
+
+ p.waitOutstanding(t, ctx, 1)
+
+ // "Receive" a bunch of intentionally malformed packets that should not
+ // result in the Send call above returning
+ badPackets := []struct {
+ name string
+ pkt *icmp.Message
+ }{
+ {
+ name: "wrong type",
+ pkt: &icmp.Message{
+ Type: ipv4.ICMPTypeDestinationUnreachable,
+ Code: 0,
+ Body: &icmp.DstUnreach{},
+ },
+ },
+ {
+ name: "wrong id",
+ pkt: &icmp.Message{
+ Type: ipv4.ICMPTypeEchoReply,
+ Code: 0,
+ Body: &icmp.Echo{
+ ID: 9999,
+ Seq: 1,
+ Data: bodyData,
+ },
+ },
+ },
+ {
+ name: "wrong seq",
+ pkt: &icmp.Message{
+ Type: ipv4.ICMPTypeEchoReply,
+ Code: 0,
+ Body: &icmp.Echo{
+ ID: 1234,
+ Seq: 5,
+ Data: bodyData,
+ },
+ },
+ },
+ {
+ name: "bad body",
+ pkt: &icmp.Message{
+ Type: ipv4.ICMPTypeEchoReply,
+ Code: 0,
+ Body: &icmp.Echo{
+ ID: 1234,
+ Seq: 1,
+
+ // Intentionally missing first byte
+ Data: bodyData[1:],
+ },
+ },
+ },
+ }
+
+ const fakeDuration = 100 * time.Millisecond
+ tm := clock.Now().Add(fakeDuration)
+
+ for _, tt := range badPackets {
+ fakeResponse := mustMarshal(t, tt.pkt)
+ p.handleResponse(fakeResponse, tm, v4Type)
+ }
+
+ // Also "receive" a packet that does not unmarshal as an ICMP packet
+ p.handleResponse([]byte("foo"), tm, v4Type)
+
+ select {
+ case <-r:
+ t.Fatal("wanted timeout")
+ case <-ctx.Done():
+ t.Logf("test correctly timed out")
+ }
+}
+
+// udpingPacketConn will convert potentially ICMP destination addrs to UDP
+// destination addrs in WriteTo so that a test that is intending to send ICMP
+// traffic will instead send UDP traffic, without the higher level Pinger being
+// aware of this difference.
+type udpingPacketConn struct {
+ net.PacketConn
+ // destPort will be configured by the test to be the peer expected to respond to a ping.
+ destPort uint16
+}
+
+func (u *udpingPacketConn) WriteTo(body []byte, dest net.Addr) (int, error) {
+ switch d := dest.(type) {
+ case *net.IPAddr:
+ udpAddr := &net.UDPAddr{
+ IP: d.IP,
+ Port: int(u.destPort),
+ Zone: d.Zone,
+ }
+ return u.PacketConn.WriteTo(body, udpAddr)
+ }
+ return 0, fmt.Errorf("unimplemented udpingPacketConn for %T", dest)
+}
+
+func mockPinger(t *testing.T, clock *tstest.Clock) (*Pinger, func()) {
+ p := New(context.Background(), t.Logf, nil)
+ p.timeNow = clock.Now
+ p.Verbose = true
+ p.id = 1234
+
+ // In tests, we use UDP so that we can test without being root; this
+ // doesn't matter because we mock out the ICMP reply below to be a real
+ // ICMP echo reply packet.
+ conn4, err := net.ListenPacket("udp4", "127.0.0.1:0")
+ if err != nil {
+ t.Fatalf("net.ListenPacket: %v", err)
+ }
+
+ conn6, err := net.ListenPacket("udp6", "[::]:0")
+ if err != nil {
+ t.Fatalf("net.ListenPacket: %v", err)
+ }
+
+ conn4 = &udpingPacketConn{
+ destPort: 12345,
+ PacketConn: conn4,
+ }
+ conn6 = &udpingPacketConn{
+ PacketConn: conn6,
+ destPort: 12345,
+ }
+
+ mak.Set(&p.conns, v4Type, conn4)
+ mak.Set(&p.conns, v6Type, conn6)
+ done := func() {
+ if err := p.Close(); err != nil {
+ t.Errorf("error on close: %v", err)
+ }
+ }
+ return p, done
+}
+
+func mustMarshal(t *testing.T, m *icmp.Message) []byte {
+ t.Helper()
+
+ b, err := m.Marshal(nil)
+ if err != nil {
+ t.Fatal(err)
+ }
+ return b
+}
+
+func (p *Pinger) waitOutstanding(t *testing.T, ctx context.Context, count int) {
+ // This is a bit janky, but... we busy-loop to wait for the Send call
+ // to write to our map so we know that a response will be handled.
+ var haveMapEntry bool
+ for !haveMapEntry {
+ time.Sleep(10 * time.Millisecond)
+ select {
+ case <-ctx.Done():
+ t.Error("no entry in ping map before timeout")
+ return
+ default:
+ }
+
+ p.mu.Lock()
+ haveMapEntry = len(p.pings) == count
+ p.mu.Unlock()
+ }
+}
diff --git a/net/portmapper/pcp_test.go b/net/portmapper/pcp_test.go
index 8f8eef3ef..3dece7236 100644
--- a/net/portmapper/pcp_test.go
+++ b/net/portmapper/pcp_test.go
@@ -1,62 +1,62 @@
-// Copyright (c) Tailscale Inc & AUTHORS
-// SPDX-License-Identifier: BSD-3-Clause
-
-package portmapper
-
-import (
- "encoding/binary"
- "net/netip"
- "testing"
-
- "tailscale.com/net/netaddr"
-)
-
-var examplePCPMapResponse = []byte{2, 129, 0, 0, 0, 0, 28, 32, 0, 2, 155, 237, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 129, 112, 9, 24, 241, 208, 251, 45, 157, 76, 10, 188, 17, 0, 0, 0, 4, 210, 4, 210, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 255, 255, 135, 180, 175, 246}
-
-func TestParsePCPMapResponse(t *testing.T) {
- mapping, err := parsePCPMapResponse(examplePCPMapResponse)
- if err != nil {
- t.Fatalf("failed to parse PCP Map Response: %v", err)
- }
- if mapping == nil {
- t.Fatalf("got nil mapping when expected non-nil")
- }
- expectedAddr := netip.MustParseAddrPort("135.180.175.246:1234")
- if mapping.external != expectedAddr {
- t.Errorf("mismatched external address, got: %v, want: %v", mapping.external, expectedAddr)
- }
-}
-
-const (
- serverResponseBit = 1 << 7
- fakeLifetimeSec = 1<<31 - 1
-)
-
-func buildPCPDiscoResponse(req []byte) []byte {
- out := make([]byte, 24)
- out[0] = pcpVersion
- out[1] = req[1] | serverResponseBit
- out[3] = 0
- // Do not put an epoch time in 8:12, when we start using it, tests that use it should fail.
- return out
-}
-
-func buildPCPMapResponse(req []byte) []byte {
- out := make([]byte, 24+36)
- out[0] = pcpVersion
- out[1] = req[1] | serverResponseBit
- out[3] = 0
- binary.BigEndian.PutUint32(out[4:8], 1<<30)
- // Do not put an epoch time in 8:12, when we start using it, tests that use it should fail.
- mapResp := out[24:]
- mapReq := req[24:]
- // copy nonce, protocol and internal port
- copy(mapResp[:13], mapReq[:13])
- copy(mapResp[16:18], mapReq[16:18])
- // assign external port
- binary.BigEndian.PutUint16(mapResp[18:20], 4242)
- assignedIP := netaddr.IPv4(127, 0, 0, 1)
- assignedIP16 := assignedIP.As16()
- copy(mapResp[20:36], assignedIP16[:])
- return out
-}
+// Copyright (c) Tailscale Inc & AUTHORS
+// SPDX-License-Identifier: BSD-3-Clause
+
+package portmapper
+
+import (
+ "encoding/binary"
+ "net/netip"
+ "testing"
+
+ "tailscale.com/net/netaddr"
+)
+
+var examplePCPMapResponse = []byte{2, 129, 0, 0, 0, 0, 28, 32, 0, 2, 155, 237, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 129, 112, 9, 24, 241, 208, 251, 45, 157, 76, 10, 188, 17, 0, 0, 0, 4, 210, 4, 210, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 255, 255, 135, 180, 175, 246}
+
+func TestParsePCPMapResponse(t *testing.T) {
+ mapping, err := parsePCPMapResponse(examplePCPMapResponse)
+ if err != nil {
+ t.Fatalf("failed to parse PCP Map Response: %v", err)
+ }
+ if mapping == nil {
+ t.Fatalf("got nil mapping when expected non-nil")
+ }
+ expectedAddr := netip.MustParseAddrPort("135.180.175.246:1234")
+ if mapping.external != expectedAddr {
+ t.Errorf("mismatched external address, got: %v, want: %v", mapping.external, expectedAddr)
+ }
+}
+
+const (
+ serverResponseBit = 1 << 7
+ fakeLifetimeSec = 1<<31 - 1
+)
+
+func buildPCPDiscoResponse(req []byte) []byte {
+ out := make([]byte, 24)
+ out[0] = pcpVersion
+ out[1] = req[1] | serverResponseBit
+ out[3] = 0
+ // Do not put an epoch time in 8:12, when we start using it, tests that use it should fail.
+ return out
+}
+
+func buildPCPMapResponse(req []byte) []byte {
+ out := make([]byte, 24+36)
+ out[0] = pcpVersion
+ out[1] = req[1] | serverResponseBit
+ out[3] = 0
+ binary.BigEndian.PutUint32(out[4:8], 1<<30)
+ // Do not put an epoch time in 8:12, when we start using it, tests that use it should fail.
+ mapResp := out[24:]
+ mapReq := req[24:]
+ // copy nonce, protocol and internal port
+ copy(mapResp[:13], mapReq[:13])
+ copy(mapResp[16:18], mapReq[16:18])
+ // assign external port
+ binary.BigEndian.PutUint16(mapResp[18:20], 4242)
+ assignedIP := netaddr.IPv4(127, 0, 0, 1)
+ assignedIP16 := assignedIP.As16()
+ copy(mapResp[20:36], assignedIP16[:])
+ return out
+}
diff --git a/net/proxymux/mux.go b/net/proxymux/mux.go
index ff5aaff3b..12c3107de 100644
--- a/net/proxymux/mux.go
+++ b/net/proxymux/mux.go
@@ -1,144 +1,144 @@
-// Copyright (c) Tailscale Inc & AUTHORS
-// SPDX-License-Identifier: BSD-3-Clause
-
-// Package proxymux splits a net.Listener in two, routing SOCKS5
-// connections to one and HTTP requests to the other.
-//
-// It allows for hosting both a SOCKS5 proxy and an HTTP proxy on the
-// same listener.
-package proxymux
-
-import (
- "io"
- "net"
- "sync"
- "time"
-)
-
-// SplitSOCKSAndHTTP accepts connections on ln and passes connections
-// through to either socksListener or httpListener, depending the
-// first byte sent by the client.
-func SplitSOCKSAndHTTP(ln net.Listener) (socksListener, httpListener net.Listener) {
- sl := &listener{
- addr: ln.Addr(),
- c: make(chan net.Conn),
- closed: make(chan struct{}),
- }
- hl := &listener{
- addr: ln.Addr(),
- c: make(chan net.Conn),
- closed: make(chan struct{}),
- }
-
- go splitSOCKSAndHTTPListener(ln, sl, hl)
-
- return sl, hl
-}
-
-func splitSOCKSAndHTTPListener(ln net.Listener, sl, hl *listener) {
- for {
- conn, err := ln.Accept()
- if err != nil {
- sl.Close()
- hl.Close()
- return
- }
- go routeConn(conn, sl, hl)
- }
-}
-
-func routeConn(c net.Conn, socksListener, httpListener *listener) {
- if err := c.SetReadDeadline(time.Now().Add(15 * time.Second)); err != nil {
- c.Close()
- return
- }
-
- var b [1]byte
- if _, err := io.ReadFull(c, b[:]); err != nil {
- c.Close()
- return
- }
-
- if err := c.SetReadDeadline(time.Time{}); err != nil {
- c.Close()
- return
- }
-
- conn := &connWithOneByte{
- Conn: c,
- b: b[0],
- }
-
- // First byte of a SOCKS5 session is a version byte set to 5.
- var ln *listener
- if b[0] == 5 {
- ln = socksListener
- } else {
- ln = httpListener
- }
- select {
- case ln.c <- conn:
- case <-ln.closed:
- c.Close()
- }
-}
-
-type listener struct {
- addr net.Addr
- c chan net.Conn
- mu sync.Mutex // serializes close() on closed. It's okay to receive on closed without locking.
- closed chan struct{}
-}
-
-func (ln *listener) Accept() (net.Conn, error) {
- // Once closed, reliably stay closed, don't race with attempts at
- // further connections.
- select {
- case <-ln.closed:
- return nil, net.ErrClosed
- default:
- }
- select {
- case ret := <-ln.c:
- return ret, nil
- case <-ln.closed:
- return nil, net.ErrClosed
- }
-}
-
-func (ln *listener) Close() error {
- ln.mu.Lock()
- defer ln.mu.Unlock()
- select {
- case <-ln.closed:
- // Already closed
- default:
- close(ln.closed)
- }
- return nil
-}
-
-func (ln *listener) Addr() net.Addr {
- return ln.addr
-}
-
-// connWithOneByte is a net.Conn that returns b for the first read
-// request, then forwards everything else to Conn.
-type connWithOneByte struct {
- net.Conn
-
- b byte
- bRead bool
-}
-
-func (c *connWithOneByte) Read(bs []byte) (int, error) {
- if c.bRead {
- return c.Conn.Read(bs)
- }
- if len(bs) == 0 {
- return 0, nil
- }
- c.bRead = true
- bs[0] = c.b
- return 1, nil
-}
+// Copyright (c) Tailscale Inc & AUTHORS
+// SPDX-License-Identifier: BSD-3-Clause
+
+// Package proxymux splits a net.Listener in two, routing SOCKS5
+// connections to one and HTTP requests to the other.
+//
+// It allows for hosting both a SOCKS5 proxy and an HTTP proxy on the
+// same listener.
+package proxymux
+
+import (
+ "io"
+ "net"
+ "sync"
+ "time"
+)
+
+// SplitSOCKSAndHTTP accepts connections on ln and passes connections
+// through to either socksListener or httpListener, depending the
+// first byte sent by the client.
+func SplitSOCKSAndHTTP(ln net.Listener) (socksListener, httpListener net.Listener) {
+ sl := &listener{
+ addr: ln.Addr(),
+ c: make(chan net.Conn),
+ closed: make(chan struct{}),
+ }
+ hl := &listener{
+ addr: ln.Addr(),
+ c: make(chan net.Conn),
+ closed: make(chan struct{}),
+ }
+
+ go splitSOCKSAndHTTPListener(ln, sl, hl)
+
+ return sl, hl
+}
+
+func splitSOCKSAndHTTPListener(ln net.Listener, sl, hl *listener) {
+ for {
+ conn, err := ln.Accept()
+ if err != nil {
+ sl.Close()
+ hl.Close()
+ return
+ }
+ go routeConn(conn, sl, hl)
+ }
+}
+
+func routeConn(c net.Conn, socksListener, httpListener *listener) {
+ if err := c.SetReadDeadline(time.Now().Add(15 * time.Second)); err != nil {
+ c.Close()
+ return
+ }
+
+ var b [1]byte
+ if _, err := io.ReadFull(c, b[:]); err != nil {
+ c.Close()
+ return
+ }
+
+ if err := c.SetReadDeadline(time.Time{}); err != nil {
+ c.Close()
+ return
+ }
+
+ conn := &connWithOneByte{
+ Conn: c,
+ b: b[0],
+ }
+
+ // First byte of a SOCKS5 session is a version byte set to 5.
+ var ln *listener
+ if b[0] == 5 {
+ ln = socksListener
+ } else {
+ ln = httpListener
+ }
+ select {
+ case ln.c <- conn:
+ case <-ln.closed:
+ c.Close()
+ }
+}
+
+type listener struct {
+ addr net.Addr
+ c chan net.Conn
+ mu sync.Mutex // serializes close() on closed. It's okay to receive on closed without locking.
+ closed chan struct{}
+}
+
+func (ln *listener) Accept() (net.Conn, error) {
+ // Once closed, reliably stay closed, don't race with attempts at
+ // further connections.
+ select {
+ case <-ln.closed:
+ return nil, net.ErrClosed
+ default:
+ }
+ select {
+ case ret := <-ln.c:
+ return ret, nil
+ case <-ln.closed:
+ return nil, net.ErrClosed
+ }
+}
+
+func (ln *listener) Close() error {
+ ln.mu.Lock()
+ defer ln.mu.Unlock()
+ select {
+ case <-ln.closed:
+ // Already closed
+ default:
+ close(ln.closed)
+ }
+ return nil
+}
+
+func (ln *listener) Addr() net.Addr {
+ return ln.addr
+}
+
+// connWithOneByte is a net.Conn that returns b for the first read
+// request, then forwards everything else to Conn.
+type connWithOneByte struct {
+ net.Conn
+
+ b byte
+ bRead bool
+}
+
+func (c *connWithOneByte) Read(bs []byte) (int, error) {
+ if c.bRead {
+ return c.Conn.Read(bs)
+ }
+ if len(bs) == 0 {
+ return 0, nil
+ }
+ c.bRead = true
+ bs[0] = c.b
+ return 1, nil
+}
diff --git a/net/routetable/routetable_darwin.go b/net/routetable/routetable_darwin.go
index 7f525ae32..7de80a662 100644
--- a/net/routetable/routetable_darwin.go
+++ b/net/routetable/routetable_darwin.go
@@ -1,36 +1,36 @@
-// Copyright (c) Tailscale Inc & AUTHORS
-// SPDX-License-Identifier: BSD-3-Clause
-
-//go:build darwin
-
-package routetable
-
-import "golang.org/x/sys/unix"
-
-const (
- ribType = unix.NET_RT_DUMP2
- parseType = unix.NET_RT_IFLIST2
- rmExpectedType = unix.RTM_GET2
-
- // Skip routes that were cloned from a parent
- skipFlags = unix.RTF_WASCLONED
-)
-
-var flags = map[int]string{
- unix.RTF_BLACKHOLE: "blackhole",
- unix.RTF_BROADCAST: "broadcast",
- unix.RTF_GATEWAY: "gateway",
- unix.RTF_GLOBAL: "global",
- unix.RTF_HOST: "host",
- unix.RTF_IFSCOPE: "ifscope",
- unix.RTF_LOCAL: "local",
- unix.RTF_MULTICAST: "multicast",
- unix.RTF_REJECT: "reject",
- unix.RTF_ROUTER: "router",
- unix.RTF_STATIC: "static",
- unix.RTF_UP: "up",
- // More obscure flags, just to have full coverage.
- unix.RTF_LLINFO: "{RTF_LLINFO}",
- unix.RTF_PRCLONING: "{RTF_PRCLONING}",
- unix.RTF_CLONING: "{RTF_CLONING}",
-}
+// Copyright (c) Tailscale Inc & AUTHORS
+// SPDX-License-Identifier: BSD-3-Clause
+
+//go:build darwin
+
+package routetable
+
+import "golang.org/x/sys/unix"
+
+const (
+ ribType = unix.NET_RT_DUMP2
+ parseType = unix.NET_RT_IFLIST2
+ rmExpectedType = unix.RTM_GET2
+
+ // Skip routes that were cloned from a parent
+ skipFlags = unix.RTF_WASCLONED
+)
+
+var flags = map[int]string{
+ unix.RTF_BLACKHOLE: "blackhole",
+ unix.RTF_BROADCAST: "broadcast",
+ unix.RTF_GATEWAY: "gateway",
+ unix.RTF_GLOBAL: "global",
+ unix.RTF_HOST: "host",
+ unix.RTF_IFSCOPE: "ifscope",
+ unix.RTF_LOCAL: "local",
+ unix.RTF_MULTICAST: "multicast",
+ unix.RTF_REJECT: "reject",
+ unix.RTF_ROUTER: "router",
+ unix.RTF_STATIC: "static",
+ unix.RTF_UP: "up",
+ // More obscure flags, just to have full coverage.
+ unix.RTF_LLINFO: "{RTF_LLINFO}",
+ unix.RTF_PRCLONING: "{RTF_PRCLONING}",
+ unix.RTF_CLONING: "{RTF_CLONING}",
+}
diff --git a/net/routetable/routetable_freebsd.go b/net/routetable/routetable_freebsd.go
index 8e57a3302..aa4e03c41 100644
--- a/net/routetable/routetable_freebsd.go
+++ b/net/routetable/routetable_freebsd.go
@@ -1,28 +1,28 @@
-// Copyright (c) Tailscale Inc & AUTHORS
-// SPDX-License-Identifier: BSD-3-Clause
-
-//go:build freebsd
-
-package routetable
-
-import "golang.org/x/sys/unix"
-
-const (
- ribType = unix.NET_RT_DUMP
- parseType = unix.NET_RT_IFLIST
- rmExpectedType = unix.RTM_GET
-
- // Nothing to skip
- skipFlags = 0
-)
-
-var flags = map[int]string{
- unix.RTF_BLACKHOLE: "blackhole",
- unix.RTF_BROADCAST: "broadcast",
- unix.RTF_GATEWAY: "gateway",
- unix.RTF_HOST: "host",
- unix.RTF_MULTICAST: "multicast",
- unix.RTF_REJECT: "reject",
- unix.RTF_STATIC: "static",
- unix.RTF_UP: "up",
-}
+// Copyright (c) Tailscale Inc & AUTHORS
+// SPDX-License-Identifier: BSD-3-Clause
+
+//go:build freebsd
+
+package routetable
+
+import "golang.org/x/sys/unix"
+
+const (
+ ribType = unix.NET_RT_DUMP
+ parseType = unix.NET_RT_IFLIST
+ rmExpectedType = unix.RTM_GET
+
+ // Nothing to skip
+ skipFlags = 0
+)
+
+var flags = map[int]string{
+ unix.RTF_BLACKHOLE: "blackhole",
+ unix.RTF_BROADCAST: "broadcast",
+ unix.RTF_GATEWAY: "gateway",
+ unix.RTF_HOST: "host",
+ unix.RTF_MULTICAST: "multicast",
+ unix.RTF_REJECT: "reject",
+ unix.RTF_STATIC: "static",
+ unix.RTF_UP: "up",
+}
diff --git a/net/routetable/routetable_other.go b/net/routetable/routetable_other.go
index 35c83e374..521fe1911 100644
--- a/net/routetable/routetable_other.go
+++ b/net/routetable/routetable_other.go
@@ -1,17 +1,17 @@
-// Copyright (c) Tailscale Inc & AUTHORS
-// SPDX-License-Identifier: BSD-3-Clause
-
-//go:build !linux && !darwin && !freebsd
-
-package routetable
-
-import (
- "errors"
- "runtime"
-)
-
-var errUnsupported = errors.New("cannot get route table on platform " + runtime.GOOS)
-
-func Get(max int) ([]RouteEntry, error) {
- return nil, errUnsupported
-}
+// Copyright (c) Tailscale Inc & AUTHORS
+// SPDX-License-Identifier: BSD-3-Clause
+
+//go:build !linux && !darwin && !freebsd
+
+package routetable
+
+import (
+ "errors"
+ "runtime"
+)
+
+var errUnsupported = errors.New("cannot get route table on platform " + runtime.GOOS)
+
+func Get(max int) ([]RouteEntry, error) {
+ return nil, errUnsupported
+}
diff --git a/net/sockstats/sockstats.go b/net/sockstats/sockstats.go
index 715c1ee06..fb524a5c5 100644
--- a/net/sockstats/sockstats.go
+++ b/net/sockstats/sockstats.go
@@ -1,121 +1,121 @@
-// Copyright (c) Tailscale Inc & AUTHORS
-// SPDX-License-Identifier: BSD-3-Clause
-
-// Package sockstats collects statistics about network sockets used by
-// the Tailscale client. The context where sockets are used must be
-// instrumented with the WithSockStats() function.
-//
-// Only available on POSIX platforms when built with Tailscale's fork of Go.
-package sockstats
-
-import (
- "context"
-
- "tailscale.com/net/netmon"
- "tailscale.com/types/logger"
-)
-
-// SockStats contains statistics for sockets instrumented with the
-// WithSockStats() function
-type SockStats struct {
- Stats map[Label]SockStat
- CurrentInterfaceCellular bool
-}
-
-// SockStat contains the sent and received bytes for a socket instrumented with
-// the WithSockStats() function.
-type SockStat struct {
- TxBytes uint64
- RxBytes uint64
-}
-
-// Label is an identifier for a socket that stats are collected for. A finite
-// set of values that may be used to label a socket to encourage grouping and
-// to make storage more efficient.
-type Label uint8
-
-//go:generate go run golang.org/x/tools/cmd/stringer -type Label -trimprefix Label
-
-// Labels are named after the package and function/struct that uses the socket.
-// Values may be persisted and thus existing entries should not be re-numbered.
-const (
- LabelControlClientAuto Label = 0 // control/controlclient/auto.go
- LabelControlClientDialer Label = 1 // control/controlhttp/client.go
- LabelDERPHTTPClient Label = 2 // derp/derphttp/derphttp_client.go
- LabelLogtailLogger Label = 3 // logtail/logtail.go
- LabelDNSForwarderDoH Label = 4 // net/dns/resolver/forwarder.go
- LabelDNSForwarderUDP Label = 5 // net/dns/resolver/forwarder.go
- LabelNetcheckClient Label = 6 // net/netcheck/netcheck.go
- LabelPortmapperClient Label = 7 // net/portmapper/portmapper.go
- LabelMagicsockConnUDP4 Label = 8 // wgengine/magicsock/magicsock.go
- LabelMagicsockConnUDP6 Label = 9 // wgengine/magicsock/magicsock.go
- LabelNetlogLogger Label = 10 // wgengine/netlog/logger.go
- LabelSockstatlogLogger Label = 11 // log/sockstatlog/logger.go
- LabelDNSForwarderTCP Label = 12 // net/dns/resolver/forwarder.go
-)
-
-// WithSockStats instruments a context so that sockets created with it will
-// have their statistics collected.
-func WithSockStats(ctx context.Context, label Label, logf logger.Logf) context.Context {
- return withSockStats(ctx, label, logf)
-}
-
-// Get returns the current socket statistics.
-func Get() *SockStats {
- return get()
-}
-
-// InterfaceSockStats contains statistics for sockets instrumented with the
-// WithSockStats() function, broken down by interface. The statistics may be a
-// subset of the total if interfaces were added after the instrumented socket
-// was created.
-type InterfaceSockStats struct {
- Stats map[Label]InterfaceSockStat
- Interfaces []string
-}
-
-// InterfaceSockStat contains the per-interface sent and received bytes for a
-// socket instrumented with the WithSockStats() function.
-type InterfaceSockStat struct {
- TxBytesByInterface map[string]uint64
- RxBytesByInterface map[string]uint64
-}
-
-// GetWithInterfaces is a variant of Get that returns the current socket
-// statistics broken down by interface. It is slightly more expensive than Get.
-func GetInterfaces() *InterfaceSockStats {
- return getInterfaces()
-}
-
-// ValidationSockStats contains external validation numbers for sockets
-// instrumented with WithSockStats. It may be a subset of the all sockets,
-// depending on what externa measurement mechanisms the platform supports.
-type ValidationSockStats struct {
- Stats map[Label]ValidationSockStat
-}
-
-// ValidationSockStat contains the validation bytes for a socket instrumented
-// with WithSockStats.
-type ValidationSockStat struct {
- TxBytes uint64
- RxBytes uint64
-}
-
-// GetValidation is a variant of Get that returns external validation numbers
-// for stats. It is more expensive than Get and should be used in debug
-// interfaces only.
-func GetValidation() *ValidationSockStats {
- return getValidation()
-}
-
-// SetNetMon configures the sockstats package to monitor the active
-// interface, so that per-interface stats can be collected.
-func SetNetMon(netMon *netmon.Monitor) {
- setNetMon(netMon)
-}
-
-// DebugInfo returns a string containing debug information about the tracked
-// statistics.
-func DebugInfo() string {
- return debugInfo()
-}
+// Copyright (c) Tailscale Inc & AUTHORS
+// SPDX-License-Identifier: BSD-3-Clause
+
+// Package sockstats collects statistics about network sockets used by
+// the Tailscale client. The context where sockets are used must be
+// instrumented with the WithSockStats() function.
+//
+// Only available on POSIX platforms when built with Tailscale's fork of Go.
+package sockstats
+
+import (
+ "context"
+
+ "tailscale.com/net/netmon"
+ "tailscale.com/types/logger"
+)
+
+// SockStats contains statistics for sockets instrumented with the
+// WithSockStats() function
+type SockStats struct {
+ Stats map[Label]SockStat
+ CurrentInterfaceCellular bool
+}
+
+// SockStat contains the sent and received bytes for a socket instrumented with
+// the WithSockStats() function.
+type SockStat struct {
+ TxBytes uint64
+ RxBytes uint64
+}
+
+// Label is an identifier for a socket that stats are collected for. A finite
+// set of values that may be used to label a socket to encourage grouping and
+// to make storage more efficient.
+type Label uint8
+
+//go:generate go run golang.org/x/tools/cmd/stringer -type Label -trimprefix Label
+
+// Labels are named after the package and function/struct that uses the socket.
+// Values may be persisted and thus existing entries should not be re-numbered.
+const (
+ LabelControlClientAuto Label = 0 // control/controlclient/auto.go
+ LabelControlClientDialer Label = 1 // control/controlhttp/client.go
+ LabelDERPHTTPClient Label = 2 // derp/derphttp/derphttp_client.go
+ LabelLogtailLogger Label = 3 // logtail/logtail.go
+ LabelDNSForwarderDoH Label = 4 // net/dns/resolver/forwarder.go
+ LabelDNSForwarderUDP Label = 5 // net/dns/resolver/forwarder.go
+ LabelNetcheckClient Label = 6 // net/netcheck/netcheck.go
+ LabelPortmapperClient Label = 7 // net/portmapper/portmapper.go
+ LabelMagicsockConnUDP4 Label = 8 // wgengine/magicsock/magicsock.go
+ LabelMagicsockConnUDP6 Label = 9 // wgengine/magicsock/magicsock.go
+ LabelNetlogLogger Label = 10 // wgengine/netlog/logger.go
+ LabelSockstatlogLogger Label = 11 // log/sockstatlog/logger.go
+ LabelDNSForwarderTCP Label = 12 // net/dns/resolver/forwarder.go
+)
+
+// WithSockStats instruments a context so that sockets created with it will
+// have their statistics collected.
+func WithSockStats(ctx context.Context, label Label, logf logger.Logf) context.Context {
+ return withSockStats(ctx, label, logf)
+}
+
+// Get returns the current socket statistics.
+func Get() *SockStats {
+ return get()
+}
+
+// InterfaceSockStats contains statistics for sockets instrumented with the
+// WithSockStats() function, broken down by interface. The statistics may be a
+// subset of the total if interfaces were added after the instrumented socket
+// was created.
+type InterfaceSockStats struct {
+ Stats map[Label]InterfaceSockStat
+ Interfaces []string
+}
+
+// InterfaceSockStat contains the per-interface sent and received bytes for a
+// socket instrumented with the WithSockStats() function.
+type InterfaceSockStat struct {
+ TxBytesByInterface map[string]uint64
+ RxBytesByInterface map[string]uint64
+}
+
+// GetWithInterfaces is a variant of Get that returns the current socket
+// statistics broken down by interface. It is slightly more expensive than Get.
+func GetInterfaces() *InterfaceSockStats {
+ return getInterfaces()
+}
+
+// ValidationSockStats contains external validation numbers for sockets
+// instrumented with WithSockStats. It may be a subset of the all sockets,
+// depending on what externa measurement mechanisms the platform supports.
+type ValidationSockStats struct {
+ Stats map[Label]ValidationSockStat
+}
+
+// ValidationSockStat contains the validation bytes for a socket instrumented
+// with WithSockStats.
+type ValidationSockStat struct {
+ TxBytes uint64
+ RxBytes uint64
+}
+
+// GetValidation is a variant of Get that returns external validation numbers
+// for stats. It is more expensive than Get and should be used in debug
+// interfaces only.
+func GetValidation() *ValidationSockStats {
+ return getValidation()
+}
+
+// SetNetMon configures the sockstats package to monitor the active
+// interface, so that per-interface stats can be collected.
+func SetNetMon(netMon *netmon.Monitor) {
+ setNetMon(netMon)
+}
+
+// DebugInfo returns a string containing debug information about the tracked
+// statistics.
+func DebugInfo() string {
+ return debugInfo()
+}
diff --git a/net/sockstats/sockstats_noop.go b/net/sockstats/sockstats_noop.go
index 96723111a..797fdc42b 100644
--- a/net/sockstats/sockstats_noop.go
+++ b/net/sockstats/sockstats_noop.go
@@ -1,38 +1,38 @@
-// Copyright (c) Tailscale Inc & AUTHORS
-// SPDX-License-Identifier: BSD-3-Clause
-
-//go:build !tailscale_go || !(darwin || ios || android || ts_enable_sockstats)
-
-package sockstats
-
-import (
- "context"
-
- "tailscale.com/net/netmon"
- "tailscale.com/types/logger"
-)
-
-const IsAvailable = false
-
-func withSockStats(ctx context.Context, label Label, logf logger.Logf) context.Context {
- return ctx
-}
-
-func get() *SockStats {
- return nil
-}
-
-func getInterfaces() *InterfaceSockStats {
- return nil
-}
-
-func getValidation() *ValidationSockStats {
- return nil
-}
-
-func setNetMon(netMon *netmon.Monitor) {
-}
-
-func debugInfo() string {
- return ""
-}
+// Copyright (c) Tailscale Inc & AUTHORS
+// SPDX-License-Identifier: BSD-3-Clause
+
+//go:build !tailscale_go || !(darwin || ios || android || ts_enable_sockstats)
+
+package sockstats
+
+import (
+ "context"
+
+ "tailscale.com/net/netmon"
+ "tailscale.com/types/logger"
+)
+
+const IsAvailable = false
+
+func withSockStats(ctx context.Context, label Label, logf logger.Logf) context.Context {
+ return ctx
+}
+
+func get() *SockStats {
+ return nil
+}
+
+func getInterfaces() *InterfaceSockStats {
+ return nil
+}
+
+func getValidation() *ValidationSockStats {
+ return nil
+}
+
+func setNetMon(netMon *netmon.Monitor) {
+}
+
+func debugInfo() string {
+ return ""
+}
diff --git a/net/sockstats/sockstats_tsgo_darwin.go b/net/sockstats/sockstats_tsgo_darwin.go
index 321d32e04..4b03ed616 100644
--- a/net/sockstats/sockstats_tsgo_darwin.go
+++ b/net/sockstats/sockstats_tsgo_darwin.go
@@ -1,30 +1,30 @@
-// Copyright (c) Tailscale Inc & AUTHORS
-// SPDX-License-Identifier: BSD-3-Clause
-
-//go:build tailscale_go && (darwin || ios)
-
-package sockstats
-
-import (
- "syscall"
-
- "golang.org/x/sys/unix"
-)
-
-func init() {
- tcpConnStats = darwinTcpConnStats
-}
-
-func darwinTcpConnStats(c syscall.RawConn) (tx, rx uint64) {
- c.Control(func(fd uintptr) {
- if rawInfo, err := unix.GetsockoptTCPConnectionInfo(
- int(fd),
- unix.IPPROTO_TCP,
- unix.TCP_CONNECTION_INFO,
- ); err == nil {
- tx = uint64(rawInfo.Txbytes)
- rx = uint64(rawInfo.Rxbytes)
- }
- })
- return
-}
+// Copyright (c) Tailscale Inc & AUTHORS
+// SPDX-License-Identifier: BSD-3-Clause
+
+//go:build tailscale_go && (darwin || ios)
+
+package sockstats
+
+import (
+ "syscall"
+
+ "golang.org/x/sys/unix"
+)
+
+func init() {
+ tcpConnStats = darwinTcpConnStats
+}
+
+func darwinTcpConnStats(c syscall.RawConn) (tx, rx uint64) {
+ c.Control(func(fd uintptr) {
+ if rawInfo, err := unix.GetsockoptTCPConnectionInfo(
+ int(fd),
+ unix.IPPROTO_TCP,
+ unix.TCP_CONNECTION_INFO,
+ ); err == nil {
+ tx = uint64(rawInfo.Txbytes)
+ rx = uint64(rawInfo.Rxbytes)
+ }
+ })
+ return
+}
diff --git a/net/speedtest/speedtest.go b/net/speedtest/speedtest.go
index 7ab0881cc..89639c12d 100644
--- a/net/speedtest/speedtest.go
+++ b/net/speedtest/speedtest.go
@@ -1,87 +1,87 @@
-// Copyright (c) Tailscale Inc & AUTHORS
-// SPDX-License-Identifier: BSD-3-Clause
-
-// Package speedtest contains both server and client code for
-// running speedtests between tailscale nodes.
-package speedtest
-
-import (
- "time"
-)
-
-const (
- blockSize = 2 * 1024 * 1024 // size of the block of data to send
- MinDuration = 5 * time.Second // minimum duration for a test
- DefaultDuration = MinDuration // default duration for a test
- MaxDuration = 30 * time.Second // maximum duration for a test
- version = 2 // value used when comparing client and server versions
- increment = time.Second // increment to display results for, in seconds
- minInterval = 10 * time.Millisecond // minimum interval length for a result to be included
- DefaultPort = 20333
-)
-
-// config is the initial message sent to the server, that contains information on how to
-// conduct the test.
-type config struct {
- Version int `json:"version"`
- TestDuration time.Duration `json:"time"`
- Direction Direction `json:"direction"`
-}
-
-// configResponse is the response to the testConfig message. If the server has an
-// error with the config, the Error variable will hold that error value.
-type configResponse struct {
- Error string `json:"error,omitempty"`
-}
-
-// This represents the Result of a speedtest within a specific interval
-type Result struct {
- Bytes int // number of bytes sent/received during the interval
- IntervalStart time.Time // start of the interval
- IntervalEnd time.Time // end of the interval
- Total bool // if true, this result struct represents the entire test, rather than a segment of the test
-}
-
-func (r Result) MBitsPerSecond() float64 {
- return r.MegaBits() / r.IntervalEnd.Sub(r.IntervalStart).Seconds()
-}
-
-func (r Result) MegaBytes() float64 {
- return float64(r.Bytes) / 1000000.0
-}
-
-func (r Result) MegaBits() float64 {
- return r.MegaBytes() * 8.0
-}
-
-func (r Result) Interval() time.Duration {
- return r.IntervalEnd.Sub(r.IntervalStart)
-}
-
-type Direction int
-
-const (
- Download Direction = iota
- Upload
-)
-
-func (d Direction) String() string {
- switch d {
- case Upload:
- return "upload"
- case Download:
- return "download"
- default:
- return ""
- }
-}
-
-func (d *Direction) Reverse() {
- switch *d {
- case Upload:
- *d = Download
- case Download:
- *d = Upload
- default:
- }
-}
+// Copyright (c) Tailscale Inc & AUTHORS
+// SPDX-License-Identifier: BSD-3-Clause
+
+// Package speedtest contains both server and client code for
+// running speedtests between tailscale nodes.
+package speedtest
+
+import (
+ "time"
+)
+
+const (
+ blockSize = 2 * 1024 * 1024 // size of the block of data to send
+ MinDuration = 5 * time.Second // minimum duration for a test
+ DefaultDuration = MinDuration // default duration for a test
+ MaxDuration = 30 * time.Second // maximum duration for a test
+ version = 2 // value used when comparing client and server versions
+ increment = time.Second // increment to display results for, in seconds
+ minInterval = 10 * time.Millisecond // minimum interval length for a result to be included
+ DefaultPort = 20333
+)
+
+// config is the initial message sent to the server, that contains information on how to
+// conduct the test.
+type config struct {
+ Version int `json:"version"`
+ TestDuration time.Duration `json:"time"`
+ Direction Direction `json:"direction"`
+}
+
+// configResponse is the response to the testConfig message. If the server has an
+// error with the config, the Error variable will hold that error value.
+type configResponse struct {
+ Error string `json:"error,omitempty"`
+}
+
+// This represents the Result of a speedtest within a specific interval
+type Result struct {
+ Bytes int // number of bytes sent/received during the interval
+ IntervalStart time.Time // start of the interval
+ IntervalEnd time.Time // end of the interval
+ Total bool // if true, this result struct represents the entire test, rather than a segment of the test
+}
+
+func (r Result) MBitsPerSecond() float64 {
+ return r.MegaBits() / r.IntervalEnd.Sub(r.IntervalStart).Seconds()
+}
+
+func (r Result) MegaBytes() float64 {
+ return float64(r.Bytes) / 1000000.0
+}
+
+func (r Result) MegaBits() float64 {
+ return r.MegaBytes() * 8.0
+}
+
+func (r Result) Interval() time.Duration {
+ return r.IntervalEnd.Sub(r.IntervalStart)
+}
+
+type Direction int
+
+const (
+ Download Direction = iota
+ Upload
+)
+
+func (d Direction) String() string {
+ switch d {
+ case Upload:
+ return "upload"
+ case Download:
+ return "download"
+ default:
+ return ""
+ }
+}
+
+func (d *Direction) Reverse() {
+ switch *d {
+ case Upload:
+ *d = Download
+ case Download:
+ *d = Upload
+ default:
+ }
+}
diff --git a/net/speedtest/speedtest_client.go b/net/speedtest/speedtest_client.go
index 299a12a8d..cc34c468c 100644
--- a/net/speedtest/speedtest_client.go
+++ b/net/speedtest/speedtest_client.go
@@ -1,41 +1,41 @@
-// Copyright (c) Tailscale Inc & AUTHORS
-// SPDX-License-Identifier: BSD-3-Clause
-
-package speedtest
-
-import (
- "encoding/json"
- "errors"
- "net"
- "time"
-)
-
-// RunClient dials the given address and starts a speedtest.
-// It returns any errors that come up in the tests.
-// If there are no errors in the test, it returns a slice of results.
-func RunClient(direction Direction, duration time.Duration, host string) ([]Result, error) {
- conn, err := net.Dial("tcp", host)
- if err != nil {
- return nil, err
- }
-
- conf := config{TestDuration: duration, Version: version, Direction: direction}
-
- defer conn.Close()
- encoder := json.NewEncoder(conn)
-
- if err = encoder.Encode(conf); err != nil {
- return nil, err
- }
-
- var response configResponse
- decoder := json.NewDecoder(conn)
- if err = decoder.Decode(&response); err != nil {
- return nil, err
- }
- if response.Error != "" {
- return nil, errors.New(response.Error)
- }
-
- return doTest(conn, conf)
-}
+// Copyright (c) Tailscale Inc & AUTHORS
+// SPDX-License-Identifier: BSD-3-Clause
+
+package speedtest
+
+import (
+ "encoding/json"
+ "errors"
+ "net"
+ "time"
+)
+
+// RunClient dials the given address and starts a speedtest.
+// It returns any errors that come up in the tests.
+// If there are no errors in the test, it returns a slice of results.
+func RunClient(direction Direction, duration time.Duration, host string) ([]Result, error) {
+ conn, err := net.Dial("tcp", host)
+ if err != nil {
+ return nil, err
+ }
+
+ conf := config{TestDuration: duration, Version: version, Direction: direction}
+
+ defer conn.Close()
+ encoder := json.NewEncoder(conn)
+
+ if err = encoder.Encode(conf); err != nil {
+ return nil, err
+ }
+
+ var response configResponse
+ decoder := json.NewDecoder(conn)
+ if err = decoder.Decode(&response); err != nil {
+ return nil, err
+ }
+ if response.Error != "" {
+ return nil, errors.New(response.Error)
+ }
+
+ return doTest(conn, conf)
+}
diff --git a/net/speedtest/speedtest_server.go b/net/speedtest/speedtest_server.go
index 9dd78b195..d2673464e 100644
--- a/net/speedtest/speedtest_server.go
+++ b/net/speedtest/speedtest_server.go
@@ -1,146 +1,146 @@
-// Copyright (c) Tailscale Inc & AUTHORS
-// SPDX-License-Identifier: BSD-3-Clause
-
-package speedtest
-
-import (
- "crypto/rand"
- "encoding/json"
- "errors"
- "fmt"
- "io"
- "net"
- "time"
-)
-
-// Serve starts up the server on a given host and port pair. It starts to listen for
-// connections and handles each one in a goroutine. Because it runs in an infinite loop,
-// this function only returns if any of the speedtests return with errors, or if the
-// listener is closed.
-func Serve(l net.Listener) error {
- for {
- conn, err := l.Accept()
- if errors.Is(err, net.ErrClosed) {
- return nil
- }
- if err != nil {
- return err
- }
- err = handleConnection(conn)
- if err != nil {
- return err
- }
- }
-}
-
-// handleConnection handles the initial exchange between the server and the client.
-// It reads the testconfig message into a config struct. If any errors occur with
-// the testconfig (specifically, if there is a version mismatch), it will return those
-// errors to the client with a configResponse. After the exchange, it will start
-// the speed test.
-func handleConnection(conn net.Conn) error {
- defer conn.Close()
- var conf config
-
- decoder := json.NewDecoder(conn)
- err := decoder.Decode(&conf)
- encoder := json.NewEncoder(conn)
-
- // Both return and encode errors that occurred before the test started.
- if err != nil {
- encoder.Encode(configResponse{Error: err.Error()})
- return err
- }
-
- // The server should always be doing the opposite of what the client is doing.
- conf.Direction.Reverse()
-
- if conf.Version != version {
- err = fmt.Errorf("version mismatch! Server is version %d, client is version %d", version, conf.Version)
- encoder.Encode(configResponse{Error: err.Error()})
- return err
- }
-
- // Start the test
- encoder.Encode(configResponse{})
- _, err = doTest(conn, conf)
- return err
-}
-
-// TODO include code to detect whether the code is direct vs DERP
-
-// doTest contains the code to run both the upload and download speedtest.
-// the direction value in the config parameter determines which test to run.
-func doTest(conn net.Conn, conf config) ([]Result, error) {
- bufferData := make([]byte, blockSize)
-
- intervalBytes := 0
- totalBytes := 0
-
- var currentTime time.Time
- var results []Result
-
- if conf.Direction == Download {
- conn.SetReadDeadline(time.Now().Add(conf.TestDuration).Add(5 * time.Second))
- } else {
- _, err := rand.Read(bufferData)
- if err != nil {
- return nil, err
- }
-
- }
-
- startTime := time.Now()
- lastCalculated := startTime
-
-SpeedTestLoop:
- for {
- var n int
- var err error
-
- if conf.Direction == Download {
- n, err = io.ReadFull(conn, bufferData)
- switch err {
- case io.EOF, io.ErrUnexpectedEOF:
- break SpeedTestLoop
- case nil:
- // successful read
- default:
- return nil, fmt.Errorf("unexpected error has occurred: %w", err)
- }
- } else {
- n, err = conn.Write(bufferData)
- if err != nil {
- // If the write failed, there is most likely something wrong with the connection.
- return nil, fmt.Errorf("upload failed: %w", err)
- }
- }
- intervalBytes += n
-
- currentTime = time.Now()
- // checks if the current time is more or equal to the lastCalculated time plus the increment
- if currentTime.Sub(lastCalculated) >= increment {
- results = append(results, Result{Bytes: intervalBytes, IntervalStart: lastCalculated, IntervalEnd: currentTime, Total: false})
- lastCalculated = currentTime
- totalBytes += intervalBytes
- intervalBytes = 0
- }
-
- if conf.Direction == Upload && currentTime.Sub(startTime) > conf.TestDuration {
- break SpeedTestLoop
- }
- }
-
- // get last segment
- if currentTime.Sub(lastCalculated) > minInterval {
- results = append(results, Result{Bytes: intervalBytes, IntervalStart: lastCalculated, IntervalEnd: currentTime, Total: false})
- }
-
- // get total
- totalBytes += intervalBytes
- if currentTime.Sub(startTime) > minInterval {
- results = append(results, Result{Bytes: totalBytes, IntervalStart: startTime, IntervalEnd: currentTime, Total: true})
- }
-
- return results, nil
-}
+// Copyright (c) Tailscale Inc & AUTHORS
+// SPDX-License-Identifier: BSD-3-Clause
+
+package speedtest
+
+import (
+ "crypto/rand"
+ "encoding/json"
+ "errors"
+ "fmt"
+ "io"
+ "net"
+ "time"
+)
+
+// Serve starts up the server on a given host and port pair. It starts to listen for
+// connections and handles each one in a goroutine. Because it runs in an infinite loop,
+// this function only returns if any of the speedtests return with errors, or if the
+// listener is closed.
+func Serve(l net.Listener) error {
+ for {
+ conn, err := l.Accept()
+ if errors.Is(err, net.ErrClosed) {
+ return nil
+ }
+ if err != nil {
+ return err
+ }
+ err = handleConnection(conn)
+ if err != nil {
+ return err
+ }
+ }
+}
+
+// handleConnection handles the initial exchange between the server and the client.
+// It reads the testconfig message into a config struct. If any errors occur with
+// the testconfig (specifically, if there is a version mismatch), it will return those
+// errors to the client with a configResponse. After the exchange, it will start
+// the speed test.
+func handleConnection(conn net.Conn) error {
+ defer conn.Close()
+ var conf config
+
+ decoder := json.NewDecoder(conn)
+ err := decoder.Decode(&conf)
+ encoder := json.NewEncoder(conn)
+
+ // Both return and encode errors that occurred before the test started.
+ if err != nil {
+ encoder.Encode(configResponse{Error: err.Error()})
+ return err
+ }
+
+ // The server should always be doing the opposite of what the client is doing.
+ conf.Direction.Reverse()
+
+ if conf.Version != version {
+ err = fmt.Errorf("version mismatch! Server is version %d, client is version %d", version, conf.Version)
+ encoder.Encode(configResponse{Error: err.Error()})
+ return err
+ }
+
+ // Start the test
+ encoder.Encode(configResponse{})
+ _, err = doTest(conn, conf)
+ return err
+}
+
+// TODO include code to detect whether the code is direct vs DERP
+
+// doTest contains the code to run both the upload and download speedtest.
+// the direction value in the config parameter determines which test to run.
+func doTest(conn net.Conn, conf config) ([]Result, error) {
+ bufferData := make([]byte, blockSize)
+
+ intervalBytes := 0
+ totalBytes := 0
+
+ var currentTime time.Time
+ var results []Result
+
+ if conf.Direction == Download {
+ conn.SetReadDeadline(time.Now().Add(conf.TestDuration).Add(5 * time.Second))
+ } else {
+ _, err := rand.Read(bufferData)
+ if err != nil {
+ return nil, err
+ }
+
+ }
+
+ startTime := time.Now()
+ lastCalculated := startTime
+
+SpeedTestLoop:
+ for {
+ var n int
+ var err error
+
+ if conf.Direction == Download {
+ n, err = io.ReadFull(conn, bufferData)
+ switch err {
+ case io.EOF, io.ErrUnexpectedEOF:
+ break SpeedTestLoop
+ case nil:
+ // successful read
+ default:
+ return nil, fmt.Errorf("unexpected error has occurred: %w", err)
+ }
+ } else {
+ n, err = conn.Write(bufferData)
+ if err != nil {
+ // If the write failed, there is most likely something wrong with the connection.
+ return nil, fmt.Errorf("upload failed: %w", err)
+ }
+ }
+ intervalBytes += n
+
+ currentTime = time.Now()
+ // checks if the current time is more or equal to the lastCalculated time plus the increment
+ if currentTime.Sub(lastCalculated) >= increment {
+ results = append(results, Result{Bytes: intervalBytes, IntervalStart: lastCalculated, IntervalEnd: currentTime, Total: false})
+ lastCalculated = currentTime
+ totalBytes += intervalBytes
+ intervalBytes = 0
+ }
+
+ if conf.Direction == Upload && currentTime.Sub(startTime) > conf.TestDuration {
+ break SpeedTestLoop
+ }
+ }
+
+ // get last segment
+ if currentTime.Sub(lastCalculated) > minInterval {
+ results = append(results, Result{Bytes: intervalBytes, IntervalStart: lastCalculated, IntervalEnd: currentTime, Total: false})
+ }
+
+ // get total
+ totalBytes += intervalBytes
+ if currentTime.Sub(startTime) > minInterval {
+ results = append(results, Result{Bytes: totalBytes, IntervalStart: startTime, IntervalEnd: currentTime, Total: true})
+ }
+
+ return results, nil
+}
diff --git a/net/speedtest/speedtest_test.go b/net/speedtest/speedtest_test.go
index 55dcbeea1..a413e9efa 100644
--- a/net/speedtest/speedtest_test.go
+++ b/net/speedtest/speedtest_test.go
@@ -1,83 +1,83 @@
-// Copyright (c) Tailscale Inc & AUTHORS
-// SPDX-License-Identifier: BSD-3-Clause
-
-package speedtest
-
-import (
- "net"
- "testing"
- "time"
-)
-
-func TestDownload(t *testing.T) {
- // start a listener and find the port where the server will be listening.
- l, err := net.Listen("tcp", ":0")
- if err != nil {
- t.Fatal(err)
- }
- t.Cleanup(func() { l.Close() })
-
- serverIP := l.Addr().String()
- t.Log("server IP found:", serverIP)
-
- type state struct {
- err error
- }
- displayResult := func(t *testing.T, r Result, start time.Time) {
- t.Helper()
- t.Logf("{ Megabytes: %.2f, Start: %.1f, End: %.1f, Total: %t }", r.MegaBytes(), r.IntervalStart.Sub(start).Seconds(), r.IntervalEnd.Sub(start).Seconds(), r.Total)
- }
- stateChan := make(chan state, 1)
-
- go func() {
- err := Serve(l)
- stateChan <- state{err: err}
- }()
-
- // ensure that the test returns an appropriate number of Result structs
- expectedLen := int(DefaultDuration.Seconds()) + 1
-
- t.Run("download test", func(t *testing.T) {
- // conduct a download test
- results, err := RunClient(Download, DefaultDuration, serverIP)
-
- if err != nil {
- t.Fatal("download test failed:", err)
- }
-
- if len(results) < expectedLen {
- t.Fatalf("download results: expected length: %d, actual length: %d", expectedLen, len(results))
- }
-
- start := results[0].IntervalStart
- for _, result := range results {
- displayResult(t, result, start)
- }
- })
-
- t.Run("upload test", func(t *testing.T) {
- // conduct an upload test
- results, err := RunClient(Upload, DefaultDuration, serverIP)
-
- if err != nil {
- t.Fatal("upload test failed:", err)
- }
-
- if len(results) < expectedLen {
- t.Fatalf("upload results: expected length: %d, actual length: %d", expectedLen, len(results))
- }
-
- start := results[0].IntervalStart
- for _, result := range results {
- displayResult(t, result, start)
- }
- })
-
- // causes the server goroutine to finish
- l.Close()
-
- testState := <-stateChan
- if testState.err != nil {
- t.Error("server error:", err)
- }
-}
+// Copyright (c) Tailscale Inc & AUTHORS
+// SPDX-License-Identifier: BSD-3-Clause
+
+package speedtest
+
+import (
+ "net"
+ "testing"
+ "time"
+)
+
+func TestDownload(t *testing.T) {
+ // start a listener and find the port where the server will be listening.
+ l, err := net.Listen("tcp", ":0")
+ if err != nil {
+ t.Fatal(err)
+ }
+ t.Cleanup(func() { l.Close() })
+
+ serverIP := l.Addr().String()
+ t.Log("server IP found:", serverIP)
+
+ type state struct {
+ err error
+ }
+ displayResult := func(t *testing.T, r Result, start time.Time) {
+ t.Helper()
+ t.Logf("{ Megabytes: %.2f, Start: %.1f, End: %.1f, Total: %t }", r.MegaBytes(), r.IntervalStart.Sub(start).Seconds(), r.IntervalEnd.Sub(start).Seconds(), r.Total)
+ }
+ stateChan := make(chan state, 1)
+
+ go func() {
+ err := Serve(l)
+ stateChan <- state{err: err}
+ }()
+
+ // ensure that the test returns an appropriate number of Result structs
+ expectedLen := int(DefaultDuration.Seconds()) + 1
+
+ t.Run("download test", func(t *testing.T) {
+ // conduct a download test
+ results, err := RunClient(Download, DefaultDuration, serverIP)
+
+ if err != nil {
+ t.Fatal("download test failed:", err)
+ }
+
+ if len(results) < expectedLen {
+ t.Fatalf("download results: expected length: %d, actual length: %d", expectedLen, len(results))
+ }
+
+ start := results[0].IntervalStart
+ for _, result := range results {
+ displayResult(t, result, start)
+ }
+ })
+
+ t.Run("upload test", func(t *testing.T) {
+ // conduct an upload test
+ results, err := RunClient(Upload, DefaultDuration, serverIP)
+
+ if err != nil {
+ t.Fatal("upload test failed:", err)
+ }
+
+ if len(results) < expectedLen {
+ t.Fatalf("upload results: expected length: %d, actual length: %d", expectedLen, len(results))
+ }
+
+ start := results[0].IntervalStart
+ for _, result := range results {
+ displayResult(t, result, start)
+ }
+ })
+
+ // causes the server goroutine to finish
+ l.Close()
+
+ testState := <-stateChan
+ if testState.err != nil {
+ t.Error("server error:", err)
+ }
+}
diff --git a/net/stun/stun.go b/net/stun/stun.go
index eeac23cbb..81cf9b608 100644
--- a/net/stun/stun.go
+++ b/net/stun/stun.go
@@ -1,312 +1,312 @@
-// Copyright (c) Tailscale Inc & AUTHORS
-// SPDX-License-Identifier: BSD-3-Clause
-
-// Package STUN generates STUN request packets and parses response packets.
-package stun
-
-import (
- "bytes"
- crand "crypto/rand"
- "encoding/binary"
- "errors"
- "hash/crc32"
- "net"
- "net/netip"
-)
-
-const (
- attrNumSoftware = 0x8022
- attrNumFingerprint = 0x8028
- attrMappedAddress = 0x0001
- attrXorMappedAddress = 0x0020
- // This alternative attribute type is not
- // mentioned in the RFC, but the shift into
- // the "comprehension-optional" range seems
- // like an easy mistake for a server to make.
- // And servers appear to send it.
- attrXorMappedAddressAlt = 0x8020
-
- software = "tailnode" // notably: 8 bytes long, so no padding
- bindingRequest = "\x00\x01"
- magicCookie = "\x21\x12\xa4\x42"
- lenFingerprint = 8 // 2+byte header + 2-byte length + 4-byte crc32
- headerLen = 20
-)
-
-// TxID is a transaction ID.
-type TxID [12]byte
-
-// NewTxID returns a new random TxID.
-func NewTxID() TxID {
- var tx TxID
- if _, err := crand.Read(tx[:]); err != nil {
- panic(err)
- }
- return tx
-}
-
-// Request generates a binding request STUN packet.
-// The transaction ID, tID, should be a random sequence of bytes.
-func Request(tID TxID) []byte {
- // STUN header, RFC5389 Section 6.
- const lenAttrSoftware = 4 + len(software)
- b := make([]byte, 0, headerLen+lenAttrSoftware+lenFingerprint)
- b = append(b, bindingRequest...)
- b = appendU16(b, uint16(lenAttrSoftware+lenFingerprint)) // number of bytes following header
- b = append(b, magicCookie...)
- b = append(b, tID[:]...)
-
- // Attribute SOFTWARE, RFC5389 Section 15.5.
- b = appendU16(b, attrNumSoftware)
- b = appendU16(b, uint16(len(software)))
- b = append(b, software...)
-
- // Attribute FINGERPRINT, RFC5389 Section 15.5.
- fp := fingerPrint(b)
- b = appendU16(b, attrNumFingerprint)
- b = appendU16(b, 4)
- b = appendU32(b, fp)
-
- return b
-}
-
-func fingerPrint(b []byte) uint32 { return crc32.ChecksumIEEE(b) ^ 0x5354554e }
-
-func appendU16(b []byte, v uint16) []byte {
- return append(b, byte(v>>8), byte(v))
-}
-
-func appendU32(b []byte, v uint32) []byte {
- return append(b, byte(v>>24), byte(v>>16), byte(v>>8), byte(v))
-}
-
-// ParseBindingRequest parses a STUN binding request.
-//
-// It returns an error unless it advertises that it came from
-// Tailscale.
-func ParseBindingRequest(b []byte) (TxID, error) {
- if !Is(b) {
- return TxID{}, ErrNotSTUN
- }
- if string(b[:len(bindingRequest)]) != bindingRequest {
- return TxID{}, ErrNotBindingRequest
- }
- var txID TxID
- copy(txID[:], b[8:8+len(txID)])
- var softwareOK bool
- var lastAttr uint16
- var gotFP uint32
- if err := foreachAttr(b[headerLen:], func(attrType uint16, a []byte) error {
- lastAttr = attrType
- if attrType == attrNumSoftware && string(a) == software {
- softwareOK = true
- }
- if attrType == attrNumFingerprint && len(a) == 4 {
- gotFP = binary.BigEndian.Uint32(a)
- }
- return nil
- }); err != nil {
- return TxID{}, err
- }
- if !softwareOK {
- return TxID{}, ErrWrongSoftware
- }
- if lastAttr != attrNumFingerprint {
- return TxID{}, ErrNoFingerprint
- }
- wantFP := fingerPrint(b[:len(b)-lenFingerprint])
- if gotFP != wantFP {
- return TxID{}, ErrWrongFingerprint
- }
- return txID, nil
-}
-
-var (
- ErrNotSTUN = errors.New("response is not a STUN packet")
- ErrNotSuccessResponse = errors.New("STUN packet is not a response")
- ErrMalformedAttrs = errors.New("STUN response has malformed attributes")
- ErrNotBindingRequest = errors.New("STUN request not a binding request")
- ErrWrongSoftware = errors.New("STUN request came from non-Tailscale software")
- ErrNoFingerprint = errors.New("STUN request didn't end in fingerprint")
- ErrWrongFingerprint = errors.New("STUN request had bogus fingerprint")
-)
-
-func foreachAttr(b []byte, fn func(attrType uint16, a []byte) error) error {
- for len(b) > 0 {
- if len(b) < 4 {
- return ErrMalformedAttrs
- }
- attrType := binary.BigEndian.Uint16(b[:2])
- attrLen := int(binary.BigEndian.Uint16(b[2:4]))
- attrLenWithPad := (attrLen + 3) &^ 3
- b = b[4:]
- if attrLenWithPad > len(b) {
- return ErrMalformedAttrs
- }
- if err := fn(attrType, b[:attrLen]); err != nil {
- return err
- }
- b = b[attrLenWithPad:]
- }
- return nil
-}
-
-// Response generates a binding response.
-func Response(txID TxID, addrPort netip.AddrPort) []byte {
- addr := addrPort.Addr()
-
- var fam byte
- if addr.Is4() {
- fam = 1
- } else if addr.Is6() {
- fam = 2
- } else {
- return nil
- }
- attrsLen := 8 + addr.BitLen()/8
- b := make([]byte, 0, headerLen+attrsLen)
-
- // Header
- b = append(b, 0x01, 0x01) // success
- b = appendU16(b, uint16(attrsLen))
- b = append(b, magicCookie...)
- b = append(b, txID[:]...)
-
- // Attributes (well, one)
- b = appendU16(b, attrXorMappedAddress)
- b = appendU16(b, uint16(4+addr.BitLen()/8))
- b = append(b,
- 0, // unused byte
- fam)
- b = appendU16(b, addrPort.Port()^0x2112) // first half of magicCookie
- ipa := addr.As16()
- for i, o := range ipa[16-addr.BitLen()/8:] {
- if i < 4 {
- b = append(b, o^magicCookie[i])
- } else {
- b = append(b, o^txID[i-len(magicCookie)])
- }
- }
- return b
-}
-
-// ParseResponse parses a successful binding response STUN packet.
-// The IP address is extracted from the XOR-MAPPED-ADDRESS attribute.
-func ParseResponse(b []byte) (tID TxID, addr netip.AddrPort, err error) {
- if !Is(b) {
- return tID, netip.AddrPort{}, ErrNotSTUN
- }
- copy(tID[:], b[8:8+len(tID)])
- if b[0] != 0x01 || b[1] != 0x01 {
- return tID, netip.AddrPort{}, ErrNotSuccessResponse
- }
- attrsLen := int(binary.BigEndian.Uint16(b[2:4]))
- b = b[headerLen:] // remove STUN header
- if attrsLen > len(b) {
- return tID, netip.AddrPort{}, ErrMalformedAttrs
- } else if len(b) > attrsLen {
- b = b[:attrsLen] // trim trailing packet bytes
- }
-
- var fallbackAddr netip.AddrPort
-
- // Read through the attributes.
- // The the addr+port reported by XOR-MAPPED-ADDRESS
- // as the canonical value. If the attribute is not
- // present but the STUN server responds with
- // MAPPED-ADDRESS we fall back to it.
- if err := foreachAttr(b, func(attrType uint16, attr []byte) error {
- switch attrType {
- case attrXorMappedAddress, attrXorMappedAddressAlt:
- ipSlice, port, err := xorMappedAddress(tID, attr)
- if err != nil {
- return err
- }
- if ip, ok := netip.AddrFromSlice(ipSlice); ok {
- addr = netip.AddrPortFrom(ip.Unmap(), port)
- }
- case attrMappedAddress:
- ipSlice, port, err := mappedAddress(attr)
- if err != nil {
- return ErrMalformedAttrs
- }
- if ip, ok := netip.AddrFromSlice(ipSlice); ok {
- fallbackAddr = netip.AddrPortFrom(ip.Unmap(), port)
- }
- }
- return nil
-
- }); err != nil {
- return TxID{}, netip.AddrPort{}, err
- }
-
- if addr.IsValid() {
- return tID, addr, nil
- }
- if fallbackAddr.IsValid() {
- return tID, fallbackAddr, nil
- }
- return tID, netip.AddrPort{}, ErrMalformedAttrs
-}
-
-func xorMappedAddress(tID TxID, b []byte) (addr []byte, port uint16, err error) {
- // XOR-MAPPED-ADDRESS attribute, RFC5389 Section 15.2
- if len(b) < 4 {
- return nil, 0, ErrMalformedAttrs
- }
- xorPort := binary.BigEndian.Uint16(b[2:4])
- addrField := b[4:]
- port = xorPort ^ 0x2112 // first half of magicCookie
-
- addrLen := familyAddrLen(b[1])
- if addrLen == 0 {
- return nil, 0, ErrMalformedAttrs
- }
- if len(addrField) < addrLen {
- return nil, 0, ErrMalformedAttrs
- }
- xorAddr := addrField[:addrLen]
- addr = make([]byte, addrLen)
- for i := range xorAddr {
- if i < len(magicCookie) {
- addr[i] = xorAddr[i] ^ magicCookie[i]
- } else {
- addr[i] = xorAddr[i] ^ tID[i-len(magicCookie)]
- }
- }
- return addr, port, nil
-}
-
-func familyAddrLen(fam byte) int {
- switch fam {
- case 0x01: // IPv4
- return net.IPv4len
- case 0x02: // IPv6
- return net.IPv6len
- default:
- return 0
- }
-}
-
-func mappedAddress(b []byte) (addr []byte, port uint16, err error) {
- if len(b) < 4 {
- return nil, 0, ErrMalformedAttrs
- }
- port = uint16(b[2])<<8 | uint16(b[3])
- addrField := b[4:]
- addrLen := familyAddrLen(b[1])
- if addrLen == 0 {
- return nil, 0, ErrMalformedAttrs
- }
- if len(addrField) < addrLen {
- return nil, 0, ErrMalformedAttrs
- }
- return bytes.Clone(addrField[:addrLen]), port, nil
-}
-
-// Is reports whether b is a STUN message.
-func Is(b []byte) bool {
- return len(b) >= headerLen &&
- b[0]&0b11000000 == 0 && // top two bits must be zero
- string(b[4:8]) == magicCookie
-}
+// Copyright (c) Tailscale Inc & AUTHORS
+// SPDX-License-Identifier: BSD-3-Clause
+
+// Package STUN generates STUN request packets and parses response packets.
+package stun
+
+import (
+ "bytes"
+ crand "crypto/rand"
+ "encoding/binary"
+ "errors"
+ "hash/crc32"
+ "net"
+ "net/netip"
+)
+
+const (
+ attrNumSoftware = 0x8022
+ attrNumFingerprint = 0x8028
+ attrMappedAddress = 0x0001
+ attrXorMappedAddress = 0x0020
+ // This alternative attribute type is not
+ // mentioned in the RFC, but the shift into
+ // the "comprehension-optional" range seems
+ // like an easy mistake for a server to make.
+ // And servers appear to send it.
+ attrXorMappedAddressAlt = 0x8020
+
+ software = "tailnode" // notably: 8 bytes long, so no padding
+ bindingRequest = "\x00\x01"
+ magicCookie = "\x21\x12\xa4\x42"
+ lenFingerprint = 8 // 2+byte header + 2-byte length + 4-byte crc32
+ headerLen = 20
+)
+
+// TxID is a transaction ID.
+type TxID [12]byte
+
+// NewTxID returns a new random TxID.
+func NewTxID() TxID {
+ var tx TxID
+ if _, err := crand.Read(tx[:]); err != nil {
+ panic(err)
+ }
+ return tx
+}
+
+// Request generates a binding request STUN packet.
+// The transaction ID, tID, should be a random sequence of bytes.
+func Request(tID TxID) []byte {
+ // STUN header, RFC5389 Section 6.
+ const lenAttrSoftware = 4 + len(software)
+ b := make([]byte, 0, headerLen+lenAttrSoftware+lenFingerprint)
+ b = append(b, bindingRequest...)
+ b = appendU16(b, uint16(lenAttrSoftware+lenFingerprint)) // number of bytes following header
+ b = append(b, magicCookie...)
+ b = append(b, tID[:]...)
+
+ // Attribute SOFTWARE, RFC5389 Section 15.5.
+ b = appendU16(b, attrNumSoftware)
+ b = appendU16(b, uint16(len(software)))
+ b = append(b, software...)
+
+ // Attribute FINGERPRINT, RFC5389 Section 15.5.
+ fp := fingerPrint(b)
+ b = appendU16(b, attrNumFingerprint)
+ b = appendU16(b, 4)
+ b = appendU32(b, fp)
+
+ return b
+}
+
+func fingerPrint(b []byte) uint32 { return crc32.ChecksumIEEE(b) ^ 0x5354554e }
+
+func appendU16(b []byte, v uint16) []byte {
+ return append(b, byte(v>>8), byte(v))
+}
+
+func appendU32(b []byte, v uint32) []byte {
+ return append(b, byte(v>>24), byte(v>>16), byte(v>>8), byte(v))
+}
+
+// ParseBindingRequest parses a STUN binding request.
+//
+// It returns an error unless it advertises that it came from
+// Tailscale.
+func ParseBindingRequest(b []byte) (TxID, error) {
+ if !Is(b) {
+ return TxID{}, ErrNotSTUN
+ }
+ if string(b[:len(bindingRequest)]) != bindingRequest {
+ return TxID{}, ErrNotBindingRequest
+ }
+ var txID TxID
+ copy(txID[:], b[8:8+len(txID)])
+ var softwareOK bool
+ var lastAttr uint16
+ var gotFP uint32
+ if err := foreachAttr(b[headerLen:], func(attrType uint16, a []byte) error {
+ lastAttr = attrType
+ if attrType == attrNumSoftware && string(a) == software {
+ softwareOK = true
+ }
+ if attrType == attrNumFingerprint && len(a) == 4 {
+ gotFP = binary.BigEndian.Uint32(a)
+ }
+ return nil
+ }); err != nil {
+ return TxID{}, err
+ }
+ if !softwareOK {
+ return TxID{}, ErrWrongSoftware
+ }
+ if lastAttr != attrNumFingerprint {
+ return TxID{}, ErrNoFingerprint
+ }
+ wantFP := fingerPrint(b[:len(b)-lenFingerprint])
+ if gotFP != wantFP {
+ return TxID{}, ErrWrongFingerprint
+ }
+ return txID, nil
+}
+
+var (
+ ErrNotSTUN = errors.New("response is not a STUN packet")
+ ErrNotSuccessResponse = errors.New("STUN packet is not a response")
+ ErrMalformedAttrs = errors.New("STUN response has malformed attributes")
+ ErrNotBindingRequest = errors.New("STUN request not a binding request")
+ ErrWrongSoftware = errors.New("STUN request came from non-Tailscale software")
+ ErrNoFingerprint = errors.New("STUN request didn't end in fingerprint")
+ ErrWrongFingerprint = errors.New("STUN request had bogus fingerprint")
+)
+
+func foreachAttr(b []byte, fn func(attrType uint16, a []byte) error) error {
+ for len(b) > 0 {
+ if len(b) < 4 {
+ return ErrMalformedAttrs
+ }
+ attrType := binary.BigEndian.Uint16(b[:2])
+ attrLen := int(binary.BigEndian.Uint16(b[2:4]))
+ attrLenWithPad := (attrLen + 3) &^ 3
+ b = b[4:]
+ if attrLenWithPad > len(b) {
+ return ErrMalformedAttrs
+ }
+ if err := fn(attrType, b[:attrLen]); err != nil {
+ return err
+ }
+ b = b[attrLenWithPad:]
+ }
+ return nil
+}
+
+// Response generates a binding response.
+func Response(txID TxID, addrPort netip.AddrPort) []byte {
+ addr := addrPort.Addr()
+
+ var fam byte
+ if addr.Is4() {
+ fam = 1
+ } else if addr.Is6() {
+ fam = 2
+ } else {
+ return nil
+ }
+ attrsLen := 8 + addr.BitLen()/8
+ b := make([]byte, 0, headerLen+attrsLen)
+
+ // Header
+ b = append(b, 0x01, 0x01) // success
+ b = appendU16(b, uint16(attrsLen))
+ b = append(b, magicCookie...)
+ b = append(b, txID[:]...)
+
+ // Attributes (well, one)
+ b = appendU16(b, attrXorMappedAddress)
+ b = appendU16(b, uint16(4+addr.BitLen()/8))
+ b = append(b,
+ 0, // unused byte
+ fam)
+ b = appendU16(b, addrPort.Port()^0x2112) // first half of magicCookie
+ ipa := addr.As16()
+ for i, o := range ipa[16-addr.BitLen()/8:] {
+ if i < 4 {
+ b = append(b, o^magicCookie[i])
+ } else {
+ b = append(b, o^txID[i-len(magicCookie)])
+ }
+ }
+ return b
+}
+
+// ParseResponse parses a successful binding response STUN packet.
+// The IP address is extracted from the XOR-MAPPED-ADDRESS attribute.
+func ParseResponse(b []byte) (tID TxID, addr netip.AddrPort, err error) {
+ if !Is(b) {
+ return tID, netip.AddrPort{}, ErrNotSTUN
+ }
+ copy(tID[:], b[8:8+len(tID)])
+ if b[0] != 0x01 || b[1] != 0x01 {
+ return tID, netip.AddrPort{}, ErrNotSuccessResponse
+ }
+ attrsLen := int(binary.BigEndian.Uint16(b[2:4]))
+ b = b[headerLen:] // remove STUN header
+ if attrsLen > len(b) {
+ return tID, netip.AddrPort{}, ErrMalformedAttrs
+ } else if len(b) > attrsLen {
+ b = b[:attrsLen] // trim trailing packet bytes
+ }
+
+ var fallbackAddr netip.AddrPort
+
+ // Read through the attributes.
+ // The the addr+port reported by XOR-MAPPED-ADDRESS
+ // as the canonical value. If the attribute is not
+ // present but the STUN server responds with
+ // MAPPED-ADDRESS we fall back to it.
+ if err := foreachAttr(b, func(attrType uint16, attr []byte) error {
+ switch attrType {
+ case attrXorMappedAddress, attrXorMappedAddressAlt:
+ ipSlice, port, err := xorMappedAddress(tID, attr)
+ if err != nil {
+ return err
+ }
+ if ip, ok := netip.AddrFromSlice(ipSlice); ok {
+ addr = netip.AddrPortFrom(ip.Unmap(), port)
+ }
+ case attrMappedAddress:
+ ipSlice, port, err := mappedAddress(attr)
+ if err != nil {
+ return ErrMalformedAttrs
+ }
+ if ip, ok := netip.AddrFromSlice(ipSlice); ok {
+ fallbackAddr = netip.AddrPortFrom(ip.Unmap(), port)
+ }
+ }
+ return nil
+
+ }); err != nil {
+ return TxID{}, netip.AddrPort{}, err
+ }
+
+ if addr.IsValid() {
+ return tID, addr, nil
+ }
+ if fallbackAddr.IsValid() {
+ return tID, fallbackAddr, nil
+ }
+ return tID, netip.AddrPort{}, ErrMalformedAttrs
+}
+
+func xorMappedAddress(tID TxID, b []byte) (addr []byte, port uint16, err error) {
+ // XOR-MAPPED-ADDRESS attribute, RFC5389 Section 15.2
+ if len(b) < 4 {
+ return nil, 0, ErrMalformedAttrs
+ }
+ xorPort := binary.BigEndian.Uint16(b[2:4])
+ addrField := b[4:]
+ port = xorPort ^ 0x2112 // first half of magicCookie
+
+ addrLen := familyAddrLen(b[1])
+ if addrLen == 0 {
+ return nil, 0, ErrMalformedAttrs
+ }
+ if len(addrField) < addrLen {
+ return nil, 0, ErrMalformedAttrs
+ }
+ xorAddr := addrField[:addrLen]
+ addr = make([]byte, addrLen)
+ for i := range xorAddr {
+ if i < len(magicCookie) {
+ addr[i] = xorAddr[i] ^ magicCookie[i]
+ } else {
+ addr[i] = xorAddr[i] ^ tID[i-len(magicCookie)]
+ }
+ }
+ return addr, port, nil
+}
+
+func familyAddrLen(fam byte) int {
+ switch fam {
+ case 0x01: // IPv4
+ return net.IPv4len
+ case 0x02: // IPv6
+ return net.IPv6len
+ default:
+ return 0
+ }
+}
+
+func mappedAddress(b []byte) (addr []byte, port uint16, err error) {
+ if len(b) < 4 {
+ return nil, 0, ErrMalformedAttrs
+ }
+ port = uint16(b[2])<<8 | uint16(b[3])
+ addrField := b[4:]
+ addrLen := familyAddrLen(b[1])
+ if addrLen == 0 {
+ return nil, 0, ErrMalformedAttrs
+ }
+ if len(addrField) < addrLen {
+ return nil, 0, ErrMalformedAttrs
+ }
+ return bytes.Clone(addrField[:addrLen]), port, nil
+}
+
+// Is reports whether b is a STUN message.
+func Is(b []byte) bool {
+ return len(b) >= headerLen &&
+ b[0]&0b11000000 == 0 && // top two bits must be zero
+ string(b[4:8]) == magicCookie
+}
diff --git a/net/stun/stun_fuzzer.go b/net/stun/stun_fuzzer.go
index 6f0c9e3b0..9ddb41895 100644
--- a/net/stun/stun_fuzzer.go
+++ b/net/stun/stun_fuzzer.go
@@ -1,12 +1,12 @@
-// Copyright (c) Tailscale Inc & AUTHORS
-// SPDX-License-Identifier: BSD-3-Clause
-//go:build gofuzz
-
-package stun
-
-func FuzzStunParser(data []byte) int {
- _, _, _ = ParseResponse(data)
-
- _, _ = ParseBindingRequest(data)
- return 1
-}
+// Copyright (c) Tailscale Inc & AUTHORS
+// SPDX-License-Identifier: BSD-3-Clause
+//go:build gofuzz
+
+package stun
+
+func FuzzStunParser(data []byte) int {
+ _, _, _ = ParseResponse(data)
+
+ _, _ = ParseBindingRequest(data)
+ return 1
+}
diff --git a/net/tcpinfo/tcpinfo.go b/net/tcpinfo/tcpinfo.go
index a757add9f..adc40ca37 100644
--- a/net/tcpinfo/tcpinfo.go
+++ b/net/tcpinfo/tcpinfo.go
@@ -1,51 +1,51 @@
-// Copyright (c) Tailscale Inc & AUTHORS
-// SPDX-License-Identifier: BSD-3-Clause
-
-// Package tcpinfo provides platform-agnostic accessors to information about a
-// TCP connection (e.g. RTT, MSS, etc.).
-package tcpinfo
-
-import (
- "errors"
- "net"
- "time"
-)
-
-var (
- ErrNotTCP = errors.New("tcpinfo: not a TCP conn")
- ErrUnimplemented = errors.New("tcpinfo: unimplemented")
-)
-
-// RTT returns the RTT for the given net.Conn.
-//
-// If the net.Conn is not a *net.TCPConn and cannot be unwrapped into one, then
-// ErrNotTCP will be returned. If retrieving the RTT is not supported on the
-// current platform, ErrUnimplemented will be returned.
-func RTT(conn net.Conn) (time.Duration, error) {
- tcpConn, err := unwrap(conn)
- if err != nil {
- return 0, err
- }
-
- return rttImpl(tcpConn)
-}
-
-// netConner is implemented by crypto/tls.Conn to unwrap into an underlying
-// net.Conn.
-type netConner interface {
- NetConn() net.Conn
-}
-
-// unwrap attempts to unwrap a net.Conn into an underlying *net.TCPConn
-func unwrap(nc net.Conn) (*net.TCPConn, error) {
- for {
- switch v := nc.(type) {
- case *net.TCPConn:
- return v, nil
- case netConner:
- nc = v.NetConn()
- default:
- return nil, ErrNotTCP
- }
- }
-}
+// Copyright (c) Tailscale Inc & AUTHORS
+// SPDX-License-Identifier: BSD-3-Clause
+
+// Package tcpinfo provides platform-agnostic accessors to information about a
+// TCP connection (e.g. RTT, MSS, etc.).
+package tcpinfo
+
+import (
+ "errors"
+ "net"
+ "time"
+)
+
+var (
+ ErrNotTCP = errors.New("tcpinfo: not a TCP conn")
+ ErrUnimplemented = errors.New("tcpinfo: unimplemented")
+)
+
+// RTT returns the RTT for the given net.Conn.
+//
+// If the net.Conn is not a *net.TCPConn and cannot be unwrapped into one, then
+// ErrNotTCP will be returned. If retrieving the RTT is not supported on the
+// current platform, ErrUnimplemented will be returned.
+func RTT(conn net.Conn) (time.Duration, error) {
+ tcpConn, err := unwrap(conn)
+ if err != nil {
+ return 0, err
+ }
+
+ return rttImpl(tcpConn)
+}
+
+// netConner is implemented by crypto/tls.Conn to unwrap into an underlying
+// net.Conn.
+type netConner interface {
+ NetConn() net.Conn
+}
+
+// unwrap attempts to unwrap a net.Conn into an underlying *net.TCPConn
+func unwrap(nc net.Conn) (*net.TCPConn, error) {
+ for {
+ switch v := nc.(type) {
+ case *net.TCPConn:
+ return v, nil
+ case netConner:
+ nc = v.NetConn()
+ default:
+ return nil, ErrNotTCP
+ }
+ }
+}
diff --git a/net/tcpinfo/tcpinfo_darwin.go b/net/tcpinfo/tcpinfo_darwin.go
index 53fa22fbf..bc4ac08b3 100644
--- a/net/tcpinfo/tcpinfo_darwin.go
+++ b/net/tcpinfo/tcpinfo_darwin.go
@@ -1,33 +1,33 @@
-// Copyright (c) Tailscale Inc & AUTHORS
-// SPDX-License-Identifier: BSD-3-Clause
-
-package tcpinfo
-
-import (
- "net"
- "time"
-
- "golang.org/x/sys/unix"
-)
-
-func rttImpl(conn *net.TCPConn) (time.Duration, error) {
- rawConn, err := conn.SyscallConn()
- if err != nil {
- return 0, err
- }
-
- var (
- tcpInfo *unix.TCPConnectionInfo
- sysErr error
- )
- err = rawConn.Control(func(fd uintptr) {
- tcpInfo, sysErr = unix.GetsockoptTCPConnectionInfo(int(fd), unix.IPPROTO_TCP, unix.TCP_CONNECTION_INFO)
- })
- if err != nil {
- return 0, err
- } else if sysErr != nil {
- return 0, sysErr
- }
-
- return time.Duration(tcpInfo.Rttcur) * time.Millisecond, nil
-}
+// Copyright (c) Tailscale Inc & AUTHORS
+// SPDX-License-Identifier: BSD-3-Clause
+
+package tcpinfo
+
+import (
+ "net"
+ "time"
+
+ "golang.org/x/sys/unix"
+)
+
+func rttImpl(conn *net.TCPConn) (time.Duration, error) {
+ rawConn, err := conn.SyscallConn()
+ if err != nil {
+ return 0, err
+ }
+
+ var (
+ tcpInfo *unix.TCPConnectionInfo
+ sysErr error
+ )
+ err = rawConn.Control(func(fd uintptr) {
+ tcpInfo, sysErr = unix.GetsockoptTCPConnectionInfo(int(fd), unix.IPPROTO_TCP, unix.TCP_CONNECTION_INFO)
+ })
+ if err != nil {
+ return 0, err
+ } else if sysErr != nil {
+ return 0, sysErr
+ }
+
+ return time.Duration(tcpInfo.Rttcur) * time.Millisecond, nil
+}
diff --git a/net/tcpinfo/tcpinfo_linux.go b/net/tcpinfo/tcpinfo_linux.go
index 885d462c9..5d86055bb 100644
--- a/net/tcpinfo/tcpinfo_linux.go
+++ b/net/tcpinfo/tcpinfo_linux.go
@@ -1,33 +1,33 @@
-// Copyright (c) Tailscale Inc & AUTHORS
-// SPDX-License-Identifier: BSD-3-Clause
-
-package tcpinfo
-
-import (
- "net"
- "time"
-
- "golang.org/x/sys/unix"
-)
-
-func rttImpl(conn *net.TCPConn) (time.Duration, error) {
- rawConn, err := conn.SyscallConn()
- if err != nil {
- return 0, err
- }
-
- var (
- tcpInfo *unix.TCPInfo
- sysErr error
- )
- err = rawConn.Control(func(fd uintptr) {
- tcpInfo, sysErr = unix.GetsockoptTCPInfo(int(fd), unix.IPPROTO_TCP, unix.TCP_INFO)
- })
- if err != nil {
- return 0, err
- } else if sysErr != nil {
- return 0, sysErr
- }
-
- return time.Duration(tcpInfo.Rtt) * time.Microsecond, nil
-}
+// Copyright (c) Tailscale Inc & AUTHORS
+// SPDX-License-Identifier: BSD-3-Clause
+
+package tcpinfo
+
+import (
+ "net"
+ "time"
+
+ "golang.org/x/sys/unix"
+)
+
+func rttImpl(conn *net.TCPConn) (time.Duration, error) {
+ rawConn, err := conn.SyscallConn()
+ if err != nil {
+ return 0, err
+ }
+
+ var (
+ tcpInfo *unix.TCPInfo
+ sysErr error
+ )
+ err = rawConn.Control(func(fd uintptr) {
+ tcpInfo, sysErr = unix.GetsockoptTCPInfo(int(fd), unix.IPPROTO_TCP, unix.TCP_INFO)
+ })
+ if err != nil {
+ return 0, err
+ } else if sysErr != nil {
+ return 0, sysErr
+ }
+
+ return time.Duration(tcpInfo.Rtt) * time.Microsecond, nil
+}
diff --git a/net/tcpinfo/tcpinfo_other.go b/net/tcpinfo/tcpinfo_other.go
index be45523ae..f219cda1b 100644
--- a/net/tcpinfo/tcpinfo_other.go
+++ b/net/tcpinfo/tcpinfo_other.go
@@ -1,15 +1,15 @@
-// Copyright (c) Tailscale Inc & AUTHORS
-// SPDX-License-Identifier: BSD-3-Clause
-
-//go:build !linux && !darwin
-
-package tcpinfo
-
-import (
- "net"
- "time"
-)
-
-func rttImpl(conn *net.TCPConn) (time.Duration, error) {
- return 0, ErrUnimplemented
-}
+// Copyright (c) Tailscale Inc & AUTHORS
+// SPDX-License-Identifier: BSD-3-Clause
+
+//go:build !linux && !darwin
+
+package tcpinfo
+
+import (
+ "net"
+ "time"
+)
+
+func rttImpl(conn *net.TCPConn) (time.Duration, error) {
+ return 0, ErrUnimplemented
+}
diff --git a/net/tlsdial/deps_test.go b/net/tlsdial/deps_test.go
index 7a93899c2..750cb300a 100644
--- a/net/tlsdial/deps_test.go
+++ b/net/tlsdial/deps_test.go
@@ -1,8 +1,8 @@
-// Copyright (c) Tailscale Inc & AUTHORS
-// SPDX-License-Identifier: BSD-3-Clause
-
-//go:build for_go_mod_tidy_only
-
-package tlsdial
-
-import _ "filippo.io/mkcert"
+// Copyright (c) Tailscale Inc & AUTHORS
+// SPDX-License-Identifier: BSD-3-Clause
+
+//go:build for_go_mod_tidy_only
+
+package tlsdial
+
+import _ "filippo.io/mkcert"
diff --git a/net/tsdial/dnsmap_test.go b/net/tsdial/dnsmap_test.go
index 43461a135..f846b853e 100644
--- a/net/tsdial/dnsmap_test.go
+++ b/net/tsdial/dnsmap_test.go
@@ -1,125 +1,125 @@
-// Copyright (c) Tailscale Inc & AUTHORS
-// SPDX-License-Identifier: BSD-3-Clause
-
-package tsdial
-
-import (
- "net/netip"
- "reflect"
- "testing"
-
- "tailscale.com/tailcfg"
- "tailscale.com/types/netmap"
-)
-
-func nodeViews(v []*tailcfg.Node) []tailcfg.NodeView {
- nv := make([]tailcfg.NodeView, len(v))
- for i, n := range v {
- nv[i] = n.View()
- }
- return nv
-}
-
-func TestDNSMapFromNetworkMap(t *testing.T) {
- pfx := netip.MustParsePrefix
- ip := netip.MustParseAddr
- tests := []struct {
- name string
- nm *netmap.NetworkMap
- want dnsMap
- }{
- {
- name: "self",
- nm: &netmap.NetworkMap{
- Name: "foo.tailnet",
- SelfNode: (&tailcfg.Node{
- Addresses: []netip.Prefix{
- pfx("100.102.103.104/32"),
- pfx("100::123/128"),
- },
- }).View(),
- },
- want: dnsMap{
- "foo": ip("100.102.103.104"),
- "foo.tailnet": ip("100.102.103.104"),
- },
- },
- {
- name: "self_and_peers",
- nm: &netmap.NetworkMap{
- Name: "foo.tailnet",
- SelfNode: (&tailcfg.Node{
- Addresses: []netip.Prefix{
- pfx("100.102.103.104/32"),
- pfx("100::123/128"),
- },
- }).View(),
- Peers: []tailcfg.NodeView{
- (&tailcfg.Node{
- Name: "a.tailnet",
- Addresses: []netip.Prefix{
- pfx("100.0.0.201/32"),
- pfx("100::201/128"),
- },
- }).View(),
- (&tailcfg.Node{
- Name: "b.tailnet",
- Addresses: []netip.Prefix{
- pfx("100::202/128"),
- },
- }).View(),
- },
- },
- want: dnsMap{
- "foo": ip("100.102.103.104"),
- "foo.tailnet": ip("100.102.103.104"),
- "a": ip("100.0.0.201"),
- "a.tailnet": ip("100.0.0.201"),
- "b": ip("100::202"),
- "b.tailnet": ip("100::202"),
- },
- },
- {
- name: "self_has_v6_only",
- nm: &netmap.NetworkMap{
- Name: "foo.tailnet",
- SelfNode: (&tailcfg.Node{
- Addresses: []netip.Prefix{
- pfx("100::123/128"),
- },
- }).View(),
- Peers: nodeViews([]*tailcfg.Node{
- {
- Name: "a.tailnet",
- Addresses: []netip.Prefix{
- pfx("100.0.0.201/32"),
- pfx("100::201/128"),
- },
- },
- {
- Name: "b.tailnet",
- Addresses: []netip.Prefix{
- pfx("100::202/128"),
- },
- },
- }),
- },
- want: dnsMap{
- "foo": ip("100::123"),
- "foo.tailnet": ip("100::123"),
- "a": ip("100::201"),
- "a.tailnet": ip("100::201"),
- "b": ip("100::202"),
- "b.tailnet": ip("100::202"),
- },
- },
- }
- for _, tt := range tests {
- t.Run(tt.name, func(t *testing.T) {
- got := dnsMapFromNetworkMap(tt.nm)
- if !reflect.DeepEqual(got, tt.want) {
- t.Errorf("mismatch:\n got %v\nwant %v\n", got, tt.want)
- }
- })
- }
-}
+// Copyright (c) Tailscale Inc & AUTHORS
+// SPDX-License-Identifier: BSD-3-Clause
+
+package tsdial
+
+import (
+ "net/netip"
+ "reflect"
+ "testing"
+
+ "tailscale.com/tailcfg"
+ "tailscale.com/types/netmap"
+)
+
+func nodeViews(v []*tailcfg.Node) []tailcfg.NodeView {
+ nv := make([]tailcfg.NodeView, len(v))
+ for i, n := range v {
+ nv[i] = n.View()
+ }
+ return nv
+}
+
+func TestDNSMapFromNetworkMap(t *testing.T) {
+ pfx := netip.MustParsePrefix
+ ip := netip.MustParseAddr
+ tests := []struct {
+ name string
+ nm *netmap.NetworkMap
+ want dnsMap
+ }{
+ {
+ name: "self",
+ nm: &netmap.NetworkMap{
+ Name: "foo.tailnet",
+ SelfNode: (&tailcfg.Node{
+ Addresses: []netip.Prefix{
+ pfx("100.102.103.104/32"),
+ pfx("100::123/128"),
+ },
+ }).View(),
+ },
+ want: dnsMap{
+ "foo": ip("100.102.103.104"),
+ "foo.tailnet": ip("100.102.103.104"),
+ },
+ },
+ {
+ name: "self_and_peers",
+ nm: &netmap.NetworkMap{
+ Name: "foo.tailnet",
+ SelfNode: (&tailcfg.Node{
+ Addresses: []netip.Prefix{
+ pfx("100.102.103.104/32"),
+ pfx("100::123/128"),
+ },
+ }).View(),
+ Peers: []tailcfg.NodeView{
+ (&tailcfg.Node{
+ Name: "a.tailnet",
+ Addresses: []netip.Prefix{
+ pfx("100.0.0.201/32"),
+ pfx("100::201/128"),
+ },
+ }).View(),
+ (&tailcfg.Node{
+ Name: "b.tailnet",
+ Addresses: []netip.Prefix{
+ pfx("100::202/128"),
+ },
+ }).View(),
+ },
+ },
+ want: dnsMap{
+ "foo": ip("100.102.103.104"),
+ "foo.tailnet": ip("100.102.103.104"),
+ "a": ip("100.0.0.201"),
+ "a.tailnet": ip("100.0.0.201"),
+ "b": ip("100::202"),
+ "b.tailnet": ip("100::202"),
+ },
+ },
+ {
+ name: "self_has_v6_only",
+ nm: &netmap.NetworkMap{
+ Name: "foo.tailnet",
+ SelfNode: (&tailcfg.Node{
+ Addresses: []netip.Prefix{
+ pfx("100::123/128"),
+ },
+ }).View(),
+ Peers: nodeViews([]*tailcfg.Node{
+ {
+ Name: "a.tailnet",
+ Addresses: []netip.Prefix{
+ pfx("100.0.0.201/32"),
+ pfx("100::201/128"),
+ },
+ },
+ {
+ Name: "b.tailnet",
+ Addresses: []netip.Prefix{
+ pfx("100::202/128"),
+ },
+ },
+ }),
+ },
+ want: dnsMap{
+ "foo": ip("100::123"),
+ "foo.tailnet": ip("100::123"),
+ "a": ip("100::201"),
+ "a.tailnet": ip("100::201"),
+ "b": ip("100::202"),
+ "b.tailnet": ip("100::202"),
+ },
+ },
+ }
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ got := dnsMapFromNetworkMap(tt.nm)
+ if !reflect.DeepEqual(got, tt.want) {
+ t.Errorf("mismatch:\n got %v\nwant %v\n", got, tt.want)
+ }
+ })
+ }
+}
diff --git a/net/tsdial/dohclient.go b/net/tsdial/dohclient.go
index d830398cd..64c127fd3 100644
--- a/net/tsdial/dohclient.go
+++ b/net/tsdial/dohclient.go
@@ -1,100 +1,100 @@
-// Copyright (c) Tailscale Inc & AUTHORS
-// SPDX-License-Identifier: BSD-3-Clause
-
-package tsdial
-
-import (
- "bytes"
- "context"
- "errors"
- "fmt"
- "io"
- "net"
- "net/http"
- "time"
-
- "tailscale.com/net/dnscache"
-)
-
-// dohConn is a net.PacketConn suitable for returning from
-// net.Dialer.Dial to send DNS queries over PeerAPI to exit nodes'
-// ExitDNS DoH proxy service.
-type dohConn struct {
- ctx context.Context
- baseURL string
- hc *http.Client // if nil, default is used
- dnsCache *dnscache.MessageCache
-
- rbuf bytes.Buffer
-}
-
-var (
- _ net.Conn = (*dohConn)(nil)
- _ net.PacketConn = (*dohConn)(nil) // be a PacketConn to change net.Resolver semantics
-)
-
-func (*dohConn) Close() error { return nil }
-func (*dohConn) LocalAddr() net.Addr { return todoAddr{} }
-func (*dohConn) RemoteAddr() net.Addr { return todoAddr{} }
-func (*dohConn) SetDeadline(t time.Time) error { return nil }
-func (*dohConn) SetReadDeadline(t time.Time) error { return nil }
-func (*dohConn) SetWriteDeadline(t time.Time) error { return nil }
-
-func (c *dohConn) WriteTo(p []byte, addr net.Addr) (n int, err error) {
- return c.Write(p)
-}
-
-func (c *dohConn) ReadFrom(p []byte) (n int, addr net.Addr, err error) {
- n, err = c.Read(p)
- return n, todoAddr{}, err
-}
-
-func (c *dohConn) Read(p []byte) (n int, err error) {
- return c.rbuf.Read(p)
-}
-
-func (c *dohConn) Write(packet []byte) (n int, err error) {
- if c.dnsCache != nil {
- err := c.dnsCache.ReplyFromCache(&c.rbuf, packet)
- if err == nil {
- // Cache hit.
- // TODO(bradfitz): add clientmetric
- return len(packet), nil
- }
- c.rbuf.Reset()
- }
- req, err := http.NewRequestWithContext(c.ctx, "POST", c.baseURL, bytes.NewReader(packet))
- if err != nil {
- return 0, err
- }
- const dohType = "application/dns-message"
- req.Header.Set("Content-Type", dohType)
- hc := c.hc
- if hc == nil {
- hc = http.DefaultClient
- }
- hres, err := hc.Do(req)
- if err != nil {
- return 0, err
- }
- defer hres.Body.Close()
- if hres.StatusCode != 200 {
- return 0, errors.New(hres.Status)
- }
- if ct := hres.Header.Get("Content-Type"); ct != dohType {
- return 0, fmt.Errorf("unexpected response Content-Type %q", ct)
- }
- _, err = io.Copy(&c.rbuf, hres.Body)
- if err != nil {
- return 0, err
- }
- if c.dnsCache != nil {
- c.dnsCache.AddCacheEntry(packet, c.rbuf.Bytes())
- }
- return len(packet), nil
-}
-
-type todoAddr struct{}
-
-func (todoAddr) Network() string { return "unused" }
-func (todoAddr) String() string { return "unused-todoAddr" }
+// Copyright (c) Tailscale Inc & AUTHORS
+// SPDX-License-Identifier: BSD-3-Clause
+
+package tsdial
+
+import (
+ "bytes"
+ "context"
+ "errors"
+ "fmt"
+ "io"
+ "net"
+ "net/http"
+ "time"
+
+ "tailscale.com/net/dnscache"
+)
+
+// dohConn is a net.PacketConn suitable for returning from
+// net.Dialer.Dial to send DNS queries over PeerAPI to exit nodes'
+// ExitDNS DoH proxy service.
+type dohConn struct {
+ ctx context.Context
+ baseURL string
+ hc *http.Client // if nil, default is used
+ dnsCache *dnscache.MessageCache
+
+ rbuf bytes.Buffer
+}
+
+var (
+ _ net.Conn = (*dohConn)(nil)
+ _ net.PacketConn = (*dohConn)(nil) // be a PacketConn to change net.Resolver semantics
+)
+
+func (*dohConn) Close() error { return nil }
+func (*dohConn) LocalAddr() net.Addr { return todoAddr{} }
+func (*dohConn) RemoteAddr() net.Addr { return todoAddr{} }
+func (*dohConn) SetDeadline(t time.Time) error { return nil }
+func (*dohConn) SetReadDeadline(t time.Time) error { return nil }
+func (*dohConn) SetWriteDeadline(t time.Time) error { return nil }
+
+func (c *dohConn) WriteTo(p []byte, addr net.Addr) (n int, err error) {
+ return c.Write(p)
+}
+
+func (c *dohConn) ReadFrom(p []byte) (n int, addr net.Addr, err error) {
+ n, err = c.Read(p)
+ return n, todoAddr{}, err
+}
+
+func (c *dohConn) Read(p []byte) (n int, err error) {
+ return c.rbuf.Read(p)
+}
+
+func (c *dohConn) Write(packet []byte) (n int, err error) {
+ if c.dnsCache != nil {
+ err := c.dnsCache.ReplyFromCache(&c.rbuf, packet)
+ if err == nil {
+ // Cache hit.
+ // TODO(bradfitz): add clientmetric
+ return len(packet), nil
+ }
+ c.rbuf.Reset()
+ }
+ req, err := http.NewRequestWithContext(c.ctx, "POST", c.baseURL, bytes.NewReader(packet))
+ if err != nil {
+ return 0, err
+ }
+ const dohType = "application/dns-message"
+ req.Header.Set("Content-Type", dohType)
+ hc := c.hc
+ if hc == nil {
+ hc = http.DefaultClient
+ }
+ hres, err := hc.Do(req)
+ if err != nil {
+ return 0, err
+ }
+ defer hres.Body.Close()
+ if hres.StatusCode != 200 {
+ return 0, errors.New(hres.Status)
+ }
+ if ct := hres.Header.Get("Content-Type"); ct != dohType {
+ return 0, fmt.Errorf("unexpected response Content-Type %q", ct)
+ }
+ _, err = io.Copy(&c.rbuf, hres.Body)
+ if err != nil {
+ return 0, err
+ }
+ if c.dnsCache != nil {
+ c.dnsCache.AddCacheEntry(packet, c.rbuf.Bytes())
+ }
+ return len(packet), nil
+}
+
+type todoAddr struct{}
+
+func (todoAddr) Network() string { return "unused" }
+func (todoAddr) String() string { return "unused-todoAddr" }
diff --git a/net/tsdial/dohclient_test.go b/net/tsdial/dohclient_test.go
index 23255769f..41a66f8f7 100644
--- a/net/tsdial/dohclient_test.go
+++ b/net/tsdial/dohclient_test.go
@@ -1,31 +1,31 @@
-// Copyright (c) Tailscale Inc & AUTHORS
-// SPDX-License-Identifier: BSD-3-Clause
-
-package tsdial
-
-import (
- "context"
- "flag"
- "net"
- "testing"
- "time"
-)
-
-var dohBase = flag.String("doh-base", "", "DoH base URL for manual DoH tests; e.g. \"http://100.68.82.120:47830/dns-query\"")
-
-func TestDoHResolve(t *testing.T) {
- if *dohBase == "" {
- t.Skip("skipping manual test without --doh-base= set")
- }
- ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
- defer cancel()
- var r net.Resolver
- r.Dial = func(ctx context.Context, network, address string) (net.Conn, error) {
- return &dohConn{ctx: ctx, baseURL: *dohBase}, nil
- }
- addrs, err := r.LookupIP(ctx, "ip4", "google.com.")
- if err != nil {
- t.Fatal(err)
- }
- t.Logf("Got: %q", addrs)
-}
+// Copyright (c) Tailscale Inc & AUTHORS
+// SPDX-License-Identifier: BSD-3-Clause
+
+package tsdial
+
+import (
+ "context"
+ "flag"
+ "net"
+ "testing"
+ "time"
+)
+
+var dohBase = flag.String("doh-base", "", "DoH base URL for manual DoH tests; e.g. \"http://100.68.82.120:47830/dns-query\"")
+
+func TestDoHResolve(t *testing.T) {
+ if *dohBase == "" {
+ t.Skip("skipping manual test without --doh-base= set")
+ }
+ ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
+ defer cancel()
+ var r net.Resolver
+ r.Dial = func(ctx context.Context, network, address string) (net.Conn, error) {
+ return &dohConn{ctx: ctx, baseURL: *dohBase}, nil
+ }
+ addrs, err := r.LookupIP(ctx, "ip4", "google.com.")
+ if err != nil {
+ t.Fatal(err)
+ }
+ t.Logf("Got: %q", addrs)
+}
diff --git a/net/tshttpproxy/mksyscall.go b/net/tshttpproxy/mksyscall.go
index f8fdae89b..467dc4917 100644
--- a/net/tshttpproxy/mksyscall.go
+++ b/net/tshttpproxy/mksyscall.go
@@ -1,11 +1,11 @@
-// Copyright (c) Tailscale Inc & AUTHORS
-// SPDX-License-Identifier: BSD-3-Clause
-
-package tshttpproxy
-
-//go:generate go run golang.org/x/sys/windows/mkwinsyscall -output zsyscall_windows.go mksyscall.go
-
-//sys globalFree(hglobal winHGlobal) (err error) [failretval==0] = kernel32.GlobalFree
-//sys winHTTPCloseHandle(whi winHTTPInternet) (err error) [failretval==0] = winhttp.WinHttpCloseHandle
-//sys winHTTPGetProxyForURL(whi winHTTPInternet, url *uint16, options *winHTTPAutoProxyOptions, proxyInfo *winHTTPProxyInfo) (err error) [failretval==0] = winhttp.WinHttpGetProxyForUrl
-//sys winHTTPOpen(agent *uint16, accessType uint32, proxy *uint16, proxyBypass *uint16, flags uint32) (whi winHTTPInternet, err error) [failretval==0] = winhttp.WinHttpOpen
+// Copyright (c) Tailscale Inc & AUTHORS
+// SPDX-License-Identifier: BSD-3-Clause
+
+package tshttpproxy
+
+//go:generate go run golang.org/x/sys/windows/mkwinsyscall -output zsyscall_windows.go mksyscall.go
+
+//sys globalFree(hglobal winHGlobal) (err error) [failretval==0] = kernel32.GlobalFree
+//sys winHTTPCloseHandle(whi winHTTPInternet) (err error) [failretval==0] = winhttp.WinHttpCloseHandle
+//sys winHTTPGetProxyForURL(whi winHTTPInternet, url *uint16, options *winHTTPAutoProxyOptions, proxyInfo *winHTTPProxyInfo) (err error) [failretval==0] = winhttp.WinHttpGetProxyForUrl
+//sys winHTTPOpen(agent *uint16, accessType uint32, proxy *uint16, proxyBypass *uint16, flags uint32) (whi winHTTPInternet, err error) [failretval==0] = winhttp.WinHttpOpen
diff --git a/net/tshttpproxy/tshttpproxy_linux.go b/net/tshttpproxy/tshttpproxy_linux.go
index b241c256d..09019893a 100644
--- a/net/tshttpproxy/tshttpproxy_linux.go
+++ b/net/tshttpproxy/tshttpproxy_linux.go
@@ -1,24 +1,24 @@
-// Copyright (c) Tailscale Inc & AUTHORS
-// SPDX-License-Identifier: BSD-3-Clause
-
-//go:build linux
-
-package tshttpproxy
-
-import (
- "net/http"
- "net/url"
-
- "tailscale.com/version/distro"
-)
-
-func init() {
- sysProxyFromEnv = linuxSysProxyFromEnv
-}
-
-func linuxSysProxyFromEnv(req *http.Request) (*url.URL, error) {
- if distro.Get() == distro.Synology {
- return synologyProxyFromConfigCached(req)
- }
- return nil, nil
-}
+// Copyright (c) Tailscale Inc & AUTHORS
+// SPDX-License-Identifier: BSD-3-Clause
+
+//go:build linux
+
+package tshttpproxy
+
+import (
+ "net/http"
+ "net/url"
+
+ "tailscale.com/version/distro"
+)
+
+func init() {
+ sysProxyFromEnv = linuxSysProxyFromEnv
+}
+
+func linuxSysProxyFromEnv(req *http.Request) (*url.URL, error) {
+ if distro.Get() == distro.Synology {
+ return synologyProxyFromConfigCached(req)
+ }
+ return nil, nil
+}
diff --git a/net/tshttpproxy/tshttpproxy_synology_test.go b/net/tshttpproxy/tshttpproxy_synology_test.go
index 3061740f3..e11c9d059 100644
--- a/net/tshttpproxy/tshttpproxy_synology_test.go
+++ b/net/tshttpproxy/tshttpproxy_synology_test.go
@@ -1,376 +1,376 @@
-// Copyright (c) Tailscale Inc & AUTHORS
-// SPDX-License-Identifier: BSD-3-Clause
-
-//go:build linux
-
-package tshttpproxy
-
-import (
- "errors"
- "fmt"
- "io"
- "net/http"
- "net/url"
- "os"
- "path/filepath"
- "strings"
- "testing"
- "time"
-
- "tailscale.com/tstest"
-)
-
-func TestSynologyProxyFromConfigCached(t *testing.T) {
- req, err := http.NewRequest("GET", "http://example.org/", nil)
- if err != nil {
- t.Fatal(err)
- }
-
- tstest.Replace(t, &synologyProxyConfigPath, filepath.Join(t.TempDir(), "proxy.conf"))
-
- t.Run("no config file", func(t *testing.T) {
- if _, err := os.Stat(synologyProxyConfigPath); err == nil {
- t.Fatalf("%s must not exist for this test", synologyProxyConfigPath)
- }
-
- cache.updated = time.Time{}
- cache.httpProxy = nil
- cache.httpsProxy = nil
-
- if val, err := synologyProxyFromConfigCached(req); val != nil || err != nil {
- t.Fatalf("got %s, %v; want nil, nil", val, err)
- }
-
- if got, want := cache.updated, time.Unix(0, 0); got != want {
- t.Fatalf("got %s, want %s", got, want)
- }
- if cache.httpProxy != nil {
- t.Fatalf("got %s, want nil", cache.httpProxy)
- }
- if cache.httpsProxy != nil {
- t.Fatalf("got %s, want nil", cache.httpsProxy)
- }
- })
-
- t.Run("config file updated", func(t *testing.T) {
- cache.updated = time.Now()
- cache.httpProxy = nil
- cache.httpsProxy = nil
-
- if err := os.WriteFile(synologyProxyConfigPath, []byte(`
-proxy_enabled=yes
-http_host=10.0.0.55
-http_port=80
-https_host=10.0.0.66
-https_port=443
- `), 0600); err != nil {
- t.Fatal(err)
- }
-
- val, err := synologyProxyFromConfigCached(req)
- if err != nil {
- t.Fatal(err)
- }
-
- if cache.httpProxy == nil {
- t.Fatal("http proxy was not cached")
- }
- if cache.httpsProxy == nil {
- t.Fatal("https proxy was not cached")
- }
-
- if want := urlMustParse("http://10.0.0.55:80"); val.String() != want.String() {
- t.Fatalf("got %s; want %s", val, want)
- }
- })
-
- t.Run("config file removed", func(t *testing.T) {
- cache.updated = time.Now()
- cache.httpProxy = urlMustParse("http://127.0.0.1/")
- cache.httpsProxy = urlMustParse("http://127.0.0.1/")
-
- if err := os.Remove(synologyProxyConfigPath); err != nil && !os.IsNotExist(err) {
- t.Fatal(err)
- }
-
- val, err := synologyProxyFromConfigCached(req)
- if err != nil {
- t.Fatal(err)
- }
- if val != nil {
- t.Fatalf("got %s; want nil", val)
- }
- if cache.httpProxy != nil {
- t.Fatalf("got %s, want nil", cache.httpProxy)
- }
- if cache.httpsProxy != nil {
- t.Fatalf("got %s, want nil", cache.httpsProxy)
- }
- })
-
- t.Run("picks proxy from request scheme", func(t *testing.T) {
- cache.updated = time.Now()
- cache.httpProxy = nil
- cache.httpsProxy = nil
-
- if err := os.WriteFile(synologyProxyConfigPath, []byte(`
-proxy_enabled=yes
-http_host=10.0.0.55
-http_port=80
-https_host=10.0.0.66
-https_port=443
- `), 0600); err != nil {
- t.Fatal(err)
- }
-
- httpReq, err := http.NewRequest("GET", "http://example.com", nil)
- if err != nil {
- t.Fatal(err)
- }
- val, err := synologyProxyFromConfigCached(httpReq)
- if err != nil {
- t.Fatal(err)
- }
- if val == nil {
- t.Fatalf("got nil, want an http URL")
- }
- if got, want := val.String(), "http://10.0.0.55:80"; got != want {
- t.Fatalf("got %q, want %q", got, want)
- }
-
- httpsReq, err := http.NewRequest("GET", "https://example.com", nil)
- if err != nil {
- t.Fatal(err)
- }
- val, err = synologyProxyFromConfigCached(httpsReq)
- if err != nil {
- t.Fatal(err)
- }
- if val == nil {
- t.Fatalf("got nil, want an http URL")
- }
- if got, want := val.String(), "http://10.0.0.66:443"; got != want {
- t.Fatalf("got %q, want %q", got, want)
- }
- })
-}
-
-func TestSynologyProxiesFromConfig(t *testing.T) {
- var (
- openReader io.ReadCloser
- openErr error
- )
- tstest.Replace(t, &openSynologyProxyConf, func() (io.ReadCloser, error) {
- return openReader, openErr
- })
-
- t.Run("with config", func(t *testing.T) {
- mc := &mustCloser{Reader: strings.NewReader(`
-proxy_user=foo
-proxy_pwd=bar
-proxy_enabled=yes
-adv_enabled=yes
-bypass_enabled=yes
-auth_enabled=yes
-https_host=10.0.0.66
-https_port=8443
-http_host=10.0.0.55
-http_port=80
- `)}
- defer mc.check(t)
- openReader = mc
-
- httpProxy, httpsProxy, err := synologyProxiesFromConfig()
-
- if got, want := err, openErr; got != want {
- t.Fatalf("got %s, want %s", got, want)
- }
-
- if got, want := httpsProxy, urlMustParse("http://foo:bar@10.0.0.66:8443"); got.String() != want.String() {
- t.Fatalf("got %s, want %s", got, want)
- }
-
- if got, want := err, openErr; got != want {
- t.Fatalf("got %s, want %s", got, want)
- }
-
- if got, want := httpProxy, urlMustParse("http://foo:bar@10.0.0.55:80"); got.String() != want.String() {
- t.Fatalf("got %s, want %s", got, want)
- }
-
- })
-
- t.Run("nonexistent config", func(t *testing.T) {
- openReader = nil
- openErr = os.ErrNotExist
-
- httpProxy, httpsProxy, err := synologyProxiesFromConfig()
- if err != nil {
- t.Fatalf("expected no error, got %s", err)
- }
- if httpProxy != nil {
- t.Fatalf("expected no url, got %s", httpProxy)
- }
- if httpsProxy != nil {
- t.Fatalf("expected no url, got %s", httpsProxy)
- }
- })
-
- t.Run("error opening config", func(t *testing.T) {
- openReader = nil
- openErr = errors.New("example error")
-
- httpProxy, httpsProxy, err := synologyProxiesFromConfig()
- if err != openErr {
- t.Fatalf("expected %s, got %s", openErr, err)
- }
- if httpProxy != nil {
- t.Fatalf("expected no url, got %s", httpProxy)
- }
- if httpsProxy != nil {
- t.Fatalf("expected no url, got %s", httpsProxy)
- }
- })
-
-}
-
-func TestParseSynologyConfig(t *testing.T) {
- cases := map[string]struct {
- input string
- httpProxy *url.URL
- httpsProxy *url.URL
- err error
- }{
- "populated": {
- input: `
-proxy_user=foo
-proxy_pwd=bar
-proxy_enabled=yes
-adv_enabled=yes
-bypass_enabled=yes
-auth_enabled=yes
-https_host=10.0.0.66
-https_port=8443
-http_host=10.0.0.55
-http_port=80
-`,
- httpProxy: urlMustParse("http://foo:bar@10.0.0.55:80"),
- httpsProxy: urlMustParse("http://foo:bar@10.0.0.66:8443"),
- err: nil,
- },
- "no-auth": {
- input: `
-proxy_user=foo
-proxy_pwd=bar
-proxy_enabled=yes
-adv_enabled=yes
-bypass_enabled=yes
-auth_enabled=no
-https_host=10.0.0.66
-https_port=8443
-http_host=10.0.0.55
-http_port=80
-`,
- httpProxy: urlMustParse("http://10.0.0.55:80"),
- httpsProxy: urlMustParse("http://10.0.0.66:8443"),
- err: nil,
- },
- "http-only": {
- input: `
-proxy_user=foo
-proxy_pwd=bar
-proxy_enabled=yes
-adv_enabled=yes
-bypass_enabled=yes
-auth_enabled=yes
-https_host=
-https_port=8443
-http_host=10.0.0.55
-http_port=80
-`,
- httpProxy: urlMustParse("http://foo:bar@10.0.0.55:80"),
- httpsProxy: nil,
- err: nil,
- },
- "empty": {
- input: `
-proxy_user=
-proxy_pwd=
-proxy_enabled=
-adv_enabled=
-bypass_enabled=
-auth_enabled=
-https_host=
-https_port=
-http_host=
-http_port=
-`,
- httpProxy: nil,
- httpsProxy: nil,
- err: nil,
- },
- }
-
- for name, example := range cases {
- t.Run(name, func(t *testing.T) {
- httpProxy, httpsProxy, err := parseSynologyConfig(strings.NewReader(example.input))
- if err != example.err {
- t.Fatal(err)
- }
- if example.err != nil {
- return
- }
-
- if example.httpProxy == nil && httpProxy != nil {
- t.Fatalf("got %s, want nil", httpProxy)
- }
-
- if example.httpProxy != nil {
- if httpProxy == nil {
- t.Fatalf("got nil, want %s", example.httpProxy)
- }
-
- if got, want := example.httpProxy.String(), httpProxy.String(); got != want {
- t.Fatalf("got %s, want %s", got, want)
- }
- }
-
- if example.httpsProxy == nil && httpsProxy != nil {
- t.Fatalf("got %s, want nil", httpProxy)
- }
-
- if example.httpsProxy != nil {
- if httpsProxy == nil {
- t.Fatalf("got nil, want %s", example.httpsProxy)
- }
-
- if got, want := example.httpsProxy.String(), httpsProxy.String(); got != want {
- t.Fatalf("got %s, want %s", got, want)
- }
- }
- })
- }
-}
-func urlMustParse(u string) *url.URL {
- r, err := url.Parse(u)
- if err != nil {
- panic(fmt.Sprintf("urlMustParse: %s", err))
- }
- return r
-}
-
-type mustCloser struct {
- io.Reader
- closed bool
-}
-
-func (m *mustCloser) Close() error {
- m.closed = true
- return nil
-}
-
-func (m *mustCloser) check(t *testing.T) {
- if !m.closed {
- t.Errorf("mustCloser wrapping %#v was not closed at time of check", m.Reader)
- }
-}
+// Copyright (c) Tailscale Inc & AUTHORS
+// SPDX-License-Identifier: BSD-3-Clause
+
+//go:build linux
+
+package tshttpproxy
+
+import (
+ "errors"
+ "fmt"
+ "io"
+ "net/http"
+ "net/url"
+ "os"
+ "path/filepath"
+ "strings"
+ "testing"
+ "time"
+
+ "tailscale.com/tstest"
+)
+
+func TestSynologyProxyFromConfigCached(t *testing.T) {
+ req, err := http.NewRequest("GET", "http://example.org/", nil)
+ if err != nil {
+ t.Fatal(err)
+ }
+
+ tstest.Replace(t, &synologyProxyConfigPath, filepath.Join(t.TempDir(), "proxy.conf"))
+
+ t.Run("no config file", func(t *testing.T) {
+ if _, err := os.Stat(synologyProxyConfigPath); err == nil {
+ t.Fatalf("%s must not exist for this test", synologyProxyConfigPath)
+ }
+
+ cache.updated = time.Time{}
+ cache.httpProxy = nil
+ cache.httpsProxy = nil
+
+ if val, err := synologyProxyFromConfigCached(req); val != nil || err != nil {
+ t.Fatalf("got %s, %v; want nil, nil", val, err)
+ }
+
+ if got, want := cache.updated, time.Unix(0, 0); got != want {
+ t.Fatalf("got %s, want %s", got, want)
+ }
+ if cache.httpProxy != nil {
+ t.Fatalf("got %s, want nil", cache.httpProxy)
+ }
+ if cache.httpsProxy != nil {
+ t.Fatalf("got %s, want nil", cache.httpsProxy)
+ }
+ })
+
+ t.Run("config file updated", func(t *testing.T) {
+ cache.updated = time.Now()
+ cache.httpProxy = nil
+ cache.httpsProxy = nil
+
+ if err := os.WriteFile(synologyProxyConfigPath, []byte(`
+proxy_enabled=yes
+http_host=10.0.0.55
+http_port=80
+https_host=10.0.0.66
+https_port=443
+ `), 0600); err != nil {
+ t.Fatal(err)
+ }
+
+ val, err := synologyProxyFromConfigCached(req)
+ if err != nil {
+ t.Fatal(err)
+ }
+
+ if cache.httpProxy == nil {
+ t.Fatal("http proxy was not cached")
+ }
+ if cache.httpsProxy == nil {
+ t.Fatal("https proxy was not cached")
+ }
+
+ if want := urlMustParse("http://10.0.0.55:80"); val.String() != want.String() {
+ t.Fatalf("got %s; want %s", val, want)
+ }
+ })
+
+ t.Run("config file removed", func(t *testing.T) {
+ cache.updated = time.Now()
+ cache.httpProxy = urlMustParse("http://127.0.0.1/")
+ cache.httpsProxy = urlMustParse("http://127.0.0.1/")
+
+ if err := os.Remove(synologyProxyConfigPath); err != nil && !os.IsNotExist(err) {
+ t.Fatal(err)
+ }
+
+ val, err := synologyProxyFromConfigCached(req)
+ if err != nil {
+ t.Fatal(err)
+ }
+ if val != nil {
+ t.Fatalf("got %s; want nil", val)
+ }
+ if cache.httpProxy != nil {
+ t.Fatalf("got %s, want nil", cache.httpProxy)
+ }
+ if cache.httpsProxy != nil {
+ t.Fatalf("got %s, want nil", cache.httpsProxy)
+ }
+ })
+
+ t.Run("picks proxy from request scheme", func(t *testing.T) {
+ cache.updated = time.Now()
+ cache.httpProxy = nil
+ cache.httpsProxy = nil
+
+ if err := os.WriteFile(synologyProxyConfigPath, []byte(`
+proxy_enabled=yes
+http_host=10.0.0.55
+http_port=80
+https_host=10.0.0.66
+https_port=443
+ `), 0600); err != nil {
+ t.Fatal(err)
+ }
+
+ httpReq, err := http.NewRequest("GET", "http://example.com", nil)
+ if err != nil {
+ t.Fatal(err)
+ }
+ val, err := synologyProxyFromConfigCached(httpReq)
+ if err != nil {
+ t.Fatal(err)
+ }
+ if val == nil {
+ t.Fatalf("got nil, want an http URL")
+ }
+ if got, want := val.String(), "http://10.0.0.55:80"; got != want {
+ t.Fatalf("got %q, want %q", got, want)
+ }
+
+ httpsReq, err := http.NewRequest("GET", "https://example.com", nil)
+ if err != nil {
+ t.Fatal(err)
+ }
+ val, err = synologyProxyFromConfigCached(httpsReq)
+ if err != nil {
+ t.Fatal(err)
+ }
+ if val == nil {
+ t.Fatalf("got nil, want an http URL")
+ }
+ if got, want := val.String(), "http://10.0.0.66:443"; got != want {
+ t.Fatalf("got %q, want %q", got, want)
+ }
+ })
+}
+
+func TestSynologyProxiesFromConfig(t *testing.T) {
+ var (
+ openReader io.ReadCloser
+ openErr error
+ )
+ tstest.Replace(t, &openSynologyProxyConf, func() (io.ReadCloser, error) {
+ return openReader, openErr
+ })
+
+ t.Run("with config", func(t *testing.T) {
+ mc := &mustCloser{Reader: strings.NewReader(`
+proxy_user=foo
+proxy_pwd=bar
+proxy_enabled=yes
+adv_enabled=yes
+bypass_enabled=yes
+auth_enabled=yes
+https_host=10.0.0.66
+https_port=8443
+http_host=10.0.0.55
+http_port=80
+ `)}
+ defer mc.check(t)
+ openReader = mc
+
+ httpProxy, httpsProxy, err := synologyProxiesFromConfig()
+
+ if got, want := err, openErr; got != want {
+ t.Fatalf("got %s, want %s", got, want)
+ }
+
+ if got, want := httpsProxy, urlMustParse("http://foo:bar@10.0.0.66:8443"); got.String() != want.String() {
+ t.Fatalf("got %s, want %s", got, want)
+ }
+
+ if got, want := err, openErr; got != want {
+ t.Fatalf("got %s, want %s", got, want)
+ }
+
+ if got, want := httpProxy, urlMustParse("http://foo:bar@10.0.0.55:80"); got.String() != want.String() {
+ t.Fatalf("got %s, want %s", got, want)
+ }
+
+ })
+
+ t.Run("nonexistent config", func(t *testing.T) {
+ openReader = nil
+ openErr = os.ErrNotExist
+
+ httpProxy, httpsProxy, err := synologyProxiesFromConfig()
+ if err != nil {
+ t.Fatalf("expected no error, got %s", err)
+ }
+ if httpProxy != nil {
+ t.Fatalf("expected no url, got %s", httpProxy)
+ }
+ if httpsProxy != nil {
+ t.Fatalf("expected no url, got %s", httpsProxy)
+ }
+ })
+
+ t.Run("error opening config", func(t *testing.T) {
+ openReader = nil
+ openErr = errors.New("example error")
+
+ httpProxy, httpsProxy, err := synologyProxiesFromConfig()
+ if err != openErr {
+ t.Fatalf("expected %s, got %s", openErr, err)
+ }
+ if httpProxy != nil {
+ t.Fatalf("expected no url, got %s", httpProxy)
+ }
+ if httpsProxy != nil {
+ t.Fatalf("expected no url, got %s", httpsProxy)
+ }
+ })
+
+}
+
+func TestParseSynologyConfig(t *testing.T) {
+ cases := map[string]struct {
+ input string
+ httpProxy *url.URL
+ httpsProxy *url.URL
+ err error
+ }{
+ "populated": {
+ input: `
+proxy_user=foo
+proxy_pwd=bar
+proxy_enabled=yes
+adv_enabled=yes
+bypass_enabled=yes
+auth_enabled=yes
+https_host=10.0.0.66
+https_port=8443
+http_host=10.0.0.55
+http_port=80
+`,
+ httpProxy: urlMustParse("http://foo:bar@10.0.0.55:80"),
+ httpsProxy: urlMustParse("http://foo:bar@10.0.0.66:8443"),
+ err: nil,
+ },
+ "no-auth": {
+ input: `
+proxy_user=foo
+proxy_pwd=bar
+proxy_enabled=yes
+adv_enabled=yes
+bypass_enabled=yes
+auth_enabled=no
+https_host=10.0.0.66
+https_port=8443
+http_host=10.0.0.55
+http_port=80
+`,
+ httpProxy: urlMustParse("http://10.0.0.55:80"),
+ httpsProxy: urlMustParse("http://10.0.0.66:8443"),
+ err: nil,
+ },
+ "http-only": {
+ input: `
+proxy_user=foo
+proxy_pwd=bar
+proxy_enabled=yes
+adv_enabled=yes
+bypass_enabled=yes
+auth_enabled=yes
+https_host=
+https_port=8443
+http_host=10.0.0.55
+http_port=80
+`,
+ httpProxy: urlMustParse("http://foo:bar@10.0.0.55:80"),
+ httpsProxy: nil,
+ err: nil,
+ },
+ "empty": {
+ input: `
+proxy_user=
+proxy_pwd=
+proxy_enabled=
+adv_enabled=
+bypass_enabled=
+auth_enabled=
+https_host=
+https_port=
+http_host=
+http_port=
+`,
+ httpProxy: nil,
+ httpsProxy: nil,
+ err: nil,
+ },
+ }
+
+ for name, example := range cases {
+ t.Run(name, func(t *testing.T) {
+ httpProxy, httpsProxy, err := parseSynologyConfig(strings.NewReader(example.input))
+ if err != example.err {
+ t.Fatal(err)
+ }
+ if example.err != nil {
+ return
+ }
+
+ if example.httpProxy == nil && httpProxy != nil {
+ t.Fatalf("got %s, want nil", httpProxy)
+ }
+
+ if example.httpProxy != nil {
+ if httpProxy == nil {
+ t.Fatalf("got nil, want %s", example.httpProxy)
+ }
+
+ if got, want := example.httpProxy.String(), httpProxy.String(); got != want {
+ t.Fatalf("got %s, want %s", got, want)
+ }
+ }
+
+ if example.httpsProxy == nil && httpsProxy != nil {
+ t.Fatalf("got %s, want nil", httpProxy)
+ }
+
+ if example.httpsProxy != nil {
+ if httpsProxy == nil {
+ t.Fatalf("got nil, want %s", example.httpsProxy)
+ }
+
+ if got, want := example.httpsProxy.String(), httpsProxy.String(); got != want {
+ t.Fatalf("got %s, want %s", got, want)
+ }
+ }
+ })
+ }
+}
+func urlMustParse(u string) *url.URL {
+ r, err := url.Parse(u)
+ if err != nil {
+ panic(fmt.Sprintf("urlMustParse: %s", err))
+ }
+ return r
+}
+
+type mustCloser struct {
+ io.Reader
+ closed bool
+}
+
+func (m *mustCloser) Close() error {
+ m.closed = true
+ return nil
+}
+
+func (m *mustCloser) check(t *testing.T) {
+ if !m.closed {
+ t.Errorf("mustCloser wrapping %#v was not closed at time of check", m.Reader)
+ }
+}
diff --git a/net/tshttpproxy/tshttpproxy_windows.go b/net/tshttpproxy/tshttpproxy_windows.go
index 06a1f5ae4..cb6b24c83 100644
--- a/net/tshttpproxy/tshttpproxy_windows.go
+++ b/net/tshttpproxy/tshttpproxy_windows.go
@@ -1,276 +1,276 @@
-// Copyright (c) Tailscale Inc & AUTHORS
-// SPDX-License-Identifier: BSD-3-Clause
-
-package tshttpproxy
-
-import (
- "context"
- "encoding/base64"
- "fmt"
- "log"
- "net/http"
- "net/url"
- "runtime"
- "strings"
- "sync"
- "syscall"
- "time"
- "unsafe"
-
- "github.com/alexbrainman/sspi/negotiate"
- "golang.org/x/sys/windows"
- "tailscale.com/hostinfo"
- "tailscale.com/syncs"
- "tailscale.com/types/logger"
- "tailscale.com/util/clientmetric"
- "tailscale.com/util/cmpver"
-)
-
-func init() {
- sysProxyFromEnv = proxyFromWinHTTPOrCache
- sysAuthHeader = sysAuthHeaderWindows
-}
-
-var cachedProxy struct {
- sync.Mutex
- val *url.URL
-}
-
-// proxyErrorf is a rate-limited logger specifically for errors asking
-// WinHTTP for the proxy information. We don't want to log about
-// errors often, otherwise the log message itself will generate a new
-// HTTP request which ultimately will call back into us to log again,
-// forever. So for errors, we only log a bit.
-var proxyErrorf = logger.RateLimitedFn(log.Printf, 10*time.Minute, 2 /* burst*/, 10 /* maxCache */)
-
-var (
- metricSuccess = clientmetric.NewCounter("winhttp_proxy_success")
- metricErrDetectionFailed = clientmetric.NewCounter("winhttp_proxy_err_detection_failed")
- metricErrInvalidParameters = clientmetric.NewCounter("winhttp_proxy_err_invalid_param")
- metricErrDownloadScript = clientmetric.NewCounter("winhttp_proxy_err_download_script")
- metricErrTimeout = clientmetric.NewCounter("winhttp_proxy_err_timeout")
- metricErrOther = clientmetric.NewCounter("winhttp_proxy_err_other")
-)
-
-func proxyFromWinHTTPOrCache(req *http.Request) (*url.URL, error) {
- if req.URL == nil {
- return nil, nil
- }
- urlStr := req.URL.String()
-
- ctx, cancel := context.WithTimeout(req.Context(), 5*time.Second)
- defer cancel()
-
- type result struct {
- proxy *url.URL
- err error
- }
- resc := make(chan result, 1)
- go func() {
- proxy, err := proxyFromWinHTTP(ctx, urlStr)
- resc <- result{proxy, err}
- }()
-
- select {
- case res := <-resc:
- err := res.err
- if err == nil {
- metricSuccess.Add(1)
- cachedProxy.Lock()
- defer cachedProxy.Unlock()
- if was, now := fmt.Sprint(cachedProxy.val), fmt.Sprint(res.proxy); was != now {
- log.Printf("tshttpproxy: winhttp: updating cached proxy setting from %v to %v", was, now)
- }
- cachedProxy.val = res.proxy
- return res.proxy, nil
- }
-
- // See https://docs.microsoft.com/en-us/windows/win32/winhttp/error-messages
- const (
- ERROR_WINHTTP_AUTODETECTION_FAILED = 12180
- ERROR_WINHTTP_UNABLE_TO_DOWNLOAD_SCRIPT = 12167
- )
- if err == syscall.Errno(ERROR_WINHTTP_AUTODETECTION_FAILED) {
- metricErrDetectionFailed.Add(1)
- setNoProxyUntil(10 * time.Second)
- return nil, nil
- }
- if err == windows.ERROR_INVALID_PARAMETER {
- metricErrInvalidParameters.Add(1)
- // Seen on Windows 8.1. (https://github.com/tailscale/tailscale/issues/879)
- // TODO(bradfitz): figure this out.
- setNoProxyUntil(time.Hour)
- proxyErrorf("tshttpproxy: winhttp: GetProxyForURL(%q): ERROR_INVALID_PARAMETER [unexpected]", urlStr)
- return nil, nil
- }
- proxyErrorf("tshttpproxy: winhttp: GetProxyForURL(%q): %v/%#v", urlStr, err, err)
- if err == syscall.Errno(ERROR_WINHTTP_UNABLE_TO_DOWNLOAD_SCRIPT) {
- metricErrDownloadScript.Add(1)
- setNoProxyUntil(10 * time.Second)
- return nil, nil
- }
- metricErrOther.Add(1)
- return nil, err
- case <-ctx.Done():
- metricErrTimeout.Add(1)
- cachedProxy.Lock()
- defer cachedProxy.Unlock()
- proxyErrorf("tshttpproxy: winhttp: GetProxyForURL(%q): timeout; using cached proxy %v", urlStr, cachedProxy.val)
- return cachedProxy.val, nil
- }
-}
-
-func proxyFromWinHTTP(ctx context.Context, urlStr string) (proxy *url.URL, err error) {
- runtime.LockOSThread()
- defer runtime.UnlockOSThread()
-
- whi, err := httpOpen()
- if err != nil {
- proxyErrorf("winhttp: Open: %v", err)
- return nil, err
- }
- defer whi.Close()
-
- t0 := time.Now()
- v, err := whi.GetProxyForURL(urlStr)
- td := time.Since(t0).Round(time.Millisecond)
- if err := ctx.Err(); err != nil {
- log.Printf("tshttpproxy: winhttp: context canceled, ignoring GetProxyForURL(%q) after %v", urlStr, td)
- return nil, err
- }
- if err != nil {
- return nil, err
- }
- if v == "" {
- return nil, nil
- }
- // Discard all but first proxy value for now.
- if i := strings.Index(v, ";"); i != -1 {
- v = v[:i]
- }
- if !strings.HasPrefix(v, "https://") {
- v = "http://" + v
- }
- return url.Parse(v)
-}
-
-var userAgent = windows.StringToUTF16Ptr("Tailscale")
-
-const (
- winHTTP_ACCESS_TYPE_DEFAULT_PROXY = 0
- winHTTP_ACCESS_TYPE_AUTOMATIC_PROXY = 4
- winHTTP_AUTOPROXY_ALLOW_AUTOCONFIG = 0x00000100
- winHTTP_AUTOPROXY_AUTO_DETECT = 1
- winHTTP_AUTO_DETECT_TYPE_DHCP = 0x00000001
- winHTTP_AUTO_DETECT_TYPE_DNS_A = 0x00000002
-)
-
-// Windows 8.1 is actually Windows 6.3 under the hood. Yay, marketing!
-const win8dot1Ver = "6.3"
-
-// accessType is the flag we must pass to WinHttpOpen for proxy resolution
-// depending on whether or not we're running Windows < 8.1
-var accessType syncs.AtomicValue[uint32]
-
-func getAccessFlag() uint32 {
- if flag, ok := accessType.LoadOk(); ok {
- return flag
- }
- var flag uint32
- if cmpver.Compare(hostinfo.GetOSVersion(), win8dot1Ver) < 0 {
- flag = winHTTP_ACCESS_TYPE_DEFAULT_PROXY
- } else {
- flag = winHTTP_ACCESS_TYPE_AUTOMATIC_PROXY
- }
- accessType.Store(flag)
- return flag
-}
-
-func httpOpen() (winHTTPInternet, error) {
- return winHTTPOpen(
- userAgent,
- getAccessFlag(),
- nil, /* WINHTTP_NO_PROXY_NAME */
- nil, /* WINHTTP_NO_PROXY_BYPASS */
- 0,
- )
-}
-
-type winHTTPInternet windows.Handle
-
-func (hi winHTTPInternet) Close() error {
- return winHTTPCloseHandle(hi)
-}
-
-// WINHTTP_AUTOPROXY_OPTIONS
-// https://docs.microsoft.com/en-us/windows/win32/api/winhttp/ns-winhttp-winhttp_autoproxy_options
-type winHTTPAutoProxyOptions struct {
- DwFlags uint32
- DwAutoDetectFlags uint32
- AutoConfigUrl *uint16
- _ uintptr
- _ uint32
- FAutoLogonIfChallenged int32 // BOOL
-}
-
-// WINHTTP_PROXY_INFO
-// https://docs.microsoft.com/en-us/windows/win32/api/winhttp/ns-winhttp-winhttp_proxy_info
-type winHTTPProxyInfo struct {
- AccessType uint32
- Proxy *uint16
- ProxyBypass *uint16
-}
-
-type winHGlobal windows.Handle
-
-func globalFreeUTF16Ptr(p *uint16) error {
- return globalFree((winHGlobal)(unsafe.Pointer(p)))
-}
-
-func (pi *winHTTPProxyInfo) free() {
- if pi.Proxy != nil {
- globalFreeUTF16Ptr(pi.Proxy)
- pi.Proxy = nil
- }
- if pi.ProxyBypass != nil {
- globalFreeUTF16Ptr(pi.ProxyBypass)
- pi.ProxyBypass = nil
- }
-}
-
-var proxyForURLOpts = &winHTTPAutoProxyOptions{
- DwFlags: winHTTP_AUTOPROXY_ALLOW_AUTOCONFIG | winHTTP_AUTOPROXY_AUTO_DETECT,
- DwAutoDetectFlags: winHTTP_AUTO_DETECT_TYPE_DHCP, // | winHTTP_AUTO_DETECT_TYPE_DNS_A,
-}
-
-func (hi winHTTPInternet) GetProxyForURL(urlStr string) (string, error) {
- var out winHTTPProxyInfo
- err := winHTTPGetProxyForURL(
- hi,
- windows.StringToUTF16Ptr(urlStr),
- proxyForURLOpts,
- &out,
- )
- if err != nil {
- return "", err
- }
- defer out.free()
- return windows.UTF16PtrToString(out.Proxy), nil
-}
-
-func sysAuthHeaderWindows(u *url.URL) (string, error) {
- spn := "HTTP/" + u.Hostname()
- creds, err := negotiate.AcquireCurrentUserCredentials()
- if err != nil {
- return "", fmt.Errorf("negotiate.AcquireCurrentUserCredentials: %w", err)
- }
- defer creds.Release()
-
- secCtx, token, err := negotiate.NewClientContext(creds, spn)
- if err != nil {
- return "", fmt.Errorf("negotiate.NewClientContext: %w", err)
- }
- defer secCtx.Release()
-
- return "Negotiate " + base64.StdEncoding.EncodeToString(token), nil
-}
+// Copyright (c) Tailscale Inc & AUTHORS
+// SPDX-License-Identifier: BSD-3-Clause
+
+package tshttpproxy
+
+import (
+ "context"
+ "encoding/base64"
+ "fmt"
+ "log"
+ "net/http"
+ "net/url"
+ "runtime"
+ "strings"
+ "sync"
+ "syscall"
+ "time"
+ "unsafe"
+
+ "github.com/alexbrainman/sspi/negotiate"
+ "golang.org/x/sys/windows"
+ "tailscale.com/hostinfo"
+ "tailscale.com/syncs"
+ "tailscale.com/types/logger"
+ "tailscale.com/util/clientmetric"
+ "tailscale.com/util/cmpver"
+)
+
+func init() {
+ sysProxyFromEnv = proxyFromWinHTTPOrCache
+ sysAuthHeader = sysAuthHeaderWindows
+}
+
+var cachedProxy struct {
+ sync.Mutex
+ val *url.URL
+}
+
+// proxyErrorf is a rate-limited logger specifically for errors asking
+// WinHTTP for the proxy information. We don't want to log about
+// errors often, otherwise the log message itself will generate a new
+// HTTP request which ultimately will call back into us to log again,
+// forever. So for errors, we only log a bit.
+var proxyErrorf = logger.RateLimitedFn(log.Printf, 10*time.Minute, 2 /* burst*/, 10 /* maxCache */)
+
+var (
+ metricSuccess = clientmetric.NewCounter("winhttp_proxy_success")
+ metricErrDetectionFailed = clientmetric.NewCounter("winhttp_proxy_err_detection_failed")
+ metricErrInvalidParameters = clientmetric.NewCounter("winhttp_proxy_err_invalid_param")
+ metricErrDownloadScript = clientmetric.NewCounter("winhttp_proxy_err_download_script")
+ metricErrTimeout = clientmetric.NewCounter("winhttp_proxy_err_timeout")
+ metricErrOther = clientmetric.NewCounter("winhttp_proxy_err_other")
+)
+
+func proxyFromWinHTTPOrCache(req *http.Request) (*url.URL, error) {
+ if req.URL == nil {
+ return nil, nil
+ }
+ urlStr := req.URL.String()
+
+ ctx, cancel := context.WithTimeout(req.Context(), 5*time.Second)
+ defer cancel()
+
+ type result struct {
+ proxy *url.URL
+ err error
+ }
+ resc := make(chan result, 1)
+ go func() {
+ proxy, err := proxyFromWinHTTP(ctx, urlStr)
+ resc <- result{proxy, err}
+ }()
+
+ select {
+ case res := <-resc:
+ err := res.err
+ if err == nil {
+ metricSuccess.Add(1)
+ cachedProxy.Lock()
+ defer cachedProxy.Unlock()
+ if was, now := fmt.Sprint(cachedProxy.val), fmt.Sprint(res.proxy); was != now {
+ log.Printf("tshttpproxy: winhttp: updating cached proxy setting from %v to %v", was, now)
+ }
+ cachedProxy.val = res.proxy
+ return res.proxy, nil
+ }
+
+ // See https://docs.microsoft.com/en-us/windows/win32/winhttp/error-messages
+ const (
+ ERROR_WINHTTP_AUTODETECTION_FAILED = 12180
+ ERROR_WINHTTP_UNABLE_TO_DOWNLOAD_SCRIPT = 12167
+ )
+ if err == syscall.Errno(ERROR_WINHTTP_AUTODETECTION_FAILED) {
+ metricErrDetectionFailed.Add(1)
+ setNoProxyUntil(10 * time.Second)
+ return nil, nil
+ }
+ if err == windows.ERROR_INVALID_PARAMETER {
+ metricErrInvalidParameters.Add(1)
+ // Seen on Windows 8.1. (https://github.com/tailscale/tailscale/issues/879)
+ // TODO(bradfitz): figure this out.
+ setNoProxyUntil(time.Hour)
+ proxyErrorf("tshttpproxy: winhttp: GetProxyForURL(%q): ERROR_INVALID_PARAMETER [unexpected]", urlStr)
+ return nil, nil
+ }
+ proxyErrorf("tshttpproxy: winhttp: GetProxyForURL(%q): %v/%#v", urlStr, err, err)
+ if err == syscall.Errno(ERROR_WINHTTP_UNABLE_TO_DOWNLOAD_SCRIPT) {
+ metricErrDownloadScript.Add(1)
+ setNoProxyUntil(10 * time.Second)
+ return nil, nil
+ }
+ metricErrOther.Add(1)
+ return nil, err
+ case <-ctx.Done():
+ metricErrTimeout.Add(1)
+ cachedProxy.Lock()
+ defer cachedProxy.Unlock()
+ proxyErrorf("tshttpproxy: winhttp: GetProxyForURL(%q): timeout; using cached proxy %v", urlStr, cachedProxy.val)
+ return cachedProxy.val, nil
+ }
+}
+
+func proxyFromWinHTTP(ctx context.Context, urlStr string) (proxy *url.URL, err error) {
+ runtime.LockOSThread()
+ defer runtime.UnlockOSThread()
+
+ whi, err := httpOpen()
+ if err != nil {
+ proxyErrorf("winhttp: Open: %v", err)
+ return nil, err
+ }
+ defer whi.Close()
+
+ t0 := time.Now()
+ v, err := whi.GetProxyForURL(urlStr)
+ td := time.Since(t0).Round(time.Millisecond)
+ if err := ctx.Err(); err != nil {
+ log.Printf("tshttpproxy: winhttp: context canceled, ignoring GetProxyForURL(%q) after %v", urlStr, td)
+ return nil, err
+ }
+ if err != nil {
+ return nil, err
+ }
+ if v == "" {
+ return nil, nil
+ }
+ // Discard all but first proxy value for now.
+ if i := strings.Index(v, ";"); i != -1 {
+ v = v[:i]
+ }
+ if !strings.HasPrefix(v, "https://") {
+ v = "http://" + v
+ }
+ return url.Parse(v)
+}
+
+var userAgent = windows.StringToUTF16Ptr("Tailscale")
+
+const (
+ winHTTP_ACCESS_TYPE_DEFAULT_PROXY = 0
+ winHTTP_ACCESS_TYPE_AUTOMATIC_PROXY = 4
+ winHTTP_AUTOPROXY_ALLOW_AUTOCONFIG = 0x00000100
+ winHTTP_AUTOPROXY_AUTO_DETECT = 1
+ winHTTP_AUTO_DETECT_TYPE_DHCP = 0x00000001
+ winHTTP_AUTO_DETECT_TYPE_DNS_A = 0x00000002
+)
+
+// Windows 8.1 is actually Windows 6.3 under the hood. Yay, marketing!
+const win8dot1Ver = "6.3"
+
+// accessType is the flag we must pass to WinHttpOpen for proxy resolution
+// depending on whether or not we're running Windows < 8.1
+var accessType syncs.AtomicValue[uint32]
+
+func getAccessFlag() uint32 {
+ if flag, ok := accessType.LoadOk(); ok {
+ return flag
+ }
+ var flag uint32
+ if cmpver.Compare(hostinfo.GetOSVersion(), win8dot1Ver) < 0 {
+ flag = winHTTP_ACCESS_TYPE_DEFAULT_PROXY
+ } else {
+ flag = winHTTP_ACCESS_TYPE_AUTOMATIC_PROXY
+ }
+ accessType.Store(flag)
+ return flag
+}
+
+func httpOpen() (winHTTPInternet, error) {
+ return winHTTPOpen(
+ userAgent,
+ getAccessFlag(),
+ nil, /* WINHTTP_NO_PROXY_NAME */
+ nil, /* WINHTTP_NO_PROXY_BYPASS */
+ 0,
+ )
+}
+
+type winHTTPInternet windows.Handle
+
+func (hi winHTTPInternet) Close() error {
+ return winHTTPCloseHandle(hi)
+}
+
+// WINHTTP_AUTOPROXY_OPTIONS
+// https://docs.microsoft.com/en-us/windows/win32/api/winhttp/ns-winhttp-winhttp_autoproxy_options
+type winHTTPAutoProxyOptions struct {
+ DwFlags uint32
+ DwAutoDetectFlags uint32
+ AutoConfigUrl *uint16
+ _ uintptr
+ _ uint32
+ FAutoLogonIfChallenged int32 // BOOL
+}
+
+// WINHTTP_PROXY_INFO
+// https://docs.microsoft.com/en-us/windows/win32/api/winhttp/ns-winhttp-winhttp_proxy_info
+type winHTTPProxyInfo struct {
+ AccessType uint32
+ Proxy *uint16
+ ProxyBypass *uint16
+}
+
+type winHGlobal windows.Handle
+
+func globalFreeUTF16Ptr(p *uint16) error {
+ return globalFree((winHGlobal)(unsafe.Pointer(p)))
+}
+
+func (pi *winHTTPProxyInfo) free() {
+ if pi.Proxy != nil {
+ globalFreeUTF16Ptr(pi.Proxy)
+ pi.Proxy = nil
+ }
+ if pi.ProxyBypass != nil {
+ globalFreeUTF16Ptr(pi.ProxyBypass)
+ pi.ProxyBypass = nil
+ }
+}
+
+var proxyForURLOpts = &winHTTPAutoProxyOptions{
+ DwFlags: winHTTP_AUTOPROXY_ALLOW_AUTOCONFIG | winHTTP_AUTOPROXY_AUTO_DETECT,
+ DwAutoDetectFlags: winHTTP_AUTO_DETECT_TYPE_DHCP, // | winHTTP_AUTO_DETECT_TYPE_DNS_A,
+}
+
+func (hi winHTTPInternet) GetProxyForURL(urlStr string) (string, error) {
+ var out winHTTPProxyInfo
+ err := winHTTPGetProxyForURL(
+ hi,
+ windows.StringToUTF16Ptr(urlStr),
+ proxyForURLOpts,
+ &out,
+ )
+ if err != nil {
+ return "", err
+ }
+ defer out.free()
+ return windows.UTF16PtrToString(out.Proxy), nil
+}
+
+func sysAuthHeaderWindows(u *url.URL) (string, error) {
+ spn := "HTTP/" + u.Hostname()
+ creds, err := negotiate.AcquireCurrentUserCredentials()
+ if err != nil {
+ return "", fmt.Errorf("negotiate.AcquireCurrentUserCredentials: %w", err)
+ }
+ defer creds.Release()
+
+ secCtx, token, err := negotiate.NewClientContext(creds, spn)
+ if err != nil {
+ return "", fmt.Errorf("negotiate.NewClientContext: %w", err)
+ }
+ defer secCtx.Release()
+
+ return "Negotiate " + base64.StdEncoding.EncodeToString(token), nil
+}
diff --git a/net/tstun/fake.go b/net/tstun/fake.go
index 3d86bb3df..a002952a3 100644
--- a/net/tstun/fake.go
+++ b/net/tstun/fake.go
@@ -1,58 +1,58 @@
-// Copyright (c) Tailscale Inc & AUTHORS
-// SPDX-License-Identifier: BSD-3-Clause
-
-package tstun
-
-import (
- "io"
- "os"
-
- "github.com/tailscale/wireguard-go/tun"
-)
-
-type fakeTUN struct {
- evchan chan tun.Event
- closechan chan struct{}
-}
-
-// NewFake returns a tun.Device that does nothing.
-func NewFake() tun.Device {
- return &fakeTUN{
- evchan: make(chan tun.Event),
- closechan: make(chan struct{}),
- }
-}
-
-func (t *fakeTUN) File() *os.File {
- panic("fakeTUN.File() called, which makes no sense")
-}
-
-func (t *fakeTUN) Close() error {
- close(t.closechan)
- close(t.evchan)
- return nil
-}
-
-func (t *fakeTUN) Read(out [][]byte, sizes []int, offset int) (int, error) {
- <-t.closechan
- return 0, io.EOF
-}
-
-func (t *fakeTUN) Write(b [][]byte, n int) (int, error) {
- select {
- case <-t.closechan:
- return 0, ErrClosed
- default:
- }
- return 1, nil
-}
-
-// FakeTUNName is the name of the fake TUN device.
-const FakeTUNName = "FakeTUN"
-
-func (t *fakeTUN) Flush() error { return nil }
-func (t *fakeTUN) MTU() (int, error) { return 1500, nil }
-func (t *fakeTUN) Name() (string, error) { return FakeTUNName, nil }
-func (t *fakeTUN) Events() <-chan tun.Event { return t.evchan }
-func (t *fakeTUN) BatchSize() int { return 1 }
-func (t *fakeTUN) IsFakeTun() bool { return true }
+// Copyright (c) Tailscale Inc & AUTHORS
+// SPDX-License-Identifier: BSD-3-Clause
+
+package tstun
+
+import (
+ "io"
+ "os"
+
+ "github.com/tailscale/wireguard-go/tun"
+)
+
+type fakeTUN struct {
+ evchan chan tun.Event
+ closechan chan struct{}
+}
+
+// NewFake returns a tun.Device that does nothing.
+func NewFake() tun.Device {
+ return &fakeTUN{
+ evchan: make(chan tun.Event),
+ closechan: make(chan struct{}),
+ }
+}
+
+func (t *fakeTUN) File() *os.File {
+ panic("fakeTUN.File() called, which makes no sense")
+}
+
+func (t *fakeTUN) Close() error {
+ close(t.closechan)
+ close(t.evchan)
+ return nil
+}
+
+func (t *fakeTUN) Read(out [][]byte, sizes []int, offset int) (int, error) {
+ <-t.closechan
+ return 0, io.EOF
+}
+
+func (t *fakeTUN) Write(b [][]byte, n int) (int, error) {
+ select {
+ case <-t.closechan:
+ return 0, ErrClosed
+ default:
+ }
+ return 1, nil
+}
+
+// FakeTUNName is the name of the fake TUN device.
+const FakeTUNName = "FakeTUN"
+
+func (t *fakeTUN) Flush() error { return nil }
+func (t *fakeTUN) MTU() (int, error) { return 1500, nil }
+func (t *fakeTUN) Name() (string, error) { return FakeTUNName, nil }
+func (t *fakeTUN) Events() <-chan tun.Event { return t.evchan }
+func (t *fakeTUN) BatchSize() int { return 1 }
+func (t *fakeTUN) IsFakeTun() bool { return true }
diff --git a/net/tstun/ifstatus_noop.go b/net/tstun/ifstatus_noop.go
index 8cf569f98..4d453b72c 100644
--- a/net/tstun/ifstatus_noop.go
+++ b/net/tstun/ifstatus_noop.go
@@ -1,18 +1,18 @@
-// Copyright (c) Tailscale Inc & AUTHORS
-// SPDX-License-Identifier: BSD-3-Clause
-
-//go:build !windows
-
-package tstun
-
-import (
- "time"
-
- "github.com/tailscale/wireguard-go/tun"
- "tailscale.com/types/logger"
-)
-
-// Dummy implementation that does nothing.
-func waitInterfaceUp(iface tun.Device, timeout time.Duration, logf logger.Logf) error {
- return nil
-}
+// Copyright (c) Tailscale Inc & AUTHORS
+// SPDX-License-Identifier: BSD-3-Clause
+
+//go:build !windows
+
+package tstun
+
+import (
+ "time"
+
+ "github.com/tailscale/wireguard-go/tun"
+ "tailscale.com/types/logger"
+)
+
+// Dummy implementation that does nothing.
+func waitInterfaceUp(iface tun.Device, timeout time.Duration, logf logger.Logf) error {
+ return nil
+}
diff --git a/net/tstun/ifstatus_windows.go b/net/tstun/ifstatus_windows.go
index fd9fc2112..6c6377bb4 100644
--- a/net/tstun/ifstatus_windows.go
+++ b/net/tstun/ifstatus_windows.go
@@ -1,109 +1,109 @@
-// Copyright (c) Tailscale Inc & AUTHORS
-// SPDX-License-Identifier: BSD-3-Clause
-
-package tstun
-
-import (
- "fmt"
- "sync"
- "time"
-
- "github.com/tailscale/wireguard-go/tun"
- "golang.zx2c4.com/wireguard/windows/tunnel/winipcfg"
- "tailscale.com/types/logger"
-)
-
-// ifaceWatcher waits for an interface to be up.
-type ifaceWatcher struct {
- logf logger.Logf
- luid winipcfg.LUID
-
- mu sync.Mutex // guards following
- done bool
- sig chan bool
-}
-
-// callback is the callback we register with Windows to call when IP interface changes.
-func (iw *ifaceWatcher) callback(notificationType winipcfg.MibNotificationType, iface *winipcfg.MibIPInterfaceRow) {
- // Probably should check only when MibParameterNotification, but just in case included MibAddInstance also.
- if notificationType == winipcfg.MibParameterNotification || notificationType == winipcfg.MibAddInstance {
- // Out of paranoia, start a goroutine to finish our work, to return to Windows out of this callback.
- go iw.isUp()
- }
-}
-
-func (iw *ifaceWatcher) isUp() bool {
- iw.mu.Lock()
- defer iw.mu.Unlock()
-
- if iw.done {
- // We already know that it's up
- return true
- }
-
- if iw.getOperStatus() != winipcfg.IfOperStatusUp {
- return false
- }
-
- iw.done = true
- iw.sig <- true
- return true
-}
-
-func (iw *ifaceWatcher) getOperStatus() winipcfg.IfOperStatus {
- ifc, err := iw.luid.Interface()
- if err != nil {
- iw.logf("iw.luid.Interface error: %v", err)
- return 0
- }
- return ifc.OperStatus
-}
-
-func waitInterfaceUp(iface tun.Device, timeout time.Duration, logf logger.Logf) error {
- iw := &ifaceWatcher{
- luid: winipcfg.LUID(iface.(*tun.NativeTun).LUID()),
- logf: logger.WithPrefix(logf, "waitInterfaceUp: "),
- }
-
- // Just in case check the status first
- if iw.getOperStatus() == winipcfg.IfOperStatusUp {
- iw.logf("TUN interface already up; no need to wait")
- return nil
- }
-
- iw.sig = make(chan bool, 1)
- cb, err := winipcfg.RegisterInterfaceChangeCallback(iw.callback)
- if err != nil {
- iw.logf("RegisterInterfaceChangeCallback error: %v", err)
- return err
- }
- defer cb.Unregister()
-
- t0 := time.Now()
- expires := t0.Add(timeout)
- ticker := time.NewTicker(10 * time.Second)
- defer ticker.Stop()
-
- for {
- iw.logf("waiting for TUN interface to come up...")
-
- select {
- case <-iw.sig:
- iw.logf("TUN interface is up after %v", time.Since(t0))
- return nil
- case <-ticker.C:
- }
-
- if iw.isUp() {
- // Very unlikely to happen - either NotifyIpInterfaceChange doesn't work
- // or it came up in the same moment as tick. Indicate this in the log message.
- iw.logf("TUN interface is up after %v (on poll, without notification)", time.Since(t0))
- return nil
- }
-
- if expires.Before(time.Now()) {
- iw.logf("timeout waiting %v for TUN interface to come up", timeout)
- return fmt.Errorf("timeout waiting for TUN interface to come up")
- }
- }
-}
+// Copyright (c) Tailscale Inc & AUTHORS
+// SPDX-License-Identifier: BSD-3-Clause
+
+package tstun
+
+import (
+ "fmt"
+ "sync"
+ "time"
+
+ "github.com/tailscale/wireguard-go/tun"
+ "golang.zx2c4.com/wireguard/windows/tunnel/winipcfg"
+ "tailscale.com/types/logger"
+)
+
+// ifaceWatcher waits for an interface to be up.
+type ifaceWatcher struct {
+ logf logger.Logf
+ luid winipcfg.LUID
+
+ mu sync.Mutex // guards following
+ done bool
+ sig chan bool
+}
+
+// callback is the callback we register with Windows to call when IP interface changes.
+func (iw *ifaceWatcher) callback(notificationType winipcfg.MibNotificationType, iface *winipcfg.MibIPInterfaceRow) {
+ // Probably should check only when MibParameterNotification, but just in case included MibAddInstance also.
+ if notificationType == winipcfg.MibParameterNotification || notificationType == winipcfg.MibAddInstance {
+ // Out of paranoia, start a goroutine to finish our work, to return to Windows out of this callback.
+ go iw.isUp()
+ }
+}
+
+func (iw *ifaceWatcher) isUp() bool {
+ iw.mu.Lock()
+ defer iw.mu.Unlock()
+
+ if iw.done {
+ // We already know that it's up
+ return true
+ }
+
+ if iw.getOperStatus() != winipcfg.IfOperStatusUp {
+ return false
+ }
+
+ iw.done = true
+ iw.sig <- true
+ return true
+}
+
+func (iw *ifaceWatcher) getOperStatus() winipcfg.IfOperStatus {
+ ifc, err := iw.luid.Interface()
+ if err != nil {
+ iw.logf("iw.luid.Interface error: %v", err)
+ return 0
+ }
+ return ifc.OperStatus
+}
+
+func waitInterfaceUp(iface tun.Device, timeout time.Duration, logf logger.Logf) error {
+ iw := &ifaceWatcher{
+ luid: winipcfg.LUID(iface.(*tun.NativeTun).LUID()),
+ logf: logger.WithPrefix(logf, "waitInterfaceUp: "),
+ }
+
+ // Just in case check the status first
+ if iw.getOperStatus() == winipcfg.IfOperStatusUp {
+ iw.logf("TUN interface already up; no need to wait")
+ return nil
+ }
+
+ iw.sig = make(chan bool, 1)
+ cb, err := winipcfg.RegisterInterfaceChangeCallback(iw.callback)
+ if err != nil {
+ iw.logf("RegisterInterfaceChangeCallback error: %v", err)
+ return err
+ }
+ defer cb.Unregister()
+
+ t0 := time.Now()
+ expires := t0.Add(timeout)
+ ticker := time.NewTicker(10 * time.Second)
+ defer ticker.Stop()
+
+ for {
+ iw.logf("waiting for TUN interface to come up...")
+
+ select {
+ case <-iw.sig:
+ iw.logf("TUN interface is up after %v", time.Since(t0))
+ return nil
+ case <-ticker.C:
+ }
+
+ if iw.isUp() {
+ // Very unlikely to happen - either NotifyIpInterfaceChange doesn't work
+ // or it came up in the same moment as tick. Indicate this in the log message.
+ iw.logf("TUN interface is up after %v (on poll, without notification)", time.Since(t0))
+ return nil
+ }
+
+ if expires.Before(time.Now()) {
+ iw.logf("timeout waiting %v for TUN interface to come up", timeout)
+ return fmt.Errorf("timeout waiting for TUN interface to come up")
+ }
+ }
+}
diff --git a/net/tstun/linkattrs_linux.go b/net/tstun/linkattrs_linux.go
index 681e79269..7f5461109 100644
--- a/net/tstun/linkattrs_linux.go
+++ b/net/tstun/linkattrs_linux.go
@@ -1,63 +1,63 @@
-// Copyright (c) Tailscale Inc & AUTHORS
-// SPDX-License-Identifier: BSD-3-Clause
-
-package tstun
-
-import (
- "github.com/mdlayher/genetlink"
- "github.com/mdlayher/netlink"
- "github.com/tailscale/wireguard-go/tun"
- "golang.org/x/sys/unix"
-)
-
-// setLinkSpeed sets the advertised link speed of the TUN interface.
-func setLinkSpeed(iface tun.Device, mbps int) error {
- name, err := iface.Name()
- if err != nil {
- return err
- }
-
- conn, err := genetlink.Dial(&netlink.Config{Strict: true})
- if err != nil {
- return err
- }
-
- defer conn.Close()
-
- f, err := conn.GetFamily(unix.ETHTOOL_GENL_NAME)
- if err != nil {
- return err
- }
-
- ae := netlink.NewAttributeEncoder()
- ae.Nested(unix.ETHTOOL_A_LINKMODES_HEADER, func(nae *netlink.AttributeEncoder) error {
- nae.String(unix.ETHTOOL_A_HEADER_DEV_NAME, name)
- return nil
- })
- ae.Uint32(unix.ETHTOOL_A_LINKMODES_SPEED, uint32(mbps))
-
- b, err := ae.Encode()
- if err != nil {
- return err
- }
-
- _, err = conn.Execute(
- genetlink.Message{
- Header: genetlink.Header{
- Command: unix.ETHTOOL_MSG_LINKMODES_SET,
- Version: unix.ETHTOOL_GENL_VERSION,
- },
- Data: b,
- },
- f.ID,
- netlink.Request|netlink.Acknowledge,
- )
- return err
-}
-
-// setLinkAttrs sets up link attributes that can be queried by external tools.
-// Its failure is non-fatal to interface bringup.
-func setLinkAttrs(iface tun.Device) error {
- // By default the link speed is 10Mbps, which is easily exceeded and causes monitoring tools to complain (#3933).
- return setLinkSpeed(iface, unix.SPEED_UNKNOWN)
-}
+// Copyright (c) Tailscale Inc & AUTHORS
+// SPDX-License-Identifier: BSD-3-Clause
+
+package tstun
+
+import (
+ "github.com/mdlayher/genetlink"
+ "github.com/mdlayher/netlink"
+ "github.com/tailscale/wireguard-go/tun"
+ "golang.org/x/sys/unix"
+)
+
+// setLinkSpeed sets the advertised link speed of the TUN interface.
+func setLinkSpeed(iface tun.Device, mbps int) error {
+ name, err := iface.Name()
+ if err != nil {
+ return err
+ }
+
+ conn, err := genetlink.Dial(&netlink.Config{Strict: true})
+ if err != nil {
+ return err
+ }
+
+ defer conn.Close()
+
+ f, err := conn.GetFamily(unix.ETHTOOL_GENL_NAME)
+ if err != nil {
+ return err
+ }
+
+ ae := netlink.NewAttributeEncoder()
+ ae.Nested(unix.ETHTOOL_A_LINKMODES_HEADER, func(nae *netlink.AttributeEncoder) error {
+ nae.String(unix.ETHTOOL_A_HEADER_DEV_NAME, name)
+ return nil
+ })
+ ae.Uint32(unix.ETHTOOL_A_LINKMODES_SPEED, uint32(mbps))
+
+ b, err := ae.Encode()
+ if err != nil {
+ return err
+ }
+
+ _, err = conn.Execute(
+ genetlink.Message{
+ Header: genetlink.Header{
+ Command: unix.ETHTOOL_MSG_LINKMODES_SET,
+ Version: unix.ETHTOOL_GENL_VERSION,
+ },
+ Data: b,
+ },
+ f.ID,
+ netlink.Request|netlink.Acknowledge,
+ )
+ return err
+}
+
+// setLinkAttrs sets up link attributes that can be queried by external tools.
+// Its failure is non-fatal to interface bringup.
+func setLinkAttrs(iface tun.Device) error {
+ // By default the link speed is 10Mbps, which is easily exceeded and causes monitoring tools to complain (#3933).
+ return setLinkSpeed(iface, unix.SPEED_UNKNOWN)
+}
diff --git a/net/tstun/linkattrs_notlinux.go b/net/tstun/linkattrs_notlinux.go
index 7a7b40fc2..45dd000b3 100644
--- a/net/tstun/linkattrs_notlinux.go
+++ b/net/tstun/linkattrs_notlinux.go
@@ -1,12 +1,12 @@
-// Copyright (c) Tailscale Inc & AUTHORS
-// SPDX-License-Identifier: BSD-3-Clause
-
-//go:build !linux
-
-package tstun
-
-import "github.com/tailscale/wireguard-go/tun"
-
-func setLinkAttrs(iface tun.Device) error {
- return nil
-}
+// Copyright (c) Tailscale Inc & AUTHORS
+// SPDX-License-Identifier: BSD-3-Clause
+
+//go:build !linux
+
+package tstun
+
+import "github.com/tailscale/wireguard-go/tun"
+
+func setLinkAttrs(iface tun.Device) error {
+ return nil
+}
diff --git a/net/tstun/mtu.go b/net/tstun/mtu.go
index 004529c20..b72a19bde 100644
--- a/net/tstun/mtu.go
+++ b/net/tstun/mtu.go
@@ -1,161 +1,161 @@
-// Copyright (c) Tailscale Inc & AUTHORS
-// SPDX-License-Identifier: BSD-3-Clause
-
-package tstun
-
-import (
- "tailscale.com/envknob"
-)
-
-// The MTU (Maximum Transmission Unit) of a network interface is the largest
-// packet that can be sent or received through that interface, including all
-// headers above the link layer (e.g. IP headers, UDP headers, Wireguard
-// headers, etc.). We have to think about several different values of MTU:
-//
-// Wire MTU: The MTU of an interface underneath the tailscale TUN, e.g. an
-// Ethernet network card will default to a 1500 byte MTU. The user may change
-// this MTU at any time.
-//
-// TUN MTU: The current MTU of the tailscale TUN. This MTU is adjusted downward
-// to make room for the wireguard/tailscale headers. For example, if the
-// underlying network interface's MTU is 1500 bytes, the maximum size of a
-// packet entering the tailscale TUN is 1420 bytes. The user may change this MTU
-// at any time via the OS's tools (ifconfig, ip, etc.).
-//
-// User configured initial MTU: The MTU the tailscale TUN should be created
-// with, set by the user via TS_DEBUG_MTU. It should be adjusted down from the
-// underlying interface MTU by 80 bytes to make room for the wireguard
-// headers. This envknob is mostly for debugging. This value is used once at TUN
-// creation and ignored thereafter.
-//
-// User configured current MTU: The MTU set via the OS's tools (ifconfig, ip,
-// etc.). This MTU can change at any time. Setting the MTU this way goes through
-// the MTU() method of tailscale's TUN wrapper.
-//
-// Maximum probed MTU: This is the largest MTU size that we send probe packets
-// for.
-//
-// Safe MTU: If the tailscale TUN MTU is set to this value, almost all packets
-// will get to their destination. Tailscale defaults to this MTU in the absence
-// of path MTU probe information or user MTU configuration. We may occasionally
-// find a path that needs a smaller MTU but it is very rare.
-//
-// Peer MTU: This is the path MTU to a peer's current best endpoint. It defaults
-// to the Safe MTU unless we have path MTU probe results that tell us otherwise.
-//
-// Initial MTU: This is the MTU tailscaled creates the TUN with. In order of
-// priority, it is:
-//
-// 1. If set, the value of TS_DEBUG_MTU clamped to a maximum of 65536
-// 2. If TS_DEBUG_ENABLE_PMTUD is set, the maximum size MTU we probe, minus wg
-// overhead
-// 3. If TS_DEBUG_ENABLE_PMTUD is not set, the Safe MTU
-//
-// Current MTU: This the MTU of the tailscale TUN at any given moment
-// after TUN creation. In order of priority, it is:
-//
-// 1. The MTU set by the user via the OS, if it has ever been set
-// 2. If TS_DEBUG_ENABLE_PMTUD is set, the maximum size MTU we probe, minus wg
-// overhead
-// 4. If TS_DEBUG_ENABLE_PMTUD is not set, the Safe MTU
-
-// TUNMTU is the MTU for the tailscale TUN.
-type TUNMTU uint32
-
-// WireMTU is the MTU for the underlying network devices.
-type WireMTU uint32
-
-const (
- // maxTUNMTU is the largest MTU we will consider for the Tailscale
- // TUN. This is inherited from wireguard-go and can be surprisingly
- // small; on Windows it is currently 2048 - 32 bytes and iOS it is 1700
- // - 32 bytes.
- // TODO(val,raggi): On Windows this seems to derive from RIO driver
- // constraints in Wireguard but we don't use RIO so could probably make
- // this bigger.
- maxTUNMTU TUNMTU = TUNMTU(MaxPacketSize)
- // safeTUNMTU is the default "safe" MTU for the Tailscale TUN that we
- // use in the absence of other information such as path MTU probes.
- safeTUNMTU TUNMTU = 1280
-)
-
-// WireMTUsToProbe is a list of the on-the-wire MTUs we want to probe. Each time
-// magicsock discovery begins, it will send a set of pings, one of each size
-// listed below.
-var WireMTUsToProbe = []WireMTU{
- WireMTU(safeTUNMTU), // Tailscale over Tailscale :)
- TUNToWireMTU(safeTUNMTU), // Smallest MTU allowed for IPv6, current default
- 1400, // Most common MTU minus a few bytes for tunnels
- 1500, // Most common MTU
- 8000, // Should fit inside all jumbo frame sizes
- 9000, // Most jumbo frames are this size or larger
-}
-
-// wgHeaderLen is the length of all the headers Wireguard adds to a packet
-// in the worst case (IPv6). This constant is for use when we can't or
-// shouldn't use information about the IP version of a specific packet
-// (e.g., calculating the MTU for the Tailscale interface.
-//
-// A Wireguard header includes:
-//
-// - 20-byte IPv4 header or 40-byte IPv6 header
-// - 8-byte UDP header
-// - 4-byte type
-// - 4-byte key index
-// - 8-byte nonce
-// - 16-byte authentication tag
-const wgHeaderLen = 40 + 8 + 4 + 4 + 8 + 16
-
-// TUNToWireMTU takes the MTU that the Tailscale TUN presents to the user and
-// returns the on-the-wire MTU necessary to transmit the largest packet that
-// will fit through the TUN, given that we have to add wireguard headers.
-func TUNToWireMTU(t TUNMTU) WireMTU {
- return WireMTU(t + wgHeaderLen)
-}
-
-// WireToTUNMTU takes the MTU of an underlying network device and returns the
-// largest possible MTU for a Tailscale TUN operating on top of that device,
-// given that we have to add wireguard headers.
-func WireToTUNMTU(w WireMTU) TUNMTU {
- if w < wgHeaderLen {
- return 0
- }
- return TUNMTU(w - wgHeaderLen)
-}
-
-// DefaultTUNMTU returns the MTU we use to set the Tailscale TUN
-// MTU. It is also the path MTU that we default to if we have no
-// information about the path to a peer.
-//
-// 1. If set, the value of TS_DEBUG_MTU clamped to a maximum of MaxTUNMTU
-// 2. If TS_DEBUG_ENABLE_PMTUD is set, the maximum size MTU we probe, minus wg overhead
-// 3. If TS_DEBUG_ENABLE_PMTUD is not set, the Safe MTU
-func DefaultTUNMTU() TUNMTU {
- if m, ok := envknob.LookupUintSized("TS_DEBUG_MTU", 10, 32); ok {
- return min(TUNMTU(m), maxTUNMTU)
- }
-
- debugPMTUD, _ := envknob.LookupBool("TS_DEBUG_ENABLE_PMTUD")
- if debugPMTUD {
- // TODO: While we are just probing MTU but not generating PTB,
- // this has to continue to return the safe MTU. When we add the
- // code to generate PTB, this will be:
- //
- // return WireToTUNMTU(maxProbedWireMTU)
- return safeTUNMTU
- }
-
- return safeTUNMTU
-}
-
-// SafeWireMTU returns the wire MTU that is safe to use if we have no
-// information about the path MTU to this peer.
-func SafeWireMTU() WireMTU {
- return TUNToWireMTU(safeTUNMTU)
-}
-
-// DefaultWireMTU returns the default TUN MTU, adjusted for wireguard
-// overhead.
-func DefaultWireMTU() WireMTU {
- return TUNToWireMTU(DefaultTUNMTU())
-}
+// Copyright (c) Tailscale Inc & AUTHORS
+// SPDX-License-Identifier: BSD-3-Clause
+
+package tstun
+
+import (
+ "tailscale.com/envknob"
+)
+
+// The MTU (Maximum Transmission Unit) of a network interface is the largest
+// packet that can be sent or received through that interface, including all
+// headers above the link layer (e.g. IP headers, UDP headers, Wireguard
+// headers, etc.). We have to think about several different values of MTU:
+//
+// Wire MTU: The MTU of an interface underneath the tailscale TUN, e.g. an
+// Ethernet network card will default to a 1500 byte MTU. The user may change
+// this MTU at any time.
+//
+// TUN MTU: The current MTU of the tailscale TUN. This MTU is adjusted downward
+// to make room for the wireguard/tailscale headers. For example, if the
+// underlying network interface's MTU is 1500 bytes, the maximum size of a
+// packet entering the tailscale TUN is 1420 bytes. The user may change this MTU
+// at any time via the OS's tools (ifconfig, ip, etc.).
+//
+// User configured initial MTU: The MTU the tailscale TUN should be created
+// with, set by the user via TS_DEBUG_MTU. It should be adjusted down from the
+// underlying interface MTU by 80 bytes to make room for the wireguard
+// headers. This envknob is mostly for debugging. This value is used once at TUN
+// creation and ignored thereafter.
+//
+// User configured current MTU: The MTU set via the OS's tools (ifconfig, ip,
+// etc.). This MTU can change at any time. Setting the MTU this way goes through
+// the MTU() method of tailscale's TUN wrapper.
+//
+// Maximum probed MTU: This is the largest MTU size that we send probe packets
+// for.
+//
+// Safe MTU: If the tailscale TUN MTU is set to this value, almost all packets
+// will get to their destination. Tailscale defaults to this MTU in the absence
+// of path MTU probe information or user MTU configuration. We may occasionally
+// find a path that needs a smaller MTU but it is very rare.
+//
+// Peer MTU: This is the path MTU to a peer's current best endpoint. It defaults
+// to the Safe MTU unless we have path MTU probe results that tell us otherwise.
+//
+// Initial MTU: This is the MTU tailscaled creates the TUN with. In order of
+// priority, it is:
+//
+// 1. If set, the value of TS_DEBUG_MTU clamped to a maximum of 65536
+// 2. If TS_DEBUG_ENABLE_PMTUD is set, the maximum size MTU we probe, minus wg
+// overhead
+// 3. If TS_DEBUG_ENABLE_PMTUD is not set, the Safe MTU
+//
+// Current MTU: This the MTU of the tailscale TUN at any given moment
+// after TUN creation. In order of priority, it is:
+//
+// 1. The MTU set by the user via the OS, if it has ever been set
+// 2. If TS_DEBUG_ENABLE_PMTUD is set, the maximum size MTU we probe, minus wg
+// overhead
+// 4. If TS_DEBUG_ENABLE_PMTUD is not set, the Safe MTU
+
+// TUNMTU is the MTU for the tailscale TUN.
+type TUNMTU uint32
+
+// WireMTU is the MTU for the underlying network devices.
+type WireMTU uint32
+
+const (
+ // maxTUNMTU is the largest MTU we will consider for the Tailscale
+ // TUN. This is inherited from wireguard-go and can be surprisingly
+ // small; on Windows it is currently 2048 - 32 bytes and iOS it is 1700
+ // - 32 bytes.
+ // TODO(val,raggi): On Windows this seems to derive from RIO driver
+ // constraints in Wireguard but we don't use RIO so could probably make
+ // this bigger.
+ maxTUNMTU TUNMTU = TUNMTU(MaxPacketSize)
+ // safeTUNMTU is the default "safe" MTU for the Tailscale TUN that we
+ // use in the absence of other information such as path MTU probes.
+ safeTUNMTU TUNMTU = 1280
+)
+
+// WireMTUsToProbe is a list of the on-the-wire MTUs we want to probe. Each time
+// magicsock discovery begins, it will send a set of pings, one of each size
+// listed below.
+var WireMTUsToProbe = []WireMTU{
+ WireMTU(safeTUNMTU), // Tailscale over Tailscale :)
+ TUNToWireMTU(safeTUNMTU), // Smallest MTU allowed for IPv6, current default
+ 1400, // Most common MTU minus a few bytes for tunnels
+ 1500, // Most common MTU
+ 8000, // Should fit inside all jumbo frame sizes
+ 9000, // Most jumbo frames are this size or larger
+}
+
+// wgHeaderLen is the length of all the headers Wireguard adds to a packet
+// in the worst case (IPv6). This constant is for use when we can't or
+// shouldn't use information about the IP version of a specific packet
+// (e.g., calculating the MTU for the Tailscale interface.
+//
+// A Wireguard header includes:
+//
+// - 20-byte IPv4 header or 40-byte IPv6 header
+// - 8-byte UDP header
+// - 4-byte type
+// - 4-byte key index
+// - 8-byte nonce
+// - 16-byte authentication tag
+const wgHeaderLen = 40 + 8 + 4 + 4 + 8 + 16
+
+// TUNToWireMTU takes the MTU that the Tailscale TUN presents to the user and
+// returns the on-the-wire MTU necessary to transmit the largest packet that
+// will fit through the TUN, given that we have to add wireguard headers.
+func TUNToWireMTU(t TUNMTU) WireMTU {
+ return WireMTU(t + wgHeaderLen)
+}
+
+// WireToTUNMTU takes the MTU of an underlying network device and returns the
+// largest possible MTU for a Tailscale TUN operating on top of that device,
+// given that we have to add wireguard headers.
+func WireToTUNMTU(w WireMTU) TUNMTU {
+ if w < wgHeaderLen {
+ return 0
+ }
+ return TUNMTU(w - wgHeaderLen)
+}
+
+// DefaultTUNMTU returns the MTU we use to set the Tailscale TUN
+// MTU. It is also the path MTU that we default to if we have no
+// information about the path to a peer.
+//
+// 1. If set, the value of TS_DEBUG_MTU clamped to a maximum of MaxTUNMTU
+// 2. If TS_DEBUG_ENABLE_PMTUD is set, the maximum size MTU we probe, minus wg overhead
+// 3. If TS_DEBUG_ENABLE_PMTUD is not set, the Safe MTU
+func DefaultTUNMTU() TUNMTU {
+ if m, ok := envknob.LookupUintSized("TS_DEBUG_MTU", 10, 32); ok {
+ return min(TUNMTU(m), maxTUNMTU)
+ }
+
+ debugPMTUD, _ := envknob.LookupBool("TS_DEBUG_ENABLE_PMTUD")
+ if debugPMTUD {
+ // TODO: While we are just probing MTU but not generating PTB,
+ // this has to continue to return the safe MTU. When we add the
+ // code to generate PTB, this will be:
+ //
+ // return WireToTUNMTU(maxProbedWireMTU)
+ return safeTUNMTU
+ }
+
+ return safeTUNMTU
+}
+
+// SafeWireMTU returns the wire MTU that is safe to use if we have no
+// information about the path MTU to this peer.
+func SafeWireMTU() WireMTU {
+ return TUNToWireMTU(safeTUNMTU)
+}
+
+// DefaultWireMTU returns the default TUN MTU, adjusted for wireguard
+// overhead.
+func DefaultWireMTU() WireMTU {
+ return TUNToWireMTU(DefaultTUNMTU())
+}
diff --git a/net/tstun/mtu_test.go b/net/tstun/mtu_test.go
index 8d165bfd3..fc5274ae1 100644
--- a/net/tstun/mtu_test.go
+++ b/net/tstun/mtu_test.go
@@ -1,99 +1,99 @@
-// Copyright (c) Tailscale Inc & AUTHORS
-// SPDX-License-Identifier: BSD-3-Clause
-package tstun
-
-import (
- "os"
- "strconv"
- "testing"
-)
-
-// Test the default MTU in the presence of various envknobs.
-func TestDefaultTunMTU(t *testing.T) {
- // Save and restore the envknobs we will be changing.
-
- // TS_DEBUG_MTU sets the MTU to a specific value.
- defer os.Setenv("TS_DEBUG_MTU", os.Getenv("TS_DEBUG_MTU"))
- os.Setenv("TS_DEBUG_MTU", "")
-
- // TS_DEBUG_ENABLE_PMTUD enables path MTU discovery.
- defer os.Setenv("TS_DEBUG_ENABLE_PMTUD", os.Getenv("TS_DEBUG_ENABLE_PMTUD"))
- os.Setenv("TS_DEBUG_ENABLE_PMTUD", "")
-
- // With no MTU envknobs set, we should get the conservative MTU.
- if DefaultTUNMTU() != safeTUNMTU {
- t.Errorf("default TUN MTU = %d, want %d", DefaultTUNMTU(), safeTUNMTU)
- }
-
- // If set, TS_DEBUG_MTU should set the MTU.
- mtu := maxTUNMTU - 1
- os.Setenv("TS_DEBUG_MTU", strconv.Itoa(int(mtu)))
- if DefaultTUNMTU() != mtu {
- t.Errorf("default TUN MTU = %d, want %d, TS_DEBUG_MTU ignored", DefaultTUNMTU(), mtu)
- }
-
- // MTU should be clamped to maxTunMTU.
- mtu = maxTUNMTU + 1
- os.Setenv("TS_DEBUG_MTU", strconv.Itoa(int(mtu)))
- if DefaultTUNMTU() != maxTUNMTU {
- t.Errorf("default TUN MTU = %d, want %d, clamping failed", DefaultTUNMTU(), maxTUNMTU)
- }
-
- // If PMTUD is enabled, the MTU should default to the safe MTU, but only
- // if the user hasn't requested a specific MTU.
- //
- // TODO: When PMTUD is generating PTB responses, this will become the
- // largest MTU we probe.
- os.Setenv("TS_DEBUG_MTU", "")
- os.Setenv("TS_DEBUG_ENABLE_PMTUD", "true")
- if DefaultTUNMTU() != safeTUNMTU {
- t.Errorf("default TUN MTU = %d, want %d", DefaultTUNMTU(), safeTUNMTU)
- }
- // TS_DEBUG_MTU should take precedence over TS_DEBUG_ENABLE_PMTUD.
- mtu = WireToTUNMTU(MaxPacketSize - 1)
- os.Setenv("TS_DEBUG_MTU", strconv.Itoa(int(mtu)))
- if DefaultTUNMTU() != mtu {
- t.Errorf("default TUN MTU = %d, want %d", DefaultTUNMTU(), mtu)
- }
-}
-
-// Test the conversion of wire MTU to/from Tailscale TUN MTU corner cases.
-func TestMTUConversion(t *testing.T) {
- tests := []struct {
- w WireMTU
- t TUNMTU
- }{
- {w: 0, t: 0},
- {w: wgHeaderLen - 1, t: 0},
- {w: wgHeaderLen, t: 0},
- {w: wgHeaderLen + 1, t: 1},
- {w: 1360, t: 1280},
- {w: 1500, t: 1420},
- {w: 9000, t: 8920},
- }
-
- for _, tt := range tests {
- m := WireToTUNMTU(tt.w)
- if m != tt.t {
- t.Errorf("conversion of wire MTU %v to TUN MTU = %v, want %v", tt.w, m, tt.t)
- }
- }
-
- tests2 := []struct {
- t TUNMTU
- w WireMTU
- }{
- {t: 0, w: wgHeaderLen},
- {t: 1, w: wgHeaderLen + 1},
- {t: 1280, w: 1360},
- {t: 1420, w: 1500},
- {t: 8920, w: 9000},
- }
-
- for _, tt := range tests2 {
- m := TUNToWireMTU(tt.t)
- if m != tt.w {
- t.Errorf("conversion of TUN MTU %v to wire MTU = %v, want %v", tt.t, m, tt.w)
- }
- }
-}
+// Copyright (c) Tailscale Inc & AUTHORS
+// SPDX-License-Identifier: BSD-3-Clause
+package tstun
+
+import (
+ "os"
+ "strconv"
+ "testing"
+)
+
+// Test the default MTU in the presence of various envknobs.
+func TestDefaultTunMTU(t *testing.T) {
+ // Save and restore the envknobs we will be changing.
+
+ // TS_DEBUG_MTU sets the MTU to a specific value.
+ defer os.Setenv("TS_DEBUG_MTU", os.Getenv("TS_DEBUG_MTU"))
+ os.Setenv("TS_DEBUG_MTU", "")
+
+ // TS_DEBUG_ENABLE_PMTUD enables path MTU discovery.
+ defer os.Setenv("TS_DEBUG_ENABLE_PMTUD", os.Getenv("TS_DEBUG_ENABLE_PMTUD"))
+ os.Setenv("TS_DEBUG_ENABLE_PMTUD", "")
+
+ // With no MTU envknobs set, we should get the conservative MTU.
+ if DefaultTUNMTU() != safeTUNMTU {
+ t.Errorf("default TUN MTU = %d, want %d", DefaultTUNMTU(), safeTUNMTU)
+ }
+
+ // If set, TS_DEBUG_MTU should set the MTU.
+ mtu := maxTUNMTU - 1
+ os.Setenv("TS_DEBUG_MTU", strconv.Itoa(int(mtu)))
+ if DefaultTUNMTU() != mtu {
+ t.Errorf("default TUN MTU = %d, want %d, TS_DEBUG_MTU ignored", DefaultTUNMTU(), mtu)
+ }
+
+ // MTU should be clamped to maxTunMTU.
+ mtu = maxTUNMTU + 1
+ os.Setenv("TS_DEBUG_MTU", strconv.Itoa(int(mtu)))
+ if DefaultTUNMTU() != maxTUNMTU {
+ t.Errorf("default TUN MTU = %d, want %d, clamping failed", DefaultTUNMTU(), maxTUNMTU)
+ }
+
+ // If PMTUD is enabled, the MTU should default to the safe MTU, but only
+ // if the user hasn't requested a specific MTU.
+ //
+ // TODO: When PMTUD is generating PTB responses, this will become the
+ // largest MTU we probe.
+ os.Setenv("TS_DEBUG_MTU", "")
+ os.Setenv("TS_DEBUG_ENABLE_PMTUD", "true")
+ if DefaultTUNMTU() != safeTUNMTU {
+ t.Errorf("default TUN MTU = %d, want %d", DefaultTUNMTU(), safeTUNMTU)
+ }
+ // TS_DEBUG_MTU should take precedence over TS_DEBUG_ENABLE_PMTUD.
+ mtu = WireToTUNMTU(MaxPacketSize - 1)
+ os.Setenv("TS_DEBUG_MTU", strconv.Itoa(int(mtu)))
+ if DefaultTUNMTU() != mtu {
+ t.Errorf("default TUN MTU = %d, want %d", DefaultTUNMTU(), mtu)
+ }
+}
+
+// Test the conversion of wire MTU to/from Tailscale TUN MTU corner cases.
+func TestMTUConversion(t *testing.T) {
+ tests := []struct {
+ w WireMTU
+ t TUNMTU
+ }{
+ {w: 0, t: 0},
+ {w: wgHeaderLen - 1, t: 0},
+ {w: wgHeaderLen, t: 0},
+ {w: wgHeaderLen + 1, t: 1},
+ {w: 1360, t: 1280},
+ {w: 1500, t: 1420},
+ {w: 9000, t: 8920},
+ }
+
+ for _, tt := range tests {
+ m := WireToTUNMTU(tt.w)
+ if m != tt.t {
+ t.Errorf("conversion of wire MTU %v to TUN MTU = %v, want %v", tt.w, m, tt.t)
+ }
+ }
+
+ tests2 := []struct {
+ t TUNMTU
+ w WireMTU
+ }{
+ {t: 0, w: wgHeaderLen},
+ {t: 1, w: wgHeaderLen + 1},
+ {t: 1280, w: 1360},
+ {t: 1420, w: 1500},
+ {t: 8920, w: 9000},
+ }
+
+ for _, tt := range tests2 {
+ m := TUNToWireMTU(tt.t)
+ if m != tt.w {
+ t.Errorf("conversion of TUN MTU %v to wire MTU = %v, want %v", tt.t, m, tt.w)
+ }
+ }
+}
diff --git a/net/tstun/tun_linux.go b/net/tstun/tun_linux.go
index 9600ceb77..e08f12bc1 100644
--- a/net/tstun/tun_linux.go
+++ b/net/tstun/tun_linux.go
@@ -1,103 +1,103 @@
-// Copyright (c) Tailscale Inc & AUTHORS
-// SPDX-License-Identifier: BSD-3-Clause
-
-package tstun
-
-import (
- "bytes"
- "errors"
- "os"
- "os/exec"
- "strings"
- "syscall"
-
- "tailscale.com/types/logger"
- "tailscale.com/version/distro"
-)
-
-func init() {
- tunDiagnoseFailure = diagnoseLinuxTUNFailure
-}
-
-func diagnoseLinuxTUNFailure(tunName string, logf logger.Logf, createErr error) {
- if errors.Is(createErr, syscall.EBUSY) {
- logf("TUN device %s is busy; another process probably still has it open (from old version of Tailscale that had a bug)", tunName)
- logf("To fix, kill the process that has it open. Find with:\n\n$ sudo lsof -n /dev/net/tun\n\n")
- logf("... and then kill those PID(s)")
- return
- }
-
- var un syscall.Utsname
- err := syscall.Uname(&un)
- if err != nil {
- logf("no TUN, and failed to look up kernel version: %v", err)
- return
- }
- kernel := utsReleaseField(&un)
- logf("Linux kernel version: %s", kernel)
-
- modprobeOut, err := exec.Command("/sbin/modprobe", "tun").CombinedOutput()
- if err == nil {
- logf("'modprobe tun' successful")
- // Either tun is currently loaded, or it's statically
- // compiled into the kernel (which modprobe checks
- // with /lib/modules/$(uname -r)/modules.builtin)
- //
- // So if there's a problem at this point, it's
- // probably because /dev/net/tun doesn't exist.
- const dev = "/dev/net/tun"
- if fi, err := os.Stat(dev); err != nil {
- logf("tun module loaded in kernel, but %s does not exist", dev)
- } else {
- logf("%s: %v", dev, fi.Mode())
- }
-
- // We failed to find why it failed. Just let our
- // caller report the error it got from wireguard-go.
- return
- }
- logf("is CONFIG_TUN enabled in your kernel? `modprobe tun` failed with: %s", modprobeOut)
-
- switch distro.Get() {
- case distro.Debian:
- dpkgOut, err := exec.Command("dpkg", "-S", "kernel/drivers/net/tun.ko").CombinedOutput()
- if len(bytes.TrimSpace(dpkgOut)) == 0 || err != nil {
- logf("tun module not loaded nor found on disk")
- return
- }
- if !bytes.Contains(dpkgOut, []byte(kernel)) {
- logf("kernel/drivers/net/tun.ko found on disk, but not for current kernel; are you in middle of a system update and haven't rebooted? found: %s", dpkgOut)
- }
- case distro.Arch:
- findOut, err := exec.Command("find", "/lib/modules/", "-path", "*/net/tun.ko*").CombinedOutput()
- if len(bytes.TrimSpace(findOut)) == 0 || err != nil {
- logf("tun module not loaded nor found on disk")
- return
- }
- if !bytes.Contains(findOut, []byte(kernel)) {
- logf("kernel/drivers/net/tun.ko found on disk, but not for current kernel; are you in middle of a system update and haven't rebooted? found: %s", findOut)
- }
- case distro.OpenWrt:
- out, err := exec.Command("opkg", "list-installed").CombinedOutput()
- if err != nil {
- logf("error querying OpenWrt installed packages: %s", out)
- return
- }
- for _, pkg := range []string{"kmod-tun", "ca-bundle"} {
- if !bytes.Contains(out, []byte(pkg+" - ")) {
- logf("Missing required package %s; run: opkg install %s", pkg, pkg)
- }
- }
- }
-}
-
-func utsReleaseField(u *syscall.Utsname) string {
- var sb strings.Builder
- for _, v := range u.Release {
- if v == 0 {
- break
- }
- sb.WriteByte(byte(v))
- }
- return strings.TrimSpace(sb.String())
-}
+// Copyright (c) Tailscale Inc & AUTHORS
+// SPDX-License-Identifier: BSD-3-Clause
+
+package tstun
+
+import (
+ "bytes"
+ "errors"
+ "os"
+ "os/exec"
+ "strings"
+ "syscall"
+
+ "tailscale.com/types/logger"
+ "tailscale.com/version/distro"
+)
+
+func init() {
+ tunDiagnoseFailure = diagnoseLinuxTUNFailure
+}
+
+func diagnoseLinuxTUNFailure(tunName string, logf logger.Logf, createErr error) {
+ if errors.Is(createErr, syscall.EBUSY) {
+ logf("TUN device %s is busy; another process probably still has it open (from old version of Tailscale that had a bug)", tunName)
+ logf("To fix, kill the process that has it open. Find with:\n\n$ sudo lsof -n /dev/net/tun\n\n")
+ logf("... and then kill those PID(s)")
+ return
+ }
+
+ var un syscall.Utsname
+ err := syscall.Uname(&un)
+ if err != nil {
+ logf("no TUN, and failed to look up kernel version: %v", err)
+ return
+ }
+ kernel := utsReleaseField(&un)
+ logf("Linux kernel version: %s", kernel)
+
+ modprobeOut, err := exec.Command("/sbin/modprobe", "tun").CombinedOutput()
+ if err == nil {
+ logf("'modprobe tun' successful")
+ // Either tun is currently loaded, or it's statically
+ // compiled into the kernel (which modprobe checks
+ // with /lib/modules/$(uname -r)/modules.builtin)
+ //
+ // So if there's a problem at this point, it's
+ // probably because /dev/net/tun doesn't exist.
+ const dev = "/dev/net/tun"
+ if fi, err := os.Stat(dev); err != nil {
+ logf("tun module loaded in kernel, but %s does not exist", dev)
+ } else {
+ logf("%s: %v", dev, fi.Mode())
+ }
+
+ // We failed to find why it failed. Just let our
+ // caller report the error it got from wireguard-go.
+ return
+ }
+ logf("is CONFIG_TUN enabled in your kernel? `modprobe tun` failed with: %s", modprobeOut)
+
+ switch distro.Get() {
+ case distro.Debian:
+ dpkgOut, err := exec.Command("dpkg", "-S", "kernel/drivers/net/tun.ko").CombinedOutput()
+ if len(bytes.TrimSpace(dpkgOut)) == 0 || err != nil {
+ logf("tun module not loaded nor found on disk")
+ return
+ }
+ if !bytes.Contains(dpkgOut, []byte(kernel)) {
+ logf("kernel/drivers/net/tun.ko found on disk, but not for current kernel; are you in middle of a system update and haven't rebooted? found: %s", dpkgOut)
+ }
+ case distro.Arch:
+ findOut, err := exec.Command("find", "/lib/modules/", "-path", "*/net/tun.ko*").CombinedOutput()
+ if len(bytes.TrimSpace(findOut)) == 0 || err != nil {
+ logf("tun module not loaded nor found on disk")
+ return
+ }
+ if !bytes.Contains(findOut, []byte(kernel)) {
+ logf("kernel/drivers/net/tun.ko found on disk, but not for current kernel; are you in middle of a system update and haven't rebooted? found: %s", findOut)
+ }
+ case distro.OpenWrt:
+ out, err := exec.Command("opkg", "list-installed").CombinedOutput()
+ if err != nil {
+ logf("error querying OpenWrt installed packages: %s", out)
+ return
+ }
+ for _, pkg := range []string{"kmod-tun", "ca-bundle"} {
+ if !bytes.Contains(out, []byte(pkg+" - ")) {
+ logf("Missing required package %s; run: opkg install %s", pkg, pkg)
+ }
+ }
+ }
+}
+
+func utsReleaseField(u *syscall.Utsname) string {
+ var sb strings.Builder
+ for _, v := range u.Release {
+ if v == 0 {
+ break
+ }
+ sb.WriteByte(byte(v))
+ }
+ return strings.TrimSpace(sb.String())
+}
diff --git a/net/tstun/tun_macos.go b/net/tstun/tun_macos.go
index 3506f05b1..f71494f0b 100644
--- a/net/tstun/tun_macos.go
+++ b/net/tstun/tun_macos.go
@@ -1,25 +1,25 @@
-// Copyright (c) Tailscale Inc & AUTHORS
-// SPDX-License-Identifier: BSD-3-Clause
-
-//go:build darwin && !ios
-
-package tstun
-
-import (
- "os"
-
- "tailscale.com/types/logger"
-)
-
-func init() {
- tunDiagnoseFailure = diagnoseDarwinTUNFailure
-}
-
-func diagnoseDarwinTUNFailure(tunName string, logf logger.Logf, err error) {
- if os.Getuid() != 0 {
- logf("failed to create TUN device as non-root user; use 'sudo tailscaled', or run under launchd with 'sudo tailscaled install-system-daemon'")
- }
- if tunName != "utun" {
- logf("failed to create TUN device %q; try using tun device \"utun\" instead for automatic selection", tunName)
- }
-}
+// Copyright (c) Tailscale Inc & AUTHORS
+// SPDX-License-Identifier: BSD-3-Clause
+
+//go:build darwin && !ios
+
+package tstun
+
+import (
+ "os"
+
+ "tailscale.com/types/logger"
+)
+
+func init() {
+ tunDiagnoseFailure = diagnoseDarwinTUNFailure
+}
+
+func diagnoseDarwinTUNFailure(tunName string, logf logger.Logf, err error) {
+ if os.Getuid() != 0 {
+ logf("failed to create TUN device as non-root user; use 'sudo tailscaled', or run under launchd with 'sudo tailscaled install-system-daemon'")
+ }
+ if tunName != "utun" {
+ logf("failed to create TUN device %q; try using tun device \"utun\" instead for automatic selection", tunName)
+ }
+}
diff --git a/net/tstun/tun_notwindows.go b/net/tstun/tun_notwindows.go
index 087fcd4ee..60f1c62ba 100644
--- a/net/tstun/tun_notwindows.go
+++ b/net/tstun/tun_notwindows.go
@@ -1,12 +1,12 @@
-// Copyright (c) Tailscale Inc & AUTHORS
-// SPDX-License-Identifier: BSD-3-Clause
-
-//go:build !windows
-
-package tstun
-
-import "github.com/tailscale/wireguard-go/tun"
-
-func interfaceName(dev tun.Device) (string, error) {
- return dev.Name()
-}
+// Copyright (c) Tailscale Inc & AUTHORS
+// SPDX-License-Identifier: BSD-3-Clause
+
+//go:build !windows
+
+package tstun
+
+import "github.com/tailscale/wireguard-go/tun"
+
+func interfaceName(dev tun.Device) (string, error) {
+ return dev.Name()
+}