diff options
| author | Nick Khyl <nickk@tailscale.com> | 2024-12-05 13:16:48 -0600 |
|---|---|---|
| committer | Nick Khyl <nickk@tailscale.com> | 2024-12-05 13:16:48 -0600 |
| commit | 0267fe83b200f1702a2fa0a395442c02a053fadb (patch) | |
| tree | 63654c55225eeb834de59a5a0bc8d19033c6145b /net | |
| parent | 87546a5edf6b6503a87eeb2d666baba57398a066 (diff) | |
| download | tailscale-1.78.0.tar.xz tailscale-1.78.0.zip | |
VERSION.txt: this is v1.78.0v1.78.0
Signed-off-by: Nick Khyl <nickk@tailscale.com>
Diffstat (limited to 'net')
95 files changed, 8117 insertions, 8117 deletions
diff --git a/net/art/art_test.go b/net/art/art_test.go index daf8553ca..e3a427107 100644 --- a/net/art/art_test.go +++ b/net/art/art_test.go @@ -1,20 +1,20 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -package art - -import ( - "os" - "testing" - - "tailscale.com/util/cibuild" -) - -func TestMain(m *testing.M) { - if cibuild.On() { - // Skip CI on GitHub for now - // TODO: https://github.com/tailscale/tailscale/issues/7866 - os.Exit(0) - } - os.Exit(m.Run()) -} +// Copyright (c) Tailscale Inc & AUTHORS
+// SPDX-License-Identifier: BSD-3-Clause
+
+package art
+
+import (
+ "os"
+ "testing"
+
+ "tailscale.com/util/cibuild"
+)
+
+func TestMain(m *testing.M) {
+ if cibuild.On() {
+ // Skip CI on GitHub for now
+ // TODO: https://github.com/tailscale/tailscale/issues/7866
+ os.Exit(0)
+ }
+ os.Exit(m.Run())
+}
diff --git a/net/art/table.go b/net/art/table.go index fa3975778..2e130d82f 100644 --- a/net/art/table.go +++ b/net/art/table.go @@ -1,641 +1,641 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -// Package art provides a routing table that implements the Allotment Routing -// Table (ART) algorithm by Donald Knuth, as described in the paper by Yoichi -// Hariguchi. -// -// ART outperforms the traditional radix tree implementations for route lookups, -// insertions, and deletions. -// -// For more information, see Yoichi Hariguchi's paper: -// https://cseweb.ucsd.edu//~varghese/TEACH/cs228/artlookup.pdf -package art - -import ( - "bytes" - "encoding/binary" - "fmt" - "io" - "math/bits" - "net/netip" - "strings" - "sync" -) - -const ( - debugInsert = false - debugDelete = false -) - -// Table is an IPv4 and IPv6 routing table. -type Table[T any] struct { - v4 strideTable[T] - v6 strideTable[T] - initOnce sync.Once -} - -func (t *Table[T]) init() { - t.initOnce.Do(func() { - t.v4.prefix = netip.PrefixFrom(netip.IPv4Unspecified(), 0) - t.v6.prefix = netip.PrefixFrom(netip.IPv6Unspecified(), 0) - }) -} - -func (t *Table[T]) tableForAddr(addr netip.Addr) *strideTable[T] { - if addr.Is6() { - return &t.v6 - } - return &t.v4 -} - -// Get does a route lookup for addr and returns the associated value, or nil if -// no route matched. -func (t *Table[T]) Get(addr netip.Addr) (ret T, ok bool) { - t.init() - - // Ideally we would use addr.AsSlice here, but AsSlice is just - // barely complex enough that it can't be inlined, and that in - // turn causes the slice to escape to the heap. Using As16 and - // manual slicing here helps the compiler keep Get alloc-free. - st := t.tableForAddr(addr) - rawAddr := addr.As16() - bs := rawAddr[:] - if addr.Is4() { - bs = bs[12:] - } - - i := 0 - // With path compression, we might skip over some address bits while walking - // to a strideTable leaf. This means the leaf answer we find might not be - // correct, because path compression took us down the wrong subtree. When - // that happens, we have to backtrack and figure out which most specific - // route further up the tree is relevant to addr, and return that. - // - // So, as we walk down the stride tables, each time we find a non-nil route - // result, we have to remember it and the associated strideTable prefix. - // - // We could also deal with this edge case of path compression by checking - // the strideTable prefix on each table as we descend, but that means we - // have to pay N prefix.Contains checks on every route lookup (where N is - // the number of strideTables in the path), rather than only paying M prefix - // comparisons in the edge case (where M is the number of strideTables in - // the path with a non-nil route of their own). - const maxDepth = 16 - type prefixAndRoute struct { - prefix netip.Prefix - route T - } - strideMatch := make([]prefixAndRoute, 0, maxDepth) -findLeaf: - for { - rt, rtOK, child := st.getValAndChild(bs[i]) - if rtOK { - // This strideTable contains a route that may be relevant to our - // search, remember it. - strideMatch = append(strideMatch, prefixAndRoute{st.prefix, rt}) - } - if child == nil { - // No sub-routes further down, the last thing we recorded - // in strideRoutes is tentatively the result, barring - // misdirection from path compression. - break findLeaf - } - st = child - // Path compression means we may be skipping over some intermediate - // tables. We have to skip forward to whatever depth st now references. - i = st.prefix.Bits() / 8 - } - - // Walk backwards through the hits we recorded in strideRoutes and - // stridePrefixes, returning the first one whose subtree matches addr. - // - // In the common case where path compression did not mislead us, we'll - // return on the first loop iteration because the last route we recorded was - // the correct most-specific route. - for i := len(strideMatch) - 1; i >= 0; i-- { - if m := strideMatch[i]; m.prefix.Contains(addr) { - return m.route, true - } - } - - // We either found no route hits at all (both previous loops terminated - // immediately), or we went on a wild goose chase down a compressed path for - // the wrong prefix, and also found no usable routes on the way back up to - // the root. This is a miss. - return ret, false -} - -// Insert adds pfx to the table, with value val. -// If pfx is already present in the table, its value is set to val. -func (t *Table[T]) Insert(pfx netip.Prefix, val T) { - t.init() - - // The standard library doesn't enforce normalized prefixes (where - // the non-prefix bits are all zero). These algorithms require - // normalized prefixes, so do it upfront. - pfx = pfx.Masked() - - if debugInsert { - defer func() { - fmt.Printf("%s", t.debugSummary()) - }() - fmt.Printf("\ninsert: start pfx=%s\n", pfx) - } - - st := t.tableForAddr(pfx.Addr()) - - // This algorithm is full of off-by-one headaches that boil down - // to the fact that pfx.Bits() has (2^n)+1 values, rather than - // just 2^n. For example, an IPv4 prefix length can be 0 through - // 32, which is 33 values. - // - // This extra possible value creates a lot of problems as we do - // bits and bytes math to traverse strideTables below. So, we - // treat the default route 0/0 specially here, that way the rest - // of the logic goes back to having 2^n values to reason about, - // which can be done in a nice and regular fashion with no edge - // cases. - if pfx.Bits() == 0 { - if debugInsert { - fmt.Printf("insert: default route\n") - } - st.insert(0, 0, val) - return - } - - // No matter what we do as we traverse strideTables, our final - // action will be to insert the last 1-8 bits of pfx into a - // strideTable somewhere. - // - // We calculate upfront the byte position of the end of the - // prefix; the number of bits within that byte that contain prefix - // data; and the prefix of the strideTable into which we'll - // eventually insert. - // - // We need this in a couple different branches of the code below, - // and because the possible values are 1-indexed (1 through 32 for - // ipv4, 1 through 128 for ipv6), the math is very slightly - // unusual to account for the off-by-one indexing. Do it once up - // here, with this large comment, rather than reproduce the subtle - // math in multiple places further down. - finalByteIdx := (pfx.Bits() - 1) / 8 - finalBits := pfx.Bits() - (finalByteIdx * 8) - finalStridePrefix, err := pfx.Addr().Prefix(finalByteIdx * 8) - if err != nil { - panic(fmt.Sprintf("invalid prefix requested: %s/%d", pfx.Addr(), finalByteIdx*8)) - } - if debugInsert { - fmt.Printf("insert: finalByteIdx=%d finalBits=%d finalStridePrefix=%s\n", finalByteIdx, finalBits, finalStridePrefix) - } - - // The strideTable we want to insert into is potentially at the - // end of a chain of strideTables, each one encoding 8 bits of the - // prefix. - // - // We're expecting to walk down a path of tables, although with - // prefix compression we may end up skipping some links in the - // chain, or taking wrong turns and having to course correct. - // - // As we walk down the tree, byteIdx is the byte of bs we're - // currently examining to choose our next step, and numBits is the - // number of bits that remain in pfx, starting with the byte at - // byteIdx inclusive. - bs := pfx.Addr().AsSlice() - byteIdx := 0 - numBits := pfx.Bits() - for { - if debugInsert { - fmt.Printf("insert: loop byteIdx=%d numBits=%d st.prefix=%s\n", byteIdx, numBits, st.prefix) - } - if numBits <= 8 { - if debugInsert { - fmt.Printf("insert: existing leaf st.prefix=%s addr=%d/%d\n", st.prefix, bs[finalByteIdx], finalBits) - } - // We've reached the end of the prefix, whichever - // strideTable we're looking at now is the place where we - // need to insert. - st.insert(bs[finalByteIdx], finalBits, val) - return - } - - // Otherwise, we need to go down at least one more level of - // strideTables. With prefix compression, each level of - // descent can have one of three outcomes: we find a place - // where prefix compression is possible; a place where prefix - // compression made us take a "wrong turn"; or a point along - // our intended path that we have to keep following. - child, created := st.getOrCreateChild(bs[byteIdx]) - switch { - case created: - // The subtree we need for pfx doesn't exist yet. The rest - // of the path, if we were to create it, will consist of a - // bunch of strideTables with a single child each. We can - // use path compression to elide those intermediates, and - // jump straight to the final strideTable that hosts this - // prefix. - child.prefix = finalStridePrefix - child.insert(bs[finalByteIdx], finalBits, val) - if debugInsert { - fmt.Printf("insert: new leaf st.prefix=%s child.prefix=%s addr=%d/%d\n", st.prefix, child.prefix, bs[finalByteIdx], finalBits) - } - return - case !prefixStrictlyContains(child.prefix, pfx): - // child already exists, but its prefix does not contain - // our destination. This means that the path between st - // and child was compressed by a previous insertion, and - // somewhere in the (implicit) compressed path we took a - // wrong turn, into the wrong part of st's subtree. - // - // This is okay, because pfx and child.prefix must have a - // common ancestor node somewhere between st and child. We - // can figure out what node that is, and materialize it. - // - // Once we've done that, we can immediately complete the - // remainder of the insertion in one of two ways, without - // further traversal. See a little further down for what - // those are. - if debugInsert { - fmt.Printf("insert: wrong turn, pfx=%s child.prefix=%s\n", pfx, child.prefix) - } - intermediatePrefix, addrOfExisting, addrOfNew := computePrefixSplit(child.prefix, pfx) - intermediate := &strideTable[T]{prefix: intermediatePrefix} // TODO: make this whole thing be st.AddIntermediate or something? - st.setChild(bs[byteIdx], intermediate) - intermediate.setChild(addrOfExisting, child) - - if debugInsert { - fmt.Printf("insert: new intermediate st.prefix=%s intermediate.prefix=%s child.prefix=%s\n", st.prefix, intermediate.prefix, child.prefix) - } - - // Now, we have a chain of st -> intermediate -> child. - // - // pfx either lives in a different child of intermediate, - // or in intermediate itself. For example, if we created - // the intermediate 1.2.0.0/16, pfx=1.2.3.4/32 would have - // to go into a new child of intermediate, but - // pfx=1.2.0.0/18 would go into intermediate directly. - if remain := pfx.Bits() - intermediate.prefix.Bits(); remain <= 8 { - // pfx lives in intermediate. - if debugInsert { - fmt.Printf("insert: into intermediate intermediate.prefix=%s addr=%d/%d\n", intermediate.prefix, bs[finalByteIdx], finalBits) - } - intermediate.insert(bs[finalByteIdx], finalBits, val) - } else { - // pfx lives in a different child subtree of - // intermediate. By definition this subtree doesn't - // exist at all, otherwise we'd never have entered - // this entire "wrong turn" codepath in the first - // place. - // - // This means we can apply prefix compression as we - // create this new child, and we're done. - st, created = intermediate.getOrCreateChild(addrOfNew) - if !created { - panic("new child path unexpectedly exists during path decompression") - } - st.prefix = finalStridePrefix - st.insert(bs[finalByteIdx], finalBits, val) - if debugInsert { - fmt.Printf("insert: new child st.prefix=%s addr=%d/%d\n", st.prefix, bs[finalByteIdx], finalBits) - } - } - - return - default: - // An expected child table exists along pfx's - // path. Continue traversing downwards. - st = child - byteIdx = child.prefix.Bits() / 8 - numBits = pfx.Bits() - child.prefix.Bits() - if debugInsert { - fmt.Printf("insert: descend st.prefix=%s\n", st.prefix) - } - } - } -} - -// Delete removes pfx from the table, if it is present. -func (t *Table[T]) Delete(pfx netip.Prefix) { - t.init() - - // The standard library doesn't enforce normalized prefixes (where - // the non-prefix bits are all zero). These algorithms require - // normalized prefixes, so do it upfront. - pfx = pfx.Masked() - - if debugDelete { - defer func() { - fmt.Printf("%s", t.debugSummary()) - }() - fmt.Printf("\ndelete: start pfx=%s table:\n%s", pfx, t.debugSummary()) - } - - st := t.tableForAddr(pfx.Addr()) - - // This algorithm is full of off-by-one headaches, just like - // Insert. See the comment in Insert for more details. Bottom - // line: we handle the default route as a special case, and that - // simplifies the rest of the code slightly. - if pfx.Bits() == 0 { - if debugDelete { - fmt.Printf("delete: default route\n") - } - st.delete(0, 0) - return - } - - // Deletion may drive the refcount of some strideTables down to - // zero. We need to clean up these dangling tables, so we have to - // keep track of which tables we touch on the way down, and which - // strideEntry index each child is registered in. - // - // Note that the strideIndex and strideTables entries are off-by-one. - // The child table pointer is recorded at i+1, but it is referenced by a - // particular index in the parent table, at index i. - // - // In other words: entry number strideIndexes[0] in - // strideTables[0] is the same pointer as strideTables[1]. - // - // This results in some slightly odd array accesses further down - // in this code, because in a single loop iteration we have to - // write to strideTables[N] and strideIndexes[N-1]. - strideIdx := 0 - strideTables := [16]*strideTable[T]{st} - strideIndexes := [15]uint8{} - - // Similar to Insert, navigate down the tree of strideTables, - // looking for the one that houses this prefix. This part is - // easier than with insertion, since we can bail if the path ends - // early or takes an unexpected detour. However, unlike - // insertion, there's a whole post-deletion cleanup phase later - // on. - // - // As we walk down the tree, byteIdx is the byte of bs we're - // currently examining to choose our next step, and numBits is the - // number of bits that remain in pfx, starting with the byte at - // byteIdx inclusive. - bs := pfx.Addr().AsSlice() - byteIdx := 0 - numBits := pfx.Bits() - for numBits > 8 { - if debugDelete { - fmt.Printf("delete: loop byteIdx=%d numBits=%d st.prefix=%s\n", byteIdx, numBits, st.prefix) - } - child := st.getChild(bs[byteIdx]) - if child == nil { - // Prefix can't exist in the table, because one of the - // necessary strideTables doesn't exist. - if debugDelete { - fmt.Printf("delete: missing necessary child pfx=%s\n", pfx) - } - return - } - strideIndexes[strideIdx] = bs[byteIdx] - strideTables[strideIdx+1] = child - strideIdx++ - - // Path compression means byteIdx can jump forwards - // unpredictably. Recompute the next byte to look at from the - // child we just found. - byteIdx = child.prefix.Bits() / 8 - numBits = pfx.Bits() - child.prefix.Bits() - st = child - - if debugDelete { - fmt.Printf("delete: descend st.prefix=%s\n", st.prefix) - } - } - - // We reached a leaf stride table that seems to be in the right - // spot. But path compression might have led us to the wrong - // table. - if !prefixStrictlyContains(st.prefix, pfx) { - // Wrong table, the requested prefix can't exist since its - // path led us to the wrong place. - if debugDelete { - fmt.Printf("delete: wrong leaf table pfx=%s\n", pfx) - } - return - } - if debugDelete { - fmt.Printf("delete: delete from st.prefix=%s addr=%d/%d\n", st.prefix, bs[byteIdx], numBits) - } - if routeExisted := st.delete(bs[byteIdx], numBits); !routeExisted { - // We're in the right strideTable, but pfx wasn't in - // it. Refcounts haven't changed, so we can skip cleanup. - if debugDelete { - fmt.Printf("delete: prefix not present pfx=%s\n", pfx) - } - return - } - - // st.delete reduced st's refcount by one. This table may now be - // reclaimable, and depending on how we can reclaim it, the parent - // tables may also need to be reclaimed. This loop ends as soon as - // an iteration takes no action, or takes an action that doesn't - // alter the parent table's refcounts. - // - // We start our walk back at strideTables[strideIdx], which - // contains st. - for strideIdx > 0 { - cur := strideTables[strideIdx] - if debugDelete { - fmt.Printf("delete: GC? strideIdx=%d st.prefix=%s\n", strideIdx, cur.prefix) - } - if cur.routeRefs > 0 { - // the strideTable has other route entries, it cannot be - // deleted or compacted. - if debugDelete { - fmt.Printf("delete: has other routes st.prefix=%s\n", cur.prefix) - } - return - } - switch cur.childRefs { - case 0: - // no routeRefs and no childRefs, this table can be - // deleted. This will alter the parent table's refcount, - // so we'll have to look at it as well (in the next loop - // iteration). - if debugDelete { - fmt.Printf("delete: remove st.prefix=%s\n", cur.prefix) - } - strideTables[strideIdx-1].deleteChild(strideIndexes[strideIdx-1]) - strideIdx-- - case 1: - // This table has no routes, and a single child. Compact - // this table out of existence by making the parent point - // directly at the one child. This does not affect the - // parent's refcounts, so the parent can't be eligible for - // deletion or compaction, and we can stop. - child := strideTables[strideIdx].findFirstChild() // only 1 child exists, by definition - parent := strideTables[strideIdx-1] - if debugDelete { - fmt.Printf("delete: compact parent.prefix=%s st.prefix=%s child.prefix=%s\n", parent.prefix, cur.prefix, child.prefix) - } - strideTables[strideIdx-1].setChild(strideIndexes[strideIdx-1], child) - return - default: - // This table has two or more children, so it's acting as a "fork in - // the road" between two prefix subtrees. It cannot be deleted, and - // thus no further cleanups are possible. - if debugDelete { - fmt.Printf("delete: fork table st.prefix=%s\n", cur.prefix) - } - return - } - } -} - -// debugSummary prints the tree of allocated strideTables in t, with each -// strideTable's refcount. -func (t *Table[T]) debugSummary() string { - t.init() - var ret bytes.Buffer - fmt.Fprintf(&ret, "v4: ") - strideSummary(&ret, &t.v4, 4) - fmt.Fprintf(&ret, "v6: ") - strideSummary(&ret, &t.v6, 4) - return ret.String() -} - -func strideSummary[T any](w io.Writer, st *strideTable[T], indent int) { - fmt.Fprintf(w, "%s: %d routes, %d children\n", st.prefix, st.routeRefs, st.childRefs) - indent += 4 - st.treeDebugStringRec(w, 1, indent) - for addr, child := range st.children { - if child == nil { - continue - } - fmt.Fprintf(w, "%s%d/8 (%02x/8): ", strings.Repeat(" ", indent), addr, addr) - strideSummary(w, child, indent) - } -} - -// prefixStrictlyContains reports whether child is a prefix within -// parent, but not parent itself. -func prefixStrictlyContains(parent, child netip.Prefix) bool { - return parent.Overlaps(child) && parent.Bits() < child.Bits() -} - -// computePrefixSplit returns the smallest common prefix that contains -// both a and b. lastCommon is 8-bit aligned, with aStride and bStride -// indicating the value of the 8-bit stride immediately following -// lastCommon. -// -// computePrefixSplit is used in constructing an intermediate -// strideTable when a new prefix needs to be inserted in a compressed -// table. It can be read as: given that a is already in the table, and -// b is being inserted, what is the prefix of the new intermediate -// strideTable that needs to be created, and at what addresses in that -// new strideTable should a and b's subsequent strideTables be -// attached? -// -// Note as a special case, this can be called with a==b. An example of -// when this happens: -// - We want to insert the prefix 1.2.0.0/16 -// - A strideTable exists for 1.2.0.0/16, because another child -// prefix already exists (e.g. 1.2.3.4/32) -// - The 1.0.0.0/8 strideTable does not exist, because path -// compression removed it. -// -// In this scenario, the caller of computePrefixSplit ends up making a -// "wrong turn" while traversing strideTables: it was looking for the -// 1.0.0.0/8 table, but ended up at the 1.2.0.0/16 table. When this -// happens, it will invoke computePrefixSplit(1.2.0.0/16, 1.2.0.0/16), -// and we return 1.0.0.0/8 as the missing intermediate. -func computePrefixSplit(a, b netip.Prefix) (lastCommon netip.Prefix, aStride, bStride uint8) { - a = a.Masked() - b = b.Masked() - if a.Bits() == 0 || b.Bits() == 0 { - panic("computePrefixSplit called with a default route") - } - if a.Addr().Is4() != b.Addr().Is4() { - panic("computePrefixSplit called with mismatched address families") - } - - minPrefixLen := a.Bits() - if b.Bits() < minPrefixLen { - minPrefixLen = b.Bits() - } - - commonBits := commonBits(a.Addr(), b.Addr(), minPrefixLen) - // We want to know how many 8-bit strides are shared between a and - // b. Naively, this would be commonBits/8, but this introduces an - // off-by-one error. This is due to the way our ART stores - // prefixes whose length falls exactly on a stride boundary. - // - // Consider 192.168.1.0/24 and 192.168.0.0/16. commonBits - // correctly reports that these prefixes have their first 16 bits - // in common. However, in the ART they only share 1 common stride: - // they both use the 192.0.0.0/8 strideTable, but 192.168.0.0/16 - // is stored as 168/8 within that table, and not as 0/0 in the - // 192.168.0.0/16 table. - // - // So, when commonBits matches the length of one of the inputs and - // falls on a boundary between strides, the strideTable one - // further up from commonBits/8 is the one we need to create, - // which means we have to adjust the stride count down by one. - if commonBits == minPrefixLen { - commonBits-- - } - commonStrides := commonBits / 8 - lastCommon, err := a.Addr().Prefix(commonStrides * 8) - if err != nil { - panic(fmt.Sprintf("computePrefixSplit constructing common prefix: %v", err)) - } - if a.Addr().Is4() { - aStride = a.Addr().As4()[commonStrides] - bStride = b.Addr().As4()[commonStrides] - } else { - aStride = a.Addr().As16()[commonStrides] - bStride = b.Addr().As16()[commonStrides] - } - return lastCommon, aStride, bStride -} - -// commonBits returns the number of common leading bits of a and b. -// If the number of common bits exceeds maxBits, it returns maxBits -// instead. -func commonBits(a, b netip.Addr, maxBits int) int { - if a.Is4() != b.Is4() { - panic("commonStrides called with mismatched address families") - } - var common int - // The following implements an old bit-twiddling trick to compute - // the number of common leading bits: if you XOR two numbers - // together, equal bits become 0 and unequal bits become 1. You - // can then count the number of leading zeros (which is a single - // instruction on modern CPUs) to get the answer. - // - // This code is a little more complex than just XOR + count - // leading zeros, because IPv4 and IPv6 are different sizes, and - // for IPv6 we have to do the math in two 64-bit chunks because Go - // lacks a uint128 type. - if a.Is4() { - aNum, bNum := ipv4AsUint(a), ipv4AsUint(b) - common = bits.LeadingZeros32(aNum ^ bNum) - } else { - aNumHi, aNumLo := ipv6AsUint(a) - bNumHi, bNumLo := ipv6AsUint(b) - common = bits.LeadingZeros64(aNumHi ^ bNumHi) - if common == 64 { - common += bits.LeadingZeros64(aNumLo ^ bNumLo) - } - } - if common > maxBits { - common = maxBits - } - return common -} - -// ipv4AsUint returns ip as a uint32. -func ipv4AsUint(ip netip.Addr) uint32 { - bs := ip.As4() - return binary.BigEndian.Uint32(bs[:]) -} - -// ipv6AsUint returns ip as a pair of uint64s. -func ipv6AsUint(ip netip.Addr) (uint64, uint64) { - bs := ip.As16() - return binary.BigEndian.Uint64(bs[:8]), binary.BigEndian.Uint64(bs[8:]) -} +// Copyright (c) Tailscale Inc & AUTHORS
+// SPDX-License-Identifier: BSD-3-Clause
+
+// Package art provides a routing table that implements the Allotment Routing
+// Table (ART) algorithm by Donald Knuth, as described in the paper by Yoichi
+// Hariguchi.
+//
+// ART outperforms the traditional radix tree implementations for route lookups,
+// insertions, and deletions.
+//
+// For more information, see Yoichi Hariguchi's paper:
+// https://cseweb.ucsd.edu//~varghese/TEACH/cs228/artlookup.pdf
+package art
+
+import (
+ "bytes"
+ "encoding/binary"
+ "fmt"
+ "io"
+ "math/bits"
+ "net/netip"
+ "strings"
+ "sync"
+)
+
+const (
+ debugInsert = false
+ debugDelete = false
+)
+
+// Table is an IPv4 and IPv6 routing table.
+type Table[T any] struct {
+ v4 strideTable[T]
+ v6 strideTable[T]
+ initOnce sync.Once
+}
+
+func (t *Table[T]) init() {
+ t.initOnce.Do(func() {
+ t.v4.prefix = netip.PrefixFrom(netip.IPv4Unspecified(), 0)
+ t.v6.prefix = netip.PrefixFrom(netip.IPv6Unspecified(), 0)
+ })
+}
+
+func (t *Table[T]) tableForAddr(addr netip.Addr) *strideTable[T] {
+ if addr.Is6() {
+ return &t.v6
+ }
+ return &t.v4
+}
+
+// Get does a route lookup for addr and returns the associated value, or nil if
+// no route matched.
+func (t *Table[T]) Get(addr netip.Addr) (ret T, ok bool) {
+ t.init()
+
+ // Ideally we would use addr.AsSlice here, but AsSlice is just
+ // barely complex enough that it can't be inlined, and that in
+ // turn causes the slice to escape to the heap. Using As16 and
+ // manual slicing here helps the compiler keep Get alloc-free.
+ st := t.tableForAddr(addr)
+ rawAddr := addr.As16()
+ bs := rawAddr[:]
+ if addr.Is4() {
+ bs = bs[12:]
+ }
+
+ i := 0
+ // With path compression, we might skip over some address bits while walking
+ // to a strideTable leaf. This means the leaf answer we find might not be
+ // correct, because path compression took us down the wrong subtree. When
+ // that happens, we have to backtrack and figure out which most specific
+ // route further up the tree is relevant to addr, and return that.
+ //
+ // So, as we walk down the stride tables, each time we find a non-nil route
+ // result, we have to remember it and the associated strideTable prefix.
+ //
+ // We could also deal with this edge case of path compression by checking
+ // the strideTable prefix on each table as we descend, but that means we
+ // have to pay N prefix.Contains checks on every route lookup (where N is
+ // the number of strideTables in the path), rather than only paying M prefix
+ // comparisons in the edge case (where M is the number of strideTables in
+ // the path with a non-nil route of their own).
+ const maxDepth = 16
+ type prefixAndRoute struct {
+ prefix netip.Prefix
+ route T
+ }
+ strideMatch := make([]prefixAndRoute, 0, maxDepth)
+findLeaf:
+ for {
+ rt, rtOK, child := st.getValAndChild(bs[i])
+ if rtOK {
+ // This strideTable contains a route that may be relevant to our
+ // search, remember it.
+ strideMatch = append(strideMatch, prefixAndRoute{st.prefix, rt})
+ }
+ if child == nil {
+ // No sub-routes further down, the last thing we recorded
+ // in strideRoutes is tentatively the result, barring
+ // misdirection from path compression.
+ break findLeaf
+ }
+ st = child
+ // Path compression means we may be skipping over some intermediate
+ // tables. We have to skip forward to whatever depth st now references.
+ i = st.prefix.Bits() / 8
+ }
+
+ // Walk backwards through the hits we recorded in strideRoutes and
+ // stridePrefixes, returning the first one whose subtree matches addr.
+ //
+ // In the common case where path compression did not mislead us, we'll
+ // return on the first loop iteration because the last route we recorded was
+ // the correct most-specific route.
+ for i := len(strideMatch) - 1; i >= 0; i-- {
+ if m := strideMatch[i]; m.prefix.Contains(addr) {
+ return m.route, true
+ }
+ }
+
+ // We either found no route hits at all (both previous loops terminated
+ // immediately), or we went on a wild goose chase down a compressed path for
+ // the wrong prefix, and also found no usable routes on the way back up to
+ // the root. This is a miss.
+ return ret, false
+}
+
+// Insert adds pfx to the table, with value val.
+// If pfx is already present in the table, its value is set to val.
+func (t *Table[T]) Insert(pfx netip.Prefix, val T) {
+ t.init()
+
+ // The standard library doesn't enforce normalized prefixes (where
+ // the non-prefix bits are all zero). These algorithms require
+ // normalized prefixes, so do it upfront.
+ pfx = pfx.Masked()
+
+ if debugInsert {
+ defer func() {
+ fmt.Printf("%s", t.debugSummary())
+ }()
+ fmt.Printf("\ninsert: start pfx=%s\n", pfx)
+ }
+
+ st := t.tableForAddr(pfx.Addr())
+
+ // This algorithm is full of off-by-one headaches that boil down
+ // to the fact that pfx.Bits() has (2^n)+1 values, rather than
+ // just 2^n. For example, an IPv4 prefix length can be 0 through
+ // 32, which is 33 values.
+ //
+ // This extra possible value creates a lot of problems as we do
+ // bits and bytes math to traverse strideTables below. So, we
+ // treat the default route 0/0 specially here, that way the rest
+ // of the logic goes back to having 2^n values to reason about,
+ // which can be done in a nice and regular fashion with no edge
+ // cases.
+ if pfx.Bits() == 0 {
+ if debugInsert {
+ fmt.Printf("insert: default route\n")
+ }
+ st.insert(0, 0, val)
+ return
+ }
+
+ // No matter what we do as we traverse strideTables, our final
+ // action will be to insert the last 1-8 bits of pfx into a
+ // strideTable somewhere.
+ //
+ // We calculate upfront the byte position of the end of the
+ // prefix; the number of bits within that byte that contain prefix
+ // data; and the prefix of the strideTable into which we'll
+ // eventually insert.
+ //
+ // We need this in a couple different branches of the code below,
+ // and because the possible values are 1-indexed (1 through 32 for
+ // ipv4, 1 through 128 for ipv6), the math is very slightly
+ // unusual to account for the off-by-one indexing. Do it once up
+ // here, with this large comment, rather than reproduce the subtle
+ // math in multiple places further down.
+ finalByteIdx := (pfx.Bits() - 1) / 8
+ finalBits := pfx.Bits() - (finalByteIdx * 8)
+ finalStridePrefix, err := pfx.Addr().Prefix(finalByteIdx * 8)
+ if err != nil {
+ panic(fmt.Sprintf("invalid prefix requested: %s/%d", pfx.Addr(), finalByteIdx*8))
+ }
+ if debugInsert {
+ fmt.Printf("insert: finalByteIdx=%d finalBits=%d finalStridePrefix=%s\n", finalByteIdx, finalBits, finalStridePrefix)
+ }
+
+ // The strideTable we want to insert into is potentially at the
+ // end of a chain of strideTables, each one encoding 8 bits of the
+ // prefix.
+ //
+ // We're expecting to walk down a path of tables, although with
+ // prefix compression we may end up skipping some links in the
+ // chain, or taking wrong turns and having to course correct.
+ //
+ // As we walk down the tree, byteIdx is the byte of bs we're
+ // currently examining to choose our next step, and numBits is the
+ // number of bits that remain in pfx, starting with the byte at
+ // byteIdx inclusive.
+ bs := pfx.Addr().AsSlice()
+ byteIdx := 0
+ numBits := pfx.Bits()
+ for {
+ if debugInsert {
+ fmt.Printf("insert: loop byteIdx=%d numBits=%d st.prefix=%s\n", byteIdx, numBits, st.prefix)
+ }
+ if numBits <= 8 {
+ if debugInsert {
+ fmt.Printf("insert: existing leaf st.prefix=%s addr=%d/%d\n", st.prefix, bs[finalByteIdx], finalBits)
+ }
+ // We've reached the end of the prefix, whichever
+ // strideTable we're looking at now is the place where we
+ // need to insert.
+ st.insert(bs[finalByteIdx], finalBits, val)
+ return
+ }
+
+ // Otherwise, we need to go down at least one more level of
+ // strideTables. With prefix compression, each level of
+ // descent can have one of three outcomes: we find a place
+ // where prefix compression is possible; a place where prefix
+ // compression made us take a "wrong turn"; or a point along
+ // our intended path that we have to keep following.
+ child, created := st.getOrCreateChild(bs[byteIdx])
+ switch {
+ case created:
+ // The subtree we need for pfx doesn't exist yet. The rest
+ // of the path, if we were to create it, will consist of a
+ // bunch of strideTables with a single child each. We can
+ // use path compression to elide those intermediates, and
+ // jump straight to the final strideTable that hosts this
+ // prefix.
+ child.prefix = finalStridePrefix
+ child.insert(bs[finalByteIdx], finalBits, val)
+ if debugInsert {
+ fmt.Printf("insert: new leaf st.prefix=%s child.prefix=%s addr=%d/%d\n", st.prefix, child.prefix, bs[finalByteIdx], finalBits)
+ }
+ return
+ case !prefixStrictlyContains(child.prefix, pfx):
+ // child already exists, but its prefix does not contain
+ // our destination. This means that the path between st
+ // and child was compressed by a previous insertion, and
+ // somewhere in the (implicit) compressed path we took a
+ // wrong turn, into the wrong part of st's subtree.
+ //
+ // This is okay, because pfx and child.prefix must have a
+ // common ancestor node somewhere between st and child. We
+ // can figure out what node that is, and materialize it.
+ //
+ // Once we've done that, we can immediately complete the
+ // remainder of the insertion in one of two ways, without
+ // further traversal. See a little further down for what
+ // those are.
+ if debugInsert {
+ fmt.Printf("insert: wrong turn, pfx=%s child.prefix=%s\n", pfx, child.prefix)
+ }
+ intermediatePrefix, addrOfExisting, addrOfNew := computePrefixSplit(child.prefix, pfx)
+ intermediate := &strideTable[T]{prefix: intermediatePrefix} // TODO: make this whole thing be st.AddIntermediate or something?
+ st.setChild(bs[byteIdx], intermediate)
+ intermediate.setChild(addrOfExisting, child)
+
+ if debugInsert {
+ fmt.Printf("insert: new intermediate st.prefix=%s intermediate.prefix=%s child.prefix=%s\n", st.prefix, intermediate.prefix, child.prefix)
+ }
+
+ // Now, we have a chain of st -> intermediate -> child.
+ //
+ // pfx either lives in a different child of intermediate,
+ // or in intermediate itself. For example, if we created
+ // the intermediate 1.2.0.0/16, pfx=1.2.3.4/32 would have
+ // to go into a new child of intermediate, but
+ // pfx=1.2.0.0/18 would go into intermediate directly.
+ if remain := pfx.Bits() - intermediate.prefix.Bits(); remain <= 8 {
+ // pfx lives in intermediate.
+ if debugInsert {
+ fmt.Printf("insert: into intermediate intermediate.prefix=%s addr=%d/%d\n", intermediate.prefix, bs[finalByteIdx], finalBits)
+ }
+ intermediate.insert(bs[finalByteIdx], finalBits, val)
+ } else {
+ // pfx lives in a different child subtree of
+ // intermediate. By definition this subtree doesn't
+ // exist at all, otherwise we'd never have entered
+ // this entire "wrong turn" codepath in the first
+ // place.
+ //
+ // This means we can apply prefix compression as we
+ // create this new child, and we're done.
+ st, created = intermediate.getOrCreateChild(addrOfNew)
+ if !created {
+ panic("new child path unexpectedly exists during path decompression")
+ }
+ st.prefix = finalStridePrefix
+ st.insert(bs[finalByteIdx], finalBits, val)
+ if debugInsert {
+ fmt.Printf("insert: new child st.prefix=%s addr=%d/%d\n", st.prefix, bs[finalByteIdx], finalBits)
+ }
+ }
+
+ return
+ default:
+ // An expected child table exists along pfx's
+ // path. Continue traversing downwards.
+ st = child
+ byteIdx = child.prefix.Bits() / 8
+ numBits = pfx.Bits() - child.prefix.Bits()
+ if debugInsert {
+ fmt.Printf("insert: descend st.prefix=%s\n", st.prefix)
+ }
+ }
+ }
+}
+
+// Delete removes pfx from the table, if it is present.
+func (t *Table[T]) Delete(pfx netip.Prefix) {
+ t.init()
+
+ // The standard library doesn't enforce normalized prefixes (where
+ // the non-prefix bits are all zero). These algorithms require
+ // normalized prefixes, so do it upfront.
+ pfx = pfx.Masked()
+
+ if debugDelete {
+ defer func() {
+ fmt.Printf("%s", t.debugSummary())
+ }()
+ fmt.Printf("\ndelete: start pfx=%s table:\n%s", pfx, t.debugSummary())
+ }
+
+ st := t.tableForAddr(pfx.Addr())
+
+ // This algorithm is full of off-by-one headaches, just like
+ // Insert. See the comment in Insert for more details. Bottom
+ // line: we handle the default route as a special case, and that
+ // simplifies the rest of the code slightly.
+ if pfx.Bits() == 0 {
+ if debugDelete {
+ fmt.Printf("delete: default route\n")
+ }
+ st.delete(0, 0)
+ return
+ }
+
+ // Deletion may drive the refcount of some strideTables down to
+ // zero. We need to clean up these dangling tables, so we have to
+ // keep track of which tables we touch on the way down, and which
+ // strideEntry index each child is registered in.
+ //
+ // Note that the strideIndex and strideTables entries are off-by-one.
+ // The child table pointer is recorded at i+1, but it is referenced by a
+ // particular index in the parent table, at index i.
+ //
+ // In other words: entry number strideIndexes[0] in
+ // strideTables[0] is the same pointer as strideTables[1].
+ //
+ // This results in some slightly odd array accesses further down
+ // in this code, because in a single loop iteration we have to
+ // write to strideTables[N] and strideIndexes[N-1].
+ strideIdx := 0
+ strideTables := [16]*strideTable[T]{st}
+ strideIndexes := [15]uint8{}
+
+ // Similar to Insert, navigate down the tree of strideTables,
+ // looking for the one that houses this prefix. This part is
+ // easier than with insertion, since we can bail if the path ends
+ // early or takes an unexpected detour. However, unlike
+ // insertion, there's a whole post-deletion cleanup phase later
+ // on.
+ //
+ // As we walk down the tree, byteIdx is the byte of bs we're
+ // currently examining to choose our next step, and numBits is the
+ // number of bits that remain in pfx, starting with the byte at
+ // byteIdx inclusive.
+ bs := pfx.Addr().AsSlice()
+ byteIdx := 0
+ numBits := pfx.Bits()
+ for numBits > 8 {
+ if debugDelete {
+ fmt.Printf("delete: loop byteIdx=%d numBits=%d st.prefix=%s\n", byteIdx, numBits, st.prefix)
+ }
+ child := st.getChild(bs[byteIdx])
+ if child == nil {
+ // Prefix can't exist in the table, because one of the
+ // necessary strideTables doesn't exist.
+ if debugDelete {
+ fmt.Printf("delete: missing necessary child pfx=%s\n", pfx)
+ }
+ return
+ }
+ strideIndexes[strideIdx] = bs[byteIdx]
+ strideTables[strideIdx+1] = child
+ strideIdx++
+
+ // Path compression means byteIdx can jump forwards
+ // unpredictably. Recompute the next byte to look at from the
+ // child we just found.
+ byteIdx = child.prefix.Bits() / 8
+ numBits = pfx.Bits() - child.prefix.Bits()
+ st = child
+
+ if debugDelete {
+ fmt.Printf("delete: descend st.prefix=%s\n", st.prefix)
+ }
+ }
+
+ // We reached a leaf stride table that seems to be in the right
+ // spot. But path compression might have led us to the wrong
+ // table.
+ if !prefixStrictlyContains(st.prefix, pfx) {
+ // Wrong table, the requested prefix can't exist since its
+ // path led us to the wrong place.
+ if debugDelete {
+ fmt.Printf("delete: wrong leaf table pfx=%s\n", pfx)
+ }
+ return
+ }
+ if debugDelete {
+ fmt.Printf("delete: delete from st.prefix=%s addr=%d/%d\n", st.prefix, bs[byteIdx], numBits)
+ }
+ if routeExisted := st.delete(bs[byteIdx], numBits); !routeExisted {
+ // We're in the right strideTable, but pfx wasn't in
+ // it. Refcounts haven't changed, so we can skip cleanup.
+ if debugDelete {
+ fmt.Printf("delete: prefix not present pfx=%s\n", pfx)
+ }
+ return
+ }
+
+ // st.delete reduced st's refcount by one. This table may now be
+ // reclaimable, and depending on how we can reclaim it, the parent
+ // tables may also need to be reclaimed. This loop ends as soon as
+ // an iteration takes no action, or takes an action that doesn't
+ // alter the parent table's refcounts.
+ //
+ // We start our walk back at strideTables[strideIdx], which
+ // contains st.
+ for strideIdx > 0 {
+ cur := strideTables[strideIdx]
+ if debugDelete {
+ fmt.Printf("delete: GC? strideIdx=%d st.prefix=%s\n", strideIdx, cur.prefix)
+ }
+ if cur.routeRefs > 0 {
+ // the strideTable has other route entries, it cannot be
+ // deleted or compacted.
+ if debugDelete {
+ fmt.Printf("delete: has other routes st.prefix=%s\n", cur.prefix)
+ }
+ return
+ }
+ switch cur.childRefs {
+ case 0:
+ // no routeRefs and no childRefs, this table can be
+ // deleted. This will alter the parent table's refcount,
+ // so we'll have to look at it as well (in the next loop
+ // iteration).
+ if debugDelete {
+ fmt.Printf("delete: remove st.prefix=%s\n", cur.prefix)
+ }
+ strideTables[strideIdx-1].deleteChild(strideIndexes[strideIdx-1])
+ strideIdx--
+ case 1:
+ // This table has no routes, and a single child. Compact
+ // this table out of existence by making the parent point
+ // directly at the one child. This does not affect the
+ // parent's refcounts, so the parent can't be eligible for
+ // deletion or compaction, and we can stop.
+ child := strideTables[strideIdx].findFirstChild() // only 1 child exists, by definition
+ parent := strideTables[strideIdx-1]
+ if debugDelete {
+ fmt.Printf("delete: compact parent.prefix=%s st.prefix=%s child.prefix=%s\n", parent.prefix, cur.prefix, child.prefix)
+ }
+ strideTables[strideIdx-1].setChild(strideIndexes[strideIdx-1], child)
+ return
+ default:
+ // This table has two or more children, so it's acting as a "fork in
+ // the road" between two prefix subtrees. It cannot be deleted, and
+ // thus no further cleanups are possible.
+ if debugDelete {
+ fmt.Printf("delete: fork table st.prefix=%s\n", cur.prefix)
+ }
+ return
+ }
+ }
+}
+
+// debugSummary prints the tree of allocated strideTables in t, with each
+// strideTable's refcount.
+func (t *Table[T]) debugSummary() string {
+ t.init()
+ var ret bytes.Buffer
+ fmt.Fprintf(&ret, "v4: ")
+ strideSummary(&ret, &t.v4, 4)
+ fmt.Fprintf(&ret, "v6: ")
+ strideSummary(&ret, &t.v6, 4)
+ return ret.String()
+}
+
+func strideSummary[T any](w io.Writer, st *strideTable[T], indent int) {
+ fmt.Fprintf(w, "%s: %d routes, %d children\n", st.prefix, st.routeRefs, st.childRefs)
+ indent += 4
+ st.treeDebugStringRec(w, 1, indent)
+ for addr, child := range st.children {
+ if child == nil {
+ continue
+ }
+ fmt.Fprintf(w, "%s%d/8 (%02x/8): ", strings.Repeat(" ", indent), addr, addr)
+ strideSummary(w, child, indent)
+ }
+}
+
+// prefixStrictlyContains reports whether child is a prefix within
+// parent, but not parent itself.
+func prefixStrictlyContains(parent, child netip.Prefix) bool {
+ return parent.Overlaps(child) && parent.Bits() < child.Bits()
+}
+
+// computePrefixSplit returns the smallest common prefix that contains
+// both a and b. lastCommon is 8-bit aligned, with aStride and bStride
+// indicating the value of the 8-bit stride immediately following
+// lastCommon.
+//
+// computePrefixSplit is used in constructing an intermediate
+// strideTable when a new prefix needs to be inserted in a compressed
+// table. It can be read as: given that a is already in the table, and
+// b is being inserted, what is the prefix of the new intermediate
+// strideTable that needs to be created, and at what addresses in that
+// new strideTable should a and b's subsequent strideTables be
+// attached?
+//
+// Note as a special case, this can be called with a==b. An example of
+// when this happens:
+// - We want to insert the prefix 1.2.0.0/16
+// - A strideTable exists for 1.2.0.0/16, because another child
+// prefix already exists (e.g. 1.2.3.4/32)
+// - The 1.0.0.0/8 strideTable does not exist, because path
+// compression removed it.
+//
+// In this scenario, the caller of computePrefixSplit ends up making a
+// "wrong turn" while traversing strideTables: it was looking for the
+// 1.0.0.0/8 table, but ended up at the 1.2.0.0/16 table. When this
+// happens, it will invoke computePrefixSplit(1.2.0.0/16, 1.2.0.0/16),
+// and we return 1.0.0.0/8 as the missing intermediate.
+func computePrefixSplit(a, b netip.Prefix) (lastCommon netip.Prefix, aStride, bStride uint8) {
+ a = a.Masked()
+ b = b.Masked()
+ if a.Bits() == 0 || b.Bits() == 0 {
+ panic("computePrefixSplit called with a default route")
+ }
+ if a.Addr().Is4() != b.Addr().Is4() {
+ panic("computePrefixSplit called with mismatched address families")
+ }
+
+ minPrefixLen := a.Bits()
+ if b.Bits() < minPrefixLen {
+ minPrefixLen = b.Bits()
+ }
+
+ commonBits := commonBits(a.Addr(), b.Addr(), minPrefixLen)
+ // We want to know how many 8-bit strides are shared between a and
+ // b. Naively, this would be commonBits/8, but this introduces an
+ // off-by-one error. This is due to the way our ART stores
+ // prefixes whose length falls exactly on a stride boundary.
+ //
+ // Consider 192.168.1.0/24 and 192.168.0.0/16. commonBits
+ // correctly reports that these prefixes have their first 16 bits
+ // in common. However, in the ART they only share 1 common stride:
+ // they both use the 192.0.0.0/8 strideTable, but 192.168.0.0/16
+ // is stored as 168/8 within that table, and not as 0/0 in the
+ // 192.168.0.0/16 table.
+ //
+ // So, when commonBits matches the length of one of the inputs and
+ // falls on a boundary between strides, the strideTable one
+ // further up from commonBits/8 is the one we need to create,
+ // which means we have to adjust the stride count down by one.
+ if commonBits == minPrefixLen {
+ commonBits--
+ }
+ commonStrides := commonBits / 8
+ lastCommon, err := a.Addr().Prefix(commonStrides * 8)
+ if err != nil {
+ panic(fmt.Sprintf("computePrefixSplit constructing common prefix: %v", err))
+ }
+ if a.Addr().Is4() {
+ aStride = a.Addr().As4()[commonStrides]
+ bStride = b.Addr().As4()[commonStrides]
+ } else {
+ aStride = a.Addr().As16()[commonStrides]
+ bStride = b.Addr().As16()[commonStrides]
+ }
+ return lastCommon, aStride, bStride
+}
+
+// commonBits returns the number of common leading bits of a and b.
+// If the number of common bits exceeds maxBits, it returns maxBits
+// instead.
+func commonBits(a, b netip.Addr, maxBits int) int {
+ if a.Is4() != b.Is4() {
+ panic("commonStrides called with mismatched address families")
+ }
+ var common int
+ // The following implements an old bit-twiddling trick to compute
+ // the number of common leading bits: if you XOR two numbers
+ // together, equal bits become 0 and unequal bits become 1. You
+ // can then count the number of leading zeros (which is a single
+ // instruction on modern CPUs) to get the answer.
+ //
+ // This code is a little more complex than just XOR + count
+ // leading zeros, because IPv4 and IPv6 are different sizes, and
+ // for IPv6 we have to do the math in two 64-bit chunks because Go
+ // lacks a uint128 type.
+ if a.Is4() {
+ aNum, bNum := ipv4AsUint(a), ipv4AsUint(b)
+ common = bits.LeadingZeros32(aNum ^ bNum)
+ } else {
+ aNumHi, aNumLo := ipv6AsUint(a)
+ bNumHi, bNumLo := ipv6AsUint(b)
+ common = bits.LeadingZeros64(aNumHi ^ bNumHi)
+ if common == 64 {
+ common += bits.LeadingZeros64(aNumLo ^ bNumLo)
+ }
+ }
+ if common > maxBits {
+ common = maxBits
+ }
+ return common
+}
+
+// ipv4AsUint returns ip as a uint32.
+func ipv4AsUint(ip netip.Addr) uint32 {
+ bs := ip.As4()
+ return binary.BigEndian.Uint32(bs[:])
+}
+
+// ipv6AsUint returns ip as a pair of uint64s.
+func ipv6AsUint(ip netip.Addr) (uint64, uint64) {
+ bs := ip.As16()
+ return binary.BigEndian.Uint64(bs[:8]), binary.BigEndian.Uint64(bs[8:])
+}
diff --git a/net/dns/debian_resolvconf.go b/net/dns/debian_resolvconf.go index 3ffc796e0..2a1fb18de 100644 --- a/net/dns/debian_resolvconf.go +++ b/net/dns/debian_resolvconf.go @@ -1,184 +1,184 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -//go:build linux || freebsd || openbsd - -package dns - -import ( - "bufio" - "bytes" - _ "embed" - "fmt" - "os" - "os/exec" - "path/filepath" - - "tailscale.com/atomicfile" - "tailscale.com/types/logger" -) - -//go:embed resolvconf-workaround.sh -var workaroundScript []byte - -// resolvconfConfigName is the name of the config submitted to -// resolvconf. -// The name starts with 'tun' in order to match the hardcoded -// interface order in debian resolvconf, which will place this -// configuration ahead of regular network links. In theory, this -// doesn't matter because we then fix things up to ensure our config -// is the only one in use, but in case that fails, this will make our -// configuration slightly preferred. -// The 'inet' suffix has no specific meaning, but conventionally -// resolvconf implementations encourage adding a suffix roughly -// indicating where the config came from, and "inet" is the "none of -// the above" value (rather than, say, "ppp" or "dhcp"). -const resolvconfConfigName = "tun-tailscale.inet" - -// resolvconfLibcHookPath is the directory containing libc update -// scripts, which are run by Debian resolvconf when /etc/resolv.conf -// has been updated. -const resolvconfLibcHookPath = "/etc/resolvconf/update-libc.d" - -// resolvconfHookPath is the name of the libc hook script we install -// to force Tailscale's DNS config to take effect. -var resolvconfHookPath = filepath.Join(resolvconfLibcHookPath, "tailscale") - -// resolvconfManager manages DNS configuration using the Debian -// implementation of the `resolvconf` program, written by Thomas Hood. -type resolvconfManager struct { - logf logger.Logf - listRecordsPath string - interfacesDir string - scriptInstalled bool // libc update script has been installed -} - -func newDebianResolvconfManager(logf logger.Logf) (*resolvconfManager, error) { - ret := &resolvconfManager{ - logf: logf, - listRecordsPath: "/lib/resolvconf/list-records", - interfacesDir: "/etc/resolvconf/run/interface", // panic fallback if nothing seems to work - } - - if _, err := os.Stat(ret.listRecordsPath); os.IsNotExist(err) { - // This might be a Debian system from before the big /usr - // merge, try /usr instead. - ret.listRecordsPath = "/usr" + ret.listRecordsPath - } - // The runtime directory is currently (2020-04) canonically - // /etc/resolvconf/run, but the manpage is making noise about - // switching to /run/resolvconf and dropping the /etc path. So, - // let's probe the possible directories and use the first one - // that works. - for _, path := range []string{ - "/etc/resolvconf/run/interface", - "/run/resolvconf/interface", - "/var/run/resolvconf/interface", - } { - if _, err := os.Stat(path); err == nil { - ret.interfacesDir = path - break - } - } - if ret.interfacesDir == "" { - // None of the paths seem to work, use the canonical location - // that the current manpage says to use. - ret.interfacesDir = "/etc/resolvconf/run/interfaces" - } - - return ret, nil -} - -func (m *resolvconfManager) deleteTailscaleConfig() error { - cmd := exec.Command("resolvconf", "-d", resolvconfConfigName) - out, err := cmd.CombinedOutput() - if err != nil { - return fmt.Errorf("running %s: %s", cmd, out) - } - return nil -} - -func (m *resolvconfManager) SetDNS(config OSConfig) error { - if !m.scriptInstalled { - m.logf("injecting resolvconf workaround script") - if err := os.MkdirAll(resolvconfLibcHookPath, 0755); err != nil { - return err - } - if err := atomicfile.WriteFile(resolvconfHookPath, workaroundScript, 0755); err != nil { - return err - } - m.scriptInstalled = true - } - - if config.IsZero() { - if err := m.deleteTailscaleConfig(); err != nil { - return err - } - } else { - stdin := new(bytes.Buffer) - writeResolvConf(stdin, config.Nameservers, config.SearchDomains) // dns_direct.go - - // This resolvconf implementation doesn't support exclusive - // mode or interface priorities, so it will end up blending - // our configuration with other sources. However, this will - // get fixed up by the script we injected above. - cmd := exec.Command("resolvconf", "-a", resolvconfConfigName) - cmd.Stdin = stdin - out, err := cmd.CombinedOutput() - if err != nil { - return fmt.Errorf("running %s: %s", cmd, out) - } - } - - return nil -} - -func (m *resolvconfManager) SupportsSplitDNS() bool { - return false -} - -func (m *resolvconfManager) GetBaseConfig() (OSConfig, error) { - var bs bytes.Buffer - - cmd := exec.Command(m.listRecordsPath) - // list-records assumes it's being run with CWD set to the - // interfaces runtime dir, and returns nonsense otherwise. - cmd.Dir = m.interfacesDir - cmd.Stdout = &bs - if err := cmd.Run(); err != nil { - return OSConfig{}, err - } - - var conf bytes.Buffer - sc := bufio.NewScanner(&bs) - for sc.Scan() { - if sc.Text() == resolvconfConfigName { - continue - } - bs, err := os.ReadFile(filepath.Join(m.interfacesDir, sc.Text())) - if err != nil { - if os.IsNotExist(err) { - // Probably raced with a deletion, that's okay. - continue - } - return OSConfig{}, err - } - conf.Write(bs) - conf.WriteByte('\n') - } - - return readResolv(&conf) -} - -func (m *resolvconfManager) Close() error { - if err := m.deleteTailscaleConfig(); err != nil { - return err - } - - if m.scriptInstalled { - m.logf("removing resolvconf workaround script") - os.Remove(resolvconfHookPath) // Best-effort - } - - return nil -} +// Copyright (c) Tailscale Inc & AUTHORS
+// SPDX-License-Identifier: BSD-3-Clause
+
+//go:build linux || freebsd || openbsd
+
+package dns
+
+import (
+ "bufio"
+ "bytes"
+ _ "embed"
+ "fmt"
+ "os"
+ "os/exec"
+ "path/filepath"
+
+ "tailscale.com/atomicfile"
+ "tailscale.com/types/logger"
+)
+
+//go:embed resolvconf-workaround.sh
+var workaroundScript []byte
+
+// resolvconfConfigName is the name of the config submitted to
+// resolvconf.
+// The name starts with 'tun' in order to match the hardcoded
+// interface order in debian resolvconf, which will place this
+// configuration ahead of regular network links. In theory, this
+// doesn't matter because we then fix things up to ensure our config
+// is the only one in use, but in case that fails, this will make our
+// configuration slightly preferred.
+// The 'inet' suffix has no specific meaning, but conventionally
+// resolvconf implementations encourage adding a suffix roughly
+// indicating where the config came from, and "inet" is the "none of
+// the above" value (rather than, say, "ppp" or "dhcp").
+const resolvconfConfigName = "tun-tailscale.inet"
+
+// resolvconfLibcHookPath is the directory containing libc update
+// scripts, which are run by Debian resolvconf when /etc/resolv.conf
+// has been updated.
+const resolvconfLibcHookPath = "/etc/resolvconf/update-libc.d"
+
+// resolvconfHookPath is the name of the libc hook script we install
+// to force Tailscale's DNS config to take effect.
+var resolvconfHookPath = filepath.Join(resolvconfLibcHookPath, "tailscale")
+
+// resolvconfManager manages DNS configuration using the Debian
+// implementation of the `resolvconf` program, written by Thomas Hood.
+type resolvconfManager struct {
+ logf logger.Logf
+ listRecordsPath string
+ interfacesDir string
+ scriptInstalled bool // libc update script has been installed
+}
+
+func newDebianResolvconfManager(logf logger.Logf) (*resolvconfManager, error) {
+ ret := &resolvconfManager{
+ logf: logf,
+ listRecordsPath: "/lib/resolvconf/list-records",
+ interfacesDir: "/etc/resolvconf/run/interface", // panic fallback if nothing seems to work
+ }
+
+ if _, err := os.Stat(ret.listRecordsPath); os.IsNotExist(err) {
+ // This might be a Debian system from before the big /usr
+ // merge, try /usr instead.
+ ret.listRecordsPath = "/usr" + ret.listRecordsPath
+ }
+ // The runtime directory is currently (2020-04) canonically
+ // /etc/resolvconf/run, but the manpage is making noise about
+ // switching to /run/resolvconf and dropping the /etc path. So,
+ // let's probe the possible directories and use the first one
+ // that works.
+ for _, path := range []string{
+ "/etc/resolvconf/run/interface",
+ "/run/resolvconf/interface",
+ "/var/run/resolvconf/interface",
+ } {
+ if _, err := os.Stat(path); err == nil {
+ ret.interfacesDir = path
+ break
+ }
+ }
+ if ret.interfacesDir == "" {
+ // None of the paths seem to work, use the canonical location
+ // that the current manpage says to use.
+ ret.interfacesDir = "/etc/resolvconf/run/interfaces"
+ }
+
+ return ret, nil
+}
+
+func (m *resolvconfManager) deleteTailscaleConfig() error {
+ cmd := exec.Command("resolvconf", "-d", resolvconfConfigName)
+ out, err := cmd.CombinedOutput()
+ if err != nil {
+ return fmt.Errorf("running %s: %s", cmd, out)
+ }
+ return nil
+}
+
+func (m *resolvconfManager) SetDNS(config OSConfig) error {
+ if !m.scriptInstalled {
+ m.logf("injecting resolvconf workaround script")
+ if err := os.MkdirAll(resolvconfLibcHookPath, 0755); err != nil {
+ return err
+ }
+ if err := atomicfile.WriteFile(resolvconfHookPath, workaroundScript, 0755); err != nil {
+ return err
+ }
+ m.scriptInstalled = true
+ }
+
+ if config.IsZero() {
+ if err := m.deleteTailscaleConfig(); err != nil {
+ return err
+ }
+ } else {
+ stdin := new(bytes.Buffer)
+ writeResolvConf(stdin, config.Nameservers, config.SearchDomains) // dns_direct.go
+
+ // This resolvconf implementation doesn't support exclusive
+ // mode or interface priorities, so it will end up blending
+ // our configuration with other sources. However, this will
+ // get fixed up by the script we injected above.
+ cmd := exec.Command("resolvconf", "-a", resolvconfConfigName)
+ cmd.Stdin = stdin
+ out, err := cmd.CombinedOutput()
+ if err != nil {
+ return fmt.Errorf("running %s: %s", cmd, out)
+ }
+ }
+
+ return nil
+}
+
+func (m *resolvconfManager) SupportsSplitDNS() bool {
+ return false
+}
+
+func (m *resolvconfManager) GetBaseConfig() (OSConfig, error) {
+ var bs bytes.Buffer
+
+ cmd := exec.Command(m.listRecordsPath)
+ // list-records assumes it's being run with CWD set to the
+ // interfaces runtime dir, and returns nonsense otherwise.
+ cmd.Dir = m.interfacesDir
+ cmd.Stdout = &bs
+ if err := cmd.Run(); err != nil {
+ return OSConfig{}, err
+ }
+
+ var conf bytes.Buffer
+ sc := bufio.NewScanner(&bs)
+ for sc.Scan() {
+ if sc.Text() == resolvconfConfigName {
+ continue
+ }
+ bs, err := os.ReadFile(filepath.Join(m.interfacesDir, sc.Text()))
+ if err != nil {
+ if os.IsNotExist(err) {
+ // Probably raced with a deletion, that's okay.
+ continue
+ }
+ return OSConfig{}, err
+ }
+ conf.Write(bs)
+ conf.WriteByte('\n')
+ }
+
+ return readResolv(&conf)
+}
+
+func (m *resolvconfManager) Close() error {
+ if err := m.deleteTailscaleConfig(); err != nil {
+ return err
+ }
+
+ if m.scriptInstalled {
+ m.logf("removing resolvconf workaround script")
+ os.Remove(resolvconfHookPath) // Best-effort
+ }
+
+ return nil
+}
diff --git a/net/dns/direct_notlinux.go b/net/dns/direct_notlinux.go index c221ca1be..5bd8093d6 100644 --- a/net/dns/direct_notlinux.go +++ b/net/dns/direct_notlinux.go @@ -1,10 +1,10 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -//go:build !linux - -package dns - -func (m *directManager) runFileWatcher() { - // Not implemented on other platforms. Maybe it could resort to polling. -} +// Copyright (c) Tailscale Inc & AUTHORS
+// SPDX-License-Identifier: BSD-3-Clause
+
+//go:build !linux
+
+package dns
+
+func (m *directManager) runFileWatcher() {
+ // Not implemented on other platforms. Maybe it could resort to polling.
+}
diff --git a/net/dns/flush_default.go b/net/dns/flush_default.go index eb6d9da41..73e446389 100644 --- a/net/dns/flush_default.go +++ b/net/dns/flush_default.go @@ -1,10 +1,10 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -//go:build !windows - -package dns - -func flushCaches() error { - return nil -} +// Copyright (c) Tailscale Inc & AUTHORS
+// SPDX-License-Identifier: BSD-3-Clause
+
+//go:build !windows
+
+package dns
+
+func flushCaches() error {
+ return nil
+}
diff --git a/net/dns/ini.go b/net/dns/ini.go index 1e47d606e..deec04019 100644 --- a/net/dns/ini.go +++ b/net/dns/ini.go @@ -1,30 +1,30 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -//go:build windows - -package dns - -import ( - "regexp" - "strings" -) - -// parseIni parses a basic .ini file, used for wsl.conf. -func parseIni(data string) map[string]map[string]string { - sectionRE := regexp.MustCompile(`^\[([^]]+)\]`) - kvRE := regexp.MustCompile(`^\s*(\w+)\s*=\s*([^#]*)`) - - ini := map[string]map[string]string{} - var section string - for _, line := range strings.Split(data, "\n") { - if res := sectionRE.FindStringSubmatch(line); len(res) > 1 { - section = res[1] - ini[section] = map[string]string{} - } else if res := kvRE.FindStringSubmatch(line); len(res) > 2 { - k, v := strings.TrimSpace(res[1]), strings.TrimSpace(res[2]) - ini[section][k] = v - } - } - return ini -} +// Copyright (c) Tailscale Inc & AUTHORS
+// SPDX-License-Identifier: BSD-3-Clause
+
+//go:build windows
+
+package dns
+
+import (
+ "regexp"
+ "strings"
+)
+
+// parseIni parses a basic .ini file, used for wsl.conf.
+func parseIni(data string) map[string]map[string]string {
+ sectionRE := regexp.MustCompile(`^\[([^]]+)\]`)
+ kvRE := regexp.MustCompile(`^\s*(\w+)\s*=\s*([^#]*)`)
+
+ ini := map[string]map[string]string{}
+ var section string
+ for _, line := range strings.Split(data, "\n") {
+ if res := sectionRE.FindStringSubmatch(line); len(res) > 1 {
+ section = res[1]
+ ini[section] = map[string]string{}
+ } else if res := kvRE.FindStringSubmatch(line); len(res) > 2 {
+ k, v := strings.TrimSpace(res[1]), strings.TrimSpace(res[2])
+ ini[section][k] = v
+ }
+ }
+ return ini
+}
diff --git a/net/dns/ini_test.go b/net/dns/ini_test.go index 3afe7009c..0e9eaa672 100644 --- a/net/dns/ini_test.go +++ b/net/dns/ini_test.go @@ -1,38 +1,38 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -//go:build windows - -package dns - -import ( - "reflect" - "testing" -) - -func TestParseIni(t *testing.T) { - var tests = []struct { - src string - want map[string]map[string]string - }{ - { - src: `# appended wsl.conf file -[automount] - enabled = true - root=/mnt/ -# added by tailscale -[network] # trailing comment -generateResolvConf = false # trailing comment`, - want: map[string]map[string]string{ - "automount": {"enabled": "true", "root": "/mnt/"}, - "network": {"generateResolvConf": "false"}, - }, - }, - } - for _, test := range tests { - got := parseIni(test.src) - if !reflect.DeepEqual(got, test.want) { - t.Errorf("for:\n%s\ngot: %v\nwant: %v", test.src, got, test.want) - } - } -} +// Copyright (c) Tailscale Inc & AUTHORS
+// SPDX-License-Identifier: BSD-3-Clause
+
+//go:build windows
+
+package dns
+
+import (
+ "reflect"
+ "testing"
+)
+
+func TestParseIni(t *testing.T) {
+ var tests = []struct {
+ src string
+ want map[string]map[string]string
+ }{
+ {
+ src: `# appended wsl.conf file
+[automount]
+ enabled = true
+ root=/mnt/
+# added by tailscale
+[network] # trailing comment
+generateResolvConf = false # trailing comment`,
+ want: map[string]map[string]string{
+ "automount": {"enabled": "true", "root": "/mnt/"},
+ "network": {"generateResolvConf": "false"},
+ },
+ },
+ }
+ for _, test := range tests {
+ got := parseIni(test.src)
+ if !reflect.DeepEqual(got, test.want) {
+ t.Errorf("for:\n%s\ngot: %v\nwant: %v", test.src, got, test.want)
+ }
+ }
+}
diff --git a/net/dns/noop.go b/net/dns/noop.go index 9466b57a0..c90162668 100644 --- a/net/dns/noop.go +++ b/net/dns/noop.go @@ -1,17 +1,17 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -package dns - -type noopManager struct{} - -func (m noopManager) SetDNS(OSConfig) error { return nil } -func (m noopManager) SupportsSplitDNS() bool { return false } -func (m noopManager) Close() error { return nil } -func (m noopManager) GetBaseConfig() (OSConfig, error) { - return OSConfig{}, ErrGetBaseConfigNotSupported -} - -func NewNoopManager() (noopManager, error) { - return noopManager{}, nil -} +// Copyright (c) Tailscale Inc & AUTHORS
+// SPDX-License-Identifier: BSD-3-Clause
+
+package dns
+
+type noopManager struct{}
+
+func (m noopManager) SetDNS(OSConfig) error { return nil }
+func (m noopManager) SupportsSplitDNS() bool { return false }
+func (m noopManager) Close() error { return nil }
+func (m noopManager) GetBaseConfig() (OSConfig, error) {
+ return OSConfig{}, ErrGetBaseConfigNotSupported
+}
+
+func NewNoopManager() (noopManager, error) {
+ return noopManager{}, nil
+}
diff --git a/net/dns/resolvconf-workaround.sh b/net/dns/resolvconf-workaround.sh index aec6708a0..254b3949b 100644 --- a/net/dns/resolvconf-workaround.sh +++ b/net/dns/resolvconf-workaround.sh @@ -1,62 +1,62 @@ -#!/bin/sh -# Copyright (c) Tailscale Inc & AUTHORS -# SPDX-License-Identifier: BSD-3-Clause -# -# This script is a workaround for a vpn-unfriendly behavior of the -# original resolvconf by Thomas Hood. Unlike the `openresolv` -# implementation (whose binary is also called resolvconf, -# confusingly), the original resolvconf lacks a way to specify -# "exclusive mode" for a provider configuration. In practice, this -# means that if Tailscale wants to install a DNS configuration, that -# config will get "blended" with the configs from other sources, -# rather than override those other sources. -# -# This script gets installed at /etc/resolvconf/update-libc.d, which -# is a directory of hook scripts that get run after resolvconf's libc -# helper has finished rewriting /etc/resolv.conf. It's meant to notify -# consumers of resolv.conf of a new configuration. -# -# Instead, we use that hook mechanism to reach into resolvconf's -# stuff, and rewrite the libc-generated resolv.conf to exclusively -# contain Tailscale's configuration - effectively implementing -# exclusive mode ourselves in post-production. - -set -e - -if [ -n "$TAILSCALE_RESOLVCONF_HOOK_LOOP" ]; then - # Hook script being invoked by itself, skip. - exit 0 -fi - -if [ ! -f tun-tailscale.inet ]; then - # Tailscale isn't trying to manage DNS, do nothing. - exit 0 -fi - -if ! grep resolvconf /etc/resolv.conf >/dev/null; then - # resolvconf isn't managing /etc/resolv.conf, do nothing. - exit 0 -fi - -# Write out a modified /etc/resolv.conf containing just our config. -( - if [ -f /etc/resolvconf/resolv.conf.d/head ]; then - cat /etc/resolvconf/resolv.conf.d/head - fi - echo "# Tailscale workaround applied to set exclusive DNS configuration." - cat tun-tailscale.inet - if [ -f /etc/resolvconf/resolv.conf.d/base ]; then - # Keep options and sortlist, discard other base things since - # they're the things we're trying to override. - grep -e 'sortlist ' -e 'options ' /etc/resolvconf/resolv.conf.d/base || true - fi - if [ -f /etc/resolvconf/resolv.conf.d/tail ]; then - cat /etc/resolvconf/resolv.conf.d/tail - fi -) >/etc/resolv.conf - -if [ -d /etc/resolvconf/update-libc.d ] ; then - # Re-notify libc watchers that we've changed resolv.conf again. - export TAILSCALE_RESOLVCONF_HOOK_LOOP=1 - exec run-parts /etc/resolvconf/update-libc.d -fi +#!/bin/sh
+# Copyright (c) Tailscale Inc & AUTHORS
+# SPDX-License-Identifier: BSD-3-Clause
+#
+# This script is a workaround for a vpn-unfriendly behavior of the
+# original resolvconf by Thomas Hood. Unlike the `openresolv`
+# implementation (whose binary is also called resolvconf,
+# confusingly), the original resolvconf lacks a way to specify
+# "exclusive mode" for a provider configuration. In practice, this
+# means that if Tailscale wants to install a DNS configuration, that
+# config will get "blended" with the configs from other sources,
+# rather than override those other sources.
+#
+# This script gets installed at /etc/resolvconf/update-libc.d, which
+# is a directory of hook scripts that get run after resolvconf's libc
+# helper has finished rewriting /etc/resolv.conf. It's meant to notify
+# consumers of resolv.conf of a new configuration.
+#
+# Instead, we use that hook mechanism to reach into resolvconf's
+# stuff, and rewrite the libc-generated resolv.conf to exclusively
+# contain Tailscale's configuration - effectively implementing
+# exclusive mode ourselves in post-production.
+
+set -e
+
+if [ -n "$TAILSCALE_RESOLVCONF_HOOK_LOOP" ]; then
+ # Hook script being invoked by itself, skip.
+ exit 0
+fi
+
+if [ ! -f tun-tailscale.inet ]; then
+ # Tailscale isn't trying to manage DNS, do nothing.
+ exit 0
+fi
+
+if ! grep resolvconf /etc/resolv.conf >/dev/null; then
+ # resolvconf isn't managing /etc/resolv.conf, do nothing.
+ exit 0
+fi
+
+# Write out a modified /etc/resolv.conf containing just our config.
+(
+ if [ -f /etc/resolvconf/resolv.conf.d/head ]; then
+ cat /etc/resolvconf/resolv.conf.d/head
+ fi
+ echo "# Tailscale workaround applied to set exclusive DNS configuration."
+ cat tun-tailscale.inet
+ if [ -f /etc/resolvconf/resolv.conf.d/base ]; then
+ # Keep options and sortlist, discard other base things since
+ # they're the things we're trying to override.
+ grep -e 'sortlist ' -e 'options ' /etc/resolvconf/resolv.conf.d/base || true
+ fi
+ if [ -f /etc/resolvconf/resolv.conf.d/tail ]; then
+ cat /etc/resolvconf/resolv.conf.d/tail
+ fi
+) >/etc/resolv.conf
+
+if [ -d /etc/resolvconf/update-libc.d ] ; then
+ # Re-notify libc watchers that we've changed resolv.conf again.
+ export TAILSCALE_RESOLVCONF_HOOK_LOOP=1
+ exec run-parts /etc/resolvconf/update-libc.d
+fi
diff --git a/net/dns/resolvconf.go b/net/dns/resolvconf.go index ca584ffcc..9e2a41c4a 100644 --- a/net/dns/resolvconf.go +++ b/net/dns/resolvconf.go @@ -1,30 +1,30 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -//go:build linux || freebsd || openbsd - -package dns - -import ( - "bytes" - "os/exec" -) - -func resolvconfStyle() string { - if _, err := exec.LookPath("resolvconf"); err != nil { - return "" - } - output, err := exec.Command("resolvconf", "--version").CombinedOutput() - if err != nil { - // Debian resolvconf doesn't understand --version, and - // exits with a specific error code. - if exitErr, ok := err.(*exec.ExitError); ok && exitErr.ExitCode() == 99 { - return "debian" - } - } - if bytes.HasPrefix(output, []byte("Debian resolvconf")) { - return "debian" - } - // Treat everything else as openresolv, by far the more popular implementation. - return "openresolv" -} +// Copyright (c) Tailscale Inc & AUTHORS
+// SPDX-License-Identifier: BSD-3-Clause
+
+//go:build linux || freebsd || openbsd
+
+package dns
+
+import (
+ "bytes"
+ "os/exec"
+)
+
+func resolvconfStyle() string {
+ if _, err := exec.LookPath("resolvconf"); err != nil {
+ return ""
+ }
+ output, err := exec.Command("resolvconf", "--version").CombinedOutput()
+ if err != nil {
+ // Debian resolvconf doesn't understand --version, and
+ // exits with a specific error code.
+ if exitErr, ok := err.(*exec.ExitError); ok && exitErr.ExitCode() == 99 {
+ return "debian"
+ }
+ }
+ if bytes.HasPrefix(output, []byte("Debian resolvconf")) {
+ return "debian"
+ }
+ // Treat everything else as openresolv, by far the more popular implementation.
+ return "openresolv"
+}
diff --git a/net/dns/resolvconffile/resolvconffile.go b/net/dns/resolvconffile/resolvconffile.go index 753000f6d..66c1600d8 100644 --- a/net/dns/resolvconffile/resolvconffile.go +++ b/net/dns/resolvconffile/resolvconffile.go @@ -1,124 +1,124 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -// Package resolvconffile parses & serializes /etc/resolv.conf-style files. -// -// It's a leaf package so both net/dns and net/dns/resolver can depend -// on it and we can unify a handful of implementations. -// -// The package is verbosely named to disambiguate it from resolvconf -// the daemon, which Tailscale also supports. -package resolvconffile - -import ( - "bufio" - "bytes" - "fmt" - "io" - "net/netip" - "os" - "strings" - - "tailscale.com/util/dnsname" -) - -// Path is the canonical location of resolv.conf. -const Path = "/etc/resolv.conf" - -// Config represents a resolv.conf(5) file. -type Config struct { - // Nameservers are the IP addresses of the nameservers to use. - Nameservers []netip.Addr - - // SearchDomains are the domain suffixes to use when expanding - // single-label name queries. SearchDomains is additive to - // whatever non-Tailscale search domains the OS has. - SearchDomains []dnsname.FQDN -} - -// Write writes c to w. It does so in one Write call. -func (c *Config) Write(w io.Writer) error { - buf := new(bytes.Buffer) - io.WriteString(buf, "# resolv.conf(5) file generated by tailscale\n") - io.WriteString(buf, "# For more info, see https://tailscale.com/s/resolvconf-overwrite\n") - io.WriteString(buf, "# DO NOT EDIT THIS FILE BY HAND -- CHANGES WILL BE OVERWRITTEN\n\n") - for _, ns := range c.Nameservers { - io.WriteString(buf, "nameserver ") - io.WriteString(buf, ns.String()) - io.WriteString(buf, "\n") - } - if len(c.SearchDomains) > 0 { - io.WriteString(buf, "search") - for _, domain := range c.SearchDomains { - io.WriteString(buf, " ") - io.WriteString(buf, domain.WithoutTrailingDot()) - } - io.WriteString(buf, "\n") - } - _, err := w.Write(buf.Bytes()) - return err -} - -// Parse parses a resolv.conf file from r. -func Parse(r io.Reader) (*Config, error) { - config := new(Config) - scanner := bufio.NewScanner(r) - for scanner.Scan() { - line := scanner.Text() - line, _, _ = strings.Cut(line, "#") // remove any comments - line = strings.TrimSpace(line) - - if s, ok := strings.CutPrefix(line, "nameserver"); ok { - nameserver := strings.TrimSpace(s) - if len(nameserver) == len(s) { - return nil, fmt.Errorf("missing space after \"nameserver\" in %q", line) - } - ip, err := netip.ParseAddr(nameserver) - if err != nil { - return nil, err - } - config.Nameservers = append(config.Nameservers, ip) - continue - } - - if s, ok := strings.CutPrefix(line, "search"); ok { - domains := strings.TrimSpace(s) - if len(domains) == len(s) { - // No leading space?! - return nil, fmt.Errorf("missing space after \"search\" in %q", line) - } - for len(domains) > 0 { - domain := domains - i := strings.IndexAny(domain, " \t") - if i != -1 { - domain = domain[:i] - domains = strings.TrimSpace(domains[i+1:]) - } else { - domains = "" - } - fqdn, err := dnsname.ToFQDN(domain) - if err != nil { - return nil, fmt.Errorf("parsing search domain %q in %q: %w", domain, line, err) - } - config.SearchDomains = append(config.SearchDomains, fqdn) - } - } - } - return config, nil -} - -// ParseFile parses the named resolv.conf file. -func ParseFile(name string) (*Config, error) { - fi, err := os.Stat(name) - if err != nil { - return nil, err - } - if n := fi.Size(); n > 10<<10 { - return nil, fmt.Errorf("unexpectedly large %q file: %d bytes", name, n) - } - all, err := os.ReadFile(name) - if err != nil { - return nil, err - } - return Parse(bytes.NewReader(all)) -} +// Copyright (c) Tailscale Inc & AUTHORS
+// SPDX-License-Identifier: BSD-3-Clause
+
+// Package resolvconffile parses & serializes /etc/resolv.conf-style files.
+//
+// It's a leaf package so both net/dns and net/dns/resolver can depend
+// on it and we can unify a handful of implementations.
+//
+// The package is verbosely named to disambiguate it from resolvconf
+// the daemon, which Tailscale also supports.
+package resolvconffile
+
+import (
+ "bufio"
+ "bytes"
+ "fmt"
+ "io"
+ "net/netip"
+ "os"
+ "strings"
+
+ "tailscale.com/util/dnsname"
+)
+
+// Path is the canonical location of resolv.conf.
+const Path = "/etc/resolv.conf"
+
+// Config represents a resolv.conf(5) file.
+type Config struct {
+ // Nameservers are the IP addresses of the nameservers to use.
+ Nameservers []netip.Addr
+
+ // SearchDomains are the domain suffixes to use when expanding
+ // single-label name queries. SearchDomains is additive to
+ // whatever non-Tailscale search domains the OS has.
+ SearchDomains []dnsname.FQDN
+}
+
+// Write writes c to w. It does so in one Write call.
+func (c *Config) Write(w io.Writer) error {
+ buf := new(bytes.Buffer)
+ io.WriteString(buf, "# resolv.conf(5) file generated by tailscale\n")
+ io.WriteString(buf, "# For more info, see https://tailscale.com/s/resolvconf-overwrite\n")
+ io.WriteString(buf, "# DO NOT EDIT THIS FILE BY HAND -- CHANGES WILL BE OVERWRITTEN\n\n")
+ for _, ns := range c.Nameservers {
+ io.WriteString(buf, "nameserver ")
+ io.WriteString(buf, ns.String())
+ io.WriteString(buf, "\n")
+ }
+ if len(c.SearchDomains) > 0 {
+ io.WriteString(buf, "search")
+ for _, domain := range c.SearchDomains {
+ io.WriteString(buf, " ")
+ io.WriteString(buf, domain.WithoutTrailingDot())
+ }
+ io.WriteString(buf, "\n")
+ }
+ _, err := w.Write(buf.Bytes())
+ return err
+}
+
+// Parse parses a resolv.conf file from r.
+func Parse(r io.Reader) (*Config, error) {
+ config := new(Config)
+ scanner := bufio.NewScanner(r)
+ for scanner.Scan() {
+ line := scanner.Text()
+ line, _, _ = strings.Cut(line, "#") // remove any comments
+ line = strings.TrimSpace(line)
+
+ if s, ok := strings.CutPrefix(line, "nameserver"); ok {
+ nameserver := strings.TrimSpace(s)
+ if len(nameserver) == len(s) {
+ return nil, fmt.Errorf("missing space after \"nameserver\" in %q", line)
+ }
+ ip, err := netip.ParseAddr(nameserver)
+ if err != nil {
+ return nil, err
+ }
+ config.Nameservers = append(config.Nameservers, ip)
+ continue
+ }
+
+ if s, ok := strings.CutPrefix(line, "search"); ok {
+ domains := strings.TrimSpace(s)
+ if len(domains) == len(s) {
+ // No leading space?!
+ return nil, fmt.Errorf("missing space after \"search\" in %q", line)
+ }
+ for len(domains) > 0 {
+ domain := domains
+ i := strings.IndexAny(domain, " \t")
+ if i != -1 {
+ domain = domain[:i]
+ domains = strings.TrimSpace(domains[i+1:])
+ } else {
+ domains = ""
+ }
+ fqdn, err := dnsname.ToFQDN(domain)
+ if err != nil {
+ return nil, fmt.Errorf("parsing search domain %q in %q: %w", domain, line, err)
+ }
+ config.SearchDomains = append(config.SearchDomains, fqdn)
+ }
+ }
+ }
+ return config, nil
+}
+
+// ParseFile parses the named resolv.conf file.
+func ParseFile(name string) (*Config, error) {
+ fi, err := os.Stat(name)
+ if err != nil {
+ return nil, err
+ }
+ if n := fi.Size(); n > 10<<10 {
+ return nil, fmt.Errorf("unexpectedly large %q file: %d bytes", name, n)
+ }
+ all, err := os.ReadFile(name)
+ if err != nil {
+ return nil, err
+ }
+ return Parse(bytes.NewReader(all))
+}
diff --git a/net/dns/resolvconfpath_default.go b/net/dns/resolvconfpath_default.go index 57e82c4c7..02f24a0cf 100644 --- a/net/dns/resolvconfpath_default.go +++ b/net/dns/resolvconfpath_default.go @@ -1,11 +1,11 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -//go:build !gokrazy - -package dns - -const ( - resolvConf = "/etc/resolv.conf" - backupConf = "/etc/resolv.pre-tailscale-backup.conf" -) +// Copyright (c) Tailscale Inc & AUTHORS
+// SPDX-License-Identifier: BSD-3-Clause
+
+//go:build !gokrazy
+
+package dns
+
+const (
+ resolvConf = "/etc/resolv.conf"
+ backupConf = "/etc/resolv.pre-tailscale-backup.conf"
+)
diff --git a/net/dns/resolvconfpath_gokrazy.go b/net/dns/resolvconfpath_gokrazy.go index f0759b0e3..6315596d2 100644 --- a/net/dns/resolvconfpath_gokrazy.go +++ b/net/dns/resolvconfpath_gokrazy.go @@ -1,11 +1,11 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -//go:build gokrazy - -package dns - -const ( - resolvConf = "/tmp/resolv.conf" - backupConf = "/tmp/resolv.pre-tailscale-backup.conf" -) +// Copyright (c) Tailscale Inc & AUTHORS
+// SPDX-License-Identifier: BSD-3-Clause
+
+//go:build gokrazy
+
+package dns
+
+const (
+ resolvConf = "/tmp/resolv.conf"
+ backupConf = "/tmp/resolv.pre-tailscale-backup.conf"
+)
diff --git a/net/dns/resolver/doh_test.go b/net/dns/resolver/doh_test.go index a9c284761..d9ef970c2 100644 --- a/net/dns/resolver/doh_test.go +++ b/net/dns/resolver/doh_test.go @@ -1,99 +1,99 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -package resolver - -import ( - "context" - "flag" - "net/http" - "testing" - - "golang.org/x/net/dns/dnsmessage" - "tailscale.com/net/dns/publicdns" -) - -var testDoH = flag.Bool("test-doh", false, "do real DoH tests against the network") - -const someDNSID = 123 // something non-zero as a test; in violation of spec's SHOULD of 0 - -func someDNSQuestion(t testing.TB) []byte { - b := dnsmessage.NewBuilder(nil, dnsmessage.Header{ - OpCode: 0, // query - RecursionDesired: true, - ID: someDNSID, - }) - b.StartQuestions() // err - b.Question(dnsmessage.Question{ - Name: dnsmessage.MustNewName("tailscale.com."), - Type: dnsmessage.TypeA, - Class: dnsmessage.ClassINET, - }) - msg, err := b.Finish() - if err != nil { - t.Fatal(err) - } - return msg -} - -func TestDoH(t *testing.T) { - if !*testDoH { - t.Skip("skipping manual test without --test-doh flag") - } - prefixes := publicdns.KnownDoHPrefixes() - if len(prefixes) == 0 { - t.Fatal("no known DoH") - } - - f := &forwarder{} - - for _, urlBase := range prefixes { - t.Run(urlBase, func(t *testing.T) { - c, ok := f.getKnownDoHClientForProvider(urlBase) - if !ok { - t.Fatal("expected DoH") - } - res, err := f.sendDoH(context.Background(), urlBase, c, someDNSQuestion(t)) - if err != nil { - t.Fatal(err) - } - c.Transport.(*http.Transport).CloseIdleConnections() - - var p dnsmessage.Parser - h, err := p.Start(res) - if err != nil { - t.Fatal(err) - } - if h.ID != someDNSID { - t.Errorf("response DNS ID = %v; want %v", h.ID, someDNSID) - } - - p.SkipAllQuestions() - aa, err := p.AllAnswers() - if err != nil { - t.Fatal(err) - } - if len(aa) == 0 { - t.Fatal("no answers") - } - for _, r := range aa { - t.Logf("got: %v", r.GoString()) - } - }) - } -} - -func TestDoHV6Fallback(t *testing.T) { - for _, base := range publicdns.KnownDoHPrefixes() { - for _, ip := range publicdns.DoHIPsOfBase(base) { - if ip.Is4() { - ip6, ok := publicdns.DoHV6(base) - if !ok { - t.Errorf("no v6 DoH known for %v", ip) - } else if !ip6.Is6() { - t.Errorf("dohV6(%q) returned non-v6 address %v", base, ip6) - } - } - } - } -} +// Copyright (c) Tailscale Inc & AUTHORS
+// SPDX-License-Identifier: BSD-3-Clause
+
+package resolver
+
+import (
+ "context"
+ "flag"
+ "net/http"
+ "testing"
+
+ "golang.org/x/net/dns/dnsmessage"
+ "tailscale.com/net/dns/publicdns"
+)
+
+var testDoH = flag.Bool("test-doh", false, "do real DoH tests against the network")
+
+const someDNSID = 123 // something non-zero as a test; in violation of spec's SHOULD of 0
+
+func someDNSQuestion(t testing.TB) []byte {
+ b := dnsmessage.NewBuilder(nil, dnsmessage.Header{
+ OpCode: 0, // query
+ RecursionDesired: true,
+ ID: someDNSID,
+ })
+ b.StartQuestions() // err
+ b.Question(dnsmessage.Question{
+ Name: dnsmessage.MustNewName("tailscale.com."),
+ Type: dnsmessage.TypeA,
+ Class: dnsmessage.ClassINET,
+ })
+ msg, err := b.Finish()
+ if err != nil {
+ t.Fatal(err)
+ }
+ return msg
+}
+
+func TestDoH(t *testing.T) {
+ if !*testDoH {
+ t.Skip("skipping manual test without --test-doh flag")
+ }
+ prefixes := publicdns.KnownDoHPrefixes()
+ if len(prefixes) == 0 {
+ t.Fatal("no known DoH")
+ }
+
+ f := &forwarder{}
+
+ for _, urlBase := range prefixes {
+ t.Run(urlBase, func(t *testing.T) {
+ c, ok := f.getKnownDoHClientForProvider(urlBase)
+ if !ok {
+ t.Fatal("expected DoH")
+ }
+ res, err := f.sendDoH(context.Background(), urlBase, c, someDNSQuestion(t))
+ if err != nil {
+ t.Fatal(err)
+ }
+ c.Transport.(*http.Transport).CloseIdleConnections()
+
+ var p dnsmessage.Parser
+ h, err := p.Start(res)
+ if err != nil {
+ t.Fatal(err)
+ }
+ if h.ID != someDNSID {
+ t.Errorf("response DNS ID = %v; want %v", h.ID, someDNSID)
+ }
+
+ p.SkipAllQuestions()
+ aa, err := p.AllAnswers()
+ if err != nil {
+ t.Fatal(err)
+ }
+ if len(aa) == 0 {
+ t.Fatal("no answers")
+ }
+ for _, r := range aa {
+ t.Logf("got: %v", r.GoString())
+ }
+ })
+ }
+}
+
+func TestDoHV6Fallback(t *testing.T) {
+ for _, base := range publicdns.KnownDoHPrefixes() {
+ for _, ip := range publicdns.DoHIPsOfBase(base) {
+ if ip.Is4() {
+ ip6, ok := publicdns.DoHV6(base)
+ if !ok {
+ t.Errorf("no v6 DoH known for %v", ip)
+ } else if !ip6.Is6() {
+ t.Errorf("dohV6(%q) returned non-v6 address %v", base, ip6)
+ }
+ }
+ }
+ }
+}
diff --git a/net/dns/resolver/macios_ext.go b/net/dns/resolver/macios_ext.go index e3f979c19..37cccc7f0 100644 --- a/net/dns/resolver/macios_ext.go +++ b/net/dns/resolver/macios_ext.go @@ -1,26 +1,26 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -//go:build ts_macext && (darwin || ios) - -package resolver - -import ( - "errors" - "net" - - "tailscale.com/net/netmon" - "tailscale.com/net/netns" -) - -func init() { - initListenConfig = initListenConfigNetworkExtension -} - -func initListenConfigNetworkExtension(nc *net.ListenConfig, netMon *netmon.Monitor, tunName string) error { - nif, ok := netMon.InterfaceState().Interface[tunName] - if !ok { - return errors.New("utun not found") - } - return netns.SetListenConfigInterfaceIndex(nc, nif.Interface.Index) -} +// Copyright (c) Tailscale Inc & AUTHORS
+// SPDX-License-Identifier: BSD-3-Clause
+
+//go:build ts_macext && (darwin || ios)
+
+package resolver
+
+import (
+ "errors"
+ "net"
+
+ "tailscale.com/net/netmon"
+ "tailscale.com/net/netns"
+)
+
+func init() {
+ initListenConfig = initListenConfigNetworkExtension
+}
+
+func initListenConfigNetworkExtension(nc *net.ListenConfig, netMon *netmon.Monitor, tunName string) error {
+ nif, ok := netMon.InterfaceState().Interface[tunName]
+ if !ok {
+ return errors.New("utun not found")
+ }
+ return netns.SetListenConfigInterfaceIndex(nc, nif.Interface.Index)
+}
diff --git a/net/dns/resolver/tsdns_server_test.go b/net/dns/resolver/tsdns_server_test.go index 82fd3bebf..be47cdfbc 100644 --- a/net/dns/resolver/tsdns_server_test.go +++ b/net/dns/resolver/tsdns_server_test.go @@ -1,333 +1,333 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -package resolver - -import ( - "fmt" - "net" - "net/netip" - "strings" - "testing" - - "github.com/miekg/dns" -) - -// This file exists to isolate the test infrastructure -// that depends on github.com/miekg/dns -// from the rest, which only depends on dnsmessage. - -// resolveToIP returns a handler function which responds -// to queries of type A it receives with an A record containing ipv4, -// to queries of type AAAA with an AAAA record containing ipv6, -// to queries of type NS with an NS record containing name. -func resolveToIP(ipv4, ipv6 netip.Addr, ns string) dns.HandlerFunc { - return func(w dns.ResponseWriter, req *dns.Msg) { - m := new(dns.Msg) - m.SetReply(req) - - if len(req.Question) != 1 { - panic("not a single-question request") - } - question := req.Question[0] - - var ans dns.RR - switch question.Qtype { - case dns.TypeA: - ans = &dns.A{ - Hdr: dns.RR_Header{ - Name: question.Name, - Rrtype: dns.TypeA, - Class: dns.ClassINET, - }, - A: ipv4.AsSlice(), - } - case dns.TypeAAAA: - ans = &dns.AAAA{ - Hdr: dns.RR_Header{ - Name: question.Name, - Rrtype: dns.TypeAAAA, - Class: dns.ClassINET, - }, - AAAA: ipv6.AsSlice(), - } - case dns.TypeNS: - ans = &dns.NS{ - Hdr: dns.RR_Header{ - Name: question.Name, - Rrtype: dns.TypeNS, - Class: dns.ClassINET, - }, - Ns: ns, - } - } - - m.Answer = append(m.Answer, ans) - w.WriteMsg(m) - } -} - -// resolveToIPLowercase returns a handler function which canonicalizes responses -// by lowercasing the question and answer names, and responds -// to queries of type A it receives with an A record containing ipv4, -// to queries of type AAAA with an AAAA record containing ipv6, -// to queries of type NS with an NS record containing name. -func resolveToIPLowercase(ipv4, ipv6 netip.Addr, ns string) dns.HandlerFunc { - return func(w dns.ResponseWriter, req *dns.Msg) { - m := new(dns.Msg) - m.SetReply(req) - - if len(req.Question) != 1 { - panic("not a single-question request") - } - m.Question[0].Name = strings.ToLower(m.Question[0].Name) - question := req.Question[0] - - var ans dns.RR - switch question.Qtype { - case dns.TypeA: - ans = &dns.A{ - Hdr: dns.RR_Header{ - Name: question.Name, - Rrtype: dns.TypeA, - Class: dns.ClassINET, - }, - A: ipv4.AsSlice(), - } - case dns.TypeAAAA: - ans = &dns.AAAA{ - Hdr: dns.RR_Header{ - Name: question.Name, - Rrtype: dns.TypeAAAA, - Class: dns.ClassINET, - }, - AAAA: ipv6.AsSlice(), - } - case dns.TypeNS: - ans = &dns.NS{ - Hdr: dns.RR_Header{ - Name: question.Name, - Rrtype: dns.TypeNS, - Class: dns.ClassINET, - }, - Ns: ns, - } - } - - m.Answer = append(m.Answer, ans) - w.WriteMsg(m) - } -} - -// resolveToTXT returns a handler function which responds to queries of type TXT -// it receives with the strings in txts. -func resolveToTXT(txts []string, ednsMaxSize uint16) dns.HandlerFunc { - return func(w dns.ResponseWriter, req *dns.Msg) { - m := new(dns.Msg) - m.SetReply(req) - - if len(req.Question) != 1 { - panic("not a single-question request") - } - question := req.Question[0] - - if question.Qtype != dns.TypeTXT { - w.WriteMsg(m) - return - } - - ans := &dns.TXT{ - Hdr: dns.RR_Header{ - Name: question.Name, - Rrtype: dns.TypeTXT, - Class: dns.ClassINET, - }, - Txt: txts, - } - - m.Answer = append(m.Answer, ans) - - queryInfo := &dns.TXT{ - Hdr: dns.RR_Header{ - Name: "query-info.test.", - Rrtype: dns.TypeTXT, - Class: dns.ClassINET, - }, - } - - if edns := req.IsEdns0(); edns == nil { - queryInfo.Txt = []string{"EDNS=false"} - } else { - queryInfo.Txt = []string{"EDNS=true", fmt.Sprintf("maxSize=%v", edns.UDPSize())} - } - - m.Extra = append(m.Extra, queryInfo) - - if ednsMaxSize > 0 { - m.SetEdns0(ednsMaxSize, false) - } - - if err := w.WriteMsg(m); err != nil { - panic(err) - } - } -} - -var resolveToNXDOMAIN = dns.HandlerFunc(func(w dns.ResponseWriter, req *dns.Msg) { - m := new(dns.Msg) - m.SetRcode(req, dns.RcodeNameError) - w.WriteMsg(m) -}) - -// weirdoGoCNAMEHandler returns a DNS handler that satisfies -// Go's weird Resolver.LookupCNAME (read its godoc carefully!). -// -// This doesn't even return a CNAME record, because that's not -// what Go looks for. -func weirdoGoCNAMEHandler(target string) dns.HandlerFunc { - return func(w dns.ResponseWriter, req *dns.Msg) { - m := new(dns.Msg) - m.SetReply(req) - question := req.Question[0] - - switch question.Qtype { - case dns.TypeA: - m.Answer = append(m.Answer, &dns.CNAME{ - Hdr: dns.RR_Header{ - Name: target, - Rrtype: dns.TypeCNAME, - Class: dns.ClassINET, - Ttl: 600, - }, - Target: target, - }) - case dns.TypeAAAA: - m.Answer = append(m.Answer, &dns.AAAA{ - Hdr: dns.RR_Header{ - Name: target, - Rrtype: dns.TypeAAAA, - Class: dns.ClassINET, - Ttl: 600, - }, - AAAA: net.ParseIP("1::2"), - }) - } - w.WriteMsg(m) - } -} - -// dnsHandler returns a handler that replies with the answers/options -// provided. -// -// Types supported: netip.Addr. -func dnsHandler(answers ...any) dns.HandlerFunc { - return func(w dns.ResponseWriter, req *dns.Msg) { - m := new(dns.Msg) - m.SetReply(req) - if len(req.Question) != 1 { - panic("not a single-question request") - } - m.RecursionAvailable = true // to stop net package's errLameReferral on empty replies - - question := req.Question[0] - for _, a := range answers { - switch a := a.(type) { - default: - panic(fmt.Sprintf("unsupported dnsHandler arg %T", a)) - case netip.Addr: - ip := a - if ip.Is4() { - m.Answer = append(m.Answer, &dns.A{ - Hdr: dns.RR_Header{ - Name: question.Name, - Rrtype: dns.TypeA, - Class: dns.ClassINET, - }, - A: ip.AsSlice(), - }) - } else if ip.Is6() { - m.Answer = append(m.Answer, &dns.AAAA{ - Hdr: dns.RR_Header{ - Name: question.Name, - Rrtype: dns.TypeAAAA, - Class: dns.ClassINET, - }, - AAAA: ip.AsSlice(), - }) - } - case dns.PTR: - ptr := a - ptr.Hdr = dns.RR_Header{ - Name: question.Name, - Rrtype: dns.TypePTR, - Class: dns.ClassINET, - } - m.Answer = append(m.Answer, &ptr) - case dns.CNAME: - c := a - c.Hdr = dns.RR_Header{ - Name: question.Name, - Rrtype: dns.TypeCNAME, - Class: dns.ClassINET, - Ttl: 600, - } - m.Answer = append(m.Answer, &c) - case dns.TXT: - txt := a - txt.Hdr = dns.RR_Header{ - Name: question.Name, - Rrtype: dns.TypeTXT, - Class: dns.ClassINET, - } - m.Answer = append(m.Answer, &txt) - case dns.SRV: - srv := a - srv.Hdr = dns.RR_Header{ - Name: question.Name, - Rrtype: dns.TypeSRV, - Class: dns.ClassINET, - } - m.Answer = append(m.Answer, &srv) - case dns.NS: - rr := a - rr.Hdr = dns.RR_Header{ - Name: question.Name, - Rrtype: dns.TypeNS, - Class: dns.ClassINET, - } - m.Answer = append(m.Answer, &rr) - } - } - w.WriteMsg(m) - } -} - -func serveDNS(tb testing.TB, addr string, records ...any) *dns.Server { - if len(records)%2 != 0 { - panic("must have an even number of record values") - } - mux := dns.NewServeMux() - for i := 0; i < len(records); i += 2 { - name := records[i].(string) - handler := records[i+1].(dns.Handler) - mux.Handle(name, handler) - } - waitch := make(chan struct{}) - server := &dns.Server{ - Addr: addr, - Net: "udp", - Handler: mux, - NotifyStartedFunc: func() { close(waitch) }, - ReusePort: true, - } - - go func() { - err := server.ListenAndServe() - if err != nil { - panic(fmt.Sprintf("ListenAndServe(%q): %v", addr, err)) - } - }() - - <-waitch - return server -} +// Copyright (c) Tailscale Inc & AUTHORS
+// SPDX-License-Identifier: BSD-3-Clause
+
+package resolver
+
+import (
+ "fmt"
+ "net"
+ "net/netip"
+ "strings"
+ "testing"
+
+ "github.com/miekg/dns"
+)
+
+// This file exists to isolate the test infrastructure
+// that depends on github.com/miekg/dns
+// from the rest, which only depends on dnsmessage.
+
+// resolveToIP returns a handler function which responds
+// to queries of type A it receives with an A record containing ipv4,
+// to queries of type AAAA with an AAAA record containing ipv6,
+// to queries of type NS with an NS record containing name.
+func resolveToIP(ipv4, ipv6 netip.Addr, ns string) dns.HandlerFunc {
+ return func(w dns.ResponseWriter, req *dns.Msg) {
+ m := new(dns.Msg)
+ m.SetReply(req)
+
+ if len(req.Question) != 1 {
+ panic("not a single-question request")
+ }
+ question := req.Question[0]
+
+ var ans dns.RR
+ switch question.Qtype {
+ case dns.TypeA:
+ ans = &dns.A{
+ Hdr: dns.RR_Header{
+ Name: question.Name,
+ Rrtype: dns.TypeA,
+ Class: dns.ClassINET,
+ },
+ A: ipv4.AsSlice(),
+ }
+ case dns.TypeAAAA:
+ ans = &dns.AAAA{
+ Hdr: dns.RR_Header{
+ Name: question.Name,
+ Rrtype: dns.TypeAAAA,
+ Class: dns.ClassINET,
+ },
+ AAAA: ipv6.AsSlice(),
+ }
+ case dns.TypeNS:
+ ans = &dns.NS{
+ Hdr: dns.RR_Header{
+ Name: question.Name,
+ Rrtype: dns.TypeNS,
+ Class: dns.ClassINET,
+ },
+ Ns: ns,
+ }
+ }
+
+ m.Answer = append(m.Answer, ans)
+ w.WriteMsg(m)
+ }
+}
+
+// resolveToIPLowercase returns a handler function which canonicalizes responses
+// by lowercasing the question and answer names, and responds
+// to queries of type A it receives with an A record containing ipv4,
+// to queries of type AAAA with an AAAA record containing ipv6,
+// to queries of type NS with an NS record containing name.
+func resolveToIPLowercase(ipv4, ipv6 netip.Addr, ns string) dns.HandlerFunc {
+ return func(w dns.ResponseWriter, req *dns.Msg) {
+ m := new(dns.Msg)
+ m.SetReply(req)
+
+ if len(req.Question) != 1 {
+ panic("not a single-question request")
+ }
+ m.Question[0].Name = strings.ToLower(m.Question[0].Name)
+ question := req.Question[0]
+
+ var ans dns.RR
+ switch question.Qtype {
+ case dns.TypeA:
+ ans = &dns.A{
+ Hdr: dns.RR_Header{
+ Name: question.Name,
+ Rrtype: dns.TypeA,
+ Class: dns.ClassINET,
+ },
+ A: ipv4.AsSlice(),
+ }
+ case dns.TypeAAAA:
+ ans = &dns.AAAA{
+ Hdr: dns.RR_Header{
+ Name: question.Name,
+ Rrtype: dns.TypeAAAA,
+ Class: dns.ClassINET,
+ },
+ AAAA: ipv6.AsSlice(),
+ }
+ case dns.TypeNS:
+ ans = &dns.NS{
+ Hdr: dns.RR_Header{
+ Name: question.Name,
+ Rrtype: dns.TypeNS,
+ Class: dns.ClassINET,
+ },
+ Ns: ns,
+ }
+ }
+
+ m.Answer = append(m.Answer, ans)
+ w.WriteMsg(m)
+ }
+}
+
+// resolveToTXT returns a handler function which responds to queries of type TXT
+// it receives with the strings in txts.
+func resolveToTXT(txts []string, ednsMaxSize uint16) dns.HandlerFunc {
+ return func(w dns.ResponseWriter, req *dns.Msg) {
+ m := new(dns.Msg)
+ m.SetReply(req)
+
+ if len(req.Question) != 1 {
+ panic("not a single-question request")
+ }
+ question := req.Question[0]
+
+ if question.Qtype != dns.TypeTXT {
+ w.WriteMsg(m)
+ return
+ }
+
+ ans := &dns.TXT{
+ Hdr: dns.RR_Header{
+ Name: question.Name,
+ Rrtype: dns.TypeTXT,
+ Class: dns.ClassINET,
+ },
+ Txt: txts,
+ }
+
+ m.Answer = append(m.Answer, ans)
+
+ queryInfo := &dns.TXT{
+ Hdr: dns.RR_Header{
+ Name: "query-info.test.",
+ Rrtype: dns.TypeTXT,
+ Class: dns.ClassINET,
+ },
+ }
+
+ if edns := req.IsEdns0(); edns == nil {
+ queryInfo.Txt = []string{"EDNS=false"}
+ } else {
+ queryInfo.Txt = []string{"EDNS=true", fmt.Sprintf("maxSize=%v", edns.UDPSize())}
+ }
+
+ m.Extra = append(m.Extra, queryInfo)
+
+ if ednsMaxSize > 0 {
+ m.SetEdns0(ednsMaxSize, false)
+ }
+
+ if err := w.WriteMsg(m); err != nil {
+ panic(err)
+ }
+ }
+}
+
+var resolveToNXDOMAIN = dns.HandlerFunc(func(w dns.ResponseWriter, req *dns.Msg) {
+ m := new(dns.Msg)
+ m.SetRcode(req, dns.RcodeNameError)
+ w.WriteMsg(m)
+})
+
+// weirdoGoCNAMEHandler returns a DNS handler that satisfies
+// Go's weird Resolver.LookupCNAME (read its godoc carefully!).
+//
+// This doesn't even return a CNAME record, because that's not
+// what Go looks for.
+func weirdoGoCNAMEHandler(target string) dns.HandlerFunc {
+ return func(w dns.ResponseWriter, req *dns.Msg) {
+ m := new(dns.Msg)
+ m.SetReply(req)
+ question := req.Question[0]
+
+ switch question.Qtype {
+ case dns.TypeA:
+ m.Answer = append(m.Answer, &dns.CNAME{
+ Hdr: dns.RR_Header{
+ Name: target,
+ Rrtype: dns.TypeCNAME,
+ Class: dns.ClassINET,
+ Ttl: 600,
+ },
+ Target: target,
+ })
+ case dns.TypeAAAA:
+ m.Answer = append(m.Answer, &dns.AAAA{
+ Hdr: dns.RR_Header{
+ Name: target,
+ Rrtype: dns.TypeAAAA,
+ Class: dns.ClassINET,
+ Ttl: 600,
+ },
+ AAAA: net.ParseIP("1::2"),
+ })
+ }
+ w.WriteMsg(m)
+ }
+}
+
+// dnsHandler returns a handler that replies with the answers/options
+// provided.
+//
+// Types supported: netip.Addr.
+func dnsHandler(answers ...any) dns.HandlerFunc {
+ return func(w dns.ResponseWriter, req *dns.Msg) {
+ m := new(dns.Msg)
+ m.SetReply(req)
+ if len(req.Question) != 1 {
+ panic("not a single-question request")
+ }
+ m.RecursionAvailable = true // to stop net package's errLameReferral on empty replies
+
+ question := req.Question[0]
+ for _, a := range answers {
+ switch a := a.(type) {
+ default:
+ panic(fmt.Sprintf("unsupported dnsHandler arg %T", a))
+ case netip.Addr:
+ ip := a
+ if ip.Is4() {
+ m.Answer = append(m.Answer, &dns.A{
+ Hdr: dns.RR_Header{
+ Name: question.Name,
+ Rrtype: dns.TypeA,
+ Class: dns.ClassINET,
+ },
+ A: ip.AsSlice(),
+ })
+ } else if ip.Is6() {
+ m.Answer = append(m.Answer, &dns.AAAA{
+ Hdr: dns.RR_Header{
+ Name: question.Name,
+ Rrtype: dns.TypeAAAA,
+ Class: dns.ClassINET,
+ },
+ AAAA: ip.AsSlice(),
+ })
+ }
+ case dns.PTR:
+ ptr := a
+ ptr.Hdr = dns.RR_Header{
+ Name: question.Name,
+ Rrtype: dns.TypePTR,
+ Class: dns.ClassINET,
+ }
+ m.Answer = append(m.Answer, &ptr)
+ case dns.CNAME:
+ c := a
+ c.Hdr = dns.RR_Header{
+ Name: question.Name,
+ Rrtype: dns.TypeCNAME,
+ Class: dns.ClassINET,
+ Ttl: 600,
+ }
+ m.Answer = append(m.Answer, &c)
+ case dns.TXT:
+ txt := a
+ txt.Hdr = dns.RR_Header{
+ Name: question.Name,
+ Rrtype: dns.TypeTXT,
+ Class: dns.ClassINET,
+ }
+ m.Answer = append(m.Answer, &txt)
+ case dns.SRV:
+ srv := a
+ srv.Hdr = dns.RR_Header{
+ Name: question.Name,
+ Rrtype: dns.TypeSRV,
+ Class: dns.ClassINET,
+ }
+ m.Answer = append(m.Answer, &srv)
+ case dns.NS:
+ rr := a
+ rr.Hdr = dns.RR_Header{
+ Name: question.Name,
+ Rrtype: dns.TypeNS,
+ Class: dns.ClassINET,
+ }
+ m.Answer = append(m.Answer, &rr)
+ }
+ }
+ w.WriteMsg(m)
+ }
+}
+
+func serveDNS(tb testing.TB, addr string, records ...any) *dns.Server {
+ if len(records)%2 != 0 {
+ panic("must have an even number of record values")
+ }
+ mux := dns.NewServeMux()
+ for i := 0; i < len(records); i += 2 {
+ name := records[i].(string)
+ handler := records[i+1].(dns.Handler)
+ mux.Handle(name, handler)
+ }
+ waitch := make(chan struct{})
+ server := &dns.Server{
+ Addr: addr,
+ Net: "udp",
+ Handler: mux,
+ NotifyStartedFunc: func() { close(waitch) },
+ ReusePort: true,
+ }
+
+ go func() {
+ err := server.ListenAndServe()
+ if err != nil {
+ panic(fmt.Sprintf("ListenAndServe(%q): %v", addr, err))
+ }
+ }()
+
+ <-waitch
+ return server
+}
diff --git a/net/dns/utf.go b/net/dns/utf.go index 0c1db69ac..267829c05 100644 --- a/net/dns/utf.go +++ b/net/dns/utf.go @@ -1,55 +1,55 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -package dns - -// This code is only used in Windows builds, but is in an -// OS-independent file so tests can run all the time. - -import ( - "bytes" - "encoding/binary" - "unicode/utf16" -) - -// maybeUnUTF16 tries to detect whether bs contains UTF-16, and if so -// translates it to regular UTF-8. -// -// Some of wsl.exe's output get printed as UTF-16, which breaks a -// bunch of things. Try to detect this by looking for a zero byte in -// the first few bytes of output (which will appear if any of those -// codepoints are basic ASCII - very likely). From that we can infer -// that UTF-16 is being printed, and the byte order in use, and we -// decode that back to UTF-8. -// -// https://github.com/microsoft/WSL/issues/4607 -func maybeUnUTF16(bs []byte) []byte { - if len(bs)%2 != 0 { - // Can't be complete UTF-16. - return bs - } - checkLen := 20 - if len(bs) < checkLen { - checkLen = len(bs) - } - zeroOff := bytes.IndexByte(bs[:checkLen], 0) - if zeroOff == -1 { - return bs - } - - // We assume wsl.exe is trying to print an ASCII codepoint, - // meaning the zero byte is in the upper 8 bits of the - // codepoint. That means we can use the zero's byte offset to - // work out if we're seeing little-endian or big-endian - // UTF-16. - var endian binary.ByteOrder = binary.LittleEndian - if zeroOff%2 == 0 { - endian = binary.BigEndian - } - - var u16 []uint16 - for i := 0; i < len(bs); i += 2 { - u16 = append(u16, endian.Uint16(bs[i:])) - } - return []byte(string(utf16.Decode(u16))) -} +// Copyright (c) Tailscale Inc & AUTHORS
+// SPDX-License-Identifier: BSD-3-Clause
+
+package dns
+
+// This code is only used in Windows builds, but is in an
+// OS-independent file so tests can run all the time.
+
+import (
+ "bytes"
+ "encoding/binary"
+ "unicode/utf16"
+)
+
+// maybeUnUTF16 tries to detect whether bs contains UTF-16, and if so
+// translates it to regular UTF-8.
+//
+// Some of wsl.exe's output get printed as UTF-16, which breaks a
+// bunch of things. Try to detect this by looking for a zero byte in
+// the first few bytes of output (which will appear if any of those
+// codepoints are basic ASCII - very likely). From that we can infer
+// that UTF-16 is being printed, and the byte order in use, and we
+// decode that back to UTF-8.
+//
+// https://github.com/microsoft/WSL/issues/4607
+func maybeUnUTF16(bs []byte) []byte {
+ if len(bs)%2 != 0 {
+ // Can't be complete UTF-16.
+ return bs
+ }
+ checkLen := 20
+ if len(bs) < checkLen {
+ checkLen = len(bs)
+ }
+ zeroOff := bytes.IndexByte(bs[:checkLen], 0)
+ if zeroOff == -1 {
+ return bs
+ }
+
+ // We assume wsl.exe is trying to print an ASCII codepoint,
+ // meaning the zero byte is in the upper 8 bits of the
+ // codepoint. That means we can use the zero's byte offset to
+ // work out if we're seeing little-endian or big-endian
+ // UTF-16.
+ var endian binary.ByteOrder = binary.LittleEndian
+ if zeroOff%2 == 0 {
+ endian = binary.BigEndian
+ }
+
+ var u16 []uint16
+ for i := 0; i < len(bs); i += 2 {
+ u16 = append(u16, endian.Uint16(bs[i:]))
+ }
+ return []byte(string(utf16.Decode(u16)))
+}
diff --git a/net/dns/utf_test.go b/net/dns/utf_test.go index b5fd37262..fcf593497 100644 --- a/net/dns/utf_test.go +++ b/net/dns/utf_test.go @@ -1,24 +1,24 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -package dns - -import "testing" - -func TestMaybeUnUTF16(t *testing.T) { - tests := []struct { - in string - want string - }{ - {"abc", "abc"}, // UTF-8 - {"a\x00b\x00c\x00", "abc"}, // UTF-16-LE - {"\x00a\x00b\x00c", "abc"}, // UTF-16-BE - } - - for _, test := range tests { - got := string(maybeUnUTF16([]byte(test.in))) - if got != test.want { - t.Errorf("maybeUnUTF16(%q) = %q, want %q", test.in, got, test.want) - } - } -} +// Copyright (c) Tailscale Inc & AUTHORS
+// SPDX-License-Identifier: BSD-3-Clause
+
+package dns
+
+import "testing"
+
+func TestMaybeUnUTF16(t *testing.T) {
+ tests := []struct {
+ in string
+ want string
+ }{
+ {"abc", "abc"}, // UTF-8
+ {"a\x00b\x00c\x00", "abc"}, // UTF-16-LE
+ {"\x00a\x00b\x00c", "abc"}, // UTF-16-BE
+ }
+
+ for _, test := range tests {
+ got := string(maybeUnUTF16([]byte(test.in)))
+ if got != test.want {
+ t.Errorf("maybeUnUTF16(%q) = %q, want %q", test.in, got, test.want)
+ }
+ }
+}
diff --git a/net/dnscache/dnscache_test.go b/net/dnscache/dnscache_test.go index ef4249b74..6a4b96931 100644 --- a/net/dnscache/dnscache_test.go +++ b/net/dnscache/dnscache_test.go @@ -1,242 +1,242 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -package dnscache - -import ( - "context" - "errors" - "flag" - "fmt" - "net" - "net/netip" - "reflect" - "testing" - "time" - - "tailscale.com/tstest" -) - -var dialTest = flag.String("dial-test", "", "if non-empty, addr:port to test dial") - -func TestDialer(t *testing.T) { - if *dialTest == "" { - t.Skip("skipping; --dial-test is blank") - } - r := &Resolver{Logf: t.Logf} - var std net.Dialer - dialer := Dialer(std.DialContext, r) - t0 := time.Now() - ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) - defer cancel() - c, err := dialer(ctx, "tcp", *dialTest) - if err != nil { - t.Fatal(err) - } - t.Logf("dialed in %v", time.Since(t0)) - c.Close() -} - -func TestDialCall_DNSWasTrustworthy(t *testing.T) { - type step struct { - ip netip.Addr // IP we pretended to dial - err error // the dial error or nil for success - } - mustIP := netip.MustParseAddr - errFail := errors.New("some connect failure") - tests := []struct { - name string - steps []step - want bool - }{ - { - name: "no-info", - want: false, - }, - { - name: "previous-dial", - steps: []step{ - {mustIP("2003::1"), nil}, - {mustIP("2003::1"), errFail}, - }, - want: true, - }, - { - name: "no-previous-dial", - steps: []step{ - {mustIP("2003::1"), errFail}, - }, - want: false, - }, - } - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - d := &dialer{ - pastConnect: map[netip.Addr]time.Time{}, - } - dc := &dialCall{ - d: d, - } - for _, st := range tt.steps { - dc.noteDialResult(st.ip, st.err) - } - got := dc.dnsWasTrustworthy() - if got != tt.want { - t.Errorf("got %v; want %v", got, tt.want) - } - }) - } -} - -func TestDialCall_uniqueIPs(t *testing.T) { - dc := &dialCall{} - mustIP := netip.MustParseAddr - errFail := errors.New("some connect failure") - dc.noteDialResult(mustIP("2003::1"), errFail) - dc.noteDialResult(mustIP("2003::2"), errFail) - got := dc.uniqueIPs([]netip.Addr{ - mustIP("2003::1"), - mustIP("2003::2"), - mustIP("2003::2"), - mustIP("2003::3"), - mustIP("2003::3"), - mustIP("2003::4"), - mustIP("2003::4"), - }) - want := []netip.Addr{ - mustIP("2003::3"), - mustIP("2003::4"), - } - if !reflect.DeepEqual(got, want) { - t.Errorf("got %v; want %v", got, want) - } -} - -func TestResolverAllHostStaticResult(t *testing.T) { - r := &Resolver{ - Logf: t.Logf, - SingleHost: "foo.bar", - SingleHostStaticResult: []netip.Addr{ - netip.MustParseAddr("2001:4860:4860::8888"), - netip.MustParseAddr("2001:4860:4860::8844"), - netip.MustParseAddr("8.8.8.8"), - netip.MustParseAddr("8.8.4.4"), - }, - } - ip4, ip6, allIPs, err := r.LookupIP(context.Background(), "foo.bar") - if err != nil { - t.Fatal(err) - } - if got, want := ip4.String(), "8.8.8.8"; got != want { - t.Errorf("ip4 got %q; want %q", got, want) - } - if got, want := ip6.String(), "2001:4860:4860::8888"; got != want { - t.Errorf("ip4 got %q; want %q", got, want) - } - if got, want := fmt.Sprintf("%q", allIPs), `["2001:4860:4860::8888" "2001:4860:4860::8844" "8.8.8.8" "8.8.4.4"]`; got != want { - t.Errorf("allIPs got %q; want %q", got, want) - } - - _, _, _, err = r.LookupIP(context.Background(), "bad") - if got, want := fmt.Sprint(err), `dnscache: unexpected hostname "bad" doesn't match expected "foo.bar"`; got != want { - t.Errorf("bad dial error got %q; want %q", got, want) - } -} - -func TestShouldTryBootstrap(t *testing.T) { - tstest.Replace(t, &debug, func() bool { return true }) - - type step struct { - ip netip.Addr // IP we pretended to dial - err error // the dial error or nil for success - } - - canceled, cancel := context.WithCancel(context.Background()) - cancel() - - deadlineExceeded, cancel := context.WithTimeout(context.Background(), 0) - defer cancel() - - ctx := context.Background() - errFailed := errors.New("some failure") - - cacheWithFallback := &Resolver{ - Logf: t.Logf, - LookupIPFallback: func(_ context.Context, _ string) ([]netip.Addr, error) { - panic("unimplemented") - }, - } - cacheNoFallback := &Resolver{Logf: t.Logf} - - testCases := []struct { - name string - steps []step - ctx context.Context - err error - noFallback bool - want bool - }{ - { - name: "no-error", - ctx: ctx, - err: nil, - want: false, - }, - { - name: "canceled", - ctx: canceled, - err: errFailed, - want: false, - }, - { - name: "deadline-exceeded", - ctx: deadlineExceeded, - err: errFailed, - want: false, - }, - { - name: "no-fallback", - ctx: ctx, - err: errFailed, - noFallback: true, - want: false, - }, - { - name: "dns-was-trustworthy", - ctx: ctx, - err: errFailed, - steps: []step{ - {netip.MustParseAddr("2003::1"), nil}, - {netip.MustParseAddr("2003::1"), errFailed}, - }, - want: false, - }, - { - name: "should-bootstrap", - ctx: ctx, - err: errFailed, - want: true, - }, - } - - for _, tt := range testCases { - t.Run(tt.name, func(t *testing.T) { - d := &dialer{ - pastConnect: map[netip.Addr]time.Time{}, - } - if tt.noFallback { - d.dnsCache = cacheNoFallback - } else { - d.dnsCache = cacheWithFallback - } - dc := &dialCall{d: d} - for _, st := range tt.steps { - dc.noteDialResult(st.ip, st.err) - } - got := d.shouldTryBootstrap(tt.ctx, tt.err, dc) - if got != tt.want { - t.Errorf("got %v; want %v", got, tt.want) - } - }) - } -} +// Copyright (c) Tailscale Inc & AUTHORS
+// SPDX-License-Identifier: BSD-3-Clause
+
+package dnscache
+
+import (
+ "context"
+ "errors"
+ "flag"
+ "fmt"
+ "net"
+ "net/netip"
+ "reflect"
+ "testing"
+ "time"
+
+ "tailscale.com/tstest"
+)
+
+var dialTest = flag.String("dial-test", "", "if non-empty, addr:port to test dial")
+
+func TestDialer(t *testing.T) {
+ if *dialTest == "" {
+ t.Skip("skipping; --dial-test is blank")
+ }
+ r := &Resolver{Logf: t.Logf}
+ var std net.Dialer
+ dialer := Dialer(std.DialContext, r)
+ t0 := time.Now()
+ ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
+ defer cancel()
+ c, err := dialer(ctx, "tcp", *dialTest)
+ if err != nil {
+ t.Fatal(err)
+ }
+ t.Logf("dialed in %v", time.Since(t0))
+ c.Close()
+}
+
+func TestDialCall_DNSWasTrustworthy(t *testing.T) {
+ type step struct {
+ ip netip.Addr // IP we pretended to dial
+ err error // the dial error or nil for success
+ }
+ mustIP := netip.MustParseAddr
+ errFail := errors.New("some connect failure")
+ tests := []struct {
+ name string
+ steps []step
+ want bool
+ }{
+ {
+ name: "no-info",
+ want: false,
+ },
+ {
+ name: "previous-dial",
+ steps: []step{
+ {mustIP("2003::1"), nil},
+ {mustIP("2003::1"), errFail},
+ },
+ want: true,
+ },
+ {
+ name: "no-previous-dial",
+ steps: []step{
+ {mustIP("2003::1"), errFail},
+ },
+ want: false,
+ },
+ }
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ d := &dialer{
+ pastConnect: map[netip.Addr]time.Time{},
+ }
+ dc := &dialCall{
+ d: d,
+ }
+ for _, st := range tt.steps {
+ dc.noteDialResult(st.ip, st.err)
+ }
+ got := dc.dnsWasTrustworthy()
+ if got != tt.want {
+ t.Errorf("got %v; want %v", got, tt.want)
+ }
+ })
+ }
+}
+
+func TestDialCall_uniqueIPs(t *testing.T) {
+ dc := &dialCall{}
+ mustIP := netip.MustParseAddr
+ errFail := errors.New("some connect failure")
+ dc.noteDialResult(mustIP("2003::1"), errFail)
+ dc.noteDialResult(mustIP("2003::2"), errFail)
+ got := dc.uniqueIPs([]netip.Addr{
+ mustIP("2003::1"),
+ mustIP("2003::2"),
+ mustIP("2003::2"),
+ mustIP("2003::3"),
+ mustIP("2003::3"),
+ mustIP("2003::4"),
+ mustIP("2003::4"),
+ })
+ want := []netip.Addr{
+ mustIP("2003::3"),
+ mustIP("2003::4"),
+ }
+ if !reflect.DeepEqual(got, want) {
+ t.Errorf("got %v; want %v", got, want)
+ }
+}
+
+func TestResolverAllHostStaticResult(t *testing.T) {
+ r := &Resolver{
+ Logf: t.Logf,
+ SingleHost: "foo.bar",
+ SingleHostStaticResult: []netip.Addr{
+ netip.MustParseAddr("2001:4860:4860::8888"),
+ netip.MustParseAddr("2001:4860:4860::8844"),
+ netip.MustParseAddr("8.8.8.8"),
+ netip.MustParseAddr("8.8.4.4"),
+ },
+ }
+ ip4, ip6, allIPs, err := r.LookupIP(context.Background(), "foo.bar")
+ if err != nil {
+ t.Fatal(err)
+ }
+ if got, want := ip4.String(), "8.8.8.8"; got != want {
+ t.Errorf("ip4 got %q; want %q", got, want)
+ }
+ if got, want := ip6.String(), "2001:4860:4860::8888"; got != want {
+ t.Errorf("ip4 got %q; want %q", got, want)
+ }
+ if got, want := fmt.Sprintf("%q", allIPs), `["2001:4860:4860::8888" "2001:4860:4860::8844" "8.8.8.8" "8.8.4.4"]`; got != want {
+ t.Errorf("allIPs got %q; want %q", got, want)
+ }
+
+ _, _, _, err = r.LookupIP(context.Background(), "bad")
+ if got, want := fmt.Sprint(err), `dnscache: unexpected hostname "bad" doesn't match expected "foo.bar"`; got != want {
+ t.Errorf("bad dial error got %q; want %q", got, want)
+ }
+}
+
+func TestShouldTryBootstrap(t *testing.T) {
+ tstest.Replace(t, &debug, func() bool { return true })
+
+ type step struct {
+ ip netip.Addr // IP we pretended to dial
+ err error // the dial error or nil for success
+ }
+
+ canceled, cancel := context.WithCancel(context.Background())
+ cancel()
+
+ deadlineExceeded, cancel := context.WithTimeout(context.Background(), 0)
+ defer cancel()
+
+ ctx := context.Background()
+ errFailed := errors.New("some failure")
+
+ cacheWithFallback := &Resolver{
+ Logf: t.Logf,
+ LookupIPFallback: func(_ context.Context, _ string) ([]netip.Addr, error) {
+ panic("unimplemented")
+ },
+ }
+ cacheNoFallback := &Resolver{Logf: t.Logf}
+
+ testCases := []struct {
+ name string
+ steps []step
+ ctx context.Context
+ err error
+ noFallback bool
+ want bool
+ }{
+ {
+ name: "no-error",
+ ctx: ctx,
+ err: nil,
+ want: false,
+ },
+ {
+ name: "canceled",
+ ctx: canceled,
+ err: errFailed,
+ want: false,
+ },
+ {
+ name: "deadline-exceeded",
+ ctx: deadlineExceeded,
+ err: errFailed,
+ want: false,
+ },
+ {
+ name: "no-fallback",
+ ctx: ctx,
+ err: errFailed,
+ noFallback: true,
+ want: false,
+ },
+ {
+ name: "dns-was-trustworthy",
+ ctx: ctx,
+ err: errFailed,
+ steps: []step{
+ {netip.MustParseAddr("2003::1"), nil},
+ {netip.MustParseAddr("2003::1"), errFailed},
+ },
+ want: false,
+ },
+ {
+ name: "should-bootstrap",
+ ctx: ctx,
+ err: errFailed,
+ want: true,
+ },
+ }
+
+ for _, tt := range testCases {
+ t.Run(tt.name, func(t *testing.T) {
+ d := &dialer{
+ pastConnect: map[netip.Addr]time.Time{},
+ }
+ if tt.noFallback {
+ d.dnsCache = cacheNoFallback
+ } else {
+ d.dnsCache = cacheWithFallback
+ }
+ dc := &dialCall{d: d}
+ for _, st := range tt.steps {
+ dc.noteDialResult(st.ip, st.err)
+ }
+ got := d.shouldTryBootstrap(tt.ctx, tt.err, dc)
+ if got != tt.want {
+ t.Errorf("got %v; want %v", got, tt.want)
+ }
+ })
+ }
+}
diff --git a/net/dnscache/messagecache_test.go b/net/dnscache/messagecache_test.go index 41fc33448..18af32459 100644 --- a/net/dnscache/messagecache_test.go +++ b/net/dnscache/messagecache_test.go @@ -1,291 +1,291 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -package dnscache - -import ( - "bytes" - "context" - "errors" - "fmt" - "net" - "runtime" - "testing" - "time" - - "golang.org/x/net/dns/dnsmessage" - "tailscale.com/tstest" -) - -func TestMessageCache(t *testing.T) { - clock := tstest.NewClock(tstest.ClockOpts{ - Start: time.Date(1987, 11, 1, 0, 0, 0, 0, time.UTC), - }) - mc := &MessageCache{Clock: clock.Now} - mc.SetMaxCacheSize(2) - clock.Advance(time.Second) - - var out bytes.Buffer - if err := mc.ReplyFromCache(&out, makeQ(1, "foo.com.")); err != ErrCacheMiss { - t.Fatalf("unexpected error: %v", err) - } - - if err := mc.AddCacheEntry( - makeQ(2, "foo.com."), - makeRes(2, "FOO.COM.", ttlOpt(10), - &dnsmessage.AResource{A: [4]byte{127, 0, 0, 1}}, - &dnsmessage.AResource{A: [4]byte{127, 0, 0, 2}})); err != nil { - t.Fatal(err) - } - - // Expect cache hit, with 10 seconds remaining. - out.Reset() - if err := mc.ReplyFromCache(&out, makeQ(3, "foo.com.")); err != nil { - t.Fatalf("expected cache hit; got: %v", err) - } - if p := mustParseResponse(t, out.Bytes()); p.TxID != 3 { - t.Errorf("TxID = %v; want %v", p.TxID, 3) - } else if p.TTL != 10 { - t.Errorf("TTL = %v; want 10", p.TTL) - } - - // One second elapses, expect a cache hit, with 9 seconds - // remaining. - clock.Advance(time.Second) - out.Reset() - if err := mc.ReplyFromCache(&out, makeQ(4, "foo.com.")); err != nil { - t.Fatalf("expected cache hit; got: %v", err) - } - if p := mustParseResponse(t, out.Bytes()); p.TxID != 4 { - t.Errorf("TxID = %v; want %v", p.TxID, 4) - } else if p.TTL != 9 { - t.Errorf("TTL = %v; want 9", p.TTL) - } - - // Expect cache miss on MX record. - if err := mc.ReplyFromCache(&out, makeQ(4, "foo.com.", dnsmessage.TypeMX)); err != ErrCacheMiss { - t.Fatalf("expected cache miss on MX; got: %v", err) - } - // Expect cache miss on CHAOS class. - if err := mc.ReplyFromCache(&out, makeQ(4, "foo.com.", dnsmessage.ClassCHAOS)); err != ErrCacheMiss { - t.Fatalf("expected cache miss on CHAOS; got: %v", err) - } - - // Ten seconds elapses; expect a cache miss. - clock.Advance(10 * time.Second) - if err := mc.ReplyFromCache(&out, makeQ(5, "foo.com.")); err != ErrCacheMiss { - t.Fatalf("expected cache miss, got: %v", err) - } -} - -type parsedMeta struct { - TxID uint16 - TTL uint32 -} - -func mustParseResponse(t testing.TB, r []byte) (ret parsedMeta) { - t.Helper() - var p dnsmessage.Parser - h, err := p.Start(r) - if err != nil { - t.Fatal(err) - } - ret.TxID = h.ID - qq, err := p.AllQuestions() - if err != nil { - t.Fatalf("AllQuestions: %v", err) - } - if len(qq) != 1 { - t.Fatalf("num questions = %v; want 1", len(qq)) - } - aa, err := p.AllAnswers() - if err != nil { - t.Fatalf("AllAnswers: %v", err) - } - for _, r := range aa { - if ret.TTL == 0 { - ret.TTL = r.Header.TTL - } - if ret.TTL != r.Header.TTL { - t.Fatal("mixed TTLs") - } - } - return ret -} - -type responseOpt bool - -type ttlOpt uint32 - -func makeQ(txID uint16, name string, opt ...any) []byte { - opt = append(opt, responseOpt(false)) - return makeDNSPkt(txID, name, opt...) -} - -func makeRes(txID uint16, name string, opt ...any) []byte { - opt = append(opt, responseOpt(true)) - return makeDNSPkt(txID, name, opt...) -} - -func makeDNSPkt(txID uint16, name string, opt ...any) []byte { - typ := dnsmessage.TypeA - class := dnsmessage.ClassINET - var response bool - var answers []dnsmessage.ResourceBody - var ttl uint32 = 1 // one second by default - for _, o := range opt { - switch o := o.(type) { - case dnsmessage.Type: - typ = o - case dnsmessage.Class: - class = o - case responseOpt: - response = bool(o) - case dnsmessage.ResourceBody: - answers = append(answers, o) - case ttlOpt: - ttl = uint32(o) - default: - panic(fmt.Sprintf("unknown opt type %T", o)) - } - } - qname := dnsmessage.MustNewName(name) - msg := dnsmessage.Message{ - Header: dnsmessage.Header{ID: txID, Response: response}, - Questions: []dnsmessage.Question{ - { - Name: qname, - Type: typ, - Class: class, - }, - }, - } - for _, rb := range answers { - msg.Answers = append(msg.Answers, dnsmessage.Resource{ - Header: dnsmessage.ResourceHeader{ - Name: qname, - Type: typ, - Class: class, - TTL: ttl, - }, - Body: rb, - }) - } - buf, err := msg.Pack() - if err != nil { - panic(err) - } - return buf -} - -func TestASCIILowerName(t *testing.T) { - n := asciiLowerName(dnsmessage.MustNewName("Foo.COM.")) - if got, want := n.String(), "foo.com."; got != want { - t.Errorf("got = %q; want %q", got, want) - } -} - -func TestGetDNSQueryCacheKey(t *testing.T) { - tests := []struct { - name string - pkt []byte - want msgQ - txID uint16 - anyTX bool - }{ - { - name: "empty", - }, - { - name: "a", - pkt: makeQ(123, "foo.com."), - want: msgQ{"foo.com.", dnsmessage.TypeA}, - txID: 123, - }, - { - name: "aaaa", - pkt: makeQ(6, "foo.com.", dnsmessage.TypeAAAA), - want: msgQ{"foo.com.", dnsmessage.TypeAAAA}, - txID: 6, - }, - { - name: "normalize_case", - pkt: makeQ(123, "FoO.CoM."), - want: msgQ{"foo.com.", dnsmessage.TypeA}, - txID: 123, - }, - { - name: "ignore_response", - pkt: makeRes(123, "foo.com."), - }, - { - name: "ignore_question_with_answers", - pkt: makeQ(2, "foo.com.", &dnsmessage.AResource{A: [4]byte{127, 0, 0, 1}}), - }, - { - name: "whatever_go_generates", // in case Go's net package grows functionality we don't handle - pkt: getGoNetPacketDNSQuery("from-go.foo."), - want: msgQ{"from-go.foo.", dnsmessage.TypeA}, - anyTX: true, - }, - } - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - got, gotTX, ok := getDNSQueryCacheKey(tt.pkt) - if !ok { - if tt.txID == 0 && got == (msgQ{}) { - return - } - t.Fatal("failed") - } - if got != tt.want { - t.Errorf("got %+v, want %+v", got, tt.want) - } - if gotTX != tt.txID && !tt.anyTX { - t.Errorf("got tx %v, want %v", gotTX, tt.txID) - } - }) - } -} - -func getGoNetPacketDNSQuery(name string) []byte { - if runtime.GOOS == "windows" { - // On Windows, Go's net.Resolver doesn't use the DNS client. - // See https://github.com/golang/go/issues/33097 which - // was approved but not yet implemented. - // For now just pretend it's implemented to make this test - // pass on Windows with complicated the caller. - return makeQ(123, name) - } - res := make(chan []byte, 1) - r := &net.Resolver{ - PreferGo: true, - Dial: func(ctx context.Context, network, address string) (net.Conn, error) { - return goResolverConn(res), nil - }, - } - r.LookupIP(context.Background(), "ip4", name) - return <-res -} - -type goResolverConn chan<- []byte - -func (goResolverConn) Close() error { return nil } -func (goResolverConn) LocalAddr() net.Addr { return todoAddr{} } -func (goResolverConn) RemoteAddr() net.Addr { return todoAddr{} } -func (goResolverConn) SetDeadline(t time.Time) error { return nil } -func (goResolverConn) SetReadDeadline(t time.Time) error { return nil } -func (goResolverConn) SetWriteDeadline(t time.Time) error { return nil } -func (goResolverConn) Read([]byte) (int, error) { return 0, errors.New("boom") } -func (c goResolverConn) Write(p []byte) (int, error) { - select { - case c <- p[2:]: // skip 2 byte length for TCP mode DNS query - default: - } - return 0, errors.New("boom") -} - -type todoAddr struct{} - -func (todoAddr) Network() string { return "unused" } -func (todoAddr) String() string { return "unused-todoAddr" } +// Copyright (c) Tailscale Inc & AUTHORS
+// SPDX-License-Identifier: BSD-3-Clause
+
+package dnscache
+
+import (
+ "bytes"
+ "context"
+ "errors"
+ "fmt"
+ "net"
+ "runtime"
+ "testing"
+ "time"
+
+ "golang.org/x/net/dns/dnsmessage"
+ "tailscale.com/tstest"
+)
+
+func TestMessageCache(t *testing.T) {
+ clock := tstest.NewClock(tstest.ClockOpts{
+ Start: time.Date(1987, 11, 1, 0, 0, 0, 0, time.UTC),
+ })
+ mc := &MessageCache{Clock: clock.Now}
+ mc.SetMaxCacheSize(2)
+ clock.Advance(time.Second)
+
+ var out bytes.Buffer
+ if err := mc.ReplyFromCache(&out, makeQ(1, "foo.com.")); err != ErrCacheMiss {
+ t.Fatalf("unexpected error: %v", err)
+ }
+
+ if err := mc.AddCacheEntry(
+ makeQ(2, "foo.com."),
+ makeRes(2, "FOO.COM.", ttlOpt(10),
+ &dnsmessage.AResource{A: [4]byte{127, 0, 0, 1}},
+ &dnsmessage.AResource{A: [4]byte{127, 0, 0, 2}})); err != nil {
+ t.Fatal(err)
+ }
+
+ // Expect cache hit, with 10 seconds remaining.
+ out.Reset()
+ if err := mc.ReplyFromCache(&out, makeQ(3, "foo.com.")); err != nil {
+ t.Fatalf("expected cache hit; got: %v", err)
+ }
+ if p := mustParseResponse(t, out.Bytes()); p.TxID != 3 {
+ t.Errorf("TxID = %v; want %v", p.TxID, 3)
+ } else if p.TTL != 10 {
+ t.Errorf("TTL = %v; want 10", p.TTL)
+ }
+
+ // One second elapses, expect a cache hit, with 9 seconds
+ // remaining.
+ clock.Advance(time.Second)
+ out.Reset()
+ if err := mc.ReplyFromCache(&out, makeQ(4, "foo.com.")); err != nil {
+ t.Fatalf("expected cache hit; got: %v", err)
+ }
+ if p := mustParseResponse(t, out.Bytes()); p.TxID != 4 {
+ t.Errorf("TxID = %v; want %v", p.TxID, 4)
+ } else if p.TTL != 9 {
+ t.Errorf("TTL = %v; want 9", p.TTL)
+ }
+
+ // Expect cache miss on MX record.
+ if err := mc.ReplyFromCache(&out, makeQ(4, "foo.com.", dnsmessage.TypeMX)); err != ErrCacheMiss {
+ t.Fatalf("expected cache miss on MX; got: %v", err)
+ }
+ // Expect cache miss on CHAOS class.
+ if err := mc.ReplyFromCache(&out, makeQ(4, "foo.com.", dnsmessage.ClassCHAOS)); err != ErrCacheMiss {
+ t.Fatalf("expected cache miss on CHAOS; got: %v", err)
+ }
+
+ // Ten seconds elapses; expect a cache miss.
+ clock.Advance(10 * time.Second)
+ if err := mc.ReplyFromCache(&out, makeQ(5, "foo.com.")); err != ErrCacheMiss {
+ t.Fatalf("expected cache miss, got: %v", err)
+ }
+}
+
+type parsedMeta struct {
+ TxID uint16
+ TTL uint32
+}
+
+func mustParseResponse(t testing.TB, r []byte) (ret parsedMeta) {
+ t.Helper()
+ var p dnsmessage.Parser
+ h, err := p.Start(r)
+ if err != nil {
+ t.Fatal(err)
+ }
+ ret.TxID = h.ID
+ qq, err := p.AllQuestions()
+ if err != nil {
+ t.Fatalf("AllQuestions: %v", err)
+ }
+ if len(qq) != 1 {
+ t.Fatalf("num questions = %v; want 1", len(qq))
+ }
+ aa, err := p.AllAnswers()
+ if err != nil {
+ t.Fatalf("AllAnswers: %v", err)
+ }
+ for _, r := range aa {
+ if ret.TTL == 0 {
+ ret.TTL = r.Header.TTL
+ }
+ if ret.TTL != r.Header.TTL {
+ t.Fatal("mixed TTLs")
+ }
+ }
+ return ret
+}
+
+type responseOpt bool
+
+type ttlOpt uint32
+
+func makeQ(txID uint16, name string, opt ...any) []byte {
+ opt = append(opt, responseOpt(false))
+ return makeDNSPkt(txID, name, opt...)
+}
+
+func makeRes(txID uint16, name string, opt ...any) []byte {
+ opt = append(opt, responseOpt(true))
+ return makeDNSPkt(txID, name, opt...)
+}
+
+func makeDNSPkt(txID uint16, name string, opt ...any) []byte {
+ typ := dnsmessage.TypeA
+ class := dnsmessage.ClassINET
+ var response bool
+ var answers []dnsmessage.ResourceBody
+ var ttl uint32 = 1 // one second by default
+ for _, o := range opt {
+ switch o := o.(type) {
+ case dnsmessage.Type:
+ typ = o
+ case dnsmessage.Class:
+ class = o
+ case responseOpt:
+ response = bool(o)
+ case dnsmessage.ResourceBody:
+ answers = append(answers, o)
+ case ttlOpt:
+ ttl = uint32(o)
+ default:
+ panic(fmt.Sprintf("unknown opt type %T", o))
+ }
+ }
+ qname := dnsmessage.MustNewName(name)
+ msg := dnsmessage.Message{
+ Header: dnsmessage.Header{ID: txID, Response: response},
+ Questions: []dnsmessage.Question{
+ {
+ Name: qname,
+ Type: typ,
+ Class: class,
+ },
+ },
+ }
+ for _, rb := range answers {
+ msg.Answers = append(msg.Answers, dnsmessage.Resource{
+ Header: dnsmessage.ResourceHeader{
+ Name: qname,
+ Type: typ,
+ Class: class,
+ TTL: ttl,
+ },
+ Body: rb,
+ })
+ }
+ buf, err := msg.Pack()
+ if err != nil {
+ panic(err)
+ }
+ return buf
+}
+
+func TestASCIILowerName(t *testing.T) {
+ n := asciiLowerName(dnsmessage.MustNewName("Foo.COM."))
+ if got, want := n.String(), "foo.com."; got != want {
+ t.Errorf("got = %q; want %q", got, want)
+ }
+}
+
+func TestGetDNSQueryCacheKey(t *testing.T) {
+ tests := []struct {
+ name string
+ pkt []byte
+ want msgQ
+ txID uint16
+ anyTX bool
+ }{
+ {
+ name: "empty",
+ },
+ {
+ name: "a",
+ pkt: makeQ(123, "foo.com."),
+ want: msgQ{"foo.com.", dnsmessage.TypeA},
+ txID: 123,
+ },
+ {
+ name: "aaaa",
+ pkt: makeQ(6, "foo.com.", dnsmessage.TypeAAAA),
+ want: msgQ{"foo.com.", dnsmessage.TypeAAAA},
+ txID: 6,
+ },
+ {
+ name: "normalize_case",
+ pkt: makeQ(123, "FoO.CoM."),
+ want: msgQ{"foo.com.", dnsmessage.TypeA},
+ txID: 123,
+ },
+ {
+ name: "ignore_response",
+ pkt: makeRes(123, "foo.com."),
+ },
+ {
+ name: "ignore_question_with_answers",
+ pkt: makeQ(2, "foo.com.", &dnsmessage.AResource{A: [4]byte{127, 0, 0, 1}}),
+ },
+ {
+ name: "whatever_go_generates", // in case Go's net package grows functionality we don't handle
+ pkt: getGoNetPacketDNSQuery("from-go.foo."),
+ want: msgQ{"from-go.foo.", dnsmessage.TypeA},
+ anyTX: true,
+ },
+ }
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ got, gotTX, ok := getDNSQueryCacheKey(tt.pkt)
+ if !ok {
+ if tt.txID == 0 && got == (msgQ{}) {
+ return
+ }
+ t.Fatal("failed")
+ }
+ if got != tt.want {
+ t.Errorf("got %+v, want %+v", got, tt.want)
+ }
+ if gotTX != tt.txID && !tt.anyTX {
+ t.Errorf("got tx %v, want %v", gotTX, tt.txID)
+ }
+ })
+ }
+}
+
+func getGoNetPacketDNSQuery(name string) []byte {
+ if runtime.GOOS == "windows" {
+ // On Windows, Go's net.Resolver doesn't use the DNS client.
+ // See https://github.com/golang/go/issues/33097 which
+ // was approved but not yet implemented.
+ // For now just pretend it's implemented to make this test
+ // pass on Windows with complicated the caller.
+ return makeQ(123, name)
+ }
+ res := make(chan []byte, 1)
+ r := &net.Resolver{
+ PreferGo: true,
+ Dial: func(ctx context.Context, network, address string) (net.Conn, error) {
+ return goResolverConn(res), nil
+ },
+ }
+ r.LookupIP(context.Background(), "ip4", name)
+ return <-res
+}
+
+type goResolverConn chan<- []byte
+
+func (goResolverConn) Close() error { return nil }
+func (goResolverConn) LocalAddr() net.Addr { return todoAddr{} }
+func (goResolverConn) RemoteAddr() net.Addr { return todoAddr{} }
+func (goResolverConn) SetDeadline(t time.Time) error { return nil }
+func (goResolverConn) SetReadDeadline(t time.Time) error { return nil }
+func (goResolverConn) SetWriteDeadline(t time.Time) error { return nil }
+func (goResolverConn) Read([]byte) (int, error) { return 0, errors.New("boom") }
+func (c goResolverConn) Write(p []byte) (int, error) {
+ select {
+ case c <- p[2:]: // skip 2 byte length for TCP mode DNS query
+ default:
+ }
+ return 0, errors.New("boom")
+}
+
+type todoAddr struct{}
+
+func (todoAddr) Network() string { return "unused" }
+func (todoAddr) String() string { return "unused-todoAddr" }
diff --git a/net/dnsfallback/update-dns-fallbacks.go b/net/dnsfallback/update-dns-fallbacks.go index 384e77e10..ebbfc2ad1 100644 --- a/net/dnsfallback/update-dns-fallbacks.go +++ b/net/dnsfallback/update-dns-fallbacks.go @@ -1,45 +1,45 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -//go:build ignore - -package main - -import ( - "encoding/json" - "fmt" - "log" - "net/http" - "os" - - "tailscale.com/tailcfg" -) - -func main() { - res, err := http.Get("https://login.tailscale.com/derpmap/default") - if err != nil { - log.Fatal(err) - } - if res.StatusCode != 200 { - res.Write(os.Stderr) - os.Exit(1) - } - dm := new(tailcfg.DERPMap) - if err := json.NewDecoder(res.Body).Decode(dm); err != nil { - log.Fatal(err) - } - for rid, r := range dm.Regions { - // Names misleading to check into git, as this is a - // static snapshot and doesn't reflect the live DERP - // map. - r.RegionCode = fmt.Sprintf("r%d", rid) - r.RegionName = r.RegionCode - } - out, err := json.MarshalIndent(dm, "", "\t") - if err != nil { - log.Fatal(err) - } - if err := os.WriteFile("dns-fallback-servers.json", out, 0644); err != nil { - log.Fatal(err) - } -} +// Copyright (c) Tailscale Inc & AUTHORS
+// SPDX-License-Identifier: BSD-3-Clause
+
+//go:build ignore
+
+package main
+
+import (
+ "encoding/json"
+ "fmt"
+ "log"
+ "net/http"
+ "os"
+
+ "tailscale.com/tailcfg"
+)
+
+func main() {
+ res, err := http.Get("https://login.tailscale.com/derpmap/default")
+ if err != nil {
+ log.Fatal(err)
+ }
+ if res.StatusCode != 200 {
+ res.Write(os.Stderr)
+ os.Exit(1)
+ }
+ dm := new(tailcfg.DERPMap)
+ if err := json.NewDecoder(res.Body).Decode(dm); err != nil {
+ log.Fatal(err)
+ }
+ for rid, r := range dm.Regions {
+ // Names misleading to check into git, as this is a
+ // static snapshot and doesn't reflect the live DERP
+ // map.
+ r.RegionCode = fmt.Sprintf("r%d", rid)
+ r.RegionName = r.RegionCode
+ }
+ out, err := json.MarshalIndent(dm, "", "\t")
+ if err != nil {
+ log.Fatal(err)
+ }
+ if err := os.WriteFile("dns-fallback-servers.json", out, 0644); err != nil {
+ log.Fatal(err)
+ }
+}
diff --git a/net/memnet/conn.go b/net/memnet/conn.go index a9e1fd399..f599612d9 100644 --- a/net/memnet/conn.go +++ b/net/memnet/conn.go @@ -1,114 +1,114 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -package memnet - -import ( - "net" - "net/netip" - "time" -) - -// NetworkName is the network name returned by [net.Addr.Network] -// for [net.Conn.LocalAddr] and [net.Conn.RemoteAddr] from the [Conn] type. -const NetworkName = "mem" - -// Conn is a net.Conn that can additionally have its reads and writes blocked and unblocked. -type Conn interface { - net.Conn - - // SetReadBlock blocks or unblocks the Read method of this Conn. - // It reports an error if the existing value matches the new value, - // or if the Conn has been Closed. - SetReadBlock(bool) error - - // SetWriteBlock blocks or unblocks the Write method of this Conn. - // It reports an error if the existing value matches the new value, - // or if the Conn has been Closed. - SetWriteBlock(bool) error -} - -// NewConn creates a pair of Conns that are wired together by pipes. -func NewConn(name string, maxBuf int) (Conn, Conn) { - r := NewPipe(name+"|0", maxBuf) - w := NewPipe(name+"|1", maxBuf) - - return &connHalf{r: r, w: w}, &connHalf{r: w, w: r} -} - -// NewTCPConn creates a pair of Conns that are wired together by pipes. -func NewTCPConn(src, dst netip.AddrPort, maxBuf int) (local Conn, remote Conn) { - r := NewPipe(src.String(), maxBuf) - w := NewPipe(dst.String(), maxBuf) - - lAddr := net.TCPAddrFromAddrPort(src) - rAddr := net.TCPAddrFromAddrPort(dst) - - return &connHalf{r: r, w: w, remote: rAddr, local: lAddr}, &connHalf{r: w, w: r, remote: lAddr, local: rAddr} -} - -type connAddr string - -func (a connAddr) Network() string { return NetworkName } -func (a connAddr) String() string { return string(a) } - -type connHalf struct { - local, remote net.Addr - r, w *Pipe -} - -func (c *connHalf) LocalAddr() net.Addr { - if c.local != nil { - return c.local - } - return connAddr(c.r.name) -} - -func (c *connHalf) RemoteAddr() net.Addr { - if c.remote != nil { - return c.remote - } - return connAddr(c.w.name) -} - -func (c *connHalf) Read(b []byte) (n int, err error) { - return c.r.Read(b) -} -func (c *connHalf) Write(b []byte) (n int, err error) { - return c.w.Write(b) -} - -func (c *connHalf) Close() error { - if err := c.w.Close(); err != nil { - return err - } - return c.r.Close() -} - -func (c *connHalf) SetDeadline(t time.Time) error { - err1 := c.SetReadDeadline(t) - err2 := c.SetWriteDeadline(t) - if err1 != nil { - return err1 - } - return err2 -} -func (c *connHalf) SetReadDeadline(t time.Time) error { - return c.r.SetReadDeadline(t) -} -func (c *connHalf) SetWriteDeadline(t time.Time) error { - return c.w.SetWriteDeadline(t) -} - -func (c *connHalf) SetReadBlock(b bool) error { - if b { - return c.r.Block() - } - return c.r.Unblock() -} -func (c *connHalf) SetWriteBlock(b bool) error { - if b { - return c.w.Block() - } - return c.w.Unblock() -} +// Copyright (c) Tailscale Inc & AUTHORS
+// SPDX-License-Identifier: BSD-3-Clause
+
+package memnet
+
+import (
+ "net"
+ "net/netip"
+ "time"
+)
+
+// NetworkName is the network name returned by [net.Addr.Network]
+// for [net.Conn.LocalAddr] and [net.Conn.RemoteAddr] from the [Conn] type.
+const NetworkName = "mem"
+
+// Conn is a net.Conn that can additionally have its reads and writes blocked and unblocked.
+type Conn interface {
+ net.Conn
+
+ // SetReadBlock blocks or unblocks the Read method of this Conn.
+ // It reports an error if the existing value matches the new value,
+ // or if the Conn has been Closed.
+ SetReadBlock(bool) error
+
+ // SetWriteBlock blocks or unblocks the Write method of this Conn.
+ // It reports an error if the existing value matches the new value,
+ // or if the Conn has been Closed.
+ SetWriteBlock(bool) error
+}
+
+// NewConn creates a pair of Conns that are wired together by pipes.
+func NewConn(name string, maxBuf int) (Conn, Conn) {
+ r := NewPipe(name+"|0", maxBuf)
+ w := NewPipe(name+"|1", maxBuf)
+
+ return &connHalf{r: r, w: w}, &connHalf{r: w, w: r}
+}
+
+// NewTCPConn creates a pair of Conns that are wired together by pipes.
+func NewTCPConn(src, dst netip.AddrPort, maxBuf int) (local Conn, remote Conn) {
+ r := NewPipe(src.String(), maxBuf)
+ w := NewPipe(dst.String(), maxBuf)
+
+ lAddr := net.TCPAddrFromAddrPort(src)
+ rAddr := net.TCPAddrFromAddrPort(dst)
+
+ return &connHalf{r: r, w: w, remote: rAddr, local: lAddr}, &connHalf{r: w, w: r, remote: lAddr, local: rAddr}
+}
+
+type connAddr string
+
+func (a connAddr) Network() string { return NetworkName }
+func (a connAddr) String() string { return string(a) }
+
+type connHalf struct {
+ local, remote net.Addr
+ r, w *Pipe
+}
+
+func (c *connHalf) LocalAddr() net.Addr {
+ if c.local != nil {
+ return c.local
+ }
+ return connAddr(c.r.name)
+}
+
+func (c *connHalf) RemoteAddr() net.Addr {
+ if c.remote != nil {
+ return c.remote
+ }
+ return connAddr(c.w.name)
+}
+
+func (c *connHalf) Read(b []byte) (n int, err error) {
+ return c.r.Read(b)
+}
+func (c *connHalf) Write(b []byte) (n int, err error) {
+ return c.w.Write(b)
+}
+
+func (c *connHalf) Close() error {
+ if err := c.w.Close(); err != nil {
+ return err
+ }
+ return c.r.Close()
+}
+
+func (c *connHalf) SetDeadline(t time.Time) error {
+ err1 := c.SetReadDeadline(t)
+ err2 := c.SetWriteDeadline(t)
+ if err1 != nil {
+ return err1
+ }
+ return err2
+}
+func (c *connHalf) SetReadDeadline(t time.Time) error {
+ return c.r.SetReadDeadline(t)
+}
+func (c *connHalf) SetWriteDeadline(t time.Time) error {
+ return c.w.SetWriteDeadline(t)
+}
+
+func (c *connHalf) SetReadBlock(b bool) error {
+ if b {
+ return c.r.Block()
+ }
+ return c.r.Unblock()
+}
+func (c *connHalf) SetWriteBlock(b bool) error {
+ if b {
+ return c.w.Block()
+ }
+ return c.w.Unblock()
+}
diff --git a/net/memnet/conn_test.go b/net/memnet/conn_test.go index 743ce5248..3eec80bc6 100644 --- a/net/memnet/conn_test.go +++ b/net/memnet/conn_test.go @@ -1,21 +1,21 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -package memnet - -import ( - "net" - "testing" - - "golang.org/x/net/nettest" -) - -func TestConn(t *testing.T) { - nettest.TestConn(t, func() (c1 net.Conn, c2 net.Conn, stop func(), err error) { - c1, c2 = NewConn("test", bufferSize) - return c1, c2, func() { - c1.Close() - c2.Close() - }, nil - }) -} +// Copyright (c) Tailscale Inc & AUTHORS
+// SPDX-License-Identifier: BSD-3-Clause
+
+package memnet
+
+import (
+ "net"
+ "testing"
+
+ "golang.org/x/net/nettest"
+)
+
+func TestConn(t *testing.T) {
+ nettest.TestConn(t, func() (c1 net.Conn, c2 net.Conn, stop func(), err error) {
+ c1, c2 = NewConn("test", bufferSize)
+ return c1, c2, func() {
+ c1.Close()
+ c2.Close()
+ }, nil
+ })
+}
diff --git a/net/memnet/listener.go b/net/memnet/listener.go index d84a2e443..d1364d790 100644 --- a/net/memnet/listener.go +++ b/net/memnet/listener.go @@ -1,100 +1,100 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -package memnet - -import ( - "context" - "net" - "strings" - "sync" -) - -const ( - bufferSize = 256 * 1024 -) - -// Listener is a net.Listener using NewConn to create pairs of network -// connections connected in memory using a buffered pipe. It also provides a -// Dial method to establish new connections. -type Listener struct { - addr connAddr - ch chan Conn - closeOnce sync.Once - closed chan struct{} - - // NewConn, if non-nil, is called to create a new pair of connections - // when dialing. If nil, NewConn is used. - NewConn func(network, addr string, maxBuf int) (Conn, Conn) -} - -// Listen returns a new Listener for the provided address. -func Listen(addr string) *Listener { - return &Listener{ - addr: connAddr(addr), - ch: make(chan Conn), - closed: make(chan struct{}), - } -} - -// Addr implements net.Listener.Addr. -func (l *Listener) Addr() net.Addr { - return l.addr -} - -// Close closes the pipe listener. -func (l *Listener) Close() error { - l.closeOnce.Do(func() { - close(l.closed) - }) - return nil -} - -// Accept blocks until a new connection is available or the listener is closed. -func (l *Listener) Accept() (net.Conn, error) { - select { - case c := <-l.ch: - return c, nil - case <-l.closed: - return nil, net.ErrClosed - } -} - -// Dial connects to the listener using the provided context. -// The provided Context must be non-nil. If the context expires before the -// connection is complete, an error is returned. Once successfully connected -// any expiration of the context will not affect the connection. -func (l *Listener) Dial(ctx context.Context, network, addr string) (_ net.Conn, err error) { - if !strings.HasSuffix(network, "tcp") { - return nil, net.UnknownNetworkError(network) - } - if connAddr(addr) != l.addr { - return nil, &net.AddrError{ - Err: "invalid address", - Addr: addr, - } - } - - newConn := l.NewConn - if newConn == nil { - newConn = func(network, addr string, maxBuf int) (Conn, Conn) { - return NewConn(addr, maxBuf) - } - } - c, s := newConn(network, addr, bufferSize) - defer func() { - if err != nil { - c.Close() - s.Close() - } - }() - - select { - case <-ctx.Done(): - return nil, ctx.Err() - case <-l.closed: - return nil, net.ErrClosed - case l.ch <- s: - return c, nil - } -} +// Copyright (c) Tailscale Inc & AUTHORS
+// SPDX-License-Identifier: BSD-3-Clause
+
+package memnet
+
+import (
+ "context"
+ "net"
+ "strings"
+ "sync"
+)
+
+const (
+ bufferSize = 256 * 1024
+)
+
+// Listener is a net.Listener using NewConn to create pairs of network
+// connections connected in memory using a buffered pipe. It also provides a
+// Dial method to establish new connections.
+type Listener struct {
+ addr connAddr
+ ch chan Conn
+ closeOnce sync.Once
+ closed chan struct{}
+
+ // NewConn, if non-nil, is called to create a new pair of connections
+ // when dialing. If nil, NewConn is used.
+ NewConn func(network, addr string, maxBuf int) (Conn, Conn)
+}
+
+// Listen returns a new Listener for the provided address.
+func Listen(addr string) *Listener {
+ return &Listener{
+ addr: connAddr(addr),
+ ch: make(chan Conn),
+ closed: make(chan struct{}),
+ }
+}
+
+// Addr implements net.Listener.Addr.
+func (l *Listener) Addr() net.Addr {
+ return l.addr
+}
+
+// Close closes the pipe listener.
+func (l *Listener) Close() error {
+ l.closeOnce.Do(func() {
+ close(l.closed)
+ })
+ return nil
+}
+
+// Accept blocks until a new connection is available or the listener is closed.
+func (l *Listener) Accept() (net.Conn, error) {
+ select {
+ case c := <-l.ch:
+ return c, nil
+ case <-l.closed:
+ return nil, net.ErrClosed
+ }
+}
+
+// Dial connects to the listener using the provided context.
+// The provided Context must be non-nil. If the context expires before the
+// connection is complete, an error is returned. Once successfully connected
+// any expiration of the context will not affect the connection.
+func (l *Listener) Dial(ctx context.Context, network, addr string) (_ net.Conn, err error) {
+ if !strings.HasSuffix(network, "tcp") {
+ return nil, net.UnknownNetworkError(network)
+ }
+ if connAddr(addr) != l.addr {
+ return nil, &net.AddrError{
+ Err: "invalid address",
+ Addr: addr,
+ }
+ }
+
+ newConn := l.NewConn
+ if newConn == nil {
+ newConn = func(network, addr string, maxBuf int) (Conn, Conn) {
+ return NewConn(addr, maxBuf)
+ }
+ }
+ c, s := newConn(network, addr, bufferSize)
+ defer func() {
+ if err != nil {
+ c.Close()
+ s.Close()
+ }
+ }()
+
+ select {
+ case <-ctx.Done():
+ return nil, ctx.Err()
+ case <-l.closed:
+ return nil, net.ErrClosed
+ case l.ch <- s:
+ return c, nil
+ }
+}
diff --git a/net/memnet/listener_test.go b/net/memnet/listener_test.go index 73b67841a..989d5e9e4 100644 --- a/net/memnet/listener_test.go +++ b/net/memnet/listener_test.go @@ -1,33 +1,33 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -package memnet - -import ( - "context" - "testing" -) - -func TestListener(t *testing.T) { - l := Listen("srv.local") - defer l.Close() - go func() { - c, err := l.Accept() - if err != nil { - t.Error(err) - return - } - defer c.Close() - }() - - if c, err := l.Dial(context.Background(), "tcp", "invalid"); err == nil { - c.Close() - t.Fatalf("dial to invalid address succeeded") - } - c, err := l.Dial(context.Background(), "tcp", "srv.local") - if err != nil { - t.Fatalf("dial failed: %v", err) - return - } - c.Close() -} +// Copyright (c) Tailscale Inc & AUTHORS
+// SPDX-License-Identifier: BSD-3-Clause
+
+package memnet
+
+import (
+ "context"
+ "testing"
+)
+
+func TestListener(t *testing.T) {
+ l := Listen("srv.local")
+ defer l.Close()
+ go func() {
+ c, err := l.Accept()
+ if err != nil {
+ t.Error(err)
+ return
+ }
+ defer c.Close()
+ }()
+
+ if c, err := l.Dial(context.Background(), "tcp", "invalid"); err == nil {
+ c.Close()
+ t.Fatalf("dial to invalid address succeeded")
+ }
+ c, err := l.Dial(context.Background(), "tcp", "srv.local")
+ if err != nil {
+ t.Fatalf("dial failed: %v", err)
+ return
+ }
+ c.Close()
+}
diff --git a/net/memnet/memnet.go b/net/memnet/memnet.go index c8799bc17..2fc13b4b2 100644 --- a/net/memnet/memnet.go +++ b/net/memnet/memnet.go @@ -1,8 +1,8 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -// Package memnet implements an in-memory network implementation. -// It is useful for dialing and listening on in-memory addresses -// in tests and other situations where you don't want to use the -// network. -package memnet +// Copyright (c) Tailscale Inc & AUTHORS
+// SPDX-License-Identifier: BSD-3-Clause
+
+// Package memnet implements an in-memory network implementation.
+// It is useful for dialing and listening on in-memory addresses
+// in tests and other situations where you don't want to use the
+// network.
+package memnet
diff --git a/net/memnet/pipe.go b/net/memnet/pipe.go index 471635083..51bee1090 100644 --- a/net/memnet/pipe.go +++ b/net/memnet/pipe.go @@ -1,244 +1,244 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -package memnet - -import ( - "bytes" - "context" - "fmt" - "io" - "log" - "net" - "os" - "sync" - "time" -) - -const debugPipe = false - -// Pipe implements an in-memory FIFO with timeouts. -type Pipe struct { - name string - maxBuf int - mu sync.Mutex - cnd *sync.Cond - - blocked bool - closed bool - buf bytes.Buffer - readTimeout time.Time - writeTimeout time.Time - cancelReadTimer func() - cancelWriteTimer func() -} - -// NewPipe creates a Pipe with a buffer size fixed at maxBuf. -func NewPipe(name string, maxBuf int) *Pipe { - p := &Pipe{ - name: name, - maxBuf: maxBuf, - } - p.cnd = sync.NewCond(&p.mu) - return p -} - -// readOrBlock attempts to read from the buffer, if the buffer is empty and -// the connection hasn't been closed it will block until there is a change. -func (p *Pipe) readOrBlock(b []byte) (int, error) { - p.mu.Lock() - defer p.mu.Unlock() - if !p.readTimeout.IsZero() && !time.Now().Before(p.readTimeout) { - return 0, os.ErrDeadlineExceeded - } - if p.blocked { - p.cnd.Wait() - return 0, nil - } - - n, err := p.buf.Read(b) - // err will either be nil or io.EOF. - if err == io.EOF { - if p.closed { - return n, err - } - // Wait for something to change. - p.cnd.Wait() - } - return n, nil -} - -// Read implements io.Reader. -// Once the buffer is drained (i.e. after Close), subsequent calls will -// return io.EOF. -func (p *Pipe) Read(b []byte) (n int, err error) { - if debugPipe { - orig := b - defer func() { - log.Printf("Pipe(%q).Read(%q) n=%d, err=%v", p.name, string(orig[:n]), n, err) - }() - } - for n == 0 { - n2, err := p.readOrBlock(b) - if err != nil { - return n2, err - } - n += n2 - } - p.cnd.Signal() - return n, nil -} - -// writeOrBlock attempts to write to the buffer, if the buffer is full it will -// block until there is a change. -func (p *Pipe) writeOrBlock(b []byte) (int, error) { - p.mu.Lock() - defer p.mu.Unlock() - if p.closed { - return 0, net.ErrClosed - } - if !p.writeTimeout.IsZero() && !time.Now().Before(p.writeTimeout) { - return 0, os.ErrDeadlineExceeded - } - if p.blocked { - p.cnd.Wait() - return 0, nil - } - - // Optimistically we want to write the entire slice. - n := len(b) - if limit := p.maxBuf - p.buf.Len(); limit < n { - // However, we don't have enough capacity to write everything. - n = limit - } - if n == 0 { - // Wait for something to change. - p.cnd.Wait() - return 0, nil - } - - p.buf.Write(b[:n]) - p.cnd.Signal() - return n, nil -} - -// Write implements io.Writer. -func (p *Pipe) Write(b []byte) (n int, err error) { - if debugPipe { - orig := b - defer func() { - log.Printf("Pipe(%q).Write(%q) n=%d, err=%v", p.name, string(orig), n, err) - }() - } - for len(b) > 0 { - n2, err := p.writeOrBlock(b) - if err != nil { - return n + n2, err - } - n += n2 - b = b[n2:] - } - return n, nil -} - -// Close closes the pipe. -func (p *Pipe) Close() error { - p.mu.Lock() - defer p.mu.Unlock() - p.closed = true - p.blocked = false - if p.cancelWriteTimer != nil { - p.cancelWriteTimer() - p.cancelWriteTimer = nil - } - if p.cancelReadTimer != nil { - p.cancelReadTimer() - p.cancelReadTimer = nil - } - p.cnd.Broadcast() - - return nil -} - -func (p *Pipe) deadlineTimer(t time.Time) func() { - if t.IsZero() { - return nil - } - if t.Before(time.Now()) { - p.cnd.Broadcast() - return nil - } - ctx, cancel := context.WithDeadline(context.Background(), t) - go func() { - <-ctx.Done() - if ctx.Err() == context.DeadlineExceeded { - p.cnd.Broadcast() - } - }() - return cancel -} - -// SetReadDeadline sets the deadline for future Read calls. -func (p *Pipe) SetReadDeadline(t time.Time) error { - p.mu.Lock() - defer p.mu.Unlock() - p.readTimeout = t - // If we already have a deadline, cancel it and create a new one. - if p.cancelReadTimer != nil { - p.cancelReadTimer() - p.cancelReadTimer = nil - } - p.cancelReadTimer = p.deadlineTimer(t) - return nil -} - -// SetWriteDeadline sets the deadline for future Write calls. -func (p *Pipe) SetWriteDeadline(t time.Time) error { - p.mu.Lock() - defer p.mu.Unlock() - p.writeTimeout = t - // If we already have a deadline, cancel it and create a new one. - if p.cancelWriteTimer != nil { - p.cancelWriteTimer() - p.cancelWriteTimer = nil - } - p.cancelWriteTimer = p.deadlineTimer(t) - return nil -} - -// Block will cause all calls to Read and Write to block until they either -// timeout, are unblocked or the pipe is closed. -func (p *Pipe) Block() error { - p.mu.Lock() - defer p.mu.Unlock() - closed := p.closed - blocked := p.blocked - p.blocked = true - - if closed { - return fmt.Errorf("memnet.Pipe(%q).Block: closed", p.name) - } - if blocked { - return fmt.Errorf("memnet.Pipe(%q).Block: already blocked", p.name) - } - p.cnd.Broadcast() - return nil -} - -// Unblock will cause all blocked Read/Write calls to continue execution. -func (p *Pipe) Unblock() error { - p.mu.Lock() - defer p.mu.Unlock() - closed := p.closed - blocked := p.blocked - p.blocked = false - - if closed { - return fmt.Errorf("memnet.Pipe(%q).Block: closed", p.name) - } - if !blocked { - return fmt.Errorf("memnet.Pipe(%q).Block: already unblocked", p.name) - } - p.cnd.Broadcast() - return nil -} +// Copyright (c) Tailscale Inc & AUTHORS
+// SPDX-License-Identifier: BSD-3-Clause
+
+package memnet
+
+import (
+ "bytes"
+ "context"
+ "fmt"
+ "io"
+ "log"
+ "net"
+ "os"
+ "sync"
+ "time"
+)
+
+const debugPipe = false
+
+// Pipe implements an in-memory FIFO with timeouts.
+type Pipe struct {
+ name string
+ maxBuf int
+ mu sync.Mutex
+ cnd *sync.Cond
+
+ blocked bool
+ closed bool
+ buf bytes.Buffer
+ readTimeout time.Time
+ writeTimeout time.Time
+ cancelReadTimer func()
+ cancelWriteTimer func()
+}
+
+// NewPipe creates a Pipe with a buffer size fixed at maxBuf.
+func NewPipe(name string, maxBuf int) *Pipe {
+ p := &Pipe{
+ name: name,
+ maxBuf: maxBuf,
+ }
+ p.cnd = sync.NewCond(&p.mu)
+ return p
+}
+
+// readOrBlock attempts to read from the buffer, if the buffer is empty and
+// the connection hasn't been closed it will block until there is a change.
+func (p *Pipe) readOrBlock(b []byte) (int, error) {
+ p.mu.Lock()
+ defer p.mu.Unlock()
+ if !p.readTimeout.IsZero() && !time.Now().Before(p.readTimeout) {
+ return 0, os.ErrDeadlineExceeded
+ }
+ if p.blocked {
+ p.cnd.Wait()
+ return 0, nil
+ }
+
+ n, err := p.buf.Read(b)
+ // err will either be nil or io.EOF.
+ if err == io.EOF {
+ if p.closed {
+ return n, err
+ }
+ // Wait for something to change.
+ p.cnd.Wait()
+ }
+ return n, nil
+}
+
+// Read implements io.Reader.
+// Once the buffer is drained (i.e. after Close), subsequent calls will
+// return io.EOF.
+func (p *Pipe) Read(b []byte) (n int, err error) {
+ if debugPipe {
+ orig := b
+ defer func() {
+ log.Printf("Pipe(%q).Read(%q) n=%d, err=%v", p.name, string(orig[:n]), n, err)
+ }()
+ }
+ for n == 0 {
+ n2, err := p.readOrBlock(b)
+ if err != nil {
+ return n2, err
+ }
+ n += n2
+ }
+ p.cnd.Signal()
+ return n, nil
+}
+
+// writeOrBlock attempts to write to the buffer, if the buffer is full it will
+// block until there is a change.
+func (p *Pipe) writeOrBlock(b []byte) (int, error) {
+ p.mu.Lock()
+ defer p.mu.Unlock()
+ if p.closed {
+ return 0, net.ErrClosed
+ }
+ if !p.writeTimeout.IsZero() && !time.Now().Before(p.writeTimeout) {
+ return 0, os.ErrDeadlineExceeded
+ }
+ if p.blocked {
+ p.cnd.Wait()
+ return 0, nil
+ }
+
+ // Optimistically we want to write the entire slice.
+ n := len(b)
+ if limit := p.maxBuf - p.buf.Len(); limit < n {
+ // However, we don't have enough capacity to write everything.
+ n = limit
+ }
+ if n == 0 {
+ // Wait for something to change.
+ p.cnd.Wait()
+ return 0, nil
+ }
+
+ p.buf.Write(b[:n])
+ p.cnd.Signal()
+ return n, nil
+}
+
+// Write implements io.Writer.
+func (p *Pipe) Write(b []byte) (n int, err error) {
+ if debugPipe {
+ orig := b
+ defer func() {
+ log.Printf("Pipe(%q).Write(%q) n=%d, err=%v", p.name, string(orig), n, err)
+ }()
+ }
+ for len(b) > 0 {
+ n2, err := p.writeOrBlock(b)
+ if err != nil {
+ return n + n2, err
+ }
+ n += n2
+ b = b[n2:]
+ }
+ return n, nil
+}
+
+// Close closes the pipe.
+func (p *Pipe) Close() error {
+ p.mu.Lock()
+ defer p.mu.Unlock()
+ p.closed = true
+ p.blocked = false
+ if p.cancelWriteTimer != nil {
+ p.cancelWriteTimer()
+ p.cancelWriteTimer = nil
+ }
+ if p.cancelReadTimer != nil {
+ p.cancelReadTimer()
+ p.cancelReadTimer = nil
+ }
+ p.cnd.Broadcast()
+
+ return nil
+}
+
+func (p *Pipe) deadlineTimer(t time.Time) func() {
+ if t.IsZero() {
+ return nil
+ }
+ if t.Before(time.Now()) {
+ p.cnd.Broadcast()
+ return nil
+ }
+ ctx, cancel := context.WithDeadline(context.Background(), t)
+ go func() {
+ <-ctx.Done()
+ if ctx.Err() == context.DeadlineExceeded {
+ p.cnd.Broadcast()
+ }
+ }()
+ return cancel
+}
+
+// SetReadDeadline sets the deadline for future Read calls.
+func (p *Pipe) SetReadDeadline(t time.Time) error {
+ p.mu.Lock()
+ defer p.mu.Unlock()
+ p.readTimeout = t
+ // If we already have a deadline, cancel it and create a new one.
+ if p.cancelReadTimer != nil {
+ p.cancelReadTimer()
+ p.cancelReadTimer = nil
+ }
+ p.cancelReadTimer = p.deadlineTimer(t)
+ return nil
+}
+
+// SetWriteDeadline sets the deadline for future Write calls.
+func (p *Pipe) SetWriteDeadline(t time.Time) error {
+ p.mu.Lock()
+ defer p.mu.Unlock()
+ p.writeTimeout = t
+ // If we already have a deadline, cancel it and create a new one.
+ if p.cancelWriteTimer != nil {
+ p.cancelWriteTimer()
+ p.cancelWriteTimer = nil
+ }
+ p.cancelWriteTimer = p.deadlineTimer(t)
+ return nil
+}
+
+// Block will cause all calls to Read and Write to block until they either
+// timeout, are unblocked or the pipe is closed.
+func (p *Pipe) Block() error {
+ p.mu.Lock()
+ defer p.mu.Unlock()
+ closed := p.closed
+ blocked := p.blocked
+ p.blocked = true
+
+ if closed {
+ return fmt.Errorf("memnet.Pipe(%q).Block: closed", p.name)
+ }
+ if blocked {
+ return fmt.Errorf("memnet.Pipe(%q).Block: already blocked", p.name)
+ }
+ p.cnd.Broadcast()
+ return nil
+}
+
+// Unblock will cause all blocked Read/Write calls to continue execution.
+func (p *Pipe) Unblock() error {
+ p.mu.Lock()
+ defer p.mu.Unlock()
+ closed := p.closed
+ blocked := p.blocked
+ p.blocked = false
+
+ if closed {
+ return fmt.Errorf("memnet.Pipe(%q).Block: closed", p.name)
+ }
+ if !blocked {
+ return fmt.Errorf("memnet.Pipe(%q).Block: already unblocked", p.name)
+ }
+ p.cnd.Broadcast()
+ return nil
+}
diff --git a/net/memnet/pipe_test.go b/net/memnet/pipe_test.go index a86d65388..b3775cf7f 100644 --- a/net/memnet/pipe_test.go +++ b/net/memnet/pipe_test.go @@ -1,117 +1,117 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -package memnet - -import ( - "errors" - "fmt" - "os" - "testing" - "time" -) - -func TestPipeHello(t *testing.T) { - p := NewPipe("p1", 1<<16) - msg := "Hello, World!" - if n, err := p.Write([]byte(msg)); err != nil { - t.Fatal(err) - } else if n != len(msg) { - t.Errorf("p.Write(%q) n=%d, want %d", msg, n, len(msg)) - } - b := make([]byte, len(msg)) - if n, err := p.Read(b); err != nil { - t.Fatal(err) - } else if n != len(b) { - t.Errorf("p.Read(%q) n=%d, want %d", string(b[:n]), n, len(b)) - } - if got := string(b); got != msg { - t.Errorf("p.Read: %q, want %q", got, msg) - } -} - -func TestPipeTimeout(t *testing.T) { - t.Run("write", func(t *testing.T) { - p := NewPipe("p1", 1<<16) - p.SetWriteDeadline(time.Now().Add(-1 * time.Second)) - n, err := p.Write([]byte{'h'}) - if !errors.Is(err, os.ErrDeadlineExceeded) { - t.Errorf("missing write timeout got err: %v", err) - } - if n != 0 { - t.Errorf("n=%d on timeout", n) - } - }) - t.Run("read", func(t *testing.T) { - p := NewPipe("p1", 1<<16) - p.Write([]byte{'h'}) - - p.SetReadDeadline(time.Now().Add(-1 * time.Second)) - b := make([]byte, 1) - n, err := p.Read(b) - if !errors.Is(err, os.ErrDeadlineExceeded) { - t.Errorf("missing read timeout got err: %v", err) - } - if n != 0 { - t.Errorf("n=%d on timeout", n) - } - }) - t.Run("block-write", func(t *testing.T) { - p := NewPipe("p1", 1<<16) - p.SetWriteDeadline(time.Now().Add(10 * time.Millisecond)) - if err := p.Block(); err != nil { - t.Fatal(err) - } - if _, err := p.Write([]byte{'h'}); !errors.Is(err, os.ErrDeadlineExceeded) { - t.Fatalf("want write timeout got: %v", err) - } - }) - t.Run("block-read", func(t *testing.T) { - p := NewPipe("p1", 1<<16) - p.Write([]byte{'h', 'i'}) - p.SetReadDeadline(time.Now().Add(10 * time.Millisecond)) - b := make([]byte, 1) - if err := p.Block(); err != nil { - t.Fatal(err) - } - if _, err := p.Read(b); !errors.Is(err, os.ErrDeadlineExceeded) { - t.Fatalf("want read timeout got: %v", err) - } - }) -} - -func TestLimit(t *testing.T) { - p := NewPipe("p1", 1) - errCh := make(chan error) - go func() { - n, err := p.Write([]byte{'a', 'b', 'c'}) - if err != nil { - errCh <- err - } else if n != 3 { - errCh <- fmt.Errorf("p.Write n=%d, want 3", n) - } else { - errCh <- nil - } - }() - b := make([]byte, 3) - - if n, err := p.Read(b); err != nil { - t.Fatal(err) - } else if n != 1 { - t.Errorf("Read(%q): n=%d want 1", string(b), n) - } - if n, err := p.Read(b); err != nil { - t.Fatal(err) - } else if n != 1 { - t.Errorf("Read(%q): n=%d want 1", string(b), n) - } - if n, err := p.Read(b); err != nil { - t.Fatal(err) - } else if n != 1 { - t.Errorf("Read(%q): n=%d want 1", string(b), n) - } - - if err := <-errCh; err != nil { - t.Error(err) - } -} +// Copyright (c) Tailscale Inc & AUTHORS
+// SPDX-License-Identifier: BSD-3-Clause
+
+package memnet
+
+import (
+ "errors"
+ "fmt"
+ "os"
+ "testing"
+ "time"
+)
+
+func TestPipeHello(t *testing.T) {
+ p := NewPipe("p1", 1<<16)
+ msg := "Hello, World!"
+ if n, err := p.Write([]byte(msg)); err != nil {
+ t.Fatal(err)
+ } else if n != len(msg) {
+ t.Errorf("p.Write(%q) n=%d, want %d", msg, n, len(msg))
+ }
+ b := make([]byte, len(msg))
+ if n, err := p.Read(b); err != nil {
+ t.Fatal(err)
+ } else if n != len(b) {
+ t.Errorf("p.Read(%q) n=%d, want %d", string(b[:n]), n, len(b))
+ }
+ if got := string(b); got != msg {
+ t.Errorf("p.Read: %q, want %q", got, msg)
+ }
+}
+
+func TestPipeTimeout(t *testing.T) {
+ t.Run("write", func(t *testing.T) {
+ p := NewPipe("p1", 1<<16)
+ p.SetWriteDeadline(time.Now().Add(-1 * time.Second))
+ n, err := p.Write([]byte{'h'})
+ if !errors.Is(err, os.ErrDeadlineExceeded) {
+ t.Errorf("missing write timeout got err: %v", err)
+ }
+ if n != 0 {
+ t.Errorf("n=%d on timeout", n)
+ }
+ })
+ t.Run("read", func(t *testing.T) {
+ p := NewPipe("p1", 1<<16)
+ p.Write([]byte{'h'})
+
+ p.SetReadDeadline(time.Now().Add(-1 * time.Second))
+ b := make([]byte, 1)
+ n, err := p.Read(b)
+ if !errors.Is(err, os.ErrDeadlineExceeded) {
+ t.Errorf("missing read timeout got err: %v", err)
+ }
+ if n != 0 {
+ t.Errorf("n=%d on timeout", n)
+ }
+ })
+ t.Run("block-write", func(t *testing.T) {
+ p := NewPipe("p1", 1<<16)
+ p.SetWriteDeadline(time.Now().Add(10 * time.Millisecond))
+ if err := p.Block(); err != nil {
+ t.Fatal(err)
+ }
+ if _, err := p.Write([]byte{'h'}); !errors.Is(err, os.ErrDeadlineExceeded) {
+ t.Fatalf("want write timeout got: %v", err)
+ }
+ })
+ t.Run("block-read", func(t *testing.T) {
+ p := NewPipe("p1", 1<<16)
+ p.Write([]byte{'h', 'i'})
+ p.SetReadDeadline(time.Now().Add(10 * time.Millisecond))
+ b := make([]byte, 1)
+ if err := p.Block(); err != nil {
+ t.Fatal(err)
+ }
+ if _, err := p.Read(b); !errors.Is(err, os.ErrDeadlineExceeded) {
+ t.Fatalf("want read timeout got: %v", err)
+ }
+ })
+}
+
+func TestLimit(t *testing.T) {
+ p := NewPipe("p1", 1)
+ errCh := make(chan error)
+ go func() {
+ n, err := p.Write([]byte{'a', 'b', 'c'})
+ if err != nil {
+ errCh <- err
+ } else if n != 3 {
+ errCh <- fmt.Errorf("p.Write n=%d, want 3", n)
+ } else {
+ errCh <- nil
+ }
+ }()
+ b := make([]byte, 3)
+
+ if n, err := p.Read(b); err != nil {
+ t.Fatal(err)
+ } else if n != 1 {
+ t.Errorf("Read(%q): n=%d want 1", string(b), n)
+ }
+ if n, err := p.Read(b); err != nil {
+ t.Fatal(err)
+ } else if n != 1 {
+ t.Errorf("Read(%q): n=%d want 1", string(b), n)
+ }
+ if n, err := p.Read(b); err != nil {
+ t.Fatal(err)
+ } else if n != 1 {
+ t.Errorf("Read(%q): n=%d want 1", string(b), n)
+ }
+
+ if err := <-errCh; err != nil {
+ t.Error(err)
+ }
+}
diff --git a/net/netaddr/netaddr.go b/net/netaddr/netaddr.go index 1ab6c053a..6f85a52b7 100644 --- a/net/netaddr/netaddr.go +++ b/net/netaddr/netaddr.go @@ -1,49 +1,49 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -// Package netaddr is a transitional package while we finish migrating from inet.af/netaddr -// to Go 1.18's net/netip. -// -// TODO(bradfitz): delete this package eventually. Tracking bug is -// https://github.com/tailscale/tailscale/issues/5162 -package netaddr - -import ( - "net" - "net/netip" -) - -// IPv4 returns the IP of the IPv4 address a.b.c.d. -func IPv4(a, b, c, d uint8) netip.Addr { - return netip.AddrFrom4([4]byte{a, b, c, d}) -} - -// Unmap returns the provided AddrPort with its Addr IP component Unmap'ed. -// -// See https://github.com/golang/go/issues/53607#issuecomment-1203466984 -func Unmap(ap netip.AddrPort) netip.AddrPort { - return netip.AddrPortFrom(ap.Addr().Unmap(), ap.Port()) -} - -// FromStdIPNet returns an IPPrefix from the standard library's IPNet type. -// If std is invalid, ok is false. -func FromStdIPNet(std *net.IPNet) (prefix netip.Prefix, ok bool) { - ip, ok := netip.AddrFromSlice(std.IP) - if !ok { - return netip.Prefix{}, false - } - ip = ip.Unmap() - - if l := len(std.Mask); l != net.IPv4len && l != net.IPv6len { - // Invalid mask. - return netip.Prefix{}, false - } - - ones, bits := std.Mask.Size() - if ones == 0 && bits == 0 { - // IPPrefix does not support non-contiguous masks. - return netip.Prefix{}, false - } - - return netip.PrefixFrom(ip, ones), true -} +// Copyright (c) Tailscale Inc & AUTHORS
+// SPDX-License-Identifier: BSD-3-Clause
+
+// Package netaddr is a transitional package while we finish migrating from inet.af/netaddr
+// to Go 1.18's net/netip.
+//
+// TODO(bradfitz): delete this package eventually. Tracking bug is
+// https://github.com/tailscale/tailscale/issues/5162
+package netaddr
+
+import (
+ "net"
+ "net/netip"
+)
+
+// IPv4 returns the IP of the IPv4 address a.b.c.d.
+func IPv4(a, b, c, d uint8) netip.Addr {
+ return netip.AddrFrom4([4]byte{a, b, c, d})
+}
+
+// Unmap returns the provided AddrPort with its Addr IP component Unmap'ed.
+//
+// See https://github.com/golang/go/issues/53607#issuecomment-1203466984
+func Unmap(ap netip.AddrPort) netip.AddrPort {
+ return netip.AddrPortFrom(ap.Addr().Unmap(), ap.Port())
+}
+
+// FromStdIPNet returns an IPPrefix from the standard library's IPNet type.
+// If std is invalid, ok is false.
+func FromStdIPNet(std *net.IPNet) (prefix netip.Prefix, ok bool) {
+ ip, ok := netip.AddrFromSlice(std.IP)
+ if !ok {
+ return netip.Prefix{}, false
+ }
+ ip = ip.Unmap()
+
+ if l := len(std.Mask); l != net.IPv4len && l != net.IPv6len {
+ // Invalid mask.
+ return netip.Prefix{}, false
+ }
+
+ ones, bits := std.Mask.Size()
+ if ones == 0 && bits == 0 {
+ // IPPrefix does not support non-contiguous masks.
+ return netip.Prefix{}, false
+ }
+
+ return netip.PrefixFrom(ip, ones), true
+}
diff --git a/net/neterror/neterror.go b/net/neterror/neterror.go index e2387440d..f570b8930 100644 --- a/net/neterror/neterror.go +++ b/net/neterror/neterror.go @@ -1,82 +1,82 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -// Package neterror classifies network errors. -package neterror - -import ( - "errors" - "fmt" - "runtime" - "syscall" -) - -var errEPERM error = syscall.EPERM // box it into interface just once - -// TreatAsLostUDP reports whether err is an error from a UDP send -// operation that should be treated as a UDP packet that just got -// lost. -// -// Notably, on Linux this reports true for EPERM errors (from outbound -// firewall blocks) which aren't really send errors; they're just -// sends that are never going to make it because the local OS blocked -// it. -func TreatAsLostUDP(err error) bool { - if err == nil { - return false - } - switch runtime.GOOS { - case "linux": - // Linux, while not documented in the man page, - // returns EPERM when there's an OUTPUT rule with -j - // DROP or -j REJECT. We use this very specific - // Linux+EPERM check rather than something super broad - // like net.Error.Temporary which could be anything. - // - // For now we only do this on Linux, as such outgoing - // firewall violations mapping to syscall errors - // hasn't yet been observed on other OSes. - return errors.Is(err, errEPERM) - } - return false -} - -var packetWasTruncated func(error) bool // non-nil on Windows at least - -// PacketWasTruncated reports whether err indicates truncation but the RecvFrom -// that generated err was otherwise successful. On Windows, Go's UDP RecvFrom -// calls WSARecvFrom which returns the WSAEMSGSIZE error code when the received -// datagram is larger than the provided buffer. When that happens, both a valid -// size and an error are returned (as per the partial fix for golang/go#14074). -// If the WSAEMSGSIZE error is returned, then we ignore the error to get -// semantics similar to the POSIX operating systems. One caveat is that it -// appears that the source address is not returned when WSAEMSGSIZE occurs, but -// we do not currently look at the source address. -func PacketWasTruncated(err error) bool { - if packetWasTruncated == nil { - return false - } - return packetWasTruncated(err) -} - -var shouldDisableUDPGSO func(error) bool // non-nil on Linux - -func ShouldDisableUDPGSO(err error) bool { - if shouldDisableUDPGSO == nil { - return false - } - return shouldDisableUDPGSO(err) -} - -type ErrUDPGSODisabled struct { - OnLaddr string - RetryErr error -} - -func (e ErrUDPGSODisabled) Error() string { - return fmt.Sprintf("disabled UDP GSO on %s, NIC(s) may not support checksum offload", e.OnLaddr) -} - -func (e ErrUDPGSODisabled) Unwrap() error { - return e.RetryErr -} +// Copyright (c) Tailscale Inc & AUTHORS
+// SPDX-License-Identifier: BSD-3-Clause
+
+// Package neterror classifies network errors.
+package neterror
+
+import (
+ "errors"
+ "fmt"
+ "runtime"
+ "syscall"
+)
+
+var errEPERM error = syscall.EPERM // box it into interface just once
+
+// TreatAsLostUDP reports whether err is an error from a UDP send
+// operation that should be treated as a UDP packet that just got
+// lost.
+//
+// Notably, on Linux this reports true for EPERM errors (from outbound
+// firewall blocks) which aren't really send errors; they're just
+// sends that are never going to make it because the local OS blocked
+// it.
+func TreatAsLostUDP(err error) bool {
+ if err == nil {
+ return false
+ }
+ switch runtime.GOOS {
+ case "linux":
+ // Linux, while not documented in the man page,
+ // returns EPERM when there's an OUTPUT rule with -j
+ // DROP or -j REJECT. We use this very specific
+ // Linux+EPERM check rather than something super broad
+ // like net.Error.Temporary which could be anything.
+ //
+ // For now we only do this on Linux, as such outgoing
+ // firewall violations mapping to syscall errors
+ // hasn't yet been observed on other OSes.
+ return errors.Is(err, errEPERM)
+ }
+ return false
+}
+
+var packetWasTruncated func(error) bool // non-nil on Windows at least
+
+// PacketWasTruncated reports whether err indicates truncation but the RecvFrom
+// that generated err was otherwise successful. On Windows, Go's UDP RecvFrom
+// calls WSARecvFrom which returns the WSAEMSGSIZE error code when the received
+// datagram is larger than the provided buffer. When that happens, both a valid
+// size and an error are returned (as per the partial fix for golang/go#14074).
+// If the WSAEMSGSIZE error is returned, then we ignore the error to get
+// semantics similar to the POSIX operating systems. One caveat is that it
+// appears that the source address is not returned when WSAEMSGSIZE occurs, but
+// we do not currently look at the source address.
+func PacketWasTruncated(err error) bool {
+ if packetWasTruncated == nil {
+ return false
+ }
+ return packetWasTruncated(err)
+}
+
+var shouldDisableUDPGSO func(error) bool // non-nil on Linux
+
+func ShouldDisableUDPGSO(err error) bool {
+ if shouldDisableUDPGSO == nil {
+ return false
+ }
+ return shouldDisableUDPGSO(err)
+}
+
+type ErrUDPGSODisabled struct {
+ OnLaddr string
+ RetryErr error
+}
+
+func (e ErrUDPGSODisabled) Error() string {
+ return fmt.Sprintf("disabled UDP GSO on %s, NIC(s) may not support checksum offload", e.OnLaddr)
+}
+
+func (e ErrUDPGSODisabled) Unwrap() error {
+ return e.RetryErr
+}
diff --git a/net/neterror/neterror_linux.go b/net/neterror/neterror_linux.go index 857367fe8..3f402dd30 100644 --- a/net/neterror/neterror_linux.go +++ b/net/neterror/neterror_linux.go @@ -1,26 +1,26 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -package neterror - -import ( - "errors" - "os" - - "golang.org/x/sys/unix" -) - -func init() { - shouldDisableUDPGSO = func(err error) bool { - var serr *os.SyscallError - if errors.As(err, &serr) { - // EIO is returned by udp_send_skb() if the device driver does not - // have tx checksumming enabled, which is a hard requirement of - // UDP_SEGMENT. See: - // https://git.kernel.org/pub/scm/docs/man-pages/man-pages.git/tree/man7/udp.7?id=806eabd74910447f21005160e90957bde4db0183#n228 - // https://git.kernel.org/pub/scm/linux/kernel/git/torvalds/linux.git/tree/net/ipv4/udp.c?h=v6.2&id=c9c3395d5e3dcc6daee66c6908354d47bf98cb0c#n942 - return serr.Err == unix.EIO - } - return false - } -} +// Copyright (c) Tailscale Inc & AUTHORS
+// SPDX-License-Identifier: BSD-3-Clause
+
+package neterror
+
+import (
+ "errors"
+ "os"
+
+ "golang.org/x/sys/unix"
+)
+
+func init() {
+ shouldDisableUDPGSO = func(err error) bool {
+ var serr *os.SyscallError
+ if errors.As(err, &serr) {
+ // EIO is returned by udp_send_skb() if the device driver does not
+ // have tx checksumming enabled, which is a hard requirement of
+ // UDP_SEGMENT. See:
+ // https://git.kernel.org/pub/scm/docs/man-pages/man-pages.git/tree/man7/udp.7?id=806eabd74910447f21005160e90957bde4db0183#n228
+ // https://git.kernel.org/pub/scm/linux/kernel/git/torvalds/linux.git/tree/net/ipv4/udp.c?h=v6.2&id=c9c3395d5e3dcc6daee66c6908354d47bf98cb0c#n942
+ return serr.Err == unix.EIO
+ }
+ return false
+ }
+}
diff --git a/net/neterror/neterror_linux_test.go b/net/neterror/neterror_linux_test.go index 5b9906074..1d600d6b6 100644 --- a/net/neterror/neterror_linux_test.go +++ b/net/neterror/neterror_linux_test.go @@ -1,54 +1,54 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -package neterror - -import ( - "errors" - "net" - "os" - "syscall" - "testing" -) - -func TestTreatAsLostUDP(t *testing.T) { - tests := []struct { - name string - err error - want bool - }{ - {"nil", nil, false}, - {"non-nil", errors.New("foo"), false}, - {"eperm", syscall.EPERM, true}, - { - name: "operror", - err: &net.OpError{ - Op: "write", - Err: &os.SyscallError{ - Syscall: "sendto", - Err: syscall.EPERM, - }, - }, - want: true, - }, - { - name: "host_unreach", - err: &net.OpError{ - Op: "write", - Err: &os.SyscallError{ - Syscall: "sendto", - Err: syscall.EHOSTUNREACH, - }, - }, - want: false, - }, - } - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - if got := TreatAsLostUDP(tt.err); got != tt.want { - t.Errorf("got = %v; want %v", got, tt.want) - } - }) - } - -} +// Copyright (c) Tailscale Inc & AUTHORS
+// SPDX-License-Identifier: BSD-3-Clause
+
+package neterror
+
+import (
+ "errors"
+ "net"
+ "os"
+ "syscall"
+ "testing"
+)
+
+func TestTreatAsLostUDP(t *testing.T) {
+ tests := []struct {
+ name string
+ err error
+ want bool
+ }{
+ {"nil", nil, false},
+ {"non-nil", errors.New("foo"), false},
+ {"eperm", syscall.EPERM, true},
+ {
+ name: "operror",
+ err: &net.OpError{
+ Op: "write",
+ Err: &os.SyscallError{
+ Syscall: "sendto",
+ Err: syscall.EPERM,
+ },
+ },
+ want: true,
+ },
+ {
+ name: "host_unreach",
+ err: &net.OpError{
+ Op: "write",
+ Err: &os.SyscallError{
+ Syscall: "sendto",
+ Err: syscall.EHOSTUNREACH,
+ },
+ },
+ want: false,
+ },
+ }
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ if got := TreatAsLostUDP(tt.err); got != tt.want {
+ t.Errorf("got = %v; want %v", got, tt.want)
+ }
+ })
+ }
+
+}
diff --git a/net/neterror/neterror_windows.go b/net/neterror/neterror_windows.go index bf112f5ed..c293ec4a9 100644 --- a/net/neterror/neterror_windows.go +++ b/net/neterror/neterror_windows.go @@ -1,16 +1,16 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -package neterror - -import ( - "errors" - - "golang.org/x/sys/windows" -) - -func init() { - packetWasTruncated = func(err error) bool { - return errors.Is(err, windows.WSAEMSGSIZE) - } -} +// Copyright (c) Tailscale Inc & AUTHORS
+// SPDX-License-Identifier: BSD-3-Clause
+
+package neterror
+
+import (
+ "errors"
+
+ "golang.org/x/sys/windows"
+)
+
+func init() {
+ packetWasTruncated = func(err error) bool {
+ return errors.Is(err, windows.WSAEMSGSIZE)
+ }
+}
diff --git a/net/netkernelconf/netkernelconf.go b/net/netkernelconf/netkernelconf.go index 3ea502b37..23ec9c5b6 100644 --- a/net/netkernelconf/netkernelconf.go +++ b/net/netkernelconf/netkernelconf.go @@ -1,5 +1,5 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -// Package netkernelconf contains code for checking kernel netdev config. -package netkernelconf +// Copyright (c) Tailscale Inc & AUTHORS
+// SPDX-License-Identifier: BSD-3-Clause
+
+// Package netkernelconf contains code for checking kernel netdev config.
+package netkernelconf
diff --git a/net/netknob/netknob.go b/net/netknob/netknob.go index 53171f424..0b271fc95 100644 --- a/net/netknob/netknob.go +++ b/net/netknob/netknob.go @@ -1,29 +1,29 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -// Package netknob has Tailscale network knobs. -package netknob - -import ( - "runtime" - "time" -) - -// PlatformTCPKeepAlive returns the default net.Dialer.KeepAlive -// value for the current runtime.GOOS. -func PlatformTCPKeepAlive() time.Duration { - switch runtime.GOOS { - case "ios", "android": - // Disable TCP keep-alives on mobile platforms. - // See https://github.com/golang/go/issues/48622. - // - // TODO(bradfitz): in 1.17.x, try disabling TCP - // keep-alives on for all platforms. - return -1 - } - - // Otherwise, default to 30 seconds, which is mostly what we - // used to do. In some places we used the zero value, which Go - // defaults to 15 seconds. But 30 seconds is fine. - return 30 * time.Second -} +// Copyright (c) Tailscale Inc & AUTHORS
+// SPDX-License-Identifier: BSD-3-Clause
+
+// Package netknob has Tailscale network knobs.
+package netknob
+
+import (
+ "runtime"
+ "time"
+)
+
+// PlatformTCPKeepAlive returns the default net.Dialer.KeepAlive
+// value for the current runtime.GOOS.
+func PlatformTCPKeepAlive() time.Duration {
+ switch runtime.GOOS {
+ case "ios", "android":
+ // Disable TCP keep-alives on mobile platforms.
+ // See https://github.com/golang/go/issues/48622.
+ //
+ // TODO(bradfitz): in 1.17.x, try disabling TCP
+ // keep-alives on for all platforms.
+ return -1
+ }
+
+ // Otherwise, default to 30 seconds, which is mostly what we
+ // used to do. In some places we used the zero value, which Go
+ // defaults to 15 seconds. But 30 seconds is fine.
+ return 30 * time.Second
+}
diff --git a/net/netmon/netmon_darwin_test.go b/net/netmon/netmon_darwin_test.go index 84c67cf6f..77a212683 100644 --- a/net/netmon/netmon_darwin_test.go +++ b/net/netmon/netmon_darwin_test.go @@ -1,27 +1,27 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -package netmon - -import ( - "encoding/hex" - "strings" - "testing" - - "golang.org/x/net/route" -) - -func TestIssue1416RIB(t *testing.T) { - const ribHex = `32 00 05 10 30 00 00 00 00 00 00 00 04 00 00 00 14 12 04 00 06 03 06 00 65 6e 30 ac 87 a3 19 7f 82 00 00 00 0e 12 00 00 00 00 06 00 91 e0 f0 01 00 00` - rtmMsg, err := hex.DecodeString(strings.ReplaceAll(ribHex, " ", "")) - if err != nil { - t.Fatal(err) - } - msgs, err := route.ParseRIB(route.RIBTypeRoute, rtmMsg) - if err != nil { - t.Logf("ParseRIB: %v", err) - t.Skip("skipping on known failure; see https://github.com/tailscale/tailscale/issues/1416") - t.Fatal(err) - } - t.Logf("Got: %#v", msgs) -} +// Copyright (c) Tailscale Inc & AUTHORS
+// SPDX-License-Identifier: BSD-3-Clause
+
+package netmon
+
+import (
+ "encoding/hex"
+ "strings"
+ "testing"
+
+ "golang.org/x/net/route"
+)
+
+func TestIssue1416RIB(t *testing.T) {
+ const ribHex = `32 00 05 10 30 00 00 00 00 00 00 00 04 00 00 00 14 12 04 00 06 03 06 00 65 6e 30 ac 87 a3 19 7f 82 00 00 00 0e 12 00 00 00 00 06 00 91 e0 f0 01 00 00`
+ rtmMsg, err := hex.DecodeString(strings.ReplaceAll(ribHex, " ", ""))
+ if err != nil {
+ t.Fatal(err)
+ }
+ msgs, err := route.ParseRIB(route.RIBTypeRoute, rtmMsg)
+ if err != nil {
+ t.Logf("ParseRIB: %v", err)
+ t.Skip("skipping on known failure; see https://github.com/tailscale/tailscale/issues/1416")
+ t.Fatal(err)
+ }
+ t.Logf("Got: %#v", msgs)
+}
diff --git a/net/netmon/netmon_freebsd.go b/net/netmon/netmon_freebsd.go index 30480a1d3..724f964c9 100644 --- a/net/netmon/netmon_freebsd.go +++ b/net/netmon/netmon_freebsd.go @@ -1,56 +1,56 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -package netmon - -import ( - "bufio" - "fmt" - "net" - "strings" - - "tailscale.com/types/logger" -) - -// unspecifiedMessage is a minimal message implementation that should not -// be ignored. In general, OS-specific implementations should use better -// types and avoid this if they can. -type unspecifiedMessage struct{} - -func (unspecifiedMessage) ignore() bool { return false } - -// devdConn implements osMon using devd(8). -type devdConn struct { - conn net.Conn -} - -func newOSMon(logf logger.Logf, m *Monitor) (osMon, error) { - conn, err := net.Dial("unixpacket", "/var/run/devd.seqpacket.pipe") - if err != nil { - logf("devd dial error: %v, falling back to polling method", err) - return newPollingMon(logf, m) - } - return &devdConn{conn}, nil -} - -func (c *devdConn) IsInterestingInterface(iface string) bool { return true } - -func (c *devdConn) Close() error { - return c.conn.Close() -} - -func (c *devdConn) Receive() (message, error) { - for { - msg, err := bufio.NewReader(c.conn).ReadString('\n') - if err != nil { - return nil, fmt.Errorf("reading devd socket: %v", err) - } - // Only return messages related to the network subsystem. - if !strings.Contains(msg, "system=IFNET") { - continue - } - // TODO: this is where the devd-specific message would - // get converted into a "standard" event message and returned. - return unspecifiedMessage{}, nil - } -} +// Copyright (c) Tailscale Inc & AUTHORS
+// SPDX-License-Identifier: BSD-3-Clause
+
+package netmon
+
+import (
+ "bufio"
+ "fmt"
+ "net"
+ "strings"
+
+ "tailscale.com/types/logger"
+)
+
+// unspecifiedMessage is a minimal message implementation that should not
+// be ignored. In general, OS-specific implementations should use better
+// types and avoid this if they can.
+type unspecifiedMessage struct{}
+
+func (unspecifiedMessage) ignore() bool { return false }
+
+// devdConn implements osMon using devd(8).
+type devdConn struct {
+ conn net.Conn
+}
+
+func newOSMon(logf logger.Logf, m *Monitor) (osMon, error) {
+ conn, err := net.Dial("unixpacket", "/var/run/devd.seqpacket.pipe")
+ if err != nil {
+ logf("devd dial error: %v, falling back to polling method", err)
+ return newPollingMon(logf, m)
+ }
+ return &devdConn{conn}, nil
+}
+
+func (c *devdConn) IsInterestingInterface(iface string) bool { return true }
+
+func (c *devdConn) Close() error {
+ return c.conn.Close()
+}
+
+func (c *devdConn) Receive() (message, error) {
+ for {
+ msg, err := bufio.NewReader(c.conn).ReadString('\n')
+ if err != nil {
+ return nil, fmt.Errorf("reading devd socket: %v", err)
+ }
+ // Only return messages related to the network subsystem.
+ if !strings.Contains(msg, "system=IFNET") {
+ continue
+ }
+ // TODO: this is where the devd-specific message would
+ // get converted into a "standard" event message and returned.
+ return unspecifiedMessage{}, nil
+ }
+}
diff --git a/net/netmon/netmon_linux.go b/net/netmon/netmon_linux.go index dd23dd342..888afa92d 100644 --- a/net/netmon/netmon_linux.go +++ b/net/netmon/netmon_linux.go @@ -1,290 +1,290 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -//go:build !android - -package netmon - -import ( - "net" - "net/netip" - "time" - - "github.com/jsimonetti/rtnetlink" - "github.com/mdlayher/netlink" - "golang.org/x/sys/unix" - "tailscale.com/envknob" - "tailscale.com/net/tsaddr" - "tailscale.com/types/logger" -) - -var debugNetlinkMessages = envknob.RegisterBool("TS_DEBUG_NETLINK") - -// unspecifiedMessage is a minimal message implementation that should not -// be ignored. In general, OS-specific implementations should use better -// types and avoid this if they can. -type unspecifiedMessage struct{} - -func (unspecifiedMessage) ignore() bool { return false } - -// nlConn wraps a *netlink.Conn and returns a monitor.Message -// instead of a netlink.Message. Currently, messages are discarded, -// but down the line, when messages trigger different logic depending -// on the type of event, this provides the capability of handling -// each architecture-specific message in a generic fashion. -type nlConn struct { - logf logger.Logf - conn *netlink.Conn - buffered []netlink.Message - - // addrCache maps interface indices to a set of addresses, and is - // used to suppress duplicate RTM_NEWADDR messages. It is populated - // by RTM_NEWADDR messages and de-populated by RTM_DELADDR. See - // issue #4282. - addrCache map[uint32]map[netip.Addr]bool -} - -func newOSMon(logf logger.Logf, m *Monitor) (osMon, error) { - conn, err := netlink.Dial(unix.NETLINK_ROUTE, &netlink.Config{ - // Routes get us most of the events of interest, but we need - // address as well to cover things like DHCP deciding to give - // us a new address upon renewal - routing wouldn't change, - // but all reachability would. - Groups: unix.RTMGRP_IPV4_IFADDR | unix.RTMGRP_IPV6_IFADDR | - unix.RTMGRP_IPV4_ROUTE | unix.RTMGRP_IPV6_ROUTE | - unix.RTMGRP_IPV4_RULE, // no IPV6_RULE in x/sys/unix - }) - if err != nil { - // Google Cloud Run does not implement NETLINK_ROUTE RTMGRP support - logf("monitor_linux: AF_NETLINK RTMGRP failed, falling back to polling") - return newPollingMon(logf, m) - } - return &nlConn{logf: logf, conn: conn, addrCache: make(map[uint32]map[netip.Addr]bool)}, nil -} - -func (c *nlConn) IsInterestingInterface(iface string) bool { return true } - -func (c *nlConn) Close() error { return c.conn.Close() } - -func (c *nlConn) Receive() (message, error) { - if len(c.buffered) == 0 { - var err error - c.buffered, err = c.conn.Receive() - if err != nil { - return nil, err - } - if len(c.buffered) == 0 { - // Unexpected. Not seen in wild, but sleep defensively. - time.Sleep(time.Second) - return ignoreMessage{}, nil - } - } - msg := c.buffered[0] - c.buffered = c.buffered[1:] - - // See https://github.com/torvalds/linux/blob/master/include/uapi/linux/rtnetlink.h - // And https://man7.org/linux/man-pages/man7/rtnetlink.7.html - switch msg.Header.Type { - case unix.RTM_NEWADDR, unix.RTM_DELADDR: - var rmsg rtnetlink.AddressMessage - if err := rmsg.UnmarshalBinary(msg.Data); err != nil { - c.logf("failed to parse type %v: %v", msg.Header.Type, err) - return unspecifiedMessage{}, nil - } - - nip := netaddrIP(rmsg.Attributes.Address) - - if debugNetlinkMessages() { - typ := "RTM_NEWADDR" - if msg.Header.Type == unix.RTM_DELADDR { - typ = "RTM_DELADDR" - } - - // label attributes are seemingly only populated for IPv4 addresses in the wild. - label := rmsg.Attributes.Label - if label == "" { - itf, err := net.InterfaceByIndex(int(rmsg.Index)) - if err == nil { - label = itf.Name - } - } - - c.logf("%s: %s(%d) %s / %s", typ, label, rmsg.Index, rmsg.Attributes.Address, rmsg.Attributes.Local) - } - - addrs := c.addrCache[rmsg.Index] - - // Ignore duplicate RTM_NEWADDR messages using c.addrCache to - // detect them. See nlConn.addrcache and issue #4282. - if msg.Header.Type == unix.RTM_NEWADDR { - if addrs == nil { - addrs = make(map[netip.Addr]bool) - c.addrCache[rmsg.Index] = addrs - } - - if addrs[nip] { - if debugNetlinkMessages() { - c.logf("ignored duplicate RTM_NEWADDR for %s", nip) - } - return ignoreMessage{}, nil - } - - addrs[nip] = true - } else { // msg.Header.Type == unix.RTM_DELADDR - if addrs != nil { - delete(addrs, nip) - } - - if len(addrs) == 0 { - delete(c.addrCache, rmsg.Index) - } - } - - nam := &newAddrMessage{ - IfIndex: rmsg.Index, - Addr: nip, - Delete: msg.Header.Type == unix.RTM_DELADDR, - } - if debugNetlinkMessages() { - c.logf("%+v", nam) - } - return nam, nil - case unix.RTM_NEWROUTE, unix.RTM_DELROUTE: - typeStr := "RTM_NEWROUTE" - if msg.Header.Type == unix.RTM_DELROUTE { - typeStr = "RTM_DELROUTE" - } - var rmsg rtnetlink.RouteMessage - if err := rmsg.UnmarshalBinary(msg.Data); err != nil { - c.logf("%s: failed to parse: %v", typeStr, err) - return unspecifiedMessage{}, nil - } - src := netaddrIPPrefix(rmsg.Attributes.Src, rmsg.SrcLength) - dst := netaddrIPPrefix(rmsg.Attributes.Dst, rmsg.DstLength) - gw := netaddrIP(rmsg.Attributes.Gateway) - - if msg.Header.Type == unix.RTM_NEWROUTE && - (rmsg.Attributes.Table == 255 || rmsg.Attributes.Table == 254) && - (dst.Addr().IsMulticast() || dst.Addr().IsLinkLocalUnicast()) { - - if debugNetlinkMessages() { - c.logf("%s ignored", typeStr) - } - - // Normal Linux route changes on new interface coming up; don't log or react. - return ignoreMessage{}, nil - } - - if rmsg.Table == tsTable && dst.IsSingleIP() { - // Don't log. Spammy and normal to see a bunch of these on start-up, - // which we make ourselves. - } else if tsaddr.IsTailscaleIP(dst.Addr()) { - // Verbose only. - c.logf("%s: [v1] src=%v, dst=%v, gw=%v, outif=%v, table=%v", typeStr, - condNetAddrPrefix(src), condNetAddrPrefix(dst), condNetAddrIP(gw), - rmsg.Attributes.OutIface, rmsg.Attributes.Table) - } else { - c.logf("%s: src=%v, dst=%v, gw=%v, outif=%v, table=%v", typeStr, - condNetAddrPrefix(src), condNetAddrPrefix(dst), condNetAddrIP(gw), - rmsg.Attributes.OutIface, rmsg.Attributes.Table) - } - if msg.Header.Type == unix.RTM_DELROUTE { - // Just logging it for now. - // (Debugging https://github.com/tailscale/tailscale/issues/643) - return unspecifiedMessage{}, nil - } - - nrm := &newRouteMessage{ - Table: rmsg.Table, - Src: src, - Dst: dst, - Gateway: gw, - } - if debugNetlinkMessages() { - c.logf("%+v", nrm) - } - return nrm, nil - case unix.RTM_NEWRULE: - // Probably ourselves adding it. - return ignoreMessage{}, nil - case unix.RTM_DELRULE: - // For https://github.com/tailscale/tailscale/issues/1591 where - // systemd-networkd deletes our rules. - var rmsg rtnetlink.RouteMessage - err := rmsg.UnmarshalBinary(msg.Data) - if err != nil { - c.logf("ip rule deleted; failed to parse netlink message: %v", err) - } else { - c.logf("ip rule deleted: %+v", rmsg) - // On `ip -4 rule del pref 5210 table main`, logs: - // monitor: ip rule deleted: {Family:2 DstLength:0 SrcLength:0 Tos:0 Table:254 Protocol:0 Scope:0 Type:1 Flags:0 Attributes:{Dst:<nil> Src:<nil> Gateway:<nil> OutIface:0 Priority:5210 Table:254 Mark:4294967295 Expires:<nil> Metrics:<nil> Multipath:[]}} - } - rdm := ipRuleDeletedMessage{ - table: rmsg.Table, - priority: rmsg.Attributes.Priority, - } - if debugNetlinkMessages() { - c.logf("%+v", rdm) - } - return rdm, nil - case unix.RTM_NEWLINK, unix.RTM_DELLINK: - // This is an unhandled message, but don't print an error. - // See https://github.com/tailscale/tailscale/issues/6806 - return unspecifiedMessage{}, nil - default: - c.logf("unhandled netlink msg type %+v, %q", msg.Header, msg.Data) - return unspecifiedMessage{}, nil - } -} - -func netaddrIP(std net.IP) netip.Addr { - ip, _ := netip.AddrFromSlice(std) - return ip.Unmap() -} - -func netaddrIPPrefix(std net.IP, bits uint8) netip.Prefix { - ip, _ := netip.AddrFromSlice(std) - return netip.PrefixFrom(ip.Unmap(), int(bits)) -} - -func condNetAddrPrefix(ipp netip.Prefix) string { - if !ipp.Addr().IsValid() { - return "" - } - return ipp.String() -} - -func condNetAddrIP(ip netip.Addr) string { - if !ip.IsValid() { - return "" - } - return ip.String() -} - -// newRouteMessage is a message for a new route being added. -type newRouteMessage struct { - Src, Dst netip.Prefix - Gateway netip.Addr - Table uint8 -} - -const tsTable = 52 - -func (m *newRouteMessage) ignore() bool { - return m.Table == tsTable || tsaddr.IsTailscaleIP(m.Dst.Addr()) -} - -// newAddrMessage is a message for a new address being added. -type newAddrMessage struct { - Delete bool - Addr netip.Addr - IfIndex uint32 // interface index -} - -func (m *newAddrMessage) ignore() bool { - return tsaddr.IsTailscaleIP(m.Addr) -} - -type ignoreMessage struct{} - -func (ignoreMessage) ignore() bool { return true } +// Copyright (c) Tailscale Inc & AUTHORS
+// SPDX-License-Identifier: BSD-3-Clause
+
+//go:build !android
+
+package netmon
+
+import (
+ "net"
+ "net/netip"
+ "time"
+
+ "github.com/jsimonetti/rtnetlink"
+ "github.com/mdlayher/netlink"
+ "golang.org/x/sys/unix"
+ "tailscale.com/envknob"
+ "tailscale.com/net/tsaddr"
+ "tailscale.com/types/logger"
+)
+
+var debugNetlinkMessages = envknob.RegisterBool("TS_DEBUG_NETLINK")
+
+// unspecifiedMessage is a minimal message implementation that should not
+// be ignored. In general, OS-specific implementations should use better
+// types and avoid this if they can.
+type unspecifiedMessage struct{}
+
+func (unspecifiedMessage) ignore() bool { return false }
+
+// nlConn wraps a *netlink.Conn and returns a monitor.Message
+// instead of a netlink.Message. Currently, messages are discarded,
+// but down the line, when messages trigger different logic depending
+// on the type of event, this provides the capability of handling
+// each architecture-specific message in a generic fashion.
+type nlConn struct {
+ logf logger.Logf
+ conn *netlink.Conn
+ buffered []netlink.Message
+
+ // addrCache maps interface indices to a set of addresses, and is
+ // used to suppress duplicate RTM_NEWADDR messages. It is populated
+ // by RTM_NEWADDR messages and de-populated by RTM_DELADDR. See
+ // issue #4282.
+ addrCache map[uint32]map[netip.Addr]bool
+}
+
+func newOSMon(logf logger.Logf, m *Monitor) (osMon, error) {
+ conn, err := netlink.Dial(unix.NETLINK_ROUTE, &netlink.Config{
+ // Routes get us most of the events of interest, but we need
+ // address as well to cover things like DHCP deciding to give
+ // us a new address upon renewal - routing wouldn't change,
+ // but all reachability would.
+ Groups: unix.RTMGRP_IPV4_IFADDR | unix.RTMGRP_IPV6_IFADDR |
+ unix.RTMGRP_IPV4_ROUTE | unix.RTMGRP_IPV6_ROUTE |
+ unix.RTMGRP_IPV4_RULE, // no IPV6_RULE in x/sys/unix
+ })
+ if err != nil {
+ // Google Cloud Run does not implement NETLINK_ROUTE RTMGRP support
+ logf("monitor_linux: AF_NETLINK RTMGRP failed, falling back to polling")
+ return newPollingMon(logf, m)
+ }
+ return &nlConn{logf: logf, conn: conn, addrCache: make(map[uint32]map[netip.Addr]bool)}, nil
+}
+
+func (c *nlConn) IsInterestingInterface(iface string) bool { return true }
+
+func (c *nlConn) Close() error { return c.conn.Close() }
+
+func (c *nlConn) Receive() (message, error) {
+ if len(c.buffered) == 0 {
+ var err error
+ c.buffered, err = c.conn.Receive()
+ if err != nil {
+ return nil, err
+ }
+ if len(c.buffered) == 0 {
+ // Unexpected. Not seen in wild, but sleep defensively.
+ time.Sleep(time.Second)
+ return ignoreMessage{}, nil
+ }
+ }
+ msg := c.buffered[0]
+ c.buffered = c.buffered[1:]
+
+ // See https://github.com/torvalds/linux/blob/master/include/uapi/linux/rtnetlink.h
+ // And https://man7.org/linux/man-pages/man7/rtnetlink.7.html
+ switch msg.Header.Type {
+ case unix.RTM_NEWADDR, unix.RTM_DELADDR:
+ var rmsg rtnetlink.AddressMessage
+ if err := rmsg.UnmarshalBinary(msg.Data); err != nil {
+ c.logf("failed to parse type %v: %v", msg.Header.Type, err)
+ return unspecifiedMessage{}, nil
+ }
+
+ nip := netaddrIP(rmsg.Attributes.Address)
+
+ if debugNetlinkMessages() {
+ typ := "RTM_NEWADDR"
+ if msg.Header.Type == unix.RTM_DELADDR {
+ typ = "RTM_DELADDR"
+ }
+
+ // label attributes are seemingly only populated for IPv4 addresses in the wild.
+ label := rmsg.Attributes.Label
+ if label == "" {
+ itf, err := net.InterfaceByIndex(int(rmsg.Index))
+ if err == nil {
+ label = itf.Name
+ }
+ }
+
+ c.logf("%s: %s(%d) %s / %s", typ, label, rmsg.Index, rmsg.Attributes.Address, rmsg.Attributes.Local)
+ }
+
+ addrs := c.addrCache[rmsg.Index]
+
+ // Ignore duplicate RTM_NEWADDR messages using c.addrCache to
+ // detect them. See nlConn.addrcache and issue #4282.
+ if msg.Header.Type == unix.RTM_NEWADDR {
+ if addrs == nil {
+ addrs = make(map[netip.Addr]bool)
+ c.addrCache[rmsg.Index] = addrs
+ }
+
+ if addrs[nip] {
+ if debugNetlinkMessages() {
+ c.logf("ignored duplicate RTM_NEWADDR for %s", nip)
+ }
+ return ignoreMessage{}, nil
+ }
+
+ addrs[nip] = true
+ } else { // msg.Header.Type == unix.RTM_DELADDR
+ if addrs != nil {
+ delete(addrs, nip)
+ }
+
+ if len(addrs) == 0 {
+ delete(c.addrCache, rmsg.Index)
+ }
+ }
+
+ nam := &newAddrMessage{
+ IfIndex: rmsg.Index,
+ Addr: nip,
+ Delete: msg.Header.Type == unix.RTM_DELADDR,
+ }
+ if debugNetlinkMessages() {
+ c.logf("%+v", nam)
+ }
+ return nam, nil
+ case unix.RTM_NEWROUTE, unix.RTM_DELROUTE:
+ typeStr := "RTM_NEWROUTE"
+ if msg.Header.Type == unix.RTM_DELROUTE {
+ typeStr = "RTM_DELROUTE"
+ }
+ var rmsg rtnetlink.RouteMessage
+ if err := rmsg.UnmarshalBinary(msg.Data); err != nil {
+ c.logf("%s: failed to parse: %v", typeStr, err)
+ return unspecifiedMessage{}, nil
+ }
+ src := netaddrIPPrefix(rmsg.Attributes.Src, rmsg.SrcLength)
+ dst := netaddrIPPrefix(rmsg.Attributes.Dst, rmsg.DstLength)
+ gw := netaddrIP(rmsg.Attributes.Gateway)
+
+ if msg.Header.Type == unix.RTM_NEWROUTE &&
+ (rmsg.Attributes.Table == 255 || rmsg.Attributes.Table == 254) &&
+ (dst.Addr().IsMulticast() || dst.Addr().IsLinkLocalUnicast()) {
+
+ if debugNetlinkMessages() {
+ c.logf("%s ignored", typeStr)
+ }
+
+ // Normal Linux route changes on new interface coming up; don't log or react.
+ return ignoreMessage{}, nil
+ }
+
+ if rmsg.Table == tsTable && dst.IsSingleIP() {
+ // Don't log. Spammy and normal to see a bunch of these on start-up,
+ // which we make ourselves.
+ } else if tsaddr.IsTailscaleIP(dst.Addr()) {
+ // Verbose only.
+ c.logf("%s: [v1] src=%v, dst=%v, gw=%v, outif=%v, table=%v", typeStr,
+ condNetAddrPrefix(src), condNetAddrPrefix(dst), condNetAddrIP(gw),
+ rmsg.Attributes.OutIface, rmsg.Attributes.Table)
+ } else {
+ c.logf("%s: src=%v, dst=%v, gw=%v, outif=%v, table=%v", typeStr,
+ condNetAddrPrefix(src), condNetAddrPrefix(dst), condNetAddrIP(gw),
+ rmsg.Attributes.OutIface, rmsg.Attributes.Table)
+ }
+ if msg.Header.Type == unix.RTM_DELROUTE {
+ // Just logging it for now.
+ // (Debugging https://github.com/tailscale/tailscale/issues/643)
+ return unspecifiedMessage{}, nil
+ }
+
+ nrm := &newRouteMessage{
+ Table: rmsg.Table,
+ Src: src,
+ Dst: dst,
+ Gateway: gw,
+ }
+ if debugNetlinkMessages() {
+ c.logf("%+v", nrm)
+ }
+ return nrm, nil
+ case unix.RTM_NEWRULE:
+ // Probably ourselves adding it.
+ return ignoreMessage{}, nil
+ case unix.RTM_DELRULE:
+ // For https://github.com/tailscale/tailscale/issues/1591 where
+ // systemd-networkd deletes our rules.
+ var rmsg rtnetlink.RouteMessage
+ err := rmsg.UnmarshalBinary(msg.Data)
+ if err != nil {
+ c.logf("ip rule deleted; failed to parse netlink message: %v", err)
+ } else {
+ c.logf("ip rule deleted: %+v", rmsg)
+ // On `ip -4 rule del pref 5210 table main`, logs:
+ // monitor: ip rule deleted: {Family:2 DstLength:0 SrcLength:0 Tos:0 Table:254 Protocol:0 Scope:0 Type:1 Flags:0 Attributes:{Dst:<nil> Src:<nil> Gateway:<nil> OutIface:0 Priority:5210 Table:254 Mark:4294967295 Expires:<nil> Metrics:<nil> Multipath:[]}}
+ }
+ rdm := ipRuleDeletedMessage{
+ table: rmsg.Table,
+ priority: rmsg.Attributes.Priority,
+ }
+ if debugNetlinkMessages() {
+ c.logf("%+v", rdm)
+ }
+ return rdm, nil
+ case unix.RTM_NEWLINK, unix.RTM_DELLINK:
+ // This is an unhandled message, but don't print an error.
+ // See https://github.com/tailscale/tailscale/issues/6806
+ return unspecifiedMessage{}, nil
+ default:
+ c.logf("unhandled netlink msg type %+v, %q", msg.Header, msg.Data)
+ return unspecifiedMessage{}, nil
+ }
+}
+
+func netaddrIP(std net.IP) netip.Addr {
+ ip, _ := netip.AddrFromSlice(std)
+ return ip.Unmap()
+}
+
+func netaddrIPPrefix(std net.IP, bits uint8) netip.Prefix {
+ ip, _ := netip.AddrFromSlice(std)
+ return netip.PrefixFrom(ip.Unmap(), int(bits))
+}
+
+func condNetAddrPrefix(ipp netip.Prefix) string {
+ if !ipp.Addr().IsValid() {
+ return ""
+ }
+ return ipp.String()
+}
+
+func condNetAddrIP(ip netip.Addr) string {
+ if !ip.IsValid() {
+ return ""
+ }
+ return ip.String()
+}
+
+// newRouteMessage is a message for a new route being added.
+type newRouteMessage struct {
+ Src, Dst netip.Prefix
+ Gateway netip.Addr
+ Table uint8
+}
+
+const tsTable = 52
+
+func (m *newRouteMessage) ignore() bool {
+ return m.Table == tsTable || tsaddr.IsTailscaleIP(m.Dst.Addr())
+}
+
+// newAddrMessage is a message for a new address being added.
+type newAddrMessage struct {
+ Delete bool
+ Addr netip.Addr
+ IfIndex uint32 // interface index
+}
+
+func (m *newAddrMessage) ignore() bool {
+ return tsaddr.IsTailscaleIP(m.Addr)
+}
+
+type ignoreMessage struct{}
+
+func (ignoreMessage) ignore() bool { return true }
diff --git a/net/netmon/netmon_polling.go b/net/netmon/netmon_polling.go index 3d6f94731..1ce4a51de 100644 --- a/net/netmon/netmon_polling.go +++ b/net/netmon/netmon_polling.go @@ -1,21 +1,21 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -//go:build (!linux && !freebsd && !windows && !darwin) || android - -package netmon - -import ( - "tailscale.com/types/logger" -) - -func newOSMon(logf logger.Logf, m *Monitor) (osMon, error) { - return newPollingMon(logf, m) -} - -// unspecifiedMessage is a minimal message implementation that should not -// be ignored. In general, OS-specific implementations should use better -// types and avoid this if they can. -type unspecifiedMessage struct{} - -func (unspecifiedMessage) ignore() bool { return false } +// Copyright (c) Tailscale Inc & AUTHORS
+// SPDX-License-Identifier: BSD-3-Clause
+
+//go:build (!linux && !freebsd && !windows && !darwin) || android
+
+package netmon
+
+import (
+ "tailscale.com/types/logger"
+)
+
+func newOSMon(logf logger.Logf, m *Monitor) (osMon, error) {
+ return newPollingMon(logf, m)
+}
+
+// unspecifiedMessage is a minimal message implementation that should not
+// be ignored. In general, OS-specific implementations should use better
+// types and avoid this if they can.
+type unspecifiedMessage struct{}
+
+func (unspecifiedMessage) ignore() bool { return false }
diff --git a/net/netmon/polling.go b/net/netmon/polling.go index ce1618ed6..bb7210b94 100644 --- a/net/netmon/polling.go +++ b/net/netmon/polling.go @@ -1,86 +1,86 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -//go:build !windows && !darwin - -package netmon - -import ( - "bytes" - "errors" - "os" - "runtime" - "sync" - "time" - - "tailscale.com/types/logger" -) - -func newPollingMon(logf logger.Logf, m *Monitor) (osMon, error) { - return &pollingMon{ - logf: logf, - m: m, - stop: make(chan struct{}), - }, nil -} - -// pollingMon is a bad but portable implementation of the link monitor -// that works by polling the interface state every 10 seconds, in lieu -// of anything to subscribe to. -type pollingMon struct { - logf logger.Logf - m *Monitor - - closeOnce sync.Once - stop chan struct{} -} - -func (pm *pollingMon) IsInterestingInterface(iface string) bool { - return true -} - -func (pm *pollingMon) Close() error { - pm.closeOnce.Do(func() { - close(pm.stop) - }) - return nil -} - -func (pm *pollingMon) isCloudRun() bool { - // https: //cloud.google.com/run/docs/reference/container-contract#env-vars - if os.Getenv("K_REVISION") == "" || os.Getenv("K_CONFIGURATION") == "" || - os.Getenv("K_SERVICE") == "" || os.Getenv("PORT") == "" { - return false - } - vers, err := os.ReadFile("/proc/version") - if err != nil { - pm.logf("Failed to read /proc/version: %v", err) - return false - } - return string(bytes.TrimSpace(vers)) == "Linux version 4.4.0 #1 SMP Sun Jan 10 15:06:54 PST 2016" -} - -func (pm *pollingMon) Receive() (message, error) { - d := 10 * time.Second - if runtime.GOOS == "android" { - // We'll have Android notify the link monitor to wake up earlier, - // so this can go very slowly there, to save battery. - // https://github.com/tailscale/tailscale/issues/1427 - d = 10 * time.Minute - } else if pm.isCloudRun() { - // Cloud Run routes never change at runtime. the containers are killed within - // 15 minutes by default, set the interval long enough to be effectively infinite. - pm.logf("monitor polling: Cloud Run detected, reduce polling interval to 24h") - d = 24 * time.Hour - } - timer := time.NewTimer(d) - defer timer.Stop() - for { - select { - case <-timer.C: - return unspecifiedMessage{}, nil - case <-pm.stop: - return nil, errors.New("stopped") - } - } -} +// Copyright (c) Tailscale Inc & AUTHORS
+// SPDX-License-Identifier: BSD-3-Clause
+
+//go:build !windows && !darwin
+
+package netmon
+
+import (
+ "bytes"
+ "errors"
+ "os"
+ "runtime"
+ "sync"
+ "time"
+
+ "tailscale.com/types/logger"
+)
+
+func newPollingMon(logf logger.Logf, m *Monitor) (osMon, error) {
+ return &pollingMon{
+ logf: logf,
+ m: m,
+ stop: make(chan struct{}),
+ }, nil
+}
+
+// pollingMon is a bad but portable implementation of the link monitor
+// that works by polling the interface state every 10 seconds, in lieu
+// of anything to subscribe to.
+type pollingMon struct {
+ logf logger.Logf
+ m *Monitor
+
+ closeOnce sync.Once
+ stop chan struct{}
+}
+
+func (pm *pollingMon) IsInterestingInterface(iface string) bool {
+ return true
+}
+
+func (pm *pollingMon) Close() error {
+ pm.closeOnce.Do(func() {
+ close(pm.stop)
+ })
+ return nil
+}
+
+func (pm *pollingMon) isCloudRun() bool {
+ // https: //cloud.google.com/run/docs/reference/container-contract#env-vars
+ if os.Getenv("K_REVISION") == "" || os.Getenv("K_CONFIGURATION") == "" ||
+ os.Getenv("K_SERVICE") == "" || os.Getenv("PORT") == "" {
+ return false
+ }
+ vers, err := os.ReadFile("/proc/version")
+ if err != nil {
+ pm.logf("Failed to read /proc/version: %v", err)
+ return false
+ }
+ return string(bytes.TrimSpace(vers)) == "Linux version 4.4.0 #1 SMP Sun Jan 10 15:06:54 PST 2016"
+}
+
+func (pm *pollingMon) Receive() (message, error) {
+ d := 10 * time.Second
+ if runtime.GOOS == "android" {
+ // We'll have Android notify the link monitor to wake up earlier,
+ // so this can go very slowly there, to save battery.
+ // https://github.com/tailscale/tailscale/issues/1427
+ d = 10 * time.Minute
+ } else if pm.isCloudRun() {
+ // Cloud Run routes never change at runtime. the containers are killed within
+ // 15 minutes by default, set the interval long enough to be effectively infinite.
+ pm.logf("monitor polling: Cloud Run detected, reduce polling interval to 24h")
+ d = 24 * time.Hour
+ }
+ timer := time.NewTimer(d)
+ defer timer.Stop()
+ for {
+ select {
+ case <-timer.C:
+ return unspecifiedMessage{}, nil
+ case <-pm.stop:
+ return nil, errors.New("stopped")
+ }
+ }
+}
diff --git a/net/netns/netns_android.go b/net/netns/netns_android.go index 162e5c79a..79084ff11 100644 --- a/net/netns/netns_android.go +++ b/net/netns/netns_android.go @@ -1,75 +1,75 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -//go:build android - -package netns - -import ( - "fmt" - "sync" - "syscall" - - "tailscale.com/net/netmon" - "tailscale.com/types/logger" -) - -var ( - androidProtectFuncMu sync.Mutex - androidProtectFunc func(fd int) error -) - -// UseSocketMark reports whether SO_MARK is in use. Android does not use SO_MARK. -func UseSocketMark() bool { - return false -} - -// SetAndroidProtectFunc register a func that Android provides that JNI calls into -// https://developer.android.com/reference/android/net/VpnService#protect(int) -// which is documented as: -// -// "Protect a socket from VPN connections. After protecting, data sent -// through this socket will go directly to the underlying network, so -// its traffic will not be forwarded through the VPN. This method is -// useful if some connections need to be kept outside of VPN. For -// example, a VPN tunnel should protect itself if its destination is -// covered by VPN routes. Otherwise its outgoing packets will be sent -// back to the VPN interface and cause an infinite loop. This method -// will fail if the application is not prepared or is revoked." -// -// A nil func disables the use the hook. -// -// This indirection is necessary because this is the supported, stable -// interface to use on Android, and doing the sockopts to set the -// fwmark return errors on Android. The actual implementation of -// VpnService.protect ends up doing an IPC to another process on -// Android, asking for the fwmark to be set. -func SetAndroidProtectFunc(f func(fd int) error) { - androidProtectFuncMu.Lock() - defer androidProtectFuncMu.Unlock() - androidProtectFunc = f -} - -func control(logger.Logf, *netmon.Monitor) func(network, address string, c syscall.RawConn) error { - return controlC -} - -// controlC marks c as necessary to dial in a separate network namespace. -// -// It's intentionally the same signature as net.Dialer.Control -// and net.ListenConfig.Control. -func controlC(network, address string, c syscall.RawConn) error { - var sockErr error - err := c.Control(func(fd uintptr) { - androidProtectFuncMu.Lock() - f := androidProtectFunc - androidProtectFuncMu.Unlock() - if f != nil { - sockErr = f(int(fd)) - } - }) - if err != nil { - return fmt.Errorf("RawConn.Control on %T: %w", c, err) - } - return sockErr -} +// Copyright (c) Tailscale Inc & AUTHORS
+// SPDX-License-Identifier: BSD-3-Clause
+
+//go:build android
+
+package netns
+
+import (
+ "fmt"
+ "sync"
+ "syscall"
+
+ "tailscale.com/net/netmon"
+ "tailscale.com/types/logger"
+)
+
+var (
+ androidProtectFuncMu sync.Mutex
+ androidProtectFunc func(fd int) error
+)
+
+// UseSocketMark reports whether SO_MARK is in use. Android does not use SO_MARK.
+func UseSocketMark() bool {
+ return false
+}
+
+// SetAndroidProtectFunc register a func that Android provides that JNI calls into
+// https://developer.android.com/reference/android/net/VpnService#protect(int)
+// which is documented as:
+//
+// "Protect a socket from VPN connections. After protecting, data sent
+// through this socket will go directly to the underlying network, so
+// its traffic will not be forwarded through the VPN. This method is
+// useful if some connections need to be kept outside of VPN. For
+// example, a VPN tunnel should protect itself if its destination is
+// covered by VPN routes. Otherwise its outgoing packets will be sent
+// back to the VPN interface and cause an infinite loop. This method
+// will fail if the application is not prepared or is revoked."
+//
+// A nil func disables the use the hook.
+//
+// This indirection is necessary because this is the supported, stable
+// interface to use on Android, and doing the sockopts to set the
+// fwmark return errors on Android. The actual implementation of
+// VpnService.protect ends up doing an IPC to another process on
+// Android, asking for the fwmark to be set.
+func SetAndroidProtectFunc(f func(fd int) error) {
+ androidProtectFuncMu.Lock()
+ defer androidProtectFuncMu.Unlock()
+ androidProtectFunc = f
+}
+
+func control(logger.Logf, *netmon.Monitor) func(network, address string, c syscall.RawConn) error {
+ return controlC
+}
+
+// controlC marks c as necessary to dial in a separate network namespace.
+//
+// It's intentionally the same signature as net.Dialer.Control
+// and net.ListenConfig.Control.
+func controlC(network, address string, c syscall.RawConn) error {
+ var sockErr error
+ err := c.Control(func(fd uintptr) {
+ androidProtectFuncMu.Lock()
+ f := androidProtectFunc
+ androidProtectFuncMu.Unlock()
+ if f != nil {
+ sockErr = f(int(fd))
+ }
+ })
+ if err != nil {
+ return fmt.Errorf("RawConn.Control on %T: %w", c, err)
+ }
+ return sockErr
+}
diff --git a/net/netns/netns_default.go b/net/netns/netns_default.go index 94f24d8fa..02db19e75 100644 --- a/net/netns/netns_default.go +++ b/net/netns/netns_default.go @@ -1,22 +1,22 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -//go:build !linux && !windows && !darwin - -package netns - -import ( - "syscall" - - "tailscale.com/net/netmon" - "tailscale.com/types/logger" -) - -func control(logger.Logf, *netmon.Monitor) func(network, address string, c syscall.RawConn) error { - return controlC -} - -// controlC does nothing to c. -func controlC(network, address string, c syscall.RawConn) error { - return nil -} +// Copyright (c) Tailscale Inc & AUTHORS
+// SPDX-License-Identifier: BSD-3-Clause
+
+//go:build !linux && !windows && !darwin
+
+package netns
+
+import (
+ "syscall"
+
+ "tailscale.com/net/netmon"
+ "tailscale.com/types/logger"
+)
+
+func control(logger.Logf, *netmon.Monitor) func(network, address string, c syscall.RawConn) error {
+ return controlC
+}
+
+// controlC does nothing to c.
+func controlC(network, address string, c syscall.RawConn) error {
+ return nil
+}
diff --git a/net/netns/netns_linux_test.go b/net/netns/netns_linux_test.go index a5000f37f..cc221bcb1 100644 --- a/net/netns/netns_linux_test.go +++ b/net/netns/netns_linux_test.go @@ -1,14 +1,14 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -package netns - -import ( - "testing" -) - -func TestSocketMarkWorks(t *testing.T) { - _ = socketMarkWorks() - // we cannot actually assert whether the test runner has SO_MARK available - // or not, as we don't know. We're just checking that it doesn't panic. -} +// Copyright (c) Tailscale Inc & AUTHORS
+// SPDX-License-Identifier: BSD-3-Clause
+
+package netns
+
+import (
+ "testing"
+)
+
+func TestSocketMarkWorks(t *testing.T) {
+ _ = socketMarkWorks()
+ // we cannot actually assert whether the test runner has SO_MARK available
+ // or not, as we don't know. We're just checking that it doesn't panic.
+}
diff --git a/net/netns/netns_test.go b/net/netns/netns_test.go index 82f919b94..1c6d699ac 100644 --- a/net/netns/netns_test.go +++ b/net/netns/netns_test.go @@ -1,78 +1,78 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -// Package netns contains the common code for using the Go net package -// in a logical "network namespace" to avoid routing loops where -// Tailscale-created packets would otherwise loop back through -// Tailscale routes. -// -// Despite the name netns, the exact mechanism used differs by -// operating system, and perhaps even by version of the OS. -// -// The netns package also handles connecting via SOCKS proxies when -// configured by the environment. -package netns - -import ( - "flag" - "testing" -) - -var extNetwork = flag.Bool("use-external-network", false, "use the external network in tests") - -func TestDial(t *testing.T) { - if !*extNetwork { - t.Skip("skipping test without --use-external-network") - } - d := NewDialer(t.Logf, nil) - c, err := d.Dial("tcp", "google.com:80") - if err != nil { - t.Fatal(err) - } - defer c.Close() - t.Logf("got addr %v", c.RemoteAddr()) - - c, err = d.Dial("tcp4", "google.com:80") - if err != nil { - t.Fatal(err) - } - defer c.Close() - t.Logf("got addr %v", c.RemoteAddr()) -} - -func TestIsLocalhost(t *testing.T) { - tests := []struct { - name string - host string - want bool - }{ - {"IPv4 loopback", "127.0.0.1", true}, - {"IPv4 !loopback", "192.168.0.1", false}, - {"IPv4 loopback with port", "127.0.0.1:1", true}, - {"IPv4 !loopback with port", "192.168.0.1:1", false}, - {"IPv4 unspecified", "0.0.0.0", false}, - {"IPv4 unspecified with port", "0.0.0.0:1", false}, - {"IPv6 loopback", "::1", true}, - {"IPv6 !loopback", "2001:4860:4860::8888", false}, - {"IPv6 loopback with port", "[::1]:1", true}, - {"IPv6 !loopback with port", "[2001:4860:4860::8888]:1", false}, - {"IPv6 unspecified", "::", false}, - {"IPv6 unspecified with port", "[::]:1", false}, - {"empty", "", false}, - {"hostname", "example.com", false}, - {"localhost", "localhost", true}, - {"localhost6", "localhost6", true}, - {"localhost with port", "localhost:1", true}, - {"localhost6 with port", "localhost6:1", true}, - {"ip6-localhost", "ip6-localhost", true}, - {"ip6-localhost with port", "ip6-localhost:1", true}, - {"ip6-loopback", "ip6-loopback", true}, - {"ip6-loopback with port", "ip6-loopback:1", true}, - } - - for _, test := range tests { - if got := isLocalhost(test.host); got != test.want { - t.Errorf("isLocalhost(%q) = %v, want %v", test.name, got, test.want) - } - } -} +// Copyright (c) Tailscale Inc & AUTHORS
+// SPDX-License-Identifier: BSD-3-Clause
+
+// Package netns contains the common code for using the Go net package
+// in a logical "network namespace" to avoid routing loops where
+// Tailscale-created packets would otherwise loop back through
+// Tailscale routes.
+//
+// Despite the name netns, the exact mechanism used differs by
+// operating system, and perhaps even by version of the OS.
+//
+// The netns package also handles connecting via SOCKS proxies when
+// configured by the environment.
+package netns
+
+import (
+ "flag"
+ "testing"
+)
+
+var extNetwork = flag.Bool("use-external-network", false, "use the external network in tests")
+
+func TestDial(t *testing.T) {
+ if !*extNetwork {
+ t.Skip("skipping test without --use-external-network")
+ }
+ d := NewDialer(t.Logf, nil)
+ c, err := d.Dial("tcp", "google.com:80")
+ if err != nil {
+ t.Fatal(err)
+ }
+ defer c.Close()
+ t.Logf("got addr %v", c.RemoteAddr())
+
+ c, err = d.Dial("tcp4", "google.com:80")
+ if err != nil {
+ t.Fatal(err)
+ }
+ defer c.Close()
+ t.Logf("got addr %v", c.RemoteAddr())
+}
+
+func TestIsLocalhost(t *testing.T) {
+ tests := []struct {
+ name string
+ host string
+ want bool
+ }{
+ {"IPv4 loopback", "127.0.0.1", true},
+ {"IPv4 !loopback", "192.168.0.1", false},
+ {"IPv4 loopback with port", "127.0.0.1:1", true},
+ {"IPv4 !loopback with port", "192.168.0.1:1", false},
+ {"IPv4 unspecified", "0.0.0.0", false},
+ {"IPv4 unspecified with port", "0.0.0.0:1", false},
+ {"IPv6 loopback", "::1", true},
+ {"IPv6 !loopback", "2001:4860:4860::8888", false},
+ {"IPv6 loopback with port", "[::1]:1", true},
+ {"IPv6 !loopback with port", "[2001:4860:4860::8888]:1", false},
+ {"IPv6 unspecified", "::", false},
+ {"IPv6 unspecified with port", "[::]:1", false},
+ {"empty", "", false},
+ {"hostname", "example.com", false},
+ {"localhost", "localhost", true},
+ {"localhost6", "localhost6", true},
+ {"localhost with port", "localhost:1", true},
+ {"localhost6 with port", "localhost6:1", true},
+ {"ip6-localhost", "ip6-localhost", true},
+ {"ip6-localhost with port", "ip6-localhost:1", true},
+ {"ip6-loopback", "ip6-loopback", true},
+ {"ip6-loopback with port", "ip6-loopback:1", true},
+ }
+
+ for _, test := range tests {
+ if got := isLocalhost(test.host); got != test.want {
+ t.Errorf("isLocalhost(%q) = %v, want %v", test.name, got, test.want)
+ }
+ }
+}
diff --git a/net/netns/socks.go b/net/netns/socks.go index eea69d865..a3d10d3ae 100644 --- a/net/netns/socks.go +++ b/net/netns/socks.go @@ -1,19 +1,19 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -//go:build !ios && !js - -package netns - -import "golang.org/x/net/proxy" - -func init() { - wrapDialer = wrapSocks -} - -func wrapSocks(d Dialer) Dialer { - if cd, ok := proxy.FromEnvironmentUsing(d).(Dialer); ok { - return cd - } - return d -} +// Copyright (c) Tailscale Inc & AUTHORS
+// SPDX-License-Identifier: BSD-3-Clause
+
+//go:build !ios && !js
+
+package netns
+
+import "golang.org/x/net/proxy"
+
+func init() {
+ wrapDialer = wrapSocks
+}
+
+func wrapSocks(d Dialer) Dialer {
+ if cd, ok := proxy.FromEnvironmentUsing(d).(Dialer); ok {
+ return cd
+ }
+ return d
+}
diff --git a/net/netstat/netstat.go b/net/netstat/netstat.go index 53c7d7757..53121dc52 100644 --- a/net/netstat/netstat.go +++ b/net/netstat/netstat.go @@ -1,35 +1,35 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -// Package netstat returns the local machine's network connection table. -package netstat - -import ( - "errors" - "net/netip" - "runtime" -) - -var ErrNotImplemented = errors.New("not implemented for GOOS=" + runtime.GOOS) - -type Entry struct { - Local, Remote netip.AddrPort - Pid int - State string // TODO: type? - OSMetadata OSMetadata -} - -// Table contains local machine's TCP connection entries. -// -// Currently only TCP (IPv4 and IPv6) are included. -type Table struct { - Entries []Entry -} - -// Get returns the connection table. -// -// It returns ErrNotImplemented if the table is not available for the -// current operating system. -func Get() (*Table, error) { - return get() -} +// Copyright (c) Tailscale Inc & AUTHORS
+// SPDX-License-Identifier: BSD-3-Clause
+
+// Package netstat returns the local machine's network connection table.
+package netstat
+
+import (
+ "errors"
+ "net/netip"
+ "runtime"
+)
+
+var ErrNotImplemented = errors.New("not implemented for GOOS=" + runtime.GOOS)
+
+type Entry struct {
+ Local, Remote netip.AddrPort
+ Pid int
+ State string // TODO: type?
+ OSMetadata OSMetadata
+}
+
+// Table contains local machine's TCP connection entries.
+//
+// Currently only TCP (IPv4 and IPv6) are included.
+type Table struct {
+ Entries []Entry
+}
+
+// Get returns the connection table.
+//
+// It returns ErrNotImplemented if the table is not available for the
+// current operating system.
+func Get() (*Table, error) {
+ return get()
+}
diff --git a/net/netstat/netstat_noimpl.go b/net/netstat/netstat_noimpl.go index e455c8ce9..608b1a617 100644 --- a/net/netstat/netstat_noimpl.go +++ b/net/netstat/netstat_noimpl.go @@ -1,14 +1,14 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -//go:build !windows - -package netstat - -// OSMetadata includes any additional OS-specific information that may be -// obtained during the retrieval of a given Entry. -type OSMetadata struct{} - -func get() (*Table, error) { - return nil, ErrNotImplemented -} +// Copyright (c) Tailscale Inc & AUTHORS
+// SPDX-License-Identifier: BSD-3-Clause
+
+//go:build !windows
+
+package netstat
+
+// OSMetadata includes any additional OS-specific information that may be
+// obtained during the retrieval of a given Entry.
+type OSMetadata struct{}
+
+func get() (*Table, error) {
+ return nil, ErrNotImplemented
+}
diff --git a/net/netstat/netstat_test.go b/net/netstat/netstat_test.go index 38827df5e..74f4fcec0 100644 --- a/net/netstat/netstat_test.go +++ b/net/netstat/netstat_test.go @@ -1,21 +1,21 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -package netstat - -import ( - "testing" -) - -func TestGet(t *testing.T) { - nt, err := Get() - if err == ErrNotImplemented { - t.Skip("TODO: not implemented") - } - if err != nil { - t.Fatal(err) - } - for _, e := range nt.Entries { - t.Logf("Entry: %+v", e) - } -} +// Copyright (c) Tailscale Inc & AUTHORS
+// SPDX-License-Identifier: BSD-3-Clause
+
+package netstat
+
+import (
+ "testing"
+)
+
+func TestGet(t *testing.T) {
+ nt, err := Get()
+ if err == ErrNotImplemented {
+ t.Skip("TODO: not implemented")
+ }
+ if err != nil {
+ t.Fatal(err)
+ }
+ for _, e := range nt.Entries {
+ t.Logf("Entry: %+v", e)
+ }
+}
diff --git a/net/packet/doc.go b/net/packet/doc.go index ce6c0c307..f3cb93db8 100644 --- a/net/packet/doc.go +++ b/net/packet/doc.go @@ -1,15 +1,15 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -// Package packet contains packet parsing and marshaling utilities. -// -// Parsed provides allocation-free minimal packet header decoding, for -// use in packet filtering. The other types in the package are for -// constructing and marshaling packets into []bytes. -// -// To support allocation-free parsing, this package defines IPv4 and -// IPv6 address types. You should prefer to use netaddr's types, -// except where you absolutely need allocation-free IP handling -// (i.e. in the tunnel datapath) and are willing to implement all -// codepaths and data structures twice, once per IP family. -package packet +// Copyright (c) Tailscale Inc & AUTHORS
+// SPDX-License-Identifier: BSD-3-Clause
+
+// Package packet contains packet parsing and marshaling utilities.
+//
+// Parsed provides allocation-free minimal packet header decoding, for
+// use in packet filtering. The other types in the package are for
+// constructing and marshaling packets into []bytes.
+//
+// To support allocation-free parsing, this package defines IPv4 and
+// IPv6 address types. You should prefer to use netaddr's types,
+// except where you absolutely need allocation-free IP handling
+// (i.e. in the tunnel datapath) and are willing to implement all
+// codepaths and data structures twice, once per IP family.
+package packet
diff --git a/net/packet/header.go b/net/packet/header.go index dbe84429a..0b1947c0a 100644 --- a/net/packet/header.go +++ b/net/packet/header.go @@ -1,66 +1,66 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -package packet - -import ( - "errors" - "math" -) - -const tcpHeaderLength = 20 -const sctpHeaderLength = 12 - -// maxPacketLength is the largest length that all headers support. -// IPv4 headers using uint16 for this forces an upper bound of 64KB. -const maxPacketLength = math.MaxUint16 - -var ( - // errSmallBuffer is returned when Marshal receives a buffer - // too small to contain the header to marshal. - errSmallBuffer = errors.New("buffer too small") - // errLargePacket is returned when Marshal receives a payload - // larger than the maximum representable size in header - // fields. - errLargePacket = errors.New("packet too large") -) - -// Header is a packet header capable of marshaling itself into a byte -// buffer. -type Header interface { - // Len returns the length of the marshaled packet. - Len() int - // Marshal serializes the header into buf, which must be at - // least Len() bytes long. Implementations of Marshal assume - // that bytes after the first Len() are payload bytes for the - // purpose of computing length and checksum fields. Marshal - // implementations must not allocate memory. - Marshal(buf []byte) error -} - -// HeaderChecksummer is implemented by Header implementations that -// need to do a checksum over their payloads. -type HeaderChecksummer interface { - Header - - // WriteCheck writes the correct checksum into buf, which should - // be be the already-marshalled header and payload. - WriteChecksum(buf []byte) -} - -// Generate generates a new packet with the given Header and -// payload. This function allocates memory, see Header.Marshal for an -// allocation-free option. -func Generate(h Header, payload []byte) []byte { - hlen := h.Len() - buf := make([]byte, hlen+len(payload)) - - copy(buf[hlen:], payload) - h.Marshal(buf) - - if hc, ok := h.(HeaderChecksummer); ok { - hc.WriteChecksum(buf) - } - - return buf -} +// Copyright (c) Tailscale Inc & AUTHORS
+// SPDX-License-Identifier: BSD-3-Clause
+
+package packet
+
+import (
+ "errors"
+ "math"
+)
+
+const tcpHeaderLength = 20
+const sctpHeaderLength = 12
+
+// maxPacketLength is the largest length that all headers support.
+// IPv4 headers using uint16 for this forces an upper bound of 64KB.
+const maxPacketLength = math.MaxUint16
+
+var (
+ // errSmallBuffer is returned when Marshal receives a buffer
+ // too small to contain the header to marshal.
+ errSmallBuffer = errors.New("buffer too small")
+ // errLargePacket is returned when Marshal receives a payload
+ // larger than the maximum representable size in header
+ // fields.
+ errLargePacket = errors.New("packet too large")
+)
+
+// Header is a packet header capable of marshaling itself into a byte
+// buffer.
+type Header interface {
+ // Len returns the length of the marshaled packet.
+ Len() int
+ // Marshal serializes the header into buf, which must be at
+ // least Len() bytes long. Implementations of Marshal assume
+ // that bytes after the first Len() are payload bytes for the
+ // purpose of computing length and checksum fields. Marshal
+ // implementations must not allocate memory.
+ Marshal(buf []byte) error
+}
+
+// HeaderChecksummer is implemented by Header implementations that
+// need to do a checksum over their payloads.
+type HeaderChecksummer interface {
+ Header
+
+ // WriteCheck writes the correct checksum into buf, which should
+ // be be the already-marshalled header and payload.
+ WriteChecksum(buf []byte)
+}
+
+// Generate generates a new packet with the given Header and
+// payload. This function allocates memory, see Header.Marshal for an
+// allocation-free option.
+func Generate(h Header, payload []byte) []byte {
+ hlen := h.Len()
+ buf := make([]byte, hlen+len(payload))
+
+ copy(buf[hlen:], payload)
+ h.Marshal(buf)
+
+ if hc, ok := h.(HeaderChecksummer); ok {
+ hc.WriteChecksum(buf)
+ }
+
+ return buf
+}
diff --git a/net/packet/icmp.go b/net/packet/icmp.go index 89a7aaa32..7b86edd81 100644 --- a/net/packet/icmp.go +++ b/net/packet/icmp.go @@ -1,28 +1,28 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -package packet - -import ( - crand "crypto/rand" - - "encoding/binary" -) - -// ICMPEchoPayload generates a new random ID/Sequence pair, and returns a uint32 -// derived from them, along with the id, sequence and given payload in a buffer. -// It returns an error if the random source could not be read. -func ICMPEchoPayload(payload []byte) (idSeq uint32, buf []byte) { - buf = make([]byte, len(payload)+4) - - // make a completely random id/sequence combo, which is very unlikely to - // collide with a running ping sequence on the host system. Errors are - // ignored, that would result in collisions, but errors reading from the - // random device are rare, and will cause this process universe to soon end. - crand.Read(buf[:4]) - - idSeq = binary.LittleEndian.Uint32(buf) - copy(buf[4:], payload) - - return -} +// Copyright (c) Tailscale Inc & AUTHORS
+// SPDX-License-Identifier: BSD-3-Clause
+
+package packet
+
+import (
+ crand "crypto/rand"
+
+ "encoding/binary"
+)
+
+// ICMPEchoPayload generates a new random ID/Sequence pair, and returns a uint32
+// derived from them, along with the id, sequence and given payload in a buffer.
+// It returns an error if the random source could not be read.
+func ICMPEchoPayload(payload []byte) (idSeq uint32, buf []byte) {
+ buf = make([]byte, len(payload)+4)
+
+ // make a completely random id/sequence combo, which is very unlikely to
+ // collide with a running ping sequence on the host system. Errors are
+ // ignored, that would result in collisions, but errors reading from the
+ // random device are rare, and will cause this process universe to soon end.
+ crand.Read(buf[:4])
+
+ idSeq = binary.LittleEndian.Uint32(buf)
+ copy(buf[4:], payload)
+
+ return
+}
diff --git a/net/packet/icmp6_test.go b/net/packet/icmp6_test.go index f34883ca4..c2fab353a 100644 --- a/net/packet/icmp6_test.go +++ b/net/packet/icmp6_test.go @@ -1,79 +1,79 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -package packet - -import ( - "net/netip" - "testing" - - "tailscale.com/types/ipproto" -) - -func TestICMPv6PingResponse(t *testing.T) { - pingHdr := ICMP6Header{ - IP6Header: IP6Header{ - Src: netip.MustParseAddr("1::1"), - Dst: netip.MustParseAddr("2::2"), - IPProto: ipproto.ICMPv6, - }, - Type: ICMP6EchoRequest, - Code: ICMP6NoCode, - } - - // echoReqLen is 2 bytes identifier + 2 bytes seq number. - // https://datatracker.ietf.org/doc/html/rfc4443#section-4.1 - // Packet.IsEchoRequest verifies that these 4 bytes are present. - const echoReqLen = 4 - buf := make([]byte, pingHdr.Len()+echoReqLen) - if err := pingHdr.Marshal(buf); err != nil { - t.Fatal(err) - } - - var p Parsed - p.Decode(buf) - if !p.IsEchoRequest() { - t.Fatalf("not an echo request, got: %+v", p) - } - - pingHdr.ToResponse() - buf = make([]byte, pingHdr.Len()+echoReqLen) - if err := pingHdr.Marshal(buf); err != nil { - t.Fatal(err) - } - - p.Decode(buf) - if p.IsEchoRequest() { - t.Fatalf("unexpectedly still an echo request: %+v", p) - } - if !p.IsEchoResponse() { - t.Fatalf("not an echo response: %+v", p) - } -} - -func TestICMPv6Checksum(t *testing.T) { - const req = "\x60\x0f\x07\x00\x00\x10\x3a\x40\xfd\x7a\x11\x5c\xa1\xe0\xab\x12" + - "\x48\x43\xcd\x96\x62\x7b\x65\x28\x26\x07\xf8\xb0\x40\x0a\x08\x07" + - "\x00\x00\x00\x00\x00\x00\x20\x0e\x80\x00\x4a\x9a\x2e\xea\x00\x02" + - "\x61\xb1\x9e\xad\x00\x06\x45\xaa" - // The packet that we'd originally generated incorrectly, but with the checksum - // bytes fixed per WireShark's correct calculation: - const wantRes = "\x60\x00\xf8\xff\x00\x10\x3a\x40\x26\x07\xf8\xb0\x40\x0a\x08\x07" + - "\x00\x00\x00\x00\x00\x00\x20\x0e\xfd\x7a\x11\x5c\xa1\xe0\xab\x12" + - "\x48\x43\xcd\x96\x62\x7b\x65\x28\x81\x00\x49\x9a\x2e\xea\x00\x02" + - "\x61\xb1\x9e\xad\x00\x06\x45\xaa" - - var p Parsed - p.Decode([]byte(req)) - if !p.IsEchoRequest() { - t.Fatalf("not an echo request, got: %+v", p) - } - - h := p.ICMP6Header() - h.ToResponse() - pong := Generate(&h, p.Payload()) - - if string(pong) != wantRes { - t.Errorf("wrong packet\n\n got: %x\nwant: %x", pong, wantRes) - } -} +// Copyright (c) Tailscale Inc & AUTHORS
+// SPDX-License-Identifier: BSD-3-Clause
+
+package packet
+
+import (
+ "net/netip"
+ "testing"
+
+ "tailscale.com/types/ipproto"
+)
+
+func TestICMPv6PingResponse(t *testing.T) {
+ pingHdr := ICMP6Header{
+ IP6Header: IP6Header{
+ Src: netip.MustParseAddr("1::1"),
+ Dst: netip.MustParseAddr("2::2"),
+ IPProto: ipproto.ICMPv6,
+ },
+ Type: ICMP6EchoRequest,
+ Code: ICMP6NoCode,
+ }
+
+ // echoReqLen is 2 bytes identifier + 2 bytes seq number.
+ // https://datatracker.ietf.org/doc/html/rfc4443#section-4.1
+ // Packet.IsEchoRequest verifies that these 4 bytes are present.
+ const echoReqLen = 4
+ buf := make([]byte, pingHdr.Len()+echoReqLen)
+ if err := pingHdr.Marshal(buf); err != nil {
+ t.Fatal(err)
+ }
+
+ var p Parsed
+ p.Decode(buf)
+ if !p.IsEchoRequest() {
+ t.Fatalf("not an echo request, got: %+v", p)
+ }
+
+ pingHdr.ToResponse()
+ buf = make([]byte, pingHdr.Len()+echoReqLen)
+ if err := pingHdr.Marshal(buf); err != nil {
+ t.Fatal(err)
+ }
+
+ p.Decode(buf)
+ if p.IsEchoRequest() {
+ t.Fatalf("unexpectedly still an echo request: %+v", p)
+ }
+ if !p.IsEchoResponse() {
+ t.Fatalf("not an echo response: %+v", p)
+ }
+}
+
+func TestICMPv6Checksum(t *testing.T) {
+ const req = "\x60\x0f\x07\x00\x00\x10\x3a\x40\xfd\x7a\x11\x5c\xa1\xe0\xab\x12" +
+ "\x48\x43\xcd\x96\x62\x7b\x65\x28\x26\x07\xf8\xb0\x40\x0a\x08\x07" +
+ "\x00\x00\x00\x00\x00\x00\x20\x0e\x80\x00\x4a\x9a\x2e\xea\x00\x02" +
+ "\x61\xb1\x9e\xad\x00\x06\x45\xaa"
+ // The packet that we'd originally generated incorrectly, but with the checksum
+ // bytes fixed per WireShark's correct calculation:
+ const wantRes = "\x60\x00\xf8\xff\x00\x10\x3a\x40\x26\x07\xf8\xb0\x40\x0a\x08\x07" +
+ "\x00\x00\x00\x00\x00\x00\x20\x0e\xfd\x7a\x11\x5c\xa1\xe0\xab\x12" +
+ "\x48\x43\xcd\x96\x62\x7b\x65\x28\x81\x00\x49\x9a\x2e\xea\x00\x02" +
+ "\x61\xb1\x9e\xad\x00\x06\x45\xaa"
+
+ var p Parsed
+ p.Decode([]byte(req))
+ if !p.IsEchoRequest() {
+ t.Fatalf("not an echo request, got: %+v", p)
+ }
+
+ h := p.ICMP6Header()
+ h.ToResponse()
+ pong := Generate(&h, p.Payload())
+
+ if string(pong) != wantRes {
+ t.Errorf("wrong packet\n\n got: %x\nwant: %x", pong, wantRes)
+ }
+}
diff --git a/net/packet/ip4.go b/net/packet/ip4.go index 967a8dba7..596bc766d 100644 --- a/net/packet/ip4.go +++ b/net/packet/ip4.go @@ -1,116 +1,116 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -package packet - -import ( - "encoding/binary" - "errors" - "net/netip" - - "tailscale.com/types/ipproto" -) - -// ip4HeaderLength is the length of an IPv4 header with no IP options. -const ip4HeaderLength = 20 - -// IP4Header represents an IPv4 packet header. -type IP4Header struct { - IPProto ipproto.Proto - IPID uint16 - Src netip.Addr - Dst netip.Addr -} - -// Len implements Header. -func (h IP4Header) Len() int { - return ip4HeaderLength -} - -var errWrongFamily = errors.New("wrong address family for src/dst IP") - -// Marshal implements Header. -func (h IP4Header) Marshal(buf []byte) error { - if len(buf) < h.Len() { - return errSmallBuffer - } - if len(buf) > maxPacketLength { - return errLargePacket - } - if !h.Src.Is4() || !h.Dst.Is4() { - return errWrongFamily - } - - buf[0] = 0x40 | (byte(h.Len() >> 2)) // IPv4 + IHL - buf[1] = 0x00 // DSCP + ECN - binary.BigEndian.PutUint16(buf[2:4], uint16(len(buf))) // Total length - binary.BigEndian.PutUint16(buf[4:6], h.IPID) // ID - binary.BigEndian.PutUint16(buf[6:8], 0) // Flags + fragment offset - buf[8] = 64 // TTL - buf[9] = uint8(h.IPProto) // Inner protocol - // Blank checksum. This is necessary even though we overwrite - // it later, because the checksum computation runs over these - // bytes and expects them to be zero. - binary.BigEndian.PutUint16(buf[10:12], 0) - src := h.Src.As4() - dst := h.Dst.As4() - copy(buf[12:16], src[:]) - copy(buf[16:20], dst[:]) - - binary.BigEndian.PutUint16(buf[10:12], ip4Checksum(buf[0:20])) // Checksum - - return nil -} - -// ToResponse implements Header. -func (h *IP4Header) ToResponse() { - h.Src, h.Dst = h.Dst, h.Src - // Flip the bits in the IPID. If incoming IPIDs are distinct, so are these. - h.IPID = ^h.IPID -} - -// ip4Checksum computes an IPv4 checksum, as specified in -// https://tools.ietf.org/html/rfc1071 -func ip4Checksum(b []byte) uint16 { - var ac uint32 - i := 0 - n := len(b) - for n >= 2 { - ac += uint32(binary.BigEndian.Uint16(b[i : i+2])) - n -= 2 - i += 2 - } - if n == 1 { - ac += uint32(b[i]) << 8 - } - for (ac >> 16) > 0 { - ac = (ac >> 16) + (ac & 0xffff) - } - return uint16(^ac) -} - -// ip4PseudoHeaderOffset is the number of bytes by which the IPv4 UDP -// pseudo-header is smaller than the real IPv4 header. -const ip4PseudoHeaderOffset = 8 - -// marshalPseudo serializes h into buf in the "pseudo-header" form -// required when calculating UDP checksums. The pseudo-header starts -// at buf[ip4PseudoHeaderOffset] so as to abut the following UDP -// header, while leaving enough space in buf for a full IPv4 header. -func (h IP4Header) marshalPseudo(buf []byte) error { - if len(buf) < h.Len() { - return errSmallBuffer - } - if len(buf) > maxPacketLength { - return errLargePacket - } - - length := len(buf) - h.Len() - src, dst := h.Src.As4(), h.Dst.As4() - copy(buf[8:12], src[:]) - copy(buf[12:16], dst[:]) - buf[16] = 0x0 - buf[17] = uint8(h.IPProto) - binary.BigEndian.PutUint16(buf[18:20], uint16(length)) - return nil -} +// Copyright (c) Tailscale Inc & AUTHORS
+// SPDX-License-Identifier: BSD-3-Clause
+
+package packet
+
+import (
+ "encoding/binary"
+ "errors"
+ "net/netip"
+
+ "tailscale.com/types/ipproto"
+)
+
+// ip4HeaderLength is the length of an IPv4 header with no IP options.
+const ip4HeaderLength = 20
+
+// IP4Header represents an IPv4 packet header.
+type IP4Header struct {
+ IPProto ipproto.Proto
+ IPID uint16
+ Src netip.Addr
+ Dst netip.Addr
+}
+
+// Len implements Header.
+func (h IP4Header) Len() int {
+ return ip4HeaderLength
+}
+
+var errWrongFamily = errors.New("wrong address family for src/dst IP")
+
+// Marshal implements Header.
+func (h IP4Header) Marshal(buf []byte) error {
+ if len(buf) < h.Len() {
+ return errSmallBuffer
+ }
+ if len(buf) > maxPacketLength {
+ return errLargePacket
+ }
+ if !h.Src.Is4() || !h.Dst.Is4() {
+ return errWrongFamily
+ }
+
+ buf[0] = 0x40 | (byte(h.Len() >> 2)) // IPv4 + IHL
+ buf[1] = 0x00 // DSCP + ECN
+ binary.BigEndian.PutUint16(buf[2:4], uint16(len(buf))) // Total length
+ binary.BigEndian.PutUint16(buf[4:6], h.IPID) // ID
+ binary.BigEndian.PutUint16(buf[6:8], 0) // Flags + fragment offset
+ buf[8] = 64 // TTL
+ buf[9] = uint8(h.IPProto) // Inner protocol
+ // Blank checksum. This is necessary even though we overwrite
+ // it later, because the checksum computation runs over these
+ // bytes and expects them to be zero.
+ binary.BigEndian.PutUint16(buf[10:12], 0)
+ src := h.Src.As4()
+ dst := h.Dst.As4()
+ copy(buf[12:16], src[:])
+ copy(buf[16:20], dst[:])
+
+ binary.BigEndian.PutUint16(buf[10:12], ip4Checksum(buf[0:20])) // Checksum
+
+ return nil
+}
+
+// ToResponse implements Header.
+func (h *IP4Header) ToResponse() {
+ h.Src, h.Dst = h.Dst, h.Src
+ // Flip the bits in the IPID. If incoming IPIDs are distinct, so are these.
+ h.IPID = ^h.IPID
+}
+
+// ip4Checksum computes an IPv4 checksum, as specified in
+// https://tools.ietf.org/html/rfc1071
+func ip4Checksum(b []byte) uint16 {
+ var ac uint32
+ i := 0
+ n := len(b)
+ for n >= 2 {
+ ac += uint32(binary.BigEndian.Uint16(b[i : i+2]))
+ n -= 2
+ i += 2
+ }
+ if n == 1 {
+ ac += uint32(b[i]) << 8
+ }
+ for (ac >> 16) > 0 {
+ ac = (ac >> 16) + (ac & 0xffff)
+ }
+ return uint16(^ac)
+}
+
+// ip4PseudoHeaderOffset is the number of bytes by which the IPv4 UDP
+// pseudo-header is smaller than the real IPv4 header.
+const ip4PseudoHeaderOffset = 8
+
+// marshalPseudo serializes h into buf in the "pseudo-header" form
+// required when calculating UDP checksums. The pseudo-header starts
+// at buf[ip4PseudoHeaderOffset] so as to abut the following UDP
+// header, while leaving enough space in buf for a full IPv4 header.
+func (h IP4Header) marshalPseudo(buf []byte) error {
+ if len(buf) < h.Len() {
+ return errSmallBuffer
+ }
+ if len(buf) > maxPacketLength {
+ return errLargePacket
+ }
+
+ length := len(buf) - h.Len()
+ src, dst := h.Src.As4(), h.Dst.As4()
+ copy(buf[8:12], src[:])
+ copy(buf[12:16], dst[:])
+ buf[16] = 0x0
+ buf[17] = uint8(h.IPProto)
+ binary.BigEndian.PutUint16(buf[18:20], uint16(length))
+ return nil
+}
diff --git a/net/packet/ip6.go b/net/packet/ip6.go index d26b9a161..cebc46c53 100644 --- a/net/packet/ip6.go +++ b/net/packet/ip6.go @@ -1,76 +1,76 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -package packet - -import ( - "encoding/binary" - "net/netip" - - "tailscale.com/types/ipproto" -) - -// ip6HeaderLength is the length of an IPv6 header with no IP options. -const ip6HeaderLength = 40 - -// IP6Header represents an IPv6 packet header. -type IP6Header struct { - IPProto ipproto.Proto - IPID uint32 // only lower 20 bits used - Src netip.Addr - Dst netip.Addr -} - -// Len implements Header. -func (h IP6Header) Len() int { - return ip6HeaderLength -} - -// Marshal implements Header. -func (h IP6Header) Marshal(buf []byte) error { - if len(buf) < h.Len() { - return errSmallBuffer - } - if len(buf) > maxPacketLength { - return errLargePacket - } - - binary.BigEndian.PutUint32(buf[:4], h.IPID&0x000FFFFF) - buf[0] = 0x60 - binary.BigEndian.PutUint16(buf[4:6], uint16(len(buf)-ip6HeaderLength)) // Total length - buf[6] = uint8(h.IPProto) // Inner protocol - buf[7] = 64 // TTL - src, dst := h.Src.As16(), h.Dst.As16() - copy(buf[8:24], src[:]) - copy(buf[24:40], dst[:]) - - return nil -} - -// ToResponse implements Header. -func (h *IP6Header) ToResponse() { - h.Src, h.Dst = h.Dst, h.Src - // Flip the bits in the IPID. If incoming IPIDs are distinct, so are these. - h.IPID = (^h.IPID) & 0x000FFFFF -} - -// marshalPseudo serializes h into buf in the "pseudo-header" form -// required when calculating UDP checksums. -func (h IP6Header) marshalPseudo(buf []byte, proto ipproto.Proto) error { - if len(buf) < h.Len() { - return errSmallBuffer - } - if len(buf) > maxPacketLength { - return errLargePacket - } - - src, dst := h.Src.As16(), h.Dst.As16() - copy(buf[:16], src[:]) - copy(buf[16:32], dst[:]) - binary.BigEndian.PutUint32(buf[32:36], uint32(len(buf)-h.Len())) - buf[36] = 0 - buf[37] = 0 - buf[38] = 0 - buf[39] = byte(proto) // NextProto - return nil -} +// Copyright (c) Tailscale Inc & AUTHORS
+// SPDX-License-Identifier: BSD-3-Clause
+
+package packet
+
+import (
+ "encoding/binary"
+ "net/netip"
+
+ "tailscale.com/types/ipproto"
+)
+
+// ip6HeaderLength is the length of an IPv6 header with no IP options.
+const ip6HeaderLength = 40
+
+// IP6Header represents an IPv6 packet header.
+type IP6Header struct {
+ IPProto ipproto.Proto
+ IPID uint32 // only lower 20 bits used
+ Src netip.Addr
+ Dst netip.Addr
+}
+
+// Len implements Header.
+func (h IP6Header) Len() int {
+ return ip6HeaderLength
+}
+
+// Marshal implements Header.
+func (h IP6Header) Marshal(buf []byte) error {
+ if len(buf) < h.Len() {
+ return errSmallBuffer
+ }
+ if len(buf) > maxPacketLength {
+ return errLargePacket
+ }
+
+ binary.BigEndian.PutUint32(buf[:4], h.IPID&0x000FFFFF)
+ buf[0] = 0x60
+ binary.BigEndian.PutUint16(buf[4:6], uint16(len(buf)-ip6HeaderLength)) // Total length
+ buf[6] = uint8(h.IPProto) // Inner protocol
+ buf[7] = 64 // TTL
+ src, dst := h.Src.As16(), h.Dst.As16()
+ copy(buf[8:24], src[:])
+ copy(buf[24:40], dst[:])
+
+ return nil
+}
+
+// ToResponse implements Header.
+func (h *IP6Header) ToResponse() {
+ h.Src, h.Dst = h.Dst, h.Src
+ // Flip the bits in the IPID. If incoming IPIDs are distinct, so are these.
+ h.IPID = (^h.IPID) & 0x000FFFFF
+}
+
+// marshalPseudo serializes h into buf in the "pseudo-header" form
+// required when calculating UDP checksums.
+func (h IP6Header) marshalPseudo(buf []byte, proto ipproto.Proto) error {
+ if len(buf) < h.Len() {
+ return errSmallBuffer
+ }
+ if len(buf) > maxPacketLength {
+ return errLargePacket
+ }
+
+ src, dst := h.Src.As16(), h.Dst.As16()
+ copy(buf[:16], src[:])
+ copy(buf[16:32], dst[:])
+ binary.BigEndian.PutUint32(buf[32:36], uint32(len(buf)-h.Len()))
+ buf[36] = 0
+ buf[37] = 0
+ buf[38] = 0
+ buf[39] = byte(proto) // NextProto
+ return nil
+}
diff --git a/net/packet/tsmp_test.go b/net/packet/tsmp_test.go index e261e6a41..4ec24e1ea 100644 --- a/net/packet/tsmp_test.go +++ b/net/packet/tsmp_test.go @@ -1,73 +1,73 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -package packet - -import ( - "net/netip" - "testing" -) - -func TestTailscaleRejectedHeader(t *testing.T) { - tests := []struct { - h TailscaleRejectedHeader - wantStr string - }{ - { - h: TailscaleRejectedHeader{ - IPSrc: netip.MustParseAddr("5.5.5.5"), - IPDst: netip.MustParseAddr("1.2.3.4"), - Src: netip.MustParseAddrPort("1.2.3.4:567"), - Dst: netip.MustParseAddrPort("5.5.5.5:443"), - Proto: TCP, - Reason: RejectedDueToACLs, - }, - wantStr: "TSMP-reject-flow{TCP 1.2.3.4:567 > 5.5.5.5:443}: acl", - }, - { - h: TailscaleRejectedHeader{ - IPSrc: netip.MustParseAddr("2::2"), - IPDst: netip.MustParseAddr("1::1"), - Src: netip.MustParseAddrPort("[1::1]:567"), - Dst: netip.MustParseAddrPort("[2::2]:443"), - Proto: UDP, - Reason: RejectedDueToShieldsUp, - }, - wantStr: "TSMP-reject-flow{UDP [1::1]:567 > [2::2]:443}: shields", - }, - { - h: TailscaleRejectedHeader{ - IPSrc: netip.MustParseAddr("2::2"), - IPDst: netip.MustParseAddr("1::1"), - Src: netip.MustParseAddrPort("[1::1]:567"), - Dst: netip.MustParseAddrPort("[2::2]:443"), - Proto: UDP, - Reason: RejectedDueToIPForwarding, - MaybeBroken: true, - }, - wantStr: "TSMP-reject-flow{UDP [1::1]:567 > [2::2]:443}: host-ip-forwarding-unavailable", - }, - } - for i, tt := range tests { - gotStr := tt.h.String() - if gotStr != tt.wantStr { - t.Errorf("%v. String = %q; want %q", i, gotStr, tt.wantStr) - continue - } - pkt := make([]byte, tt.h.Len()) - tt.h.Marshal(pkt) - - var p Parsed - p.Decode(pkt) - t.Logf("Parsed: %+v", p) - t.Logf("Parsed: %s", p.String()) - back, ok := p.AsTailscaleRejectedHeader() - if !ok { - t.Errorf("%v. %q (%02x) didn't parse back", i, gotStr, pkt) - continue - } - if back != tt.h { - t.Errorf("%v. %q parsed back as %q", i, tt.h, back) - } - } -} +// Copyright (c) Tailscale Inc & AUTHORS
+// SPDX-License-Identifier: BSD-3-Clause
+
+package packet
+
+import (
+ "net/netip"
+ "testing"
+)
+
+func TestTailscaleRejectedHeader(t *testing.T) {
+ tests := []struct {
+ h TailscaleRejectedHeader
+ wantStr string
+ }{
+ {
+ h: TailscaleRejectedHeader{
+ IPSrc: netip.MustParseAddr("5.5.5.5"),
+ IPDst: netip.MustParseAddr("1.2.3.4"),
+ Src: netip.MustParseAddrPort("1.2.3.4:567"),
+ Dst: netip.MustParseAddrPort("5.5.5.5:443"),
+ Proto: TCP,
+ Reason: RejectedDueToACLs,
+ },
+ wantStr: "TSMP-reject-flow{TCP 1.2.3.4:567 > 5.5.5.5:443}: acl",
+ },
+ {
+ h: TailscaleRejectedHeader{
+ IPSrc: netip.MustParseAddr("2::2"),
+ IPDst: netip.MustParseAddr("1::1"),
+ Src: netip.MustParseAddrPort("[1::1]:567"),
+ Dst: netip.MustParseAddrPort("[2::2]:443"),
+ Proto: UDP,
+ Reason: RejectedDueToShieldsUp,
+ },
+ wantStr: "TSMP-reject-flow{UDP [1::1]:567 > [2::2]:443}: shields",
+ },
+ {
+ h: TailscaleRejectedHeader{
+ IPSrc: netip.MustParseAddr("2::2"),
+ IPDst: netip.MustParseAddr("1::1"),
+ Src: netip.MustParseAddrPort("[1::1]:567"),
+ Dst: netip.MustParseAddrPort("[2::2]:443"),
+ Proto: UDP,
+ Reason: RejectedDueToIPForwarding,
+ MaybeBroken: true,
+ },
+ wantStr: "TSMP-reject-flow{UDP [1::1]:567 > [2::2]:443}: host-ip-forwarding-unavailable",
+ },
+ }
+ for i, tt := range tests {
+ gotStr := tt.h.String()
+ if gotStr != tt.wantStr {
+ t.Errorf("%v. String = %q; want %q", i, gotStr, tt.wantStr)
+ continue
+ }
+ pkt := make([]byte, tt.h.Len())
+ tt.h.Marshal(pkt)
+
+ var p Parsed
+ p.Decode(pkt)
+ t.Logf("Parsed: %+v", p)
+ t.Logf("Parsed: %s", p.String())
+ back, ok := p.AsTailscaleRejectedHeader()
+ if !ok {
+ t.Errorf("%v. %q (%02x) didn't parse back", i, gotStr, pkt)
+ continue
+ }
+ if back != tt.h {
+ t.Errorf("%v. %q parsed back as %q", i, tt.h, back)
+ }
+ }
+}
diff --git a/net/packet/udp4.go b/net/packet/udp4.go index 0d5bca73e..c8761baef 100644 --- a/net/packet/udp4.go +++ b/net/packet/udp4.go @@ -1,58 +1,58 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -package packet - -import ( - "encoding/binary" - - "tailscale.com/types/ipproto" -) - -// udpHeaderLength is the size of the UDP packet header, not including -// the outer IP header. -const udpHeaderLength = 8 - -// UDP4Header is an IPv4+UDP header. -type UDP4Header struct { - IP4Header - SrcPort uint16 - DstPort uint16 -} - -// Len implements Header. -func (h UDP4Header) Len() int { - return h.IP4Header.Len() + udpHeaderLength -} - -// Marshal implements Header. -func (h UDP4Header) Marshal(buf []byte) error { - if len(buf) < h.Len() { - return errSmallBuffer - } - if len(buf) > maxPacketLength { - return errLargePacket - } - // The caller does not need to set this. - h.IPProto = ipproto.UDP - - length := len(buf) - h.IP4Header.Len() - binary.BigEndian.PutUint16(buf[20:22], h.SrcPort) - binary.BigEndian.PutUint16(buf[22:24], h.DstPort) - binary.BigEndian.PutUint16(buf[24:26], uint16(length)) - binary.BigEndian.PutUint16(buf[26:28], 0) // blank checksum - - // UDP checksum with IP pseudo header. - h.IP4Header.marshalPseudo(buf) - binary.BigEndian.PutUint16(buf[26:28], ip4Checksum(buf[ip4PseudoHeaderOffset:])) - - h.IP4Header.Marshal(buf) - - return nil -} - -// ToResponse implements Header. -func (h *UDP4Header) ToResponse() { - h.SrcPort, h.DstPort = h.DstPort, h.SrcPort - h.IP4Header.ToResponse() -} +// Copyright (c) Tailscale Inc & AUTHORS
+// SPDX-License-Identifier: BSD-3-Clause
+
+package packet
+
+import (
+ "encoding/binary"
+
+ "tailscale.com/types/ipproto"
+)
+
+// udpHeaderLength is the size of the UDP packet header, not including
+// the outer IP header.
+const udpHeaderLength = 8
+
+// UDP4Header is an IPv4+UDP header.
+type UDP4Header struct {
+ IP4Header
+ SrcPort uint16
+ DstPort uint16
+}
+
+// Len implements Header.
+func (h UDP4Header) Len() int {
+ return h.IP4Header.Len() + udpHeaderLength
+}
+
+// Marshal implements Header.
+func (h UDP4Header) Marshal(buf []byte) error {
+ if len(buf) < h.Len() {
+ return errSmallBuffer
+ }
+ if len(buf) > maxPacketLength {
+ return errLargePacket
+ }
+ // The caller does not need to set this.
+ h.IPProto = ipproto.UDP
+
+ length := len(buf) - h.IP4Header.Len()
+ binary.BigEndian.PutUint16(buf[20:22], h.SrcPort)
+ binary.BigEndian.PutUint16(buf[22:24], h.DstPort)
+ binary.BigEndian.PutUint16(buf[24:26], uint16(length))
+ binary.BigEndian.PutUint16(buf[26:28], 0) // blank checksum
+
+ // UDP checksum with IP pseudo header.
+ h.IP4Header.marshalPseudo(buf)
+ binary.BigEndian.PutUint16(buf[26:28], ip4Checksum(buf[ip4PseudoHeaderOffset:]))
+
+ h.IP4Header.Marshal(buf)
+
+ return nil
+}
+
+// ToResponse implements Header.
+func (h *UDP4Header) ToResponse() {
+ h.SrcPort, h.DstPort = h.DstPort, h.SrcPort
+ h.IP4Header.ToResponse()
+}
diff --git a/net/packet/udp6.go b/net/packet/udp6.go index 10fdcb99e..c8634b508 100644 --- a/net/packet/udp6.go +++ b/net/packet/udp6.go @@ -1,54 +1,54 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -package packet - -import ( - "encoding/binary" - - "tailscale.com/types/ipproto" -) - -// UDP6Header is an IPv6+UDP header. -type UDP6Header struct { - IP6Header - SrcPort uint16 - DstPort uint16 -} - -// Len implements Header. -func (h UDP6Header) Len() int { - return h.IP6Header.Len() + udpHeaderLength -} - -// Marshal implements Header. -func (h UDP6Header) Marshal(buf []byte) error { - if len(buf) < h.Len() { - return errSmallBuffer - } - if len(buf) > maxPacketLength { - return errLargePacket - } - // The caller does not need to set this. - h.IPProto = ipproto.UDP - - length := len(buf) - h.IP6Header.Len() - binary.BigEndian.PutUint16(buf[40:42], h.SrcPort) - binary.BigEndian.PutUint16(buf[42:44], h.DstPort) - binary.BigEndian.PutUint16(buf[44:46], uint16(length)) - binary.BigEndian.PutUint16(buf[46:48], 0) // blank checksum - - // UDP checksum with IP pseudo header. - h.IP6Header.marshalPseudo(buf, ipproto.UDP) - binary.BigEndian.PutUint16(buf[46:48], ip4Checksum(buf[:])) - - h.IP6Header.Marshal(buf) - - return nil -} - -// ToResponse implements Header. -func (h *UDP6Header) ToResponse() { - h.SrcPort, h.DstPort = h.DstPort, h.SrcPort - h.IP6Header.ToResponse() -} +// Copyright (c) Tailscale Inc & AUTHORS
+// SPDX-License-Identifier: BSD-3-Clause
+
+package packet
+
+import (
+ "encoding/binary"
+
+ "tailscale.com/types/ipproto"
+)
+
+// UDP6Header is an IPv6+UDP header.
+type UDP6Header struct {
+ IP6Header
+ SrcPort uint16
+ DstPort uint16
+}
+
+// Len implements Header.
+func (h UDP6Header) Len() int {
+ return h.IP6Header.Len() + udpHeaderLength
+}
+
+// Marshal implements Header.
+func (h UDP6Header) Marshal(buf []byte) error {
+ if len(buf) < h.Len() {
+ return errSmallBuffer
+ }
+ if len(buf) > maxPacketLength {
+ return errLargePacket
+ }
+ // The caller does not need to set this.
+ h.IPProto = ipproto.UDP
+
+ length := len(buf) - h.IP6Header.Len()
+ binary.BigEndian.PutUint16(buf[40:42], h.SrcPort)
+ binary.BigEndian.PutUint16(buf[42:44], h.DstPort)
+ binary.BigEndian.PutUint16(buf[44:46], uint16(length))
+ binary.BigEndian.PutUint16(buf[46:48], 0) // blank checksum
+
+ // UDP checksum with IP pseudo header.
+ h.IP6Header.marshalPseudo(buf, ipproto.UDP)
+ binary.BigEndian.PutUint16(buf[46:48], ip4Checksum(buf[:]))
+
+ h.IP6Header.Marshal(buf)
+
+ return nil
+}
+
+// ToResponse implements Header.
+func (h *UDP6Header) ToResponse() {
+ h.SrcPort, h.DstPort = h.DstPort, h.SrcPort
+ h.IP6Header.ToResponse()
+}
diff --git a/net/ping/ping.go b/net/ping/ping.go index 01f3dcf2c..f2093292a 100644 --- a/net/ping/ping.go +++ b/net/ping/ping.go @@ -1,343 +1,343 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -// Package ping allows sending ICMP echo requests to a host in order to -// determine network latency. -package ping - -import ( - "bytes" - "context" - "crypto/rand" - "encoding/binary" - "fmt" - "io" - "log" - "net" - "net/netip" - "sync" - "sync/atomic" - "time" - - "golang.org/x/net/icmp" - "golang.org/x/net/ipv4" - "golang.org/x/net/ipv6" - "tailscale.com/types/logger" - "tailscale.com/util/mak" - "tailscale.com/util/multierr" -) - -const ( - v4Type = "ip4:icmp" - v6Type = "ip6:icmp" -) - -type response struct { - t time.Time - err error -} - -type outstanding struct { - ch chan response - data []byte -} - -// PacketListener defines the interface required to listen to packages -// on an address. -type ListenPacketer interface { - ListenPacket(ctx context.Context, typ string, addr string) (net.PacketConn, error) -} - -// Pinger represents a set of ICMP echo requests to be sent at a single time. -// -// A new instance should be created for each concurrent set of ping requests; -// this type should not be reused. -type Pinger struct { - lp ListenPacketer - - // closed guards against send incrementing the waitgroup concurrently with close. - closed atomic.Bool - Logf logger.Logf - Verbose bool - timeNow func() time.Time - id uint16 // uint16 per RFC 792 - wg sync.WaitGroup - - // Following fields protected by mu - mu sync.Mutex - // conns is a map of "type" to net.PacketConn, type is either - // "ip4:icmp" or "ip6:icmp" - conns map[string]net.PacketConn - seq uint16 // uint16 per RFC 792 - pings map[uint16]outstanding -} - -// New creates a new Pinger. The Context provided will be used to create -// network listeners, and to set an absolute deadline (if any) on the net.Conn -func New(ctx context.Context, logf logger.Logf, lp ListenPacketer) *Pinger { - var id [2]byte - if _, err := io.ReadFull(rand.Reader, id[:]); err != nil { - panic("net/ping: New:" + err.Error()) - } - - return &Pinger{ - lp: lp, - Logf: logf, - timeNow: time.Now, - id: binary.LittleEndian.Uint16(id[:]), - pings: make(map[uint16]outstanding), - } -} - -func (p *Pinger) mkconn(ctx context.Context, typ, addr string) (net.PacketConn, error) { - if p.closed.Load() { - return nil, net.ErrClosed - } - - c, err := p.lp.ListenPacket(ctx, typ, addr) - if err != nil { - return nil, err - } - - // Start by setting the deadline from the context; note that this - // applies to all future I/O, so we only need to do it once. - deadline, ok := ctx.Deadline() - if ok { - if err := c.SetReadDeadline(deadline); err != nil { - return nil, err - } - } - - p.wg.Add(1) - go p.run(ctx, c, typ) - - return c, err -} - -// getConn creates or returns a conn matching typ which is ip4:icmp -// or ip6:icmp. -func (p *Pinger) getConn(ctx context.Context, typ string) (net.PacketConn, error) { - p.mu.Lock() - defer p.mu.Unlock() - if c, ok := p.conns[typ]; ok { - return c, nil - } - - var addr = "0.0.0.0" - if typ == v6Type { - addr = "::" - } - c, err := p.mkconn(ctx, typ, addr) - if err != nil { - return nil, err - } - mak.Set(&p.conns, typ, c) - return c, nil -} - -func (p *Pinger) logf(format string, a ...any) { - if p.Logf != nil { - p.Logf(format, a...) - } else { - log.Printf(format, a...) - } -} - -func (p *Pinger) vlogf(format string, a ...any) { - if p.Verbose { - p.logf(format, a...) - } -} - -func (p *Pinger) Close() error { - p.closed.Store(true) - - p.mu.Lock() - conns := p.conns - p.conns = nil - p.mu.Unlock() - - var errors []error - for _, c := range conns { - if err := c.Close(); err != nil { - errors = append(errors, err) - } - } - - p.wg.Wait() - p.cleanupOutstanding() - - return multierr.New(errors...) -} - -func (p *Pinger) run(ctx context.Context, conn net.PacketConn, typ string) { - defer p.wg.Done() - defer func() { - conn.Close() - p.mu.Lock() - delete(p.conns, typ) - p.mu.Unlock() - }() - buf := make([]byte, 1500) - -loop: - for { - select { - case <-ctx.Done(): - break loop - default: - } - - n, _, err := conn.ReadFrom(buf) - if err != nil { - // Ignore temporary errors; everything else is fatal - if netErr, ok := err.(net.Error); !ok || !netErr.Temporary() { - break - } - continue - } - - p.handleResponse(buf[:n], p.timeNow(), typ) - } -} - -func (p *Pinger) cleanupOutstanding() { - // Complete outstanding requests - p.mu.Lock() - defer p.mu.Unlock() - for _, o := range p.pings { - o.ch <- response{err: net.ErrClosed} - } -} - -func (p *Pinger) handleResponse(buf []byte, now time.Time, typ string) { - // We need to handle responding to both IPv4 - // and IPv6. - var icmpType icmp.Type - switch typ { - case v4Type: - icmpType = ipv4.ICMPTypeEchoReply - case v6Type: - icmpType = ipv6.ICMPTypeEchoReply - default: - p.vlogf("handleResponse: unknown icmp.Type") - return - } - - m, err := icmp.ParseMessage(icmpType.Protocol(), buf) - if err != nil { - p.vlogf("handleResponse: invalid packet: %v", err) - return - } - - if m.Type != icmpType { - p.vlogf("handleResponse: wanted m.Type=%d; got %d", icmpType, m.Type) - return - } - - resp, ok := m.Body.(*icmp.Echo) - if !ok || resp == nil { - p.vlogf("handleResponse: wanted body=*icmp.Echo; got %v", m.Body) - return - } - - // We assume we sent this if the ID in the response is ours. - if uint16(resp.ID) != p.id { - p.vlogf("handleResponse: wanted ID=%d; got %d", p.id, resp.ID) - return - } - - // Search for existing running echo request - var o outstanding - p.mu.Lock() - if o, ok = p.pings[uint16(resp.Seq)]; ok { - // Ensure that the data matches before we delete from our map, - // so a future correct packet will be handled correctly. - if bytes.Equal(resp.Data, o.data) { - delete(p.pings, uint16(resp.Seq)) - } else { - p.vlogf("handleResponse: got response for Seq %d with mismatched data", resp.Seq) - ok = false - } - } else { - p.vlogf("handleResponse: got response for unknown Seq %d", resp.Seq) - } - p.mu.Unlock() - - if ok { - o.ch <- response{t: now} - } -} - -// Send sends an ICMP Echo Request packet to the destination, waits for a -// response, and returns the duration between when the request was sent and -// when the reply was received. -// -// If provided, "data" is sent with the packet and is compared upon receiving a -// reply. -func (p *Pinger) Send(ctx context.Context, dest net.Addr, data []byte) (time.Duration, error) { - // Use sequential sequence numbers on the assumption that we will not - // wrap around when using a single Pinger instance - p.mu.Lock() - p.seq++ - seq := p.seq - p.mu.Unlock() - - // Check whether the address is IPv4 or IPv6 to - // determine the icmp.Type and conn to use. - var conn net.PacketConn - var icmpType icmp.Type = ipv4.ICMPTypeEcho - ap, err := netip.ParseAddr(dest.String()) - if err != nil { - return 0, err - } - if ap.Is6() { - icmpType = ipv6.ICMPTypeEchoRequest - conn, err = p.getConn(ctx, v6Type) - } else { - conn, err = p.getConn(ctx, v4Type) - } - if err != nil { - return 0, err - } - - m := icmp.Message{ - Type: icmpType, - Code: 0, - Body: &icmp.Echo{ - ID: int(p.id), - Seq: int(seq), - Data: data, - }, - } - b, err := m.Marshal(nil) - if err != nil { - return 0, err - } - - // Register our response before sending since we could otherwise race a - // quick reply. - ch := make(chan response, 1) - p.mu.Lock() - p.pings[seq] = outstanding{ch: ch, data: data} - p.mu.Unlock() - - start := p.timeNow() - n, err := conn.WriteTo(b, dest) - if err != nil { - return 0, err - } else if n != len(b) { - return 0, fmt.Errorf("conn.WriteTo: got %v; want %v", n, len(b)) - } - - select { - case resp := <-ch: - if resp.err != nil { - return 0, resp.err - } - return resp.t.Sub(start), nil - - case <-ctx.Done(): - return 0, ctx.Err() - } -} +// Copyright (c) Tailscale Inc & AUTHORS
+// SPDX-License-Identifier: BSD-3-Clause
+
+// Package ping allows sending ICMP echo requests to a host in order to
+// determine network latency.
+package ping
+
+import (
+ "bytes"
+ "context"
+ "crypto/rand"
+ "encoding/binary"
+ "fmt"
+ "io"
+ "log"
+ "net"
+ "net/netip"
+ "sync"
+ "sync/atomic"
+ "time"
+
+ "golang.org/x/net/icmp"
+ "golang.org/x/net/ipv4"
+ "golang.org/x/net/ipv6"
+ "tailscale.com/types/logger"
+ "tailscale.com/util/mak"
+ "tailscale.com/util/multierr"
+)
+
+const (
+ v4Type = "ip4:icmp"
+ v6Type = "ip6:icmp"
+)
+
+type response struct {
+ t time.Time
+ err error
+}
+
+type outstanding struct {
+ ch chan response
+ data []byte
+}
+
+// PacketListener defines the interface required to listen to packages
+// on an address.
+type ListenPacketer interface {
+ ListenPacket(ctx context.Context, typ string, addr string) (net.PacketConn, error)
+}
+
+// Pinger represents a set of ICMP echo requests to be sent at a single time.
+//
+// A new instance should be created for each concurrent set of ping requests;
+// this type should not be reused.
+type Pinger struct {
+ lp ListenPacketer
+
+ // closed guards against send incrementing the waitgroup concurrently with close.
+ closed atomic.Bool
+ Logf logger.Logf
+ Verbose bool
+ timeNow func() time.Time
+ id uint16 // uint16 per RFC 792
+ wg sync.WaitGroup
+
+ // Following fields protected by mu
+ mu sync.Mutex
+ // conns is a map of "type" to net.PacketConn, type is either
+ // "ip4:icmp" or "ip6:icmp"
+ conns map[string]net.PacketConn
+ seq uint16 // uint16 per RFC 792
+ pings map[uint16]outstanding
+}
+
+// New creates a new Pinger. The Context provided will be used to create
+// network listeners, and to set an absolute deadline (if any) on the net.Conn
+func New(ctx context.Context, logf logger.Logf, lp ListenPacketer) *Pinger {
+ var id [2]byte
+ if _, err := io.ReadFull(rand.Reader, id[:]); err != nil {
+ panic("net/ping: New:" + err.Error())
+ }
+
+ return &Pinger{
+ lp: lp,
+ Logf: logf,
+ timeNow: time.Now,
+ id: binary.LittleEndian.Uint16(id[:]),
+ pings: make(map[uint16]outstanding),
+ }
+}
+
+func (p *Pinger) mkconn(ctx context.Context, typ, addr string) (net.PacketConn, error) {
+ if p.closed.Load() {
+ return nil, net.ErrClosed
+ }
+
+ c, err := p.lp.ListenPacket(ctx, typ, addr)
+ if err != nil {
+ return nil, err
+ }
+
+ // Start by setting the deadline from the context; note that this
+ // applies to all future I/O, so we only need to do it once.
+ deadline, ok := ctx.Deadline()
+ if ok {
+ if err := c.SetReadDeadline(deadline); err != nil {
+ return nil, err
+ }
+ }
+
+ p.wg.Add(1)
+ go p.run(ctx, c, typ)
+
+ return c, err
+}
+
+// getConn creates or returns a conn matching typ which is ip4:icmp
+// or ip6:icmp.
+func (p *Pinger) getConn(ctx context.Context, typ string) (net.PacketConn, error) {
+ p.mu.Lock()
+ defer p.mu.Unlock()
+ if c, ok := p.conns[typ]; ok {
+ return c, nil
+ }
+
+ var addr = "0.0.0.0"
+ if typ == v6Type {
+ addr = "::"
+ }
+ c, err := p.mkconn(ctx, typ, addr)
+ if err != nil {
+ return nil, err
+ }
+ mak.Set(&p.conns, typ, c)
+ return c, nil
+}
+
+func (p *Pinger) logf(format string, a ...any) {
+ if p.Logf != nil {
+ p.Logf(format, a...)
+ } else {
+ log.Printf(format, a...)
+ }
+}
+
+func (p *Pinger) vlogf(format string, a ...any) {
+ if p.Verbose {
+ p.logf(format, a...)
+ }
+}
+
+func (p *Pinger) Close() error {
+ p.closed.Store(true)
+
+ p.mu.Lock()
+ conns := p.conns
+ p.conns = nil
+ p.mu.Unlock()
+
+ var errors []error
+ for _, c := range conns {
+ if err := c.Close(); err != nil {
+ errors = append(errors, err)
+ }
+ }
+
+ p.wg.Wait()
+ p.cleanupOutstanding()
+
+ return multierr.New(errors...)
+}
+
+func (p *Pinger) run(ctx context.Context, conn net.PacketConn, typ string) {
+ defer p.wg.Done()
+ defer func() {
+ conn.Close()
+ p.mu.Lock()
+ delete(p.conns, typ)
+ p.mu.Unlock()
+ }()
+ buf := make([]byte, 1500)
+
+loop:
+ for {
+ select {
+ case <-ctx.Done():
+ break loop
+ default:
+ }
+
+ n, _, err := conn.ReadFrom(buf)
+ if err != nil {
+ // Ignore temporary errors; everything else is fatal
+ if netErr, ok := err.(net.Error); !ok || !netErr.Temporary() {
+ break
+ }
+ continue
+ }
+
+ p.handleResponse(buf[:n], p.timeNow(), typ)
+ }
+}
+
+func (p *Pinger) cleanupOutstanding() {
+ // Complete outstanding requests
+ p.mu.Lock()
+ defer p.mu.Unlock()
+ for _, o := range p.pings {
+ o.ch <- response{err: net.ErrClosed}
+ }
+}
+
+func (p *Pinger) handleResponse(buf []byte, now time.Time, typ string) {
+ // We need to handle responding to both IPv4
+ // and IPv6.
+ var icmpType icmp.Type
+ switch typ {
+ case v4Type:
+ icmpType = ipv4.ICMPTypeEchoReply
+ case v6Type:
+ icmpType = ipv6.ICMPTypeEchoReply
+ default:
+ p.vlogf("handleResponse: unknown icmp.Type")
+ return
+ }
+
+ m, err := icmp.ParseMessage(icmpType.Protocol(), buf)
+ if err != nil {
+ p.vlogf("handleResponse: invalid packet: %v", err)
+ return
+ }
+
+ if m.Type != icmpType {
+ p.vlogf("handleResponse: wanted m.Type=%d; got %d", icmpType, m.Type)
+ return
+ }
+
+ resp, ok := m.Body.(*icmp.Echo)
+ if !ok || resp == nil {
+ p.vlogf("handleResponse: wanted body=*icmp.Echo; got %v", m.Body)
+ return
+ }
+
+ // We assume we sent this if the ID in the response is ours.
+ if uint16(resp.ID) != p.id {
+ p.vlogf("handleResponse: wanted ID=%d; got %d", p.id, resp.ID)
+ return
+ }
+
+ // Search for existing running echo request
+ var o outstanding
+ p.mu.Lock()
+ if o, ok = p.pings[uint16(resp.Seq)]; ok {
+ // Ensure that the data matches before we delete from our map,
+ // so a future correct packet will be handled correctly.
+ if bytes.Equal(resp.Data, o.data) {
+ delete(p.pings, uint16(resp.Seq))
+ } else {
+ p.vlogf("handleResponse: got response for Seq %d with mismatched data", resp.Seq)
+ ok = false
+ }
+ } else {
+ p.vlogf("handleResponse: got response for unknown Seq %d", resp.Seq)
+ }
+ p.mu.Unlock()
+
+ if ok {
+ o.ch <- response{t: now}
+ }
+}
+
+// Send sends an ICMP Echo Request packet to the destination, waits for a
+// response, and returns the duration between when the request was sent and
+// when the reply was received.
+//
+// If provided, "data" is sent with the packet and is compared upon receiving a
+// reply.
+func (p *Pinger) Send(ctx context.Context, dest net.Addr, data []byte) (time.Duration, error) {
+ // Use sequential sequence numbers on the assumption that we will not
+ // wrap around when using a single Pinger instance
+ p.mu.Lock()
+ p.seq++
+ seq := p.seq
+ p.mu.Unlock()
+
+ // Check whether the address is IPv4 or IPv6 to
+ // determine the icmp.Type and conn to use.
+ var conn net.PacketConn
+ var icmpType icmp.Type = ipv4.ICMPTypeEcho
+ ap, err := netip.ParseAddr(dest.String())
+ if err != nil {
+ return 0, err
+ }
+ if ap.Is6() {
+ icmpType = ipv6.ICMPTypeEchoRequest
+ conn, err = p.getConn(ctx, v6Type)
+ } else {
+ conn, err = p.getConn(ctx, v4Type)
+ }
+ if err != nil {
+ return 0, err
+ }
+
+ m := icmp.Message{
+ Type: icmpType,
+ Code: 0,
+ Body: &icmp.Echo{
+ ID: int(p.id),
+ Seq: int(seq),
+ Data: data,
+ },
+ }
+ b, err := m.Marshal(nil)
+ if err != nil {
+ return 0, err
+ }
+
+ // Register our response before sending since we could otherwise race a
+ // quick reply.
+ ch := make(chan response, 1)
+ p.mu.Lock()
+ p.pings[seq] = outstanding{ch: ch, data: data}
+ p.mu.Unlock()
+
+ start := p.timeNow()
+ n, err := conn.WriteTo(b, dest)
+ if err != nil {
+ return 0, err
+ } else if n != len(b) {
+ return 0, fmt.Errorf("conn.WriteTo: got %v; want %v", n, len(b))
+ }
+
+ select {
+ case resp := <-ch:
+ if resp.err != nil {
+ return 0, resp.err
+ }
+ return resp.t.Sub(start), nil
+
+ case <-ctx.Done():
+ return 0, ctx.Err()
+ }
+}
diff --git a/net/ping/ping_test.go b/net/ping/ping_test.go index bbedbcad8..5232f6ada 100644 --- a/net/ping/ping_test.go +++ b/net/ping/ping_test.go @@ -1,350 +1,350 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -package ping - -import ( - "context" - "errors" - "fmt" - "net" - "testing" - "time" - - "golang.org/x/net/icmp" - "golang.org/x/net/ipv4" - "golang.org/x/net/ipv6" - "tailscale.com/tstest" - "tailscale.com/util/mak" -) - -var ( - localhost = &net.IPAddr{IP: net.IPv4(127, 0, 0, 1)} -) - -func TestPinger(t *testing.T) { - clock := &tstest.Clock{} - - ctx := context.Background() - ctx, cancel := context.WithTimeout(ctx, 5*time.Second) - defer cancel() - - p, closeP := mockPinger(t, clock) - defer closeP() - - bodyData := []byte("data goes here") - - // Start a ping in the background - r := make(chan time.Duration, 1) - go func() { - dur, err := p.Send(ctx, localhost, bodyData) - if err != nil { - t.Errorf("p.Send: %v", err) - r <- 0 - } else { - r <- dur - } - }() - - p.waitOutstanding(t, ctx, 1) - - // Fake a response from ourself - fakeResponse := mustMarshal(t, &icmp.Message{ - Type: ipv4.ICMPTypeEchoReply, - Code: ipv4.ICMPTypeEchoReply.Protocol(), - Body: &icmp.Echo{ - ID: 1234, - Seq: 1, - Data: bodyData, - }, - }) - - const fakeDuration = 100 * time.Millisecond - p.handleResponse(fakeResponse, clock.Now().Add(fakeDuration), v4Type) - - select { - case dur := <-r: - want := fakeDuration - if dur != want { - t.Errorf("wanted ping response time = %d; got %d", want, dur) - } - case <-ctx.Done(): - t.Fatal("did not get response by timeout") - } -} - -func TestV6Pinger(t *testing.T) { - if c, err := net.ListenPacket("udp6", "::1"); err != nil { - // skip test if we can't use IPv6. - t.Skipf("IPv6 not supported: %s", err) - } else { - c.Close() - } - - clock := &tstest.Clock{} - - ctx := context.Background() - ctx, cancel := context.WithTimeout(ctx, 5*time.Second) - defer cancel() - - p, closeP := mockPinger(t, clock) - defer closeP() - - bodyData := []byte("data goes here") - - // Start a ping in the background - r := make(chan time.Duration, 1) - go func() { - dur, err := p.Send(ctx, &net.IPAddr{IP: net.ParseIP("::")}, bodyData) - if err != nil { - t.Errorf("p.Send: %v", err) - r <- 0 - } else { - r <- dur - } - }() - - p.waitOutstanding(t, ctx, 1) - - // Fake a response from ourself - fakeResponse := mustMarshal(t, &icmp.Message{ - Type: ipv6.ICMPTypeEchoReply, - Code: ipv6.ICMPTypeEchoReply.Protocol(), - Body: &icmp.Echo{ - ID: 1234, - Seq: 1, - Data: bodyData, - }, - }) - - const fakeDuration = 100 * time.Millisecond - p.handleResponse(fakeResponse, clock.Now().Add(fakeDuration), v6Type) - - select { - case dur := <-r: - want := fakeDuration - if dur != want { - t.Errorf("wanted ping response time = %d; got %d", want, dur) - } - case <-ctx.Done(): - t.Fatal("did not get response by timeout") - } -} - -func TestPingerTimeout(t *testing.T) { - ctx := context.Background() - ctx, cancel := context.WithTimeout(ctx, 5*time.Second) - defer cancel() - - clock := &tstest.Clock{} - p, closeP := mockPinger(t, clock) - defer closeP() - - // Send a ping in the background - r := make(chan error, 1) - go func() { - _, err := p.Send(ctx, localhost, []byte("data goes here")) - r <- err - }() - - // Wait until we're blocking - p.waitOutstanding(t, ctx, 1) - - // Close everything down - p.cleanupOutstanding() - - // Should have got an error from the ping - err := <-r - if !errors.Is(err, net.ErrClosed) { - t.Errorf("wanted errors.Is(err, net.ErrClosed); got=%v", err) - } -} - -func TestPingerMismatch(t *testing.T) { - clock := &tstest.Clock{} - - ctx := context.Background() - ctx, cancel := context.WithTimeout(ctx, 1*time.Second) // intentionally short - defer cancel() - - p, closeP := mockPinger(t, clock) - defer closeP() - - bodyData := []byte("data goes here") - - // Start a ping in the background - r := make(chan time.Duration, 1) - go func() { - dur, err := p.Send(ctx, localhost, bodyData) - if err != nil && !errors.Is(err, context.DeadlineExceeded) { - t.Errorf("p.Send: %v", err) - r <- 0 - } else { - r <- dur - } - }() - - p.waitOutstanding(t, ctx, 1) - - // "Receive" a bunch of intentionally malformed packets that should not - // result in the Send call above returning - badPackets := []struct { - name string - pkt *icmp.Message - }{ - { - name: "wrong type", - pkt: &icmp.Message{ - Type: ipv4.ICMPTypeDestinationUnreachable, - Code: 0, - Body: &icmp.DstUnreach{}, - }, - }, - { - name: "wrong id", - pkt: &icmp.Message{ - Type: ipv4.ICMPTypeEchoReply, - Code: 0, - Body: &icmp.Echo{ - ID: 9999, - Seq: 1, - Data: bodyData, - }, - }, - }, - { - name: "wrong seq", - pkt: &icmp.Message{ - Type: ipv4.ICMPTypeEchoReply, - Code: 0, - Body: &icmp.Echo{ - ID: 1234, - Seq: 5, - Data: bodyData, - }, - }, - }, - { - name: "bad body", - pkt: &icmp.Message{ - Type: ipv4.ICMPTypeEchoReply, - Code: 0, - Body: &icmp.Echo{ - ID: 1234, - Seq: 1, - - // Intentionally missing first byte - Data: bodyData[1:], - }, - }, - }, - } - - const fakeDuration = 100 * time.Millisecond - tm := clock.Now().Add(fakeDuration) - - for _, tt := range badPackets { - fakeResponse := mustMarshal(t, tt.pkt) - p.handleResponse(fakeResponse, tm, v4Type) - } - - // Also "receive" a packet that does not unmarshal as an ICMP packet - p.handleResponse([]byte("foo"), tm, v4Type) - - select { - case <-r: - t.Fatal("wanted timeout") - case <-ctx.Done(): - t.Logf("test correctly timed out") - } -} - -// udpingPacketConn will convert potentially ICMP destination addrs to UDP -// destination addrs in WriteTo so that a test that is intending to send ICMP -// traffic will instead send UDP traffic, without the higher level Pinger being -// aware of this difference. -type udpingPacketConn struct { - net.PacketConn - // destPort will be configured by the test to be the peer expected to respond to a ping. - destPort uint16 -} - -func (u *udpingPacketConn) WriteTo(body []byte, dest net.Addr) (int, error) { - switch d := dest.(type) { - case *net.IPAddr: - udpAddr := &net.UDPAddr{ - IP: d.IP, - Port: int(u.destPort), - Zone: d.Zone, - } - return u.PacketConn.WriteTo(body, udpAddr) - } - return 0, fmt.Errorf("unimplemented udpingPacketConn for %T", dest) -} - -func mockPinger(t *testing.T, clock *tstest.Clock) (*Pinger, func()) { - p := New(context.Background(), t.Logf, nil) - p.timeNow = clock.Now - p.Verbose = true - p.id = 1234 - - // In tests, we use UDP so that we can test without being root; this - // doesn't matter because we mock out the ICMP reply below to be a real - // ICMP echo reply packet. - conn4, err := net.ListenPacket("udp4", "127.0.0.1:0") - if err != nil { - t.Fatalf("net.ListenPacket: %v", err) - } - - conn6, err := net.ListenPacket("udp6", "[::]:0") - if err != nil { - t.Fatalf("net.ListenPacket: %v", err) - } - - conn4 = &udpingPacketConn{ - destPort: 12345, - PacketConn: conn4, - } - conn6 = &udpingPacketConn{ - PacketConn: conn6, - destPort: 12345, - } - - mak.Set(&p.conns, v4Type, conn4) - mak.Set(&p.conns, v6Type, conn6) - done := func() { - if err := p.Close(); err != nil { - t.Errorf("error on close: %v", err) - } - } - return p, done -} - -func mustMarshal(t *testing.T, m *icmp.Message) []byte { - t.Helper() - - b, err := m.Marshal(nil) - if err != nil { - t.Fatal(err) - } - return b -} - -func (p *Pinger) waitOutstanding(t *testing.T, ctx context.Context, count int) { - // This is a bit janky, but... we busy-loop to wait for the Send call - // to write to our map so we know that a response will be handled. - var haveMapEntry bool - for !haveMapEntry { - time.Sleep(10 * time.Millisecond) - select { - case <-ctx.Done(): - t.Error("no entry in ping map before timeout") - return - default: - } - - p.mu.Lock() - haveMapEntry = len(p.pings) == count - p.mu.Unlock() - } -} +// Copyright (c) Tailscale Inc & AUTHORS
+// SPDX-License-Identifier: BSD-3-Clause
+
+package ping
+
+import (
+ "context"
+ "errors"
+ "fmt"
+ "net"
+ "testing"
+ "time"
+
+ "golang.org/x/net/icmp"
+ "golang.org/x/net/ipv4"
+ "golang.org/x/net/ipv6"
+ "tailscale.com/tstest"
+ "tailscale.com/util/mak"
+)
+
+var (
+ localhost = &net.IPAddr{IP: net.IPv4(127, 0, 0, 1)}
+)
+
+func TestPinger(t *testing.T) {
+ clock := &tstest.Clock{}
+
+ ctx := context.Background()
+ ctx, cancel := context.WithTimeout(ctx, 5*time.Second)
+ defer cancel()
+
+ p, closeP := mockPinger(t, clock)
+ defer closeP()
+
+ bodyData := []byte("data goes here")
+
+ // Start a ping in the background
+ r := make(chan time.Duration, 1)
+ go func() {
+ dur, err := p.Send(ctx, localhost, bodyData)
+ if err != nil {
+ t.Errorf("p.Send: %v", err)
+ r <- 0
+ } else {
+ r <- dur
+ }
+ }()
+
+ p.waitOutstanding(t, ctx, 1)
+
+ // Fake a response from ourself
+ fakeResponse := mustMarshal(t, &icmp.Message{
+ Type: ipv4.ICMPTypeEchoReply,
+ Code: ipv4.ICMPTypeEchoReply.Protocol(),
+ Body: &icmp.Echo{
+ ID: 1234,
+ Seq: 1,
+ Data: bodyData,
+ },
+ })
+
+ const fakeDuration = 100 * time.Millisecond
+ p.handleResponse(fakeResponse, clock.Now().Add(fakeDuration), v4Type)
+
+ select {
+ case dur := <-r:
+ want := fakeDuration
+ if dur != want {
+ t.Errorf("wanted ping response time = %d; got %d", want, dur)
+ }
+ case <-ctx.Done():
+ t.Fatal("did not get response by timeout")
+ }
+}
+
+func TestV6Pinger(t *testing.T) {
+ if c, err := net.ListenPacket("udp6", "::1"); err != nil {
+ // skip test if we can't use IPv6.
+ t.Skipf("IPv6 not supported: %s", err)
+ } else {
+ c.Close()
+ }
+
+ clock := &tstest.Clock{}
+
+ ctx := context.Background()
+ ctx, cancel := context.WithTimeout(ctx, 5*time.Second)
+ defer cancel()
+
+ p, closeP := mockPinger(t, clock)
+ defer closeP()
+
+ bodyData := []byte("data goes here")
+
+ // Start a ping in the background
+ r := make(chan time.Duration, 1)
+ go func() {
+ dur, err := p.Send(ctx, &net.IPAddr{IP: net.ParseIP("::")}, bodyData)
+ if err != nil {
+ t.Errorf("p.Send: %v", err)
+ r <- 0
+ } else {
+ r <- dur
+ }
+ }()
+
+ p.waitOutstanding(t, ctx, 1)
+
+ // Fake a response from ourself
+ fakeResponse := mustMarshal(t, &icmp.Message{
+ Type: ipv6.ICMPTypeEchoReply,
+ Code: ipv6.ICMPTypeEchoReply.Protocol(),
+ Body: &icmp.Echo{
+ ID: 1234,
+ Seq: 1,
+ Data: bodyData,
+ },
+ })
+
+ const fakeDuration = 100 * time.Millisecond
+ p.handleResponse(fakeResponse, clock.Now().Add(fakeDuration), v6Type)
+
+ select {
+ case dur := <-r:
+ want := fakeDuration
+ if dur != want {
+ t.Errorf("wanted ping response time = %d; got %d", want, dur)
+ }
+ case <-ctx.Done():
+ t.Fatal("did not get response by timeout")
+ }
+}
+
+func TestPingerTimeout(t *testing.T) {
+ ctx := context.Background()
+ ctx, cancel := context.WithTimeout(ctx, 5*time.Second)
+ defer cancel()
+
+ clock := &tstest.Clock{}
+ p, closeP := mockPinger(t, clock)
+ defer closeP()
+
+ // Send a ping in the background
+ r := make(chan error, 1)
+ go func() {
+ _, err := p.Send(ctx, localhost, []byte("data goes here"))
+ r <- err
+ }()
+
+ // Wait until we're blocking
+ p.waitOutstanding(t, ctx, 1)
+
+ // Close everything down
+ p.cleanupOutstanding()
+
+ // Should have got an error from the ping
+ err := <-r
+ if !errors.Is(err, net.ErrClosed) {
+ t.Errorf("wanted errors.Is(err, net.ErrClosed); got=%v", err)
+ }
+}
+
+func TestPingerMismatch(t *testing.T) {
+ clock := &tstest.Clock{}
+
+ ctx := context.Background()
+ ctx, cancel := context.WithTimeout(ctx, 1*time.Second) // intentionally short
+ defer cancel()
+
+ p, closeP := mockPinger(t, clock)
+ defer closeP()
+
+ bodyData := []byte("data goes here")
+
+ // Start a ping in the background
+ r := make(chan time.Duration, 1)
+ go func() {
+ dur, err := p.Send(ctx, localhost, bodyData)
+ if err != nil && !errors.Is(err, context.DeadlineExceeded) {
+ t.Errorf("p.Send: %v", err)
+ r <- 0
+ } else {
+ r <- dur
+ }
+ }()
+
+ p.waitOutstanding(t, ctx, 1)
+
+ // "Receive" a bunch of intentionally malformed packets that should not
+ // result in the Send call above returning
+ badPackets := []struct {
+ name string
+ pkt *icmp.Message
+ }{
+ {
+ name: "wrong type",
+ pkt: &icmp.Message{
+ Type: ipv4.ICMPTypeDestinationUnreachable,
+ Code: 0,
+ Body: &icmp.DstUnreach{},
+ },
+ },
+ {
+ name: "wrong id",
+ pkt: &icmp.Message{
+ Type: ipv4.ICMPTypeEchoReply,
+ Code: 0,
+ Body: &icmp.Echo{
+ ID: 9999,
+ Seq: 1,
+ Data: bodyData,
+ },
+ },
+ },
+ {
+ name: "wrong seq",
+ pkt: &icmp.Message{
+ Type: ipv4.ICMPTypeEchoReply,
+ Code: 0,
+ Body: &icmp.Echo{
+ ID: 1234,
+ Seq: 5,
+ Data: bodyData,
+ },
+ },
+ },
+ {
+ name: "bad body",
+ pkt: &icmp.Message{
+ Type: ipv4.ICMPTypeEchoReply,
+ Code: 0,
+ Body: &icmp.Echo{
+ ID: 1234,
+ Seq: 1,
+
+ // Intentionally missing first byte
+ Data: bodyData[1:],
+ },
+ },
+ },
+ }
+
+ const fakeDuration = 100 * time.Millisecond
+ tm := clock.Now().Add(fakeDuration)
+
+ for _, tt := range badPackets {
+ fakeResponse := mustMarshal(t, tt.pkt)
+ p.handleResponse(fakeResponse, tm, v4Type)
+ }
+
+ // Also "receive" a packet that does not unmarshal as an ICMP packet
+ p.handleResponse([]byte("foo"), tm, v4Type)
+
+ select {
+ case <-r:
+ t.Fatal("wanted timeout")
+ case <-ctx.Done():
+ t.Logf("test correctly timed out")
+ }
+}
+
+// udpingPacketConn will convert potentially ICMP destination addrs to UDP
+// destination addrs in WriteTo so that a test that is intending to send ICMP
+// traffic will instead send UDP traffic, without the higher level Pinger being
+// aware of this difference.
+type udpingPacketConn struct {
+ net.PacketConn
+ // destPort will be configured by the test to be the peer expected to respond to a ping.
+ destPort uint16
+}
+
+func (u *udpingPacketConn) WriteTo(body []byte, dest net.Addr) (int, error) {
+ switch d := dest.(type) {
+ case *net.IPAddr:
+ udpAddr := &net.UDPAddr{
+ IP: d.IP,
+ Port: int(u.destPort),
+ Zone: d.Zone,
+ }
+ return u.PacketConn.WriteTo(body, udpAddr)
+ }
+ return 0, fmt.Errorf("unimplemented udpingPacketConn for %T", dest)
+}
+
+func mockPinger(t *testing.T, clock *tstest.Clock) (*Pinger, func()) {
+ p := New(context.Background(), t.Logf, nil)
+ p.timeNow = clock.Now
+ p.Verbose = true
+ p.id = 1234
+
+ // In tests, we use UDP so that we can test without being root; this
+ // doesn't matter because we mock out the ICMP reply below to be a real
+ // ICMP echo reply packet.
+ conn4, err := net.ListenPacket("udp4", "127.0.0.1:0")
+ if err != nil {
+ t.Fatalf("net.ListenPacket: %v", err)
+ }
+
+ conn6, err := net.ListenPacket("udp6", "[::]:0")
+ if err != nil {
+ t.Fatalf("net.ListenPacket: %v", err)
+ }
+
+ conn4 = &udpingPacketConn{
+ destPort: 12345,
+ PacketConn: conn4,
+ }
+ conn6 = &udpingPacketConn{
+ PacketConn: conn6,
+ destPort: 12345,
+ }
+
+ mak.Set(&p.conns, v4Type, conn4)
+ mak.Set(&p.conns, v6Type, conn6)
+ done := func() {
+ if err := p.Close(); err != nil {
+ t.Errorf("error on close: %v", err)
+ }
+ }
+ return p, done
+}
+
+func mustMarshal(t *testing.T, m *icmp.Message) []byte {
+ t.Helper()
+
+ b, err := m.Marshal(nil)
+ if err != nil {
+ t.Fatal(err)
+ }
+ return b
+}
+
+func (p *Pinger) waitOutstanding(t *testing.T, ctx context.Context, count int) {
+ // This is a bit janky, but... we busy-loop to wait for the Send call
+ // to write to our map so we know that a response will be handled.
+ var haveMapEntry bool
+ for !haveMapEntry {
+ time.Sleep(10 * time.Millisecond)
+ select {
+ case <-ctx.Done():
+ t.Error("no entry in ping map before timeout")
+ return
+ default:
+ }
+
+ p.mu.Lock()
+ haveMapEntry = len(p.pings) == count
+ p.mu.Unlock()
+ }
+}
diff --git a/net/portmapper/pcp_test.go b/net/portmapper/pcp_test.go index 8f8eef3ef..3dece7236 100644 --- a/net/portmapper/pcp_test.go +++ b/net/portmapper/pcp_test.go @@ -1,62 +1,62 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -package portmapper - -import ( - "encoding/binary" - "net/netip" - "testing" - - "tailscale.com/net/netaddr" -) - -var examplePCPMapResponse = []byte{2, 129, 0, 0, 0, 0, 28, 32, 0, 2, 155, 237, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 129, 112, 9, 24, 241, 208, 251, 45, 157, 76, 10, 188, 17, 0, 0, 0, 4, 210, 4, 210, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 255, 255, 135, 180, 175, 246} - -func TestParsePCPMapResponse(t *testing.T) { - mapping, err := parsePCPMapResponse(examplePCPMapResponse) - if err != nil { - t.Fatalf("failed to parse PCP Map Response: %v", err) - } - if mapping == nil { - t.Fatalf("got nil mapping when expected non-nil") - } - expectedAddr := netip.MustParseAddrPort("135.180.175.246:1234") - if mapping.external != expectedAddr { - t.Errorf("mismatched external address, got: %v, want: %v", mapping.external, expectedAddr) - } -} - -const ( - serverResponseBit = 1 << 7 - fakeLifetimeSec = 1<<31 - 1 -) - -func buildPCPDiscoResponse(req []byte) []byte { - out := make([]byte, 24) - out[0] = pcpVersion - out[1] = req[1] | serverResponseBit - out[3] = 0 - // Do not put an epoch time in 8:12, when we start using it, tests that use it should fail. - return out -} - -func buildPCPMapResponse(req []byte) []byte { - out := make([]byte, 24+36) - out[0] = pcpVersion - out[1] = req[1] | serverResponseBit - out[3] = 0 - binary.BigEndian.PutUint32(out[4:8], 1<<30) - // Do not put an epoch time in 8:12, when we start using it, tests that use it should fail. - mapResp := out[24:] - mapReq := req[24:] - // copy nonce, protocol and internal port - copy(mapResp[:13], mapReq[:13]) - copy(mapResp[16:18], mapReq[16:18]) - // assign external port - binary.BigEndian.PutUint16(mapResp[18:20], 4242) - assignedIP := netaddr.IPv4(127, 0, 0, 1) - assignedIP16 := assignedIP.As16() - copy(mapResp[20:36], assignedIP16[:]) - return out -} +// Copyright (c) Tailscale Inc & AUTHORS
+// SPDX-License-Identifier: BSD-3-Clause
+
+package portmapper
+
+import (
+ "encoding/binary"
+ "net/netip"
+ "testing"
+
+ "tailscale.com/net/netaddr"
+)
+
+var examplePCPMapResponse = []byte{2, 129, 0, 0, 0, 0, 28, 32, 0, 2, 155, 237, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 129, 112, 9, 24, 241, 208, 251, 45, 157, 76, 10, 188, 17, 0, 0, 0, 4, 210, 4, 210, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 255, 255, 135, 180, 175, 246}
+
+func TestParsePCPMapResponse(t *testing.T) {
+ mapping, err := parsePCPMapResponse(examplePCPMapResponse)
+ if err != nil {
+ t.Fatalf("failed to parse PCP Map Response: %v", err)
+ }
+ if mapping == nil {
+ t.Fatalf("got nil mapping when expected non-nil")
+ }
+ expectedAddr := netip.MustParseAddrPort("135.180.175.246:1234")
+ if mapping.external != expectedAddr {
+ t.Errorf("mismatched external address, got: %v, want: %v", mapping.external, expectedAddr)
+ }
+}
+
+const (
+ serverResponseBit = 1 << 7
+ fakeLifetimeSec = 1<<31 - 1
+)
+
+func buildPCPDiscoResponse(req []byte) []byte {
+ out := make([]byte, 24)
+ out[0] = pcpVersion
+ out[1] = req[1] | serverResponseBit
+ out[3] = 0
+ // Do not put an epoch time in 8:12, when we start using it, tests that use it should fail.
+ return out
+}
+
+func buildPCPMapResponse(req []byte) []byte {
+ out := make([]byte, 24+36)
+ out[0] = pcpVersion
+ out[1] = req[1] | serverResponseBit
+ out[3] = 0
+ binary.BigEndian.PutUint32(out[4:8], 1<<30)
+ // Do not put an epoch time in 8:12, when we start using it, tests that use it should fail.
+ mapResp := out[24:]
+ mapReq := req[24:]
+ // copy nonce, protocol and internal port
+ copy(mapResp[:13], mapReq[:13])
+ copy(mapResp[16:18], mapReq[16:18])
+ // assign external port
+ binary.BigEndian.PutUint16(mapResp[18:20], 4242)
+ assignedIP := netaddr.IPv4(127, 0, 0, 1)
+ assignedIP16 := assignedIP.As16()
+ copy(mapResp[20:36], assignedIP16[:])
+ return out
+}
diff --git a/net/proxymux/mux.go b/net/proxymux/mux.go index ff5aaff3b..12c3107de 100644 --- a/net/proxymux/mux.go +++ b/net/proxymux/mux.go @@ -1,144 +1,144 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -// Package proxymux splits a net.Listener in two, routing SOCKS5 -// connections to one and HTTP requests to the other. -// -// It allows for hosting both a SOCKS5 proxy and an HTTP proxy on the -// same listener. -package proxymux - -import ( - "io" - "net" - "sync" - "time" -) - -// SplitSOCKSAndHTTP accepts connections on ln and passes connections -// through to either socksListener or httpListener, depending the -// first byte sent by the client. -func SplitSOCKSAndHTTP(ln net.Listener) (socksListener, httpListener net.Listener) { - sl := &listener{ - addr: ln.Addr(), - c: make(chan net.Conn), - closed: make(chan struct{}), - } - hl := &listener{ - addr: ln.Addr(), - c: make(chan net.Conn), - closed: make(chan struct{}), - } - - go splitSOCKSAndHTTPListener(ln, sl, hl) - - return sl, hl -} - -func splitSOCKSAndHTTPListener(ln net.Listener, sl, hl *listener) { - for { - conn, err := ln.Accept() - if err != nil { - sl.Close() - hl.Close() - return - } - go routeConn(conn, sl, hl) - } -} - -func routeConn(c net.Conn, socksListener, httpListener *listener) { - if err := c.SetReadDeadline(time.Now().Add(15 * time.Second)); err != nil { - c.Close() - return - } - - var b [1]byte - if _, err := io.ReadFull(c, b[:]); err != nil { - c.Close() - return - } - - if err := c.SetReadDeadline(time.Time{}); err != nil { - c.Close() - return - } - - conn := &connWithOneByte{ - Conn: c, - b: b[0], - } - - // First byte of a SOCKS5 session is a version byte set to 5. - var ln *listener - if b[0] == 5 { - ln = socksListener - } else { - ln = httpListener - } - select { - case ln.c <- conn: - case <-ln.closed: - c.Close() - } -} - -type listener struct { - addr net.Addr - c chan net.Conn - mu sync.Mutex // serializes close() on closed. It's okay to receive on closed without locking. - closed chan struct{} -} - -func (ln *listener) Accept() (net.Conn, error) { - // Once closed, reliably stay closed, don't race with attempts at - // further connections. - select { - case <-ln.closed: - return nil, net.ErrClosed - default: - } - select { - case ret := <-ln.c: - return ret, nil - case <-ln.closed: - return nil, net.ErrClosed - } -} - -func (ln *listener) Close() error { - ln.mu.Lock() - defer ln.mu.Unlock() - select { - case <-ln.closed: - // Already closed - default: - close(ln.closed) - } - return nil -} - -func (ln *listener) Addr() net.Addr { - return ln.addr -} - -// connWithOneByte is a net.Conn that returns b for the first read -// request, then forwards everything else to Conn. -type connWithOneByte struct { - net.Conn - - b byte - bRead bool -} - -func (c *connWithOneByte) Read(bs []byte) (int, error) { - if c.bRead { - return c.Conn.Read(bs) - } - if len(bs) == 0 { - return 0, nil - } - c.bRead = true - bs[0] = c.b - return 1, nil -} +// Copyright (c) Tailscale Inc & AUTHORS
+// SPDX-License-Identifier: BSD-3-Clause
+
+// Package proxymux splits a net.Listener in two, routing SOCKS5
+// connections to one and HTTP requests to the other.
+//
+// It allows for hosting both a SOCKS5 proxy and an HTTP proxy on the
+// same listener.
+package proxymux
+
+import (
+ "io"
+ "net"
+ "sync"
+ "time"
+)
+
+// SplitSOCKSAndHTTP accepts connections on ln and passes connections
+// through to either socksListener or httpListener, depending the
+// first byte sent by the client.
+func SplitSOCKSAndHTTP(ln net.Listener) (socksListener, httpListener net.Listener) {
+ sl := &listener{
+ addr: ln.Addr(),
+ c: make(chan net.Conn),
+ closed: make(chan struct{}),
+ }
+ hl := &listener{
+ addr: ln.Addr(),
+ c: make(chan net.Conn),
+ closed: make(chan struct{}),
+ }
+
+ go splitSOCKSAndHTTPListener(ln, sl, hl)
+
+ return sl, hl
+}
+
+func splitSOCKSAndHTTPListener(ln net.Listener, sl, hl *listener) {
+ for {
+ conn, err := ln.Accept()
+ if err != nil {
+ sl.Close()
+ hl.Close()
+ return
+ }
+ go routeConn(conn, sl, hl)
+ }
+}
+
+func routeConn(c net.Conn, socksListener, httpListener *listener) {
+ if err := c.SetReadDeadline(time.Now().Add(15 * time.Second)); err != nil {
+ c.Close()
+ return
+ }
+
+ var b [1]byte
+ if _, err := io.ReadFull(c, b[:]); err != nil {
+ c.Close()
+ return
+ }
+
+ if err := c.SetReadDeadline(time.Time{}); err != nil {
+ c.Close()
+ return
+ }
+
+ conn := &connWithOneByte{
+ Conn: c,
+ b: b[0],
+ }
+
+ // First byte of a SOCKS5 session is a version byte set to 5.
+ var ln *listener
+ if b[0] == 5 {
+ ln = socksListener
+ } else {
+ ln = httpListener
+ }
+ select {
+ case ln.c <- conn:
+ case <-ln.closed:
+ c.Close()
+ }
+}
+
+type listener struct {
+ addr net.Addr
+ c chan net.Conn
+ mu sync.Mutex // serializes close() on closed. It's okay to receive on closed without locking.
+ closed chan struct{}
+}
+
+func (ln *listener) Accept() (net.Conn, error) {
+ // Once closed, reliably stay closed, don't race with attempts at
+ // further connections.
+ select {
+ case <-ln.closed:
+ return nil, net.ErrClosed
+ default:
+ }
+ select {
+ case ret := <-ln.c:
+ return ret, nil
+ case <-ln.closed:
+ return nil, net.ErrClosed
+ }
+}
+
+func (ln *listener) Close() error {
+ ln.mu.Lock()
+ defer ln.mu.Unlock()
+ select {
+ case <-ln.closed:
+ // Already closed
+ default:
+ close(ln.closed)
+ }
+ return nil
+}
+
+func (ln *listener) Addr() net.Addr {
+ return ln.addr
+}
+
+// connWithOneByte is a net.Conn that returns b for the first read
+// request, then forwards everything else to Conn.
+type connWithOneByte struct {
+ net.Conn
+
+ b byte
+ bRead bool
+}
+
+func (c *connWithOneByte) Read(bs []byte) (int, error) {
+ if c.bRead {
+ return c.Conn.Read(bs)
+ }
+ if len(bs) == 0 {
+ return 0, nil
+ }
+ c.bRead = true
+ bs[0] = c.b
+ return 1, nil
+}
diff --git a/net/routetable/routetable_darwin.go b/net/routetable/routetable_darwin.go index 7f525ae32..7de80a662 100644 --- a/net/routetable/routetable_darwin.go +++ b/net/routetable/routetable_darwin.go @@ -1,36 +1,36 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -//go:build darwin - -package routetable - -import "golang.org/x/sys/unix" - -const ( - ribType = unix.NET_RT_DUMP2 - parseType = unix.NET_RT_IFLIST2 - rmExpectedType = unix.RTM_GET2 - - // Skip routes that were cloned from a parent - skipFlags = unix.RTF_WASCLONED -) - -var flags = map[int]string{ - unix.RTF_BLACKHOLE: "blackhole", - unix.RTF_BROADCAST: "broadcast", - unix.RTF_GATEWAY: "gateway", - unix.RTF_GLOBAL: "global", - unix.RTF_HOST: "host", - unix.RTF_IFSCOPE: "ifscope", - unix.RTF_LOCAL: "local", - unix.RTF_MULTICAST: "multicast", - unix.RTF_REJECT: "reject", - unix.RTF_ROUTER: "router", - unix.RTF_STATIC: "static", - unix.RTF_UP: "up", - // More obscure flags, just to have full coverage. - unix.RTF_LLINFO: "{RTF_LLINFO}", - unix.RTF_PRCLONING: "{RTF_PRCLONING}", - unix.RTF_CLONING: "{RTF_CLONING}", -} +// Copyright (c) Tailscale Inc & AUTHORS
+// SPDX-License-Identifier: BSD-3-Clause
+
+//go:build darwin
+
+package routetable
+
+import "golang.org/x/sys/unix"
+
+const (
+ ribType = unix.NET_RT_DUMP2
+ parseType = unix.NET_RT_IFLIST2
+ rmExpectedType = unix.RTM_GET2
+
+ // Skip routes that were cloned from a parent
+ skipFlags = unix.RTF_WASCLONED
+)
+
+var flags = map[int]string{
+ unix.RTF_BLACKHOLE: "blackhole",
+ unix.RTF_BROADCAST: "broadcast",
+ unix.RTF_GATEWAY: "gateway",
+ unix.RTF_GLOBAL: "global",
+ unix.RTF_HOST: "host",
+ unix.RTF_IFSCOPE: "ifscope",
+ unix.RTF_LOCAL: "local",
+ unix.RTF_MULTICAST: "multicast",
+ unix.RTF_REJECT: "reject",
+ unix.RTF_ROUTER: "router",
+ unix.RTF_STATIC: "static",
+ unix.RTF_UP: "up",
+ // More obscure flags, just to have full coverage.
+ unix.RTF_LLINFO: "{RTF_LLINFO}",
+ unix.RTF_PRCLONING: "{RTF_PRCLONING}",
+ unix.RTF_CLONING: "{RTF_CLONING}",
+}
diff --git a/net/routetable/routetable_freebsd.go b/net/routetable/routetable_freebsd.go index 8e57a3302..aa4e03c41 100644 --- a/net/routetable/routetable_freebsd.go +++ b/net/routetable/routetable_freebsd.go @@ -1,28 +1,28 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -//go:build freebsd - -package routetable - -import "golang.org/x/sys/unix" - -const ( - ribType = unix.NET_RT_DUMP - parseType = unix.NET_RT_IFLIST - rmExpectedType = unix.RTM_GET - - // Nothing to skip - skipFlags = 0 -) - -var flags = map[int]string{ - unix.RTF_BLACKHOLE: "blackhole", - unix.RTF_BROADCAST: "broadcast", - unix.RTF_GATEWAY: "gateway", - unix.RTF_HOST: "host", - unix.RTF_MULTICAST: "multicast", - unix.RTF_REJECT: "reject", - unix.RTF_STATIC: "static", - unix.RTF_UP: "up", -} +// Copyright (c) Tailscale Inc & AUTHORS
+// SPDX-License-Identifier: BSD-3-Clause
+
+//go:build freebsd
+
+package routetable
+
+import "golang.org/x/sys/unix"
+
+const (
+ ribType = unix.NET_RT_DUMP
+ parseType = unix.NET_RT_IFLIST
+ rmExpectedType = unix.RTM_GET
+
+ // Nothing to skip
+ skipFlags = 0
+)
+
+var flags = map[int]string{
+ unix.RTF_BLACKHOLE: "blackhole",
+ unix.RTF_BROADCAST: "broadcast",
+ unix.RTF_GATEWAY: "gateway",
+ unix.RTF_HOST: "host",
+ unix.RTF_MULTICAST: "multicast",
+ unix.RTF_REJECT: "reject",
+ unix.RTF_STATIC: "static",
+ unix.RTF_UP: "up",
+}
diff --git a/net/routetable/routetable_other.go b/net/routetable/routetable_other.go index 35c83e374..521fe1911 100644 --- a/net/routetable/routetable_other.go +++ b/net/routetable/routetable_other.go @@ -1,17 +1,17 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -//go:build !linux && !darwin && !freebsd - -package routetable - -import ( - "errors" - "runtime" -) - -var errUnsupported = errors.New("cannot get route table on platform " + runtime.GOOS) - -func Get(max int) ([]RouteEntry, error) { - return nil, errUnsupported -} +// Copyright (c) Tailscale Inc & AUTHORS
+// SPDX-License-Identifier: BSD-3-Clause
+
+//go:build !linux && !darwin && !freebsd
+
+package routetable
+
+import (
+ "errors"
+ "runtime"
+)
+
+var errUnsupported = errors.New("cannot get route table on platform " + runtime.GOOS)
+
+func Get(max int) ([]RouteEntry, error) {
+ return nil, errUnsupported
+}
diff --git a/net/sockstats/sockstats.go b/net/sockstats/sockstats.go index 715c1ee06..fb524a5c5 100644 --- a/net/sockstats/sockstats.go +++ b/net/sockstats/sockstats.go @@ -1,121 +1,121 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -// Package sockstats collects statistics about network sockets used by -// the Tailscale client. The context where sockets are used must be -// instrumented with the WithSockStats() function. -// -// Only available on POSIX platforms when built with Tailscale's fork of Go. -package sockstats - -import ( - "context" - - "tailscale.com/net/netmon" - "tailscale.com/types/logger" -) - -// SockStats contains statistics for sockets instrumented with the -// WithSockStats() function -type SockStats struct { - Stats map[Label]SockStat - CurrentInterfaceCellular bool -} - -// SockStat contains the sent and received bytes for a socket instrumented with -// the WithSockStats() function. -type SockStat struct { - TxBytes uint64 - RxBytes uint64 -} - -// Label is an identifier for a socket that stats are collected for. A finite -// set of values that may be used to label a socket to encourage grouping and -// to make storage more efficient. -type Label uint8 - -//go:generate go run golang.org/x/tools/cmd/stringer -type Label -trimprefix Label - -// Labels are named after the package and function/struct that uses the socket. -// Values may be persisted and thus existing entries should not be re-numbered. -const ( - LabelControlClientAuto Label = 0 // control/controlclient/auto.go - LabelControlClientDialer Label = 1 // control/controlhttp/client.go - LabelDERPHTTPClient Label = 2 // derp/derphttp/derphttp_client.go - LabelLogtailLogger Label = 3 // logtail/logtail.go - LabelDNSForwarderDoH Label = 4 // net/dns/resolver/forwarder.go - LabelDNSForwarderUDP Label = 5 // net/dns/resolver/forwarder.go - LabelNetcheckClient Label = 6 // net/netcheck/netcheck.go - LabelPortmapperClient Label = 7 // net/portmapper/portmapper.go - LabelMagicsockConnUDP4 Label = 8 // wgengine/magicsock/magicsock.go - LabelMagicsockConnUDP6 Label = 9 // wgengine/magicsock/magicsock.go - LabelNetlogLogger Label = 10 // wgengine/netlog/logger.go - LabelSockstatlogLogger Label = 11 // log/sockstatlog/logger.go - LabelDNSForwarderTCP Label = 12 // net/dns/resolver/forwarder.go -) - -// WithSockStats instruments a context so that sockets created with it will -// have their statistics collected. -func WithSockStats(ctx context.Context, label Label, logf logger.Logf) context.Context { - return withSockStats(ctx, label, logf) -} - -// Get returns the current socket statistics. -func Get() *SockStats { - return get() -} - -// InterfaceSockStats contains statistics for sockets instrumented with the -// WithSockStats() function, broken down by interface. The statistics may be a -// subset of the total if interfaces were added after the instrumented socket -// was created. -type InterfaceSockStats struct { - Stats map[Label]InterfaceSockStat - Interfaces []string -} - -// InterfaceSockStat contains the per-interface sent and received bytes for a -// socket instrumented with the WithSockStats() function. -type InterfaceSockStat struct { - TxBytesByInterface map[string]uint64 - RxBytesByInterface map[string]uint64 -} - -// GetWithInterfaces is a variant of Get that returns the current socket -// statistics broken down by interface. It is slightly more expensive than Get. -func GetInterfaces() *InterfaceSockStats { - return getInterfaces() -} - -// ValidationSockStats contains external validation numbers for sockets -// instrumented with WithSockStats. It may be a subset of the all sockets, -// depending on what externa measurement mechanisms the platform supports. -type ValidationSockStats struct { - Stats map[Label]ValidationSockStat -} - -// ValidationSockStat contains the validation bytes for a socket instrumented -// with WithSockStats. -type ValidationSockStat struct { - TxBytes uint64 - RxBytes uint64 -} - -// GetValidation is a variant of Get that returns external validation numbers -// for stats. It is more expensive than Get and should be used in debug -// interfaces only. -func GetValidation() *ValidationSockStats { - return getValidation() -} - -// SetNetMon configures the sockstats package to monitor the active -// interface, so that per-interface stats can be collected. -func SetNetMon(netMon *netmon.Monitor) { - setNetMon(netMon) -} - -// DebugInfo returns a string containing debug information about the tracked -// statistics. -func DebugInfo() string { - return debugInfo() -} +// Copyright (c) Tailscale Inc & AUTHORS
+// SPDX-License-Identifier: BSD-3-Clause
+
+// Package sockstats collects statistics about network sockets used by
+// the Tailscale client. The context where sockets are used must be
+// instrumented with the WithSockStats() function.
+//
+// Only available on POSIX platforms when built with Tailscale's fork of Go.
+package sockstats
+
+import (
+ "context"
+
+ "tailscale.com/net/netmon"
+ "tailscale.com/types/logger"
+)
+
+// SockStats contains statistics for sockets instrumented with the
+// WithSockStats() function
+type SockStats struct {
+ Stats map[Label]SockStat
+ CurrentInterfaceCellular bool
+}
+
+// SockStat contains the sent and received bytes for a socket instrumented with
+// the WithSockStats() function.
+type SockStat struct {
+ TxBytes uint64
+ RxBytes uint64
+}
+
+// Label is an identifier for a socket that stats are collected for. A finite
+// set of values that may be used to label a socket to encourage grouping and
+// to make storage more efficient.
+type Label uint8
+
+//go:generate go run golang.org/x/tools/cmd/stringer -type Label -trimprefix Label
+
+// Labels are named after the package and function/struct that uses the socket.
+// Values may be persisted and thus existing entries should not be re-numbered.
+const (
+ LabelControlClientAuto Label = 0 // control/controlclient/auto.go
+ LabelControlClientDialer Label = 1 // control/controlhttp/client.go
+ LabelDERPHTTPClient Label = 2 // derp/derphttp/derphttp_client.go
+ LabelLogtailLogger Label = 3 // logtail/logtail.go
+ LabelDNSForwarderDoH Label = 4 // net/dns/resolver/forwarder.go
+ LabelDNSForwarderUDP Label = 5 // net/dns/resolver/forwarder.go
+ LabelNetcheckClient Label = 6 // net/netcheck/netcheck.go
+ LabelPortmapperClient Label = 7 // net/portmapper/portmapper.go
+ LabelMagicsockConnUDP4 Label = 8 // wgengine/magicsock/magicsock.go
+ LabelMagicsockConnUDP6 Label = 9 // wgengine/magicsock/magicsock.go
+ LabelNetlogLogger Label = 10 // wgengine/netlog/logger.go
+ LabelSockstatlogLogger Label = 11 // log/sockstatlog/logger.go
+ LabelDNSForwarderTCP Label = 12 // net/dns/resolver/forwarder.go
+)
+
+// WithSockStats instruments a context so that sockets created with it will
+// have their statistics collected.
+func WithSockStats(ctx context.Context, label Label, logf logger.Logf) context.Context {
+ return withSockStats(ctx, label, logf)
+}
+
+// Get returns the current socket statistics.
+func Get() *SockStats {
+ return get()
+}
+
+// InterfaceSockStats contains statistics for sockets instrumented with the
+// WithSockStats() function, broken down by interface. The statistics may be a
+// subset of the total if interfaces were added after the instrumented socket
+// was created.
+type InterfaceSockStats struct {
+ Stats map[Label]InterfaceSockStat
+ Interfaces []string
+}
+
+// InterfaceSockStat contains the per-interface sent and received bytes for a
+// socket instrumented with the WithSockStats() function.
+type InterfaceSockStat struct {
+ TxBytesByInterface map[string]uint64
+ RxBytesByInterface map[string]uint64
+}
+
+// GetWithInterfaces is a variant of Get that returns the current socket
+// statistics broken down by interface. It is slightly more expensive than Get.
+func GetInterfaces() *InterfaceSockStats {
+ return getInterfaces()
+}
+
+// ValidationSockStats contains external validation numbers for sockets
+// instrumented with WithSockStats. It may be a subset of the all sockets,
+// depending on what externa measurement mechanisms the platform supports.
+type ValidationSockStats struct {
+ Stats map[Label]ValidationSockStat
+}
+
+// ValidationSockStat contains the validation bytes for a socket instrumented
+// with WithSockStats.
+type ValidationSockStat struct {
+ TxBytes uint64
+ RxBytes uint64
+}
+
+// GetValidation is a variant of Get that returns external validation numbers
+// for stats. It is more expensive than Get and should be used in debug
+// interfaces only.
+func GetValidation() *ValidationSockStats {
+ return getValidation()
+}
+
+// SetNetMon configures the sockstats package to monitor the active
+// interface, so that per-interface stats can be collected.
+func SetNetMon(netMon *netmon.Monitor) {
+ setNetMon(netMon)
+}
+
+// DebugInfo returns a string containing debug information about the tracked
+// statistics.
+func DebugInfo() string {
+ return debugInfo()
+}
diff --git a/net/sockstats/sockstats_noop.go b/net/sockstats/sockstats_noop.go index 96723111a..797fdc42b 100644 --- a/net/sockstats/sockstats_noop.go +++ b/net/sockstats/sockstats_noop.go @@ -1,38 +1,38 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -//go:build !tailscale_go || !(darwin || ios || android || ts_enable_sockstats) - -package sockstats - -import ( - "context" - - "tailscale.com/net/netmon" - "tailscale.com/types/logger" -) - -const IsAvailable = false - -func withSockStats(ctx context.Context, label Label, logf logger.Logf) context.Context { - return ctx -} - -func get() *SockStats { - return nil -} - -func getInterfaces() *InterfaceSockStats { - return nil -} - -func getValidation() *ValidationSockStats { - return nil -} - -func setNetMon(netMon *netmon.Monitor) { -} - -func debugInfo() string { - return "" -} +// Copyright (c) Tailscale Inc & AUTHORS
+// SPDX-License-Identifier: BSD-3-Clause
+
+//go:build !tailscale_go || !(darwin || ios || android || ts_enable_sockstats)
+
+package sockstats
+
+import (
+ "context"
+
+ "tailscale.com/net/netmon"
+ "tailscale.com/types/logger"
+)
+
+const IsAvailable = false
+
+func withSockStats(ctx context.Context, label Label, logf logger.Logf) context.Context {
+ return ctx
+}
+
+func get() *SockStats {
+ return nil
+}
+
+func getInterfaces() *InterfaceSockStats {
+ return nil
+}
+
+func getValidation() *ValidationSockStats {
+ return nil
+}
+
+func setNetMon(netMon *netmon.Monitor) {
+}
+
+func debugInfo() string {
+ return ""
+}
diff --git a/net/sockstats/sockstats_tsgo_darwin.go b/net/sockstats/sockstats_tsgo_darwin.go index 321d32e04..4b03ed616 100644 --- a/net/sockstats/sockstats_tsgo_darwin.go +++ b/net/sockstats/sockstats_tsgo_darwin.go @@ -1,30 +1,30 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -//go:build tailscale_go && (darwin || ios) - -package sockstats - -import ( - "syscall" - - "golang.org/x/sys/unix" -) - -func init() { - tcpConnStats = darwinTcpConnStats -} - -func darwinTcpConnStats(c syscall.RawConn) (tx, rx uint64) { - c.Control(func(fd uintptr) { - if rawInfo, err := unix.GetsockoptTCPConnectionInfo( - int(fd), - unix.IPPROTO_TCP, - unix.TCP_CONNECTION_INFO, - ); err == nil { - tx = uint64(rawInfo.Txbytes) - rx = uint64(rawInfo.Rxbytes) - } - }) - return -} +// Copyright (c) Tailscale Inc & AUTHORS
+// SPDX-License-Identifier: BSD-3-Clause
+
+//go:build tailscale_go && (darwin || ios)
+
+package sockstats
+
+import (
+ "syscall"
+
+ "golang.org/x/sys/unix"
+)
+
+func init() {
+ tcpConnStats = darwinTcpConnStats
+}
+
+func darwinTcpConnStats(c syscall.RawConn) (tx, rx uint64) {
+ c.Control(func(fd uintptr) {
+ if rawInfo, err := unix.GetsockoptTCPConnectionInfo(
+ int(fd),
+ unix.IPPROTO_TCP,
+ unix.TCP_CONNECTION_INFO,
+ ); err == nil {
+ tx = uint64(rawInfo.Txbytes)
+ rx = uint64(rawInfo.Rxbytes)
+ }
+ })
+ return
+}
diff --git a/net/speedtest/speedtest.go b/net/speedtest/speedtest.go index 7ab0881cc..89639c12d 100644 --- a/net/speedtest/speedtest.go +++ b/net/speedtest/speedtest.go @@ -1,87 +1,87 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -// Package speedtest contains both server and client code for -// running speedtests between tailscale nodes. -package speedtest - -import ( - "time" -) - -const ( - blockSize = 2 * 1024 * 1024 // size of the block of data to send - MinDuration = 5 * time.Second // minimum duration for a test - DefaultDuration = MinDuration // default duration for a test - MaxDuration = 30 * time.Second // maximum duration for a test - version = 2 // value used when comparing client and server versions - increment = time.Second // increment to display results for, in seconds - minInterval = 10 * time.Millisecond // minimum interval length for a result to be included - DefaultPort = 20333 -) - -// config is the initial message sent to the server, that contains information on how to -// conduct the test. -type config struct { - Version int `json:"version"` - TestDuration time.Duration `json:"time"` - Direction Direction `json:"direction"` -} - -// configResponse is the response to the testConfig message. If the server has an -// error with the config, the Error variable will hold that error value. -type configResponse struct { - Error string `json:"error,omitempty"` -} - -// This represents the Result of a speedtest within a specific interval -type Result struct { - Bytes int // number of bytes sent/received during the interval - IntervalStart time.Time // start of the interval - IntervalEnd time.Time // end of the interval - Total bool // if true, this result struct represents the entire test, rather than a segment of the test -} - -func (r Result) MBitsPerSecond() float64 { - return r.MegaBits() / r.IntervalEnd.Sub(r.IntervalStart).Seconds() -} - -func (r Result) MegaBytes() float64 { - return float64(r.Bytes) / 1000000.0 -} - -func (r Result) MegaBits() float64 { - return r.MegaBytes() * 8.0 -} - -func (r Result) Interval() time.Duration { - return r.IntervalEnd.Sub(r.IntervalStart) -} - -type Direction int - -const ( - Download Direction = iota - Upload -) - -func (d Direction) String() string { - switch d { - case Upload: - return "upload" - case Download: - return "download" - default: - return "" - } -} - -func (d *Direction) Reverse() { - switch *d { - case Upload: - *d = Download - case Download: - *d = Upload - default: - } -} +// Copyright (c) Tailscale Inc & AUTHORS
+// SPDX-License-Identifier: BSD-3-Clause
+
+// Package speedtest contains both server and client code for
+// running speedtests between tailscale nodes.
+package speedtest
+
+import (
+ "time"
+)
+
+const (
+ blockSize = 2 * 1024 * 1024 // size of the block of data to send
+ MinDuration = 5 * time.Second // minimum duration for a test
+ DefaultDuration = MinDuration // default duration for a test
+ MaxDuration = 30 * time.Second // maximum duration for a test
+ version = 2 // value used when comparing client and server versions
+ increment = time.Second // increment to display results for, in seconds
+ minInterval = 10 * time.Millisecond // minimum interval length for a result to be included
+ DefaultPort = 20333
+)
+
+// config is the initial message sent to the server, that contains information on how to
+// conduct the test.
+type config struct {
+ Version int `json:"version"`
+ TestDuration time.Duration `json:"time"`
+ Direction Direction `json:"direction"`
+}
+
+// configResponse is the response to the testConfig message. If the server has an
+// error with the config, the Error variable will hold that error value.
+type configResponse struct {
+ Error string `json:"error,omitempty"`
+}
+
+// This represents the Result of a speedtest within a specific interval
+type Result struct {
+ Bytes int // number of bytes sent/received during the interval
+ IntervalStart time.Time // start of the interval
+ IntervalEnd time.Time // end of the interval
+ Total bool // if true, this result struct represents the entire test, rather than a segment of the test
+}
+
+func (r Result) MBitsPerSecond() float64 {
+ return r.MegaBits() / r.IntervalEnd.Sub(r.IntervalStart).Seconds()
+}
+
+func (r Result) MegaBytes() float64 {
+ return float64(r.Bytes) / 1000000.0
+}
+
+func (r Result) MegaBits() float64 {
+ return r.MegaBytes() * 8.0
+}
+
+func (r Result) Interval() time.Duration {
+ return r.IntervalEnd.Sub(r.IntervalStart)
+}
+
+type Direction int
+
+const (
+ Download Direction = iota
+ Upload
+)
+
+func (d Direction) String() string {
+ switch d {
+ case Upload:
+ return "upload"
+ case Download:
+ return "download"
+ default:
+ return ""
+ }
+}
+
+func (d *Direction) Reverse() {
+ switch *d {
+ case Upload:
+ *d = Download
+ case Download:
+ *d = Upload
+ default:
+ }
+}
diff --git a/net/speedtest/speedtest_client.go b/net/speedtest/speedtest_client.go index 299a12a8d..cc34c468c 100644 --- a/net/speedtest/speedtest_client.go +++ b/net/speedtest/speedtest_client.go @@ -1,41 +1,41 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -package speedtest - -import ( - "encoding/json" - "errors" - "net" - "time" -) - -// RunClient dials the given address and starts a speedtest. -// It returns any errors that come up in the tests. -// If there are no errors in the test, it returns a slice of results. -func RunClient(direction Direction, duration time.Duration, host string) ([]Result, error) { - conn, err := net.Dial("tcp", host) - if err != nil { - return nil, err - } - - conf := config{TestDuration: duration, Version: version, Direction: direction} - - defer conn.Close() - encoder := json.NewEncoder(conn) - - if err = encoder.Encode(conf); err != nil { - return nil, err - } - - var response configResponse - decoder := json.NewDecoder(conn) - if err = decoder.Decode(&response); err != nil { - return nil, err - } - if response.Error != "" { - return nil, errors.New(response.Error) - } - - return doTest(conn, conf) -} +// Copyright (c) Tailscale Inc & AUTHORS
+// SPDX-License-Identifier: BSD-3-Clause
+
+package speedtest
+
+import (
+ "encoding/json"
+ "errors"
+ "net"
+ "time"
+)
+
+// RunClient dials the given address and starts a speedtest.
+// It returns any errors that come up in the tests.
+// If there are no errors in the test, it returns a slice of results.
+func RunClient(direction Direction, duration time.Duration, host string) ([]Result, error) {
+ conn, err := net.Dial("tcp", host)
+ if err != nil {
+ return nil, err
+ }
+
+ conf := config{TestDuration: duration, Version: version, Direction: direction}
+
+ defer conn.Close()
+ encoder := json.NewEncoder(conn)
+
+ if err = encoder.Encode(conf); err != nil {
+ return nil, err
+ }
+
+ var response configResponse
+ decoder := json.NewDecoder(conn)
+ if err = decoder.Decode(&response); err != nil {
+ return nil, err
+ }
+ if response.Error != "" {
+ return nil, errors.New(response.Error)
+ }
+
+ return doTest(conn, conf)
+}
diff --git a/net/speedtest/speedtest_server.go b/net/speedtest/speedtest_server.go index 9dd78b195..d2673464e 100644 --- a/net/speedtest/speedtest_server.go +++ b/net/speedtest/speedtest_server.go @@ -1,146 +1,146 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -package speedtest - -import ( - "crypto/rand" - "encoding/json" - "errors" - "fmt" - "io" - "net" - "time" -) - -// Serve starts up the server on a given host and port pair. It starts to listen for -// connections and handles each one in a goroutine. Because it runs in an infinite loop, -// this function only returns if any of the speedtests return with errors, or if the -// listener is closed. -func Serve(l net.Listener) error { - for { - conn, err := l.Accept() - if errors.Is(err, net.ErrClosed) { - return nil - } - if err != nil { - return err - } - err = handleConnection(conn) - if err != nil { - return err - } - } -} - -// handleConnection handles the initial exchange between the server and the client. -// It reads the testconfig message into a config struct. If any errors occur with -// the testconfig (specifically, if there is a version mismatch), it will return those -// errors to the client with a configResponse. After the exchange, it will start -// the speed test. -func handleConnection(conn net.Conn) error { - defer conn.Close() - var conf config - - decoder := json.NewDecoder(conn) - err := decoder.Decode(&conf) - encoder := json.NewEncoder(conn) - - // Both return and encode errors that occurred before the test started. - if err != nil { - encoder.Encode(configResponse{Error: err.Error()}) - return err - } - - // The server should always be doing the opposite of what the client is doing. - conf.Direction.Reverse() - - if conf.Version != version { - err = fmt.Errorf("version mismatch! Server is version %d, client is version %d", version, conf.Version) - encoder.Encode(configResponse{Error: err.Error()}) - return err - } - - // Start the test - encoder.Encode(configResponse{}) - _, err = doTest(conn, conf) - return err -} - -// TODO include code to detect whether the code is direct vs DERP - -// doTest contains the code to run both the upload and download speedtest. -// the direction value in the config parameter determines which test to run. -func doTest(conn net.Conn, conf config) ([]Result, error) { - bufferData := make([]byte, blockSize) - - intervalBytes := 0 - totalBytes := 0 - - var currentTime time.Time - var results []Result - - if conf.Direction == Download { - conn.SetReadDeadline(time.Now().Add(conf.TestDuration).Add(5 * time.Second)) - } else { - _, err := rand.Read(bufferData) - if err != nil { - return nil, err - } - - } - - startTime := time.Now() - lastCalculated := startTime - -SpeedTestLoop: - for { - var n int - var err error - - if conf.Direction == Download { - n, err = io.ReadFull(conn, bufferData) - switch err { - case io.EOF, io.ErrUnexpectedEOF: - break SpeedTestLoop - case nil: - // successful read - default: - return nil, fmt.Errorf("unexpected error has occurred: %w", err) - } - } else { - n, err = conn.Write(bufferData) - if err != nil { - // If the write failed, there is most likely something wrong with the connection. - return nil, fmt.Errorf("upload failed: %w", err) - } - } - intervalBytes += n - - currentTime = time.Now() - // checks if the current time is more or equal to the lastCalculated time plus the increment - if currentTime.Sub(lastCalculated) >= increment { - results = append(results, Result{Bytes: intervalBytes, IntervalStart: lastCalculated, IntervalEnd: currentTime, Total: false}) - lastCalculated = currentTime - totalBytes += intervalBytes - intervalBytes = 0 - } - - if conf.Direction == Upload && currentTime.Sub(startTime) > conf.TestDuration { - break SpeedTestLoop - } - } - - // get last segment - if currentTime.Sub(lastCalculated) > minInterval { - results = append(results, Result{Bytes: intervalBytes, IntervalStart: lastCalculated, IntervalEnd: currentTime, Total: false}) - } - - // get total - totalBytes += intervalBytes - if currentTime.Sub(startTime) > minInterval { - results = append(results, Result{Bytes: totalBytes, IntervalStart: startTime, IntervalEnd: currentTime, Total: true}) - } - - return results, nil -} +// Copyright (c) Tailscale Inc & AUTHORS
+// SPDX-License-Identifier: BSD-3-Clause
+
+package speedtest
+
+import (
+ "crypto/rand"
+ "encoding/json"
+ "errors"
+ "fmt"
+ "io"
+ "net"
+ "time"
+)
+
+// Serve starts up the server on a given host and port pair. It starts to listen for
+// connections and handles each one in a goroutine. Because it runs in an infinite loop,
+// this function only returns if any of the speedtests return with errors, or if the
+// listener is closed.
+func Serve(l net.Listener) error {
+ for {
+ conn, err := l.Accept()
+ if errors.Is(err, net.ErrClosed) {
+ return nil
+ }
+ if err != nil {
+ return err
+ }
+ err = handleConnection(conn)
+ if err != nil {
+ return err
+ }
+ }
+}
+
+// handleConnection handles the initial exchange between the server and the client.
+// It reads the testconfig message into a config struct. If any errors occur with
+// the testconfig (specifically, if there is a version mismatch), it will return those
+// errors to the client with a configResponse. After the exchange, it will start
+// the speed test.
+func handleConnection(conn net.Conn) error {
+ defer conn.Close()
+ var conf config
+
+ decoder := json.NewDecoder(conn)
+ err := decoder.Decode(&conf)
+ encoder := json.NewEncoder(conn)
+
+ // Both return and encode errors that occurred before the test started.
+ if err != nil {
+ encoder.Encode(configResponse{Error: err.Error()})
+ return err
+ }
+
+ // The server should always be doing the opposite of what the client is doing.
+ conf.Direction.Reverse()
+
+ if conf.Version != version {
+ err = fmt.Errorf("version mismatch! Server is version %d, client is version %d", version, conf.Version)
+ encoder.Encode(configResponse{Error: err.Error()})
+ return err
+ }
+
+ // Start the test
+ encoder.Encode(configResponse{})
+ _, err = doTest(conn, conf)
+ return err
+}
+
+// TODO include code to detect whether the code is direct vs DERP
+
+// doTest contains the code to run both the upload and download speedtest.
+// the direction value in the config parameter determines which test to run.
+func doTest(conn net.Conn, conf config) ([]Result, error) {
+ bufferData := make([]byte, blockSize)
+
+ intervalBytes := 0
+ totalBytes := 0
+
+ var currentTime time.Time
+ var results []Result
+
+ if conf.Direction == Download {
+ conn.SetReadDeadline(time.Now().Add(conf.TestDuration).Add(5 * time.Second))
+ } else {
+ _, err := rand.Read(bufferData)
+ if err != nil {
+ return nil, err
+ }
+
+ }
+
+ startTime := time.Now()
+ lastCalculated := startTime
+
+SpeedTestLoop:
+ for {
+ var n int
+ var err error
+
+ if conf.Direction == Download {
+ n, err = io.ReadFull(conn, bufferData)
+ switch err {
+ case io.EOF, io.ErrUnexpectedEOF:
+ break SpeedTestLoop
+ case nil:
+ // successful read
+ default:
+ return nil, fmt.Errorf("unexpected error has occurred: %w", err)
+ }
+ } else {
+ n, err = conn.Write(bufferData)
+ if err != nil {
+ // If the write failed, there is most likely something wrong with the connection.
+ return nil, fmt.Errorf("upload failed: %w", err)
+ }
+ }
+ intervalBytes += n
+
+ currentTime = time.Now()
+ // checks if the current time is more or equal to the lastCalculated time plus the increment
+ if currentTime.Sub(lastCalculated) >= increment {
+ results = append(results, Result{Bytes: intervalBytes, IntervalStart: lastCalculated, IntervalEnd: currentTime, Total: false})
+ lastCalculated = currentTime
+ totalBytes += intervalBytes
+ intervalBytes = 0
+ }
+
+ if conf.Direction == Upload && currentTime.Sub(startTime) > conf.TestDuration {
+ break SpeedTestLoop
+ }
+ }
+
+ // get last segment
+ if currentTime.Sub(lastCalculated) > minInterval {
+ results = append(results, Result{Bytes: intervalBytes, IntervalStart: lastCalculated, IntervalEnd: currentTime, Total: false})
+ }
+
+ // get total
+ totalBytes += intervalBytes
+ if currentTime.Sub(startTime) > minInterval {
+ results = append(results, Result{Bytes: totalBytes, IntervalStart: startTime, IntervalEnd: currentTime, Total: true})
+ }
+
+ return results, nil
+}
diff --git a/net/speedtest/speedtest_test.go b/net/speedtest/speedtest_test.go index 55dcbeea1..a413e9efa 100644 --- a/net/speedtest/speedtest_test.go +++ b/net/speedtest/speedtest_test.go @@ -1,83 +1,83 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -package speedtest - -import ( - "net" - "testing" - "time" -) - -func TestDownload(t *testing.T) { - // start a listener and find the port where the server will be listening. - l, err := net.Listen("tcp", ":0") - if err != nil { - t.Fatal(err) - } - t.Cleanup(func() { l.Close() }) - - serverIP := l.Addr().String() - t.Log("server IP found:", serverIP) - - type state struct { - err error - } - displayResult := func(t *testing.T, r Result, start time.Time) { - t.Helper() - t.Logf("{ Megabytes: %.2f, Start: %.1f, End: %.1f, Total: %t }", r.MegaBytes(), r.IntervalStart.Sub(start).Seconds(), r.IntervalEnd.Sub(start).Seconds(), r.Total) - } - stateChan := make(chan state, 1) - - go func() { - err := Serve(l) - stateChan <- state{err: err} - }() - - // ensure that the test returns an appropriate number of Result structs - expectedLen := int(DefaultDuration.Seconds()) + 1 - - t.Run("download test", func(t *testing.T) { - // conduct a download test - results, err := RunClient(Download, DefaultDuration, serverIP) - - if err != nil { - t.Fatal("download test failed:", err) - } - - if len(results) < expectedLen { - t.Fatalf("download results: expected length: %d, actual length: %d", expectedLen, len(results)) - } - - start := results[0].IntervalStart - for _, result := range results { - displayResult(t, result, start) - } - }) - - t.Run("upload test", func(t *testing.T) { - // conduct an upload test - results, err := RunClient(Upload, DefaultDuration, serverIP) - - if err != nil { - t.Fatal("upload test failed:", err) - } - - if len(results) < expectedLen { - t.Fatalf("upload results: expected length: %d, actual length: %d", expectedLen, len(results)) - } - - start := results[0].IntervalStart - for _, result := range results { - displayResult(t, result, start) - } - }) - - // causes the server goroutine to finish - l.Close() - - testState := <-stateChan - if testState.err != nil { - t.Error("server error:", err) - } -} +// Copyright (c) Tailscale Inc & AUTHORS
+// SPDX-License-Identifier: BSD-3-Clause
+
+package speedtest
+
+import (
+ "net"
+ "testing"
+ "time"
+)
+
+func TestDownload(t *testing.T) {
+ // start a listener and find the port where the server will be listening.
+ l, err := net.Listen("tcp", ":0")
+ if err != nil {
+ t.Fatal(err)
+ }
+ t.Cleanup(func() { l.Close() })
+
+ serverIP := l.Addr().String()
+ t.Log("server IP found:", serverIP)
+
+ type state struct {
+ err error
+ }
+ displayResult := func(t *testing.T, r Result, start time.Time) {
+ t.Helper()
+ t.Logf("{ Megabytes: %.2f, Start: %.1f, End: %.1f, Total: %t }", r.MegaBytes(), r.IntervalStart.Sub(start).Seconds(), r.IntervalEnd.Sub(start).Seconds(), r.Total)
+ }
+ stateChan := make(chan state, 1)
+
+ go func() {
+ err := Serve(l)
+ stateChan <- state{err: err}
+ }()
+
+ // ensure that the test returns an appropriate number of Result structs
+ expectedLen := int(DefaultDuration.Seconds()) + 1
+
+ t.Run("download test", func(t *testing.T) {
+ // conduct a download test
+ results, err := RunClient(Download, DefaultDuration, serverIP)
+
+ if err != nil {
+ t.Fatal("download test failed:", err)
+ }
+
+ if len(results) < expectedLen {
+ t.Fatalf("download results: expected length: %d, actual length: %d", expectedLen, len(results))
+ }
+
+ start := results[0].IntervalStart
+ for _, result := range results {
+ displayResult(t, result, start)
+ }
+ })
+
+ t.Run("upload test", func(t *testing.T) {
+ // conduct an upload test
+ results, err := RunClient(Upload, DefaultDuration, serverIP)
+
+ if err != nil {
+ t.Fatal("upload test failed:", err)
+ }
+
+ if len(results) < expectedLen {
+ t.Fatalf("upload results: expected length: %d, actual length: %d", expectedLen, len(results))
+ }
+
+ start := results[0].IntervalStart
+ for _, result := range results {
+ displayResult(t, result, start)
+ }
+ })
+
+ // causes the server goroutine to finish
+ l.Close()
+
+ testState := <-stateChan
+ if testState.err != nil {
+ t.Error("server error:", err)
+ }
+}
diff --git a/net/stun/stun.go b/net/stun/stun.go index eeac23cbb..81cf9b608 100644 --- a/net/stun/stun.go +++ b/net/stun/stun.go @@ -1,312 +1,312 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -// Package STUN generates STUN request packets and parses response packets. -package stun - -import ( - "bytes" - crand "crypto/rand" - "encoding/binary" - "errors" - "hash/crc32" - "net" - "net/netip" -) - -const ( - attrNumSoftware = 0x8022 - attrNumFingerprint = 0x8028 - attrMappedAddress = 0x0001 - attrXorMappedAddress = 0x0020 - // This alternative attribute type is not - // mentioned in the RFC, but the shift into - // the "comprehension-optional" range seems - // like an easy mistake for a server to make. - // And servers appear to send it. - attrXorMappedAddressAlt = 0x8020 - - software = "tailnode" // notably: 8 bytes long, so no padding - bindingRequest = "\x00\x01" - magicCookie = "\x21\x12\xa4\x42" - lenFingerprint = 8 // 2+byte header + 2-byte length + 4-byte crc32 - headerLen = 20 -) - -// TxID is a transaction ID. -type TxID [12]byte - -// NewTxID returns a new random TxID. -func NewTxID() TxID { - var tx TxID - if _, err := crand.Read(tx[:]); err != nil { - panic(err) - } - return tx -} - -// Request generates a binding request STUN packet. -// The transaction ID, tID, should be a random sequence of bytes. -func Request(tID TxID) []byte { - // STUN header, RFC5389 Section 6. - const lenAttrSoftware = 4 + len(software) - b := make([]byte, 0, headerLen+lenAttrSoftware+lenFingerprint) - b = append(b, bindingRequest...) - b = appendU16(b, uint16(lenAttrSoftware+lenFingerprint)) // number of bytes following header - b = append(b, magicCookie...) - b = append(b, tID[:]...) - - // Attribute SOFTWARE, RFC5389 Section 15.5. - b = appendU16(b, attrNumSoftware) - b = appendU16(b, uint16(len(software))) - b = append(b, software...) - - // Attribute FINGERPRINT, RFC5389 Section 15.5. - fp := fingerPrint(b) - b = appendU16(b, attrNumFingerprint) - b = appendU16(b, 4) - b = appendU32(b, fp) - - return b -} - -func fingerPrint(b []byte) uint32 { return crc32.ChecksumIEEE(b) ^ 0x5354554e } - -func appendU16(b []byte, v uint16) []byte { - return append(b, byte(v>>8), byte(v)) -} - -func appendU32(b []byte, v uint32) []byte { - return append(b, byte(v>>24), byte(v>>16), byte(v>>8), byte(v)) -} - -// ParseBindingRequest parses a STUN binding request. -// -// It returns an error unless it advertises that it came from -// Tailscale. -func ParseBindingRequest(b []byte) (TxID, error) { - if !Is(b) { - return TxID{}, ErrNotSTUN - } - if string(b[:len(bindingRequest)]) != bindingRequest { - return TxID{}, ErrNotBindingRequest - } - var txID TxID - copy(txID[:], b[8:8+len(txID)]) - var softwareOK bool - var lastAttr uint16 - var gotFP uint32 - if err := foreachAttr(b[headerLen:], func(attrType uint16, a []byte) error { - lastAttr = attrType - if attrType == attrNumSoftware && string(a) == software { - softwareOK = true - } - if attrType == attrNumFingerprint && len(a) == 4 { - gotFP = binary.BigEndian.Uint32(a) - } - return nil - }); err != nil { - return TxID{}, err - } - if !softwareOK { - return TxID{}, ErrWrongSoftware - } - if lastAttr != attrNumFingerprint { - return TxID{}, ErrNoFingerprint - } - wantFP := fingerPrint(b[:len(b)-lenFingerprint]) - if gotFP != wantFP { - return TxID{}, ErrWrongFingerprint - } - return txID, nil -} - -var ( - ErrNotSTUN = errors.New("response is not a STUN packet") - ErrNotSuccessResponse = errors.New("STUN packet is not a response") - ErrMalformedAttrs = errors.New("STUN response has malformed attributes") - ErrNotBindingRequest = errors.New("STUN request not a binding request") - ErrWrongSoftware = errors.New("STUN request came from non-Tailscale software") - ErrNoFingerprint = errors.New("STUN request didn't end in fingerprint") - ErrWrongFingerprint = errors.New("STUN request had bogus fingerprint") -) - -func foreachAttr(b []byte, fn func(attrType uint16, a []byte) error) error { - for len(b) > 0 { - if len(b) < 4 { - return ErrMalformedAttrs - } - attrType := binary.BigEndian.Uint16(b[:2]) - attrLen := int(binary.BigEndian.Uint16(b[2:4])) - attrLenWithPad := (attrLen + 3) &^ 3 - b = b[4:] - if attrLenWithPad > len(b) { - return ErrMalformedAttrs - } - if err := fn(attrType, b[:attrLen]); err != nil { - return err - } - b = b[attrLenWithPad:] - } - return nil -} - -// Response generates a binding response. -func Response(txID TxID, addrPort netip.AddrPort) []byte { - addr := addrPort.Addr() - - var fam byte - if addr.Is4() { - fam = 1 - } else if addr.Is6() { - fam = 2 - } else { - return nil - } - attrsLen := 8 + addr.BitLen()/8 - b := make([]byte, 0, headerLen+attrsLen) - - // Header - b = append(b, 0x01, 0x01) // success - b = appendU16(b, uint16(attrsLen)) - b = append(b, magicCookie...) - b = append(b, txID[:]...) - - // Attributes (well, one) - b = appendU16(b, attrXorMappedAddress) - b = appendU16(b, uint16(4+addr.BitLen()/8)) - b = append(b, - 0, // unused byte - fam) - b = appendU16(b, addrPort.Port()^0x2112) // first half of magicCookie - ipa := addr.As16() - for i, o := range ipa[16-addr.BitLen()/8:] { - if i < 4 { - b = append(b, o^magicCookie[i]) - } else { - b = append(b, o^txID[i-len(magicCookie)]) - } - } - return b -} - -// ParseResponse parses a successful binding response STUN packet. -// The IP address is extracted from the XOR-MAPPED-ADDRESS attribute. -func ParseResponse(b []byte) (tID TxID, addr netip.AddrPort, err error) { - if !Is(b) { - return tID, netip.AddrPort{}, ErrNotSTUN - } - copy(tID[:], b[8:8+len(tID)]) - if b[0] != 0x01 || b[1] != 0x01 { - return tID, netip.AddrPort{}, ErrNotSuccessResponse - } - attrsLen := int(binary.BigEndian.Uint16(b[2:4])) - b = b[headerLen:] // remove STUN header - if attrsLen > len(b) { - return tID, netip.AddrPort{}, ErrMalformedAttrs - } else if len(b) > attrsLen { - b = b[:attrsLen] // trim trailing packet bytes - } - - var fallbackAddr netip.AddrPort - - // Read through the attributes. - // The the addr+port reported by XOR-MAPPED-ADDRESS - // as the canonical value. If the attribute is not - // present but the STUN server responds with - // MAPPED-ADDRESS we fall back to it. - if err := foreachAttr(b, func(attrType uint16, attr []byte) error { - switch attrType { - case attrXorMappedAddress, attrXorMappedAddressAlt: - ipSlice, port, err := xorMappedAddress(tID, attr) - if err != nil { - return err - } - if ip, ok := netip.AddrFromSlice(ipSlice); ok { - addr = netip.AddrPortFrom(ip.Unmap(), port) - } - case attrMappedAddress: - ipSlice, port, err := mappedAddress(attr) - if err != nil { - return ErrMalformedAttrs - } - if ip, ok := netip.AddrFromSlice(ipSlice); ok { - fallbackAddr = netip.AddrPortFrom(ip.Unmap(), port) - } - } - return nil - - }); err != nil { - return TxID{}, netip.AddrPort{}, err - } - - if addr.IsValid() { - return tID, addr, nil - } - if fallbackAddr.IsValid() { - return tID, fallbackAddr, nil - } - return tID, netip.AddrPort{}, ErrMalformedAttrs -} - -func xorMappedAddress(tID TxID, b []byte) (addr []byte, port uint16, err error) { - // XOR-MAPPED-ADDRESS attribute, RFC5389 Section 15.2 - if len(b) < 4 { - return nil, 0, ErrMalformedAttrs - } - xorPort := binary.BigEndian.Uint16(b[2:4]) - addrField := b[4:] - port = xorPort ^ 0x2112 // first half of magicCookie - - addrLen := familyAddrLen(b[1]) - if addrLen == 0 { - return nil, 0, ErrMalformedAttrs - } - if len(addrField) < addrLen { - return nil, 0, ErrMalformedAttrs - } - xorAddr := addrField[:addrLen] - addr = make([]byte, addrLen) - for i := range xorAddr { - if i < len(magicCookie) { - addr[i] = xorAddr[i] ^ magicCookie[i] - } else { - addr[i] = xorAddr[i] ^ tID[i-len(magicCookie)] - } - } - return addr, port, nil -} - -func familyAddrLen(fam byte) int { - switch fam { - case 0x01: // IPv4 - return net.IPv4len - case 0x02: // IPv6 - return net.IPv6len - default: - return 0 - } -} - -func mappedAddress(b []byte) (addr []byte, port uint16, err error) { - if len(b) < 4 { - return nil, 0, ErrMalformedAttrs - } - port = uint16(b[2])<<8 | uint16(b[3]) - addrField := b[4:] - addrLen := familyAddrLen(b[1]) - if addrLen == 0 { - return nil, 0, ErrMalformedAttrs - } - if len(addrField) < addrLen { - return nil, 0, ErrMalformedAttrs - } - return bytes.Clone(addrField[:addrLen]), port, nil -} - -// Is reports whether b is a STUN message. -func Is(b []byte) bool { - return len(b) >= headerLen && - b[0]&0b11000000 == 0 && // top two bits must be zero - string(b[4:8]) == magicCookie -} +// Copyright (c) Tailscale Inc & AUTHORS
+// SPDX-License-Identifier: BSD-3-Clause
+
+// Package STUN generates STUN request packets and parses response packets.
+package stun
+
+import (
+ "bytes"
+ crand "crypto/rand"
+ "encoding/binary"
+ "errors"
+ "hash/crc32"
+ "net"
+ "net/netip"
+)
+
+const (
+ attrNumSoftware = 0x8022
+ attrNumFingerprint = 0x8028
+ attrMappedAddress = 0x0001
+ attrXorMappedAddress = 0x0020
+ // This alternative attribute type is not
+ // mentioned in the RFC, but the shift into
+ // the "comprehension-optional" range seems
+ // like an easy mistake for a server to make.
+ // And servers appear to send it.
+ attrXorMappedAddressAlt = 0x8020
+
+ software = "tailnode" // notably: 8 bytes long, so no padding
+ bindingRequest = "\x00\x01"
+ magicCookie = "\x21\x12\xa4\x42"
+ lenFingerprint = 8 // 2+byte header + 2-byte length + 4-byte crc32
+ headerLen = 20
+)
+
+// TxID is a transaction ID.
+type TxID [12]byte
+
+// NewTxID returns a new random TxID.
+func NewTxID() TxID {
+ var tx TxID
+ if _, err := crand.Read(tx[:]); err != nil {
+ panic(err)
+ }
+ return tx
+}
+
+// Request generates a binding request STUN packet.
+// The transaction ID, tID, should be a random sequence of bytes.
+func Request(tID TxID) []byte {
+ // STUN header, RFC5389 Section 6.
+ const lenAttrSoftware = 4 + len(software)
+ b := make([]byte, 0, headerLen+lenAttrSoftware+lenFingerprint)
+ b = append(b, bindingRequest...)
+ b = appendU16(b, uint16(lenAttrSoftware+lenFingerprint)) // number of bytes following header
+ b = append(b, magicCookie...)
+ b = append(b, tID[:]...)
+
+ // Attribute SOFTWARE, RFC5389 Section 15.5.
+ b = appendU16(b, attrNumSoftware)
+ b = appendU16(b, uint16(len(software)))
+ b = append(b, software...)
+
+ // Attribute FINGERPRINT, RFC5389 Section 15.5.
+ fp := fingerPrint(b)
+ b = appendU16(b, attrNumFingerprint)
+ b = appendU16(b, 4)
+ b = appendU32(b, fp)
+
+ return b
+}
+
+func fingerPrint(b []byte) uint32 { return crc32.ChecksumIEEE(b) ^ 0x5354554e }
+
+func appendU16(b []byte, v uint16) []byte {
+ return append(b, byte(v>>8), byte(v))
+}
+
+func appendU32(b []byte, v uint32) []byte {
+ return append(b, byte(v>>24), byte(v>>16), byte(v>>8), byte(v))
+}
+
+// ParseBindingRequest parses a STUN binding request.
+//
+// It returns an error unless it advertises that it came from
+// Tailscale.
+func ParseBindingRequest(b []byte) (TxID, error) {
+ if !Is(b) {
+ return TxID{}, ErrNotSTUN
+ }
+ if string(b[:len(bindingRequest)]) != bindingRequest {
+ return TxID{}, ErrNotBindingRequest
+ }
+ var txID TxID
+ copy(txID[:], b[8:8+len(txID)])
+ var softwareOK bool
+ var lastAttr uint16
+ var gotFP uint32
+ if err := foreachAttr(b[headerLen:], func(attrType uint16, a []byte) error {
+ lastAttr = attrType
+ if attrType == attrNumSoftware && string(a) == software {
+ softwareOK = true
+ }
+ if attrType == attrNumFingerprint && len(a) == 4 {
+ gotFP = binary.BigEndian.Uint32(a)
+ }
+ return nil
+ }); err != nil {
+ return TxID{}, err
+ }
+ if !softwareOK {
+ return TxID{}, ErrWrongSoftware
+ }
+ if lastAttr != attrNumFingerprint {
+ return TxID{}, ErrNoFingerprint
+ }
+ wantFP := fingerPrint(b[:len(b)-lenFingerprint])
+ if gotFP != wantFP {
+ return TxID{}, ErrWrongFingerprint
+ }
+ return txID, nil
+}
+
+var (
+ ErrNotSTUN = errors.New("response is not a STUN packet")
+ ErrNotSuccessResponse = errors.New("STUN packet is not a response")
+ ErrMalformedAttrs = errors.New("STUN response has malformed attributes")
+ ErrNotBindingRequest = errors.New("STUN request not a binding request")
+ ErrWrongSoftware = errors.New("STUN request came from non-Tailscale software")
+ ErrNoFingerprint = errors.New("STUN request didn't end in fingerprint")
+ ErrWrongFingerprint = errors.New("STUN request had bogus fingerprint")
+)
+
+func foreachAttr(b []byte, fn func(attrType uint16, a []byte) error) error {
+ for len(b) > 0 {
+ if len(b) < 4 {
+ return ErrMalformedAttrs
+ }
+ attrType := binary.BigEndian.Uint16(b[:2])
+ attrLen := int(binary.BigEndian.Uint16(b[2:4]))
+ attrLenWithPad := (attrLen + 3) &^ 3
+ b = b[4:]
+ if attrLenWithPad > len(b) {
+ return ErrMalformedAttrs
+ }
+ if err := fn(attrType, b[:attrLen]); err != nil {
+ return err
+ }
+ b = b[attrLenWithPad:]
+ }
+ return nil
+}
+
+// Response generates a binding response.
+func Response(txID TxID, addrPort netip.AddrPort) []byte {
+ addr := addrPort.Addr()
+
+ var fam byte
+ if addr.Is4() {
+ fam = 1
+ } else if addr.Is6() {
+ fam = 2
+ } else {
+ return nil
+ }
+ attrsLen := 8 + addr.BitLen()/8
+ b := make([]byte, 0, headerLen+attrsLen)
+
+ // Header
+ b = append(b, 0x01, 0x01) // success
+ b = appendU16(b, uint16(attrsLen))
+ b = append(b, magicCookie...)
+ b = append(b, txID[:]...)
+
+ // Attributes (well, one)
+ b = appendU16(b, attrXorMappedAddress)
+ b = appendU16(b, uint16(4+addr.BitLen()/8))
+ b = append(b,
+ 0, // unused byte
+ fam)
+ b = appendU16(b, addrPort.Port()^0x2112) // first half of magicCookie
+ ipa := addr.As16()
+ for i, o := range ipa[16-addr.BitLen()/8:] {
+ if i < 4 {
+ b = append(b, o^magicCookie[i])
+ } else {
+ b = append(b, o^txID[i-len(magicCookie)])
+ }
+ }
+ return b
+}
+
+// ParseResponse parses a successful binding response STUN packet.
+// The IP address is extracted from the XOR-MAPPED-ADDRESS attribute.
+func ParseResponse(b []byte) (tID TxID, addr netip.AddrPort, err error) {
+ if !Is(b) {
+ return tID, netip.AddrPort{}, ErrNotSTUN
+ }
+ copy(tID[:], b[8:8+len(tID)])
+ if b[0] != 0x01 || b[1] != 0x01 {
+ return tID, netip.AddrPort{}, ErrNotSuccessResponse
+ }
+ attrsLen := int(binary.BigEndian.Uint16(b[2:4]))
+ b = b[headerLen:] // remove STUN header
+ if attrsLen > len(b) {
+ return tID, netip.AddrPort{}, ErrMalformedAttrs
+ } else if len(b) > attrsLen {
+ b = b[:attrsLen] // trim trailing packet bytes
+ }
+
+ var fallbackAddr netip.AddrPort
+
+ // Read through the attributes.
+ // The the addr+port reported by XOR-MAPPED-ADDRESS
+ // as the canonical value. If the attribute is not
+ // present but the STUN server responds with
+ // MAPPED-ADDRESS we fall back to it.
+ if err := foreachAttr(b, func(attrType uint16, attr []byte) error {
+ switch attrType {
+ case attrXorMappedAddress, attrXorMappedAddressAlt:
+ ipSlice, port, err := xorMappedAddress(tID, attr)
+ if err != nil {
+ return err
+ }
+ if ip, ok := netip.AddrFromSlice(ipSlice); ok {
+ addr = netip.AddrPortFrom(ip.Unmap(), port)
+ }
+ case attrMappedAddress:
+ ipSlice, port, err := mappedAddress(attr)
+ if err != nil {
+ return ErrMalformedAttrs
+ }
+ if ip, ok := netip.AddrFromSlice(ipSlice); ok {
+ fallbackAddr = netip.AddrPortFrom(ip.Unmap(), port)
+ }
+ }
+ return nil
+
+ }); err != nil {
+ return TxID{}, netip.AddrPort{}, err
+ }
+
+ if addr.IsValid() {
+ return tID, addr, nil
+ }
+ if fallbackAddr.IsValid() {
+ return tID, fallbackAddr, nil
+ }
+ return tID, netip.AddrPort{}, ErrMalformedAttrs
+}
+
+func xorMappedAddress(tID TxID, b []byte) (addr []byte, port uint16, err error) {
+ // XOR-MAPPED-ADDRESS attribute, RFC5389 Section 15.2
+ if len(b) < 4 {
+ return nil, 0, ErrMalformedAttrs
+ }
+ xorPort := binary.BigEndian.Uint16(b[2:4])
+ addrField := b[4:]
+ port = xorPort ^ 0x2112 // first half of magicCookie
+
+ addrLen := familyAddrLen(b[1])
+ if addrLen == 0 {
+ return nil, 0, ErrMalformedAttrs
+ }
+ if len(addrField) < addrLen {
+ return nil, 0, ErrMalformedAttrs
+ }
+ xorAddr := addrField[:addrLen]
+ addr = make([]byte, addrLen)
+ for i := range xorAddr {
+ if i < len(magicCookie) {
+ addr[i] = xorAddr[i] ^ magicCookie[i]
+ } else {
+ addr[i] = xorAddr[i] ^ tID[i-len(magicCookie)]
+ }
+ }
+ return addr, port, nil
+}
+
+func familyAddrLen(fam byte) int {
+ switch fam {
+ case 0x01: // IPv4
+ return net.IPv4len
+ case 0x02: // IPv6
+ return net.IPv6len
+ default:
+ return 0
+ }
+}
+
+func mappedAddress(b []byte) (addr []byte, port uint16, err error) {
+ if len(b) < 4 {
+ return nil, 0, ErrMalformedAttrs
+ }
+ port = uint16(b[2])<<8 | uint16(b[3])
+ addrField := b[4:]
+ addrLen := familyAddrLen(b[1])
+ if addrLen == 0 {
+ return nil, 0, ErrMalformedAttrs
+ }
+ if len(addrField) < addrLen {
+ return nil, 0, ErrMalformedAttrs
+ }
+ return bytes.Clone(addrField[:addrLen]), port, nil
+}
+
+// Is reports whether b is a STUN message.
+func Is(b []byte) bool {
+ return len(b) >= headerLen &&
+ b[0]&0b11000000 == 0 && // top two bits must be zero
+ string(b[4:8]) == magicCookie
+}
diff --git a/net/stun/stun_fuzzer.go b/net/stun/stun_fuzzer.go index 6f0c9e3b0..9ddb41895 100644 --- a/net/stun/stun_fuzzer.go +++ b/net/stun/stun_fuzzer.go @@ -1,12 +1,12 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause -//go:build gofuzz - -package stun - -func FuzzStunParser(data []byte) int { - _, _, _ = ParseResponse(data) - - _, _ = ParseBindingRequest(data) - return 1 -} +// Copyright (c) Tailscale Inc & AUTHORS
+// SPDX-License-Identifier: BSD-3-Clause
+//go:build gofuzz
+
+package stun
+
+func FuzzStunParser(data []byte) int {
+ _, _, _ = ParseResponse(data)
+
+ _, _ = ParseBindingRequest(data)
+ return 1
+}
diff --git a/net/tcpinfo/tcpinfo.go b/net/tcpinfo/tcpinfo.go index a757add9f..adc40ca37 100644 --- a/net/tcpinfo/tcpinfo.go +++ b/net/tcpinfo/tcpinfo.go @@ -1,51 +1,51 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -// Package tcpinfo provides platform-agnostic accessors to information about a -// TCP connection (e.g. RTT, MSS, etc.). -package tcpinfo - -import ( - "errors" - "net" - "time" -) - -var ( - ErrNotTCP = errors.New("tcpinfo: not a TCP conn") - ErrUnimplemented = errors.New("tcpinfo: unimplemented") -) - -// RTT returns the RTT for the given net.Conn. -// -// If the net.Conn is not a *net.TCPConn and cannot be unwrapped into one, then -// ErrNotTCP will be returned. If retrieving the RTT is not supported on the -// current platform, ErrUnimplemented will be returned. -func RTT(conn net.Conn) (time.Duration, error) { - tcpConn, err := unwrap(conn) - if err != nil { - return 0, err - } - - return rttImpl(tcpConn) -} - -// netConner is implemented by crypto/tls.Conn to unwrap into an underlying -// net.Conn. -type netConner interface { - NetConn() net.Conn -} - -// unwrap attempts to unwrap a net.Conn into an underlying *net.TCPConn -func unwrap(nc net.Conn) (*net.TCPConn, error) { - for { - switch v := nc.(type) { - case *net.TCPConn: - return v, nil - case netConner: - nc = v.NetConn() - default: - return nil, ErrNotTCP - } - } -} +// Copyright (c) Tailscale Inc & AUTHORS
+// SPDX-License-Identifier: BSD-3-Clause
+
+// Package tcpinfo provides platform-agnostic accessors to information about a
+// TCP connection (e.g. RTT, MSS, etc.).
+package tcpinfo
+
+import (
+ "errors"
+ "net"
+ "time"
+)
+
+var (
+ ErrNotTCP = errors.New("tcpinfo: not a TCP conn")
+ ErrUnimplemented = errors.New("tcpinfo: unimplemented")
+)
+
+// RTT returns the RTT for the given net.Conn.
+//
+// If the net.Conn is not a *net.TCPConn and cannot be unwrapped into one, then
+// ErrNotTCP will be returned. If retrieving the RTT is not supported on the
+// current platform, ErrUnimplemented will be returned.
+func RTT(conn net.Conn) (time.Duration, error) {
+ tcpConn, err := unwrap(conn)
+ if err != nil {
+ return 0, err
+ }
+
+ return rttImpl(tcpConn)
+}
+
+// netConner is implemented by crypto/tls.Conn to unwrap into an underlying
+// net.Conn.
+type netConner interface {
+ NetConn() net.Conn
+}
+
+// unwrap attempts to unwrap a net.Conn into an underlying *net.TCPConn
+func unwrap(nc net.Conn) (*net.TCPConn, error) {
+ for {
+ switch v := nc.(type) {
+ case *net.TCPConn:
+ return v, nil
+ case netConner:
+ nc = v.NetConn()
+ default:
+ return nil, ErrNotTCP
+ }
+ }
+}
diff --git a/net/tcpinfo/tcpinfo_darwin.go b/net/tcpinfo/tcpinfo_darwin.go index 53fa22fbf..bc4ac08b3 100644 --- a/net/tcpinfo/tcpinfo_darwin.go +++ b/net/tcpinfo/tcpinfo_darwin.go @@ -1,33 +1,33 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -package tcpinfo - -import ( - "net" - "time" - - "golang.org/x/sys/unix" -) - -func rttImpl(conn *net.TCPConn) (time.Duration, error) { - rawConn, err := conn.SyscallConn() - if err != nil { - return 0, err - } - - var ( - tcpInfo *unix.TCPConnectionInfo - sysErr error - ) - err = rawConn.Control(func(fd uintptr) { - tcpInfo, sysErr = unix.GetsockoptTCPConnectionInfo(int(fd), unix.IPPROTO_TCP, unix.TCP_CONNECTION_INFO) - }) - if err != nil { - return 0, err - } else if sysErr != nil { - return 0, sysErr - } - - return time.Duration(tcpInfo.Rttcur) * time.Millisecond, nil -} +// Copyright (c) Tailscale Inc & AUTHORS
+// SPDX-License-Identifier: BSD-3-Clause
+
+package tcpinfo
+
+import (
+ "net"
+ "time"
+
+ "golang.org/x/sys/unix"
+)
+
+func rttImpl(conn *net.TCPConn) (time.Duration, error) {
+ rawConn, err := conn.SyscallConn()
+ if err != nil {
+ return 0, err
+ }
+
+ var (
+ tcpInfo *unix.TCPConnectionInfo
+ sysErr error
+ )
+ err = rawConn.Control(func(fd uintptr) {
+ tcpInfo, sysErr = unix.GetsockoptTCPConnectionInfo(int(fd), unix.IPPROTO_TCP, unix.TCP_CONNECTION_INFO)
+ })
+ if err != nil {
+ return 0, err
+ } else if sysErr != nil {
+ return 0, sysErr
+ }
+
+ return time.Duration(tcpInfo.Rttcur) * time.Millisecond, nil
+}
diff --git a/net/tcpinfo/tcpinfo_linux.go b/net/tcpinfo/tcpinfo_linux.go index 885d462c9..5d86055bb 100644 --- a/net/tcpinfo/tcpinfo_linux.go +++ b/net/tcpinfo/tcpinfo_linux.go @@ -1,33 +1,33 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -package tcpinfo - -import ( - "net" - "time" - - "golang.org/x/sys/unix" -) - -func rttImpl(conn *net.TCPConn) (time.Duration, error) { - rawConn, err := conn.SyscallConn() - if err != nil { - return 0, err - } - - var ( - tcpInfo *unix.TCPInfo - sysErr error - ) - err = rawConn.Control(func(fd uintptr) { - tcpInfo, sysErr = unix.GetsockoptTCPInfo(int(fd), unix.IPPROTO_TCP, unix.TCP_INFO) - }) - if err != nil { - return 0, err - } else if sysErr != nil { - return 0, sysErr - } - - return time.Duration(tcpInfo.Rtt) * time.Microsecond, nil -} +// Copyright (c) Tailscale Inc & AUTHORS
+// SPDX-License-Identifier: BSD-3-Clause
+
+package tcpinfo
+
+import (
+ "net"
+ "time"
+
+ "golang.org/x/sys/unix"
+)
+
+func rttImpl(conn *net.TCPConn) (time.Duration, error) {
+ rawConn, err := conn.SyscallConn()
+ if err != nil {
+ return 0, err
+ }
+
+ var (
+ tcpInfo *unix.TCPInfo
+ sysErr error
+ )
+ err = rawConn.Control(func(fd uintptr) {
+ tcpInfo, sysErr = unix.GetsockoptTCPInfo(int(fd), unix.IPPROTO_TCP, unix.TCP_INFO)
+ })
+ if err != nil {
+ return 0, err
+ } else if sysErr != nil {
+ return 0, sysErr
+ }
+
+ return time.Duration(tcpInfo.Rtt) * time.Microsecond, nil
+}
diff --git a/net/tcpinfo/tcpinfo_other.go b/net/tcpinfo/tcpinfo_other.go index be45523ae..f219cda1b 100644 --- a/net/tcpinfo/tcpinfo_other.go +++ b/net/tcpinfo/tcpinfo_other.go @@ -1,15 +1,15 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -//go:build !linux && !darwin - -package tcpinfo - -import ( - "net" - "time" -) - -func rttImpl(conn *net.TCPConn) (time.Duration, error) { - return 0, ErrUnimplemented -} +// Copyright (c) Tailscale Inc & AUTHORS
+// SPDX-License-Identifier: BSD-3-Clause
+
+//go:build !linux && !darwin
+
+package tcpinfo
+
+import (
+ "net"
+ "time"
+)
+
+func rttImpl(conn *net.TCPConn) (time.Duration, error) {
+ return 0, ErrUnimplemented
+}
diff --git a/net/tlsdial/deps_test.go b/net/tlsdial/deps_test.go index 7a93899c2..750cb300a 100644 --- a/net/tlsdial/deps_test.go +++ b/net/tlsdial/deps_test.go @@ -1,8 +1,8 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -//go:build for_go_mod_tidy_only - -package tlsdial - -import _ "filippo.io/mkcert" +// Copyright (c) Tailscale Inc & AUTHORS
+// SPDX-License-Identifier: BSD-3-Clause
+
+//go:build for_go_mod_tidy_only
+
+package tlsdial
+
+import _ "filippo.io/mkcert"
diff --git a/net/tsdial/dnsmap_test.go b/net/tsdial/dnsmap_test.go index 43461a135..f846b853e 100644 --- a/net/tsdial/dnsmap_test.go +++ b/net/tsdial/dnsmap_test.go @@ -1,125 +1,125 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -package tsdial - -import ( - "net/netip" - "reflect" - "testing" - - "tailscale.com/tailcfg" - "tailscale.com/types/netmap" -) - -func nodeViews(v []*tailcfg.Node) []tailcfg.NodeView { - nv := make([]tailcfg.NodeView, len(v)) - for i, n := range v { - nv[i] = n.View() - } - return nv -} - -func TestDNSMapFromNetworkMap(t *testing.T) { - pfx := netip.MustParsePrefix - ip := netip.MustParseAddr - tests := []struct { - name string - nm *netmap.NetworkMap - want dnsMap - }{ - { - name: "self", - nm: &netmap.NetworkMap{ - Name: "foo.tailnet", - SelfNode: (&tailcfg.Node{ - Addresses: []netip.Prefix{ - pfx("100.102.103.104/32"), - pfx("100::123/128"), - }, - }).View(), - }, - want: dnsMap{ - "foo": ip("100.102.103.104"), - "foo.tailnet": ip("100.102.103.104"), - }, - }, - { - name: "self_and_peers", - nm: &netmap.NetworkMap{ - Name: "foo.tailnet", - SelfNode: (&tailcfg.Node{ - Addresses: []netip.Prefix{ - pfx("100.102.103.104/32"), - pfx("100::123/128"), - }, - }).View(), - Peers: []tailcfg.NodeView{ - (&tailcfg.Node{ - Name: "a.tailnet", - Addresses: []netip.Prefix{ - pfx("100.0.0.201/32"), - pfx("100::201/128"), - }, - }).View(), - (&tailcfg.Node{ - Name: "b.tailnet", - Addresses: []netip.Prefix{ - pfx("100::202/128"), - }, - }).View(), - }, - }, - want: dnsMap{ - "foo": ip("100.102.103.104"), - "foo.tailnet": ip("100.102.103.104"), - "a": ip("100.0.0.201"), - "a.tailnet": ip("100.0.0.201"), - "b": ip("100::202"), - "b.tailnet": ip("100::202"), - }, - }, - { - name: "self_has_v6_only", - nm: &netmap.NetworkMap{ - Name: "foo.tailnet", - SelfNode: (&tailcfg.Node{ - Addresses: []netip.Prefix{ - pfx("100::123/128"), - }, - }).View(), - Peers: nodeViews([]*tailcfg.Node{ - { - Name: "a.tailnet", - Addresses: []netip.Prefix{ - pfx("100.0.0.201/32"), - pfx("100::201/128"), - }, - }, - { - Name: "b.tailnet", - Addresses: []netip.Prefix{ - pfx("100::202/128"), - }, - }, - }), - }, - want: dnsMap{ - "foo": ip("100::123"), - "foo.tailnet": ip("100::123"), - "a": ip("100::201"), - "a.tailnet": ip("100::201"), - "b": ip("100::202"), - "b.tailnet": ip("100::202"), - }, - }, - } - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - got := dnsMapFromNetworkMap(tt.nm) - if !reflect.DeepEqual(got, tt.want) { - t.Errorf("mismatch:\n got %v\nwant %v\n", got, tt.want) - } - }) - } -} +// Copyright (c) Tailscale Inc & AUTHORS
+// SPDX-License-Identifier: BSD-3-Clause
+
+package tsdial
+
+import (
+ "net/netip"
+ "reflect"
+ "testing"
+
+ "tailscale.com/tailcfg"
+ "tailscale.com/types/netmap"
+)
+
+func nodeViews(v []*tailcfg.Node) []tailcfg.NodeView {
+ nv := make([]tailcfg.NodeView, len(v))
+ for i, n := range v {
+ nv[i] = n.View()
+ }
+ return nv
+}
+
+func TestDNSMapFromNetworkMap(t *testing.T) {
+ pfx := netip.MustParsePrefix
+ ip := netip.MustParseAddr
+ tests := []struct {
+ name string
+ nm *netmap.NetworkMap
+ want dnsMap
+ }{
+ {
+ name: "self",
+ nm: &netmap.NetworkMap{
+ Name: "foo.tailnet",
+ SelfNode: (&tailcfg.Node{
+ Addresses: []netip.Prefix{
+ pfx("100.102.103.104/32"),
+ pfx("100::123/128"),
+ },
+ }).View(),
+ },
+ want: dnsMap{
+ "foo": ip("100.102.103.104"),
+ "foo.tailnet": ip("100.102.103.104"),
+ },
+ },
+ {
+ name: "self_and_peers",
+ nm: &netmap.NetworkMap{
+ Name: "foo.tailnet",
+ SelfNode: (&tailcfg.Node{
+ Addresses: []netip.Prefix{
+ pfx("100.102.103.104/32"),
+ pfx("100::123/128"),
+ },
+ }).View(),
+ Peers: []tailcfg.NodeView{
+ (&tailcfg.Node{
+ Name: "a.tailnet",
+ Addresses: []netip.Prefix{
+ pfx("100.0.0.201/32"),
+ pfx("100::201/128"),
+ },
+ }).View(),
+ (&tailcfg.Node{
+ Name: "b.tailnet",
+ Addresses: []netip.Prefix{
+ pfx("100::202/128"),
+ },
+ }).View(),
+ },
+ },
+ want: dnsMap{
+ "foo": ip("100.102.103.104"),
+ "foo.tailnet": ip("100.102.103.104"),
+ "a": ip("100.0.0.201"),
+ "a.tailnet": ip("100.0.0.201"),
+ "b": ip("100::202"),
+ "b.tailnet": ip("100::202"),
+ },
+ },
+ {
+ name: "self_has_v6_only",
+ nm: &netmap.NetworkMap{
+ Name: "foo.tailnet",
+ SelfNode: (&tailcfg.Node{
+ Addresses: []netip.Prefix{
+ pfx("100::123/128"),
+ },
+ }).View(),
+ Peers: nodeViews([]*tailcfg.Node{
+ {
+ Name: "a.tailnet",
+ Addresses: []netip.Prefix{
+ pfx("100.0.0.201/32"),
+ pfx("100::201/128"),
+ },
+ },
+ {
+ Name: "b.tailnet",
+ Addresses: []netip.Prefix{
+ pfx("100::202/128"),
+ },
+ },
+ }),
+ },
+ want: dnsMap{
+ "foo": ip("100::123"),
+ "foo.tailnet": ip("100::123"),
+ "a": ip("100::201"),
+ "a.tailnet": ip("100::201"),
+ "b": ip("100::202"),
+ "b.tailnet": ip("100::202"),
+ },
+ },
+ }
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ got := dnsMapFromNetworkMap(tt.nm)
+ if !reflect.DeepEqual(got, tt.want) {
+ t.Errorf("mismatch:\n got %v\nwant %v\n", got, tt.want)
+ }
+ })
+ }
+}
diff --git a/net/tsdial/dohclient.go b/net/tsdial/dohclient.go index d830398cd..64c127fd3 100644 --- a/net/tsdial/dohclient.go +++ b/net/tsdial/dohclient.go @@ -1,100 +1,100 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -package tsdial - -import ( - "bytes" - "context" - "errors" - "fmt" - "io" - "net" - "net/http" - "time" - - "tailscale.com/net/dnscache" -) - -// dohConn is a net.PacketConn suitable for returning from -// net.Dialer.Dial to send DNS queries over PeerAPI to exit nodes' -// ExitDNS DoH proxy service. -type dohConn struct { - ctx context.Context - baseURL string - hc *http.Client // if nil, default is used - dnsCache *dnscache.MessageCache - - rbuf bytes.Buffer -} - -var ( - _ net.Conn = (*dohConn)(nil) - _ net.PacketConn = (*dohConn)(nil) // be a PacketConn to change net.Resolver semantics -) - -func (*dohConn) Close() error { return nil } -func (*dohConn) LocalAddr() net.Addr { return todoAddr{} } -func (*dohConn) RemoteAddr() net.Addr { return todoAddr{} } -func (*dohConn) SetDeadline(t time.Time) error { return nil } -func (*dohConn) SetReadDeadline(t time.Time) error { return nil } -func (*dohConn) SetWriteDeadline(t time.Time) error { return nil } - -func (c *dohConn) WriteTo(p []byte, addr net.Addr) (n int, err error) { - return c.Write(p) -} - -func (c *dohConn) ReadFrom(p []byte) (n int, addr net.Addr, err error) { - n, err = c.Read(p) - return n, todoAddr{}, err -} - -func (c *dohConn) Read(p []byte) (n int, err error) { - return c.rbuf.Read(p) -} - -func (c *dohConn) Write(packet []byte) (n int, err error) { - if c.dnsCache != nil { - err := c.dnsCache.ReplyFromCache(&c.rbuf, packet) - if err == nil { - // Cache hit. - // TODO(bradfitz): add clientmetric - return len(packet), nil - } - c.rbuf.Reset() - } - req, err := http.NewRequestWithContext(c.ctx, "POST", c.baseURL, bytes.NewReader(packet)) - if err != nil { - return 0, err - } - const dohType = "application/dns-message" - req.Header.Set("Content-Type", dohType) - hc := c.hc - if hc == nil { - hc = http.DefaultClient - } - hres, err := hc.Do(req) - if err != nil { - return 0, err - } - defer hres.Body.Close() - if hres.StatusCode != 200 { - return 0, errors.New(hres.Status) - } - if ct := hres.Header.Get("Content-Type"); ct != dohType { - return 0, fmt.Errorf("unexpected response Content-Type %q", ct) - } - _, err = io.Copy(&c.rbuf, hres.Body) - if err != nil { - return 0, err - } - if c.dnsCache != nil { - c.dnsCache.AddCacheEntry(packet, c.rbuf.Bytes()) - } - return len(packet), nil -} - -type todoAddr struct{} - -func (todoAddr) Network() string { return "unused" } -func (todoAddr) String() string { return "unused-todoAddr" } +// Copyright (c) Tailscale Inc & AUTHORS
+// SPDX-License-Identifier: BSD-3-Clause
+
+package tsdial
+
+import (
+ "bytes"
+ "context"
+ "errors"
+ "fmt"
+ "io"
+ "net"
+ "net/http"
+ "time"
+
+ "tailscale.com/net/dnscache"
+)
+
+// dohConn is a net.PacketConn suitable for returning from
+// net.Dialer.Dial to send DNS queries over PeerAPI to exit nodes'
+// ExitDNS DoH proxy service.
+type dohConn struct {
+ ctx context.Context
+ baseURL string
+ hc *http.Client // if nil, default is used
+ dnsCache *dnscache.MessageCache
+
+ rbuf bytes.Buffer
+}
+
+var (
+ _ net.Conn = (*dohConn)(nil)
+ _ net.PacketConn = (*dohConn)(nil) // be a PacketConn to change net.Resolver semantics
+)
+
+func (*dohConn) Close() error { return nil }
+func (*dohConn) LocalAddr() net.Addr { return todoAddr{} }
+func (*dohConn) RemoteAddr() net.Addr { return todoAddr{} }
+func (*dohConn) SetDeadline(t time.Time) error { return nil }
+func (*dohConn) SetReadDeadline(t time.Time) error { return nil }
+func (*dohConn) SetWriteDeadline(t time.Time) error { return nil }
+
+func (c *dohConn) WriteTo(p []byte, addr net.Addr) (n int, err error) {
+ return c.Write(p)
+}
+
+func (c *dohConn) ReadFrom(p []byte) (n int, addr net.Addr, err error) {
+ n, err = c.Read(p)
+ return n, todoAddr{}, err
+}
+
+func (c *dohConn) Read(p []byte) (n int, err error) {
+ return c.rbuf.Read(p)
+}
+
+func (c *dohConn) Write(packet []byte) (n int, err error) {
+ if c.dnsCache != nil {
+ err := c.dnsCache.ReplyFromCache(&c.rbuf, packet)
+ if err == nil {
+ // Cache hit.
+ // TODO(bradfitz): add clientmetric
+ return len(packet), nil
+ }
+ c.rbuf.Reset()
+ }
+ req, err := http.NewRequestWithContext(c.ctx, "POST", c.baseURL, bytes.NewReader(packet))
+ if err != nil {
+ return 0, err
+ }
+ const dohType = "application/dns-message"
+ req.Header.Set("Content-Type", dohType)
+ hc := c.hc
+ if hc == nil {
+ hc = http.DefaultClient
+ }
+ hres, err := hc.Do(req)
+ if err != nil {
+ return 0, err
+ }
+ defer hres.Body.Close()
+ if hres.StatusCode != 200 {
+ return 0, errors.New(hres.Status)
+ }
+ if ct := hres.Header.Get("Content-Type"); ct != dohType {
+ return 0, fmt.Errorf("unexpected response Content-Type %q", ct)
+ }
+ _, err = io.Copy(&c.rbuf, hres.Body)
+ if err != nil {
+ return 0, err
+ }
+ if c.dnsCache != nil {
+ c.dnsCache.AddCacheEntry(packet, c.rbuf.Bytes())
+ }
+ return len(packet), nil
+}
+
+type todoAddr struct{}
+
+func (todoAddr) Network() string { return "unused" }
+func (todoAddr) String() string { return "unused-todoAddr" }
diff --git a/net/tsdial/dohclient_test.go b/net/tsdial/dohclient_test.go index 23255769f..41a66f8f7 100644 --- a/net/tsdial/dohclient_test.go +++ b/net/tsdial/dohclient_test.go @@ -1,31 +1,31 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -package tsdial - -import ( - "context" - "flag" - "net" - "testing" - "time" -) - -var dohBase = flag.String("doh-base", "", "DoH base URL for manual DoH tests; e.g. \"http://100.68.82.120:47830/dns-query\"") - -func TestDoHResolve(t *testing.T) { - if *dohBase == "" { - t.Skip("skipping manual test without --doh-base= set") - } - ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) - defer cancel() - var r net.Resolver - r.Dial = func(ctx context.Context, network, address string) (net.Conn, error) { - return &dohConn{ctx: ctx, baseURL: *dohBase}, nil - } - addrs, err := r.LookupIP(ctx, "ip4", "google.com.") - if err != nil { - t.Fatal(err) - } - t.Logf("Got: %q", addrs) -} +// Copyright (c) Tailscale Inc & AUTHORS
+// SPDX-License-Identifier: BSD-3-Clause
+
+package tsdial
+
+import (
+ "context"
+ "flag"
+ "net"
+ "testing"
+ "time"
+)
+
+var dohBase = flag.String("doh-base", "", "DoH base URL for manual DoH tests; e.g. \"http://100.68.82.120:47830/dns-query\"")
+
+func TestDoHResolve(t *testing.T) {
+ if *dohBase == "" {
+ t.Skip("skipping manual test without --doh-base= set")
+ }
+ ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
+ defer cancel()
+ var r net.Resolver
+ r.Dial = func(ctx context.Context, network, address string) (net.Conn, error) {
+ return &dohConn{ctx: ctx, baseURL: *dohBase}, nil
+ }
+ addrs, err := r.LookupIP(ctx, "ip4", "google.com.")
+ if err != nil {
+ t.Fatal(err)
+ }
+ t.Logf("Got: %q", addrs)
+}
diff --git a/net/tshttpproxy/mksyscall.go b/net/tshttpproxy/mksyscall.go index f8fdae89b..467dc4917 100644 --- a/net/tshttpproxy/mksyscall.go +++ b/net/tshttpproxy/mksyscall.go @@ -1,11 +1,11 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -package tshttpproxy - -//go:generate go run golang.org/x/sys/windows/mkwinsyscall -output zsyscall_windows.go mksyscall.go - -//sys globalFree(hglobal winHGlobal) (err error) [failretval==0] = kernel32.GlobalFree -//sys winHTTPCloseHandle(whi winHTTPInternet) (err error) [failretval==0] = winhttp.WinHttpCloseHandle -//sys winHTTPGetProxyForURL(whi winHTTPInternet, url *uint16, options *winHTTPAutoProxyOptions, proxyInfo *winHTTPProxyInfo) (err error) [failretval==0] = winhttp.WinHttpGetProxyForUrl -//sys winHTTPOpen(agent *uint16, accessType uint32, proxy *uint16, proxyBypass *uint16, flags uint32) (whi winHTTPInternet, err error) [failretval==0] = winhttp.WinHttpOpen +// Copyright (c) Tailscale Inc & AUTHORS
+// SPDX-License-Identifier: BSD-3-Clause
+
+package tshttpproxy
+
+//go:generate go run golang.org/x/sys/windows/mkwinsyscall -output zsyscall_windows.go mksyscall.go
+
+//sys globalFree(hglobal winHGlobal) (err error) [failretval==0] = kernel32.GlobalFree
+//sys winHTTPCloseHandle(whi winHTTPInternet) (err error) [failretval==0] = winhttp.WinHttpCloseHandle
+//sys winHTTPGetProxyForURL(whi winHTTPInternet, url *uint16, options *winHTTPAutoProxyOptions, proxyInfo *winHTTPProxyInfo) (err error) [failretval==0] = winhttp.WinHttpGetProxyForUrl
+//sys winHTTPOpen(agent *uint16, accessType uint32, proxy *uint16, proxyBypass *uint16, flags uint32) (whi winHTTPInternet, err error) [failretval==0] = winhttp.WinHttpOpen
diff --git a/net/tshttpproxy/tshttpproxy_linux.go b/net/tshttpproxy/tshttpproxy_linux.go index b241c256d..09019893a 100644 --- a/net/tshttpproxy/tshttpproxy_linux.go +++ b/net/tshttpproxy/tshttpproxy_linux.go @@ -1,24 +1,24 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -//go:build linux - -package tshttpproxy - -import ( - "net/http" - "net/url" - - "tailscale.com/version/distro" -) - -func init() { - sysProxyFromEnv = linuxSysProxyFromEnv -} - -func linuxSysProxyFromEnv(req *http.Request) (*url.URL, error) { - if distro.Get() == distro.Synology { - return synologyProxyFromConfigCached(req) - } - return nil, nil -} +// Copyright (c) Tailscale Inc & AUTHORS
+// SPDX-License-Identifier: BSD-3-Clause
+
+//go:build linux
+
+package tshttpproxy
+
+import (
+ "net/http"
+ "net/url"
+
+ "tailscale.com/version/distro"
+)
+
+func init() {
+ sysProxyFromEnv = linuxSysProxyFromEnv
+}
+
+func linuxSysProxyFromEnv(req *http.Request) (*url.URL, error) {
+ if distro.Get() == distro.Synology {
+ return synologyProxyFromConfigCached(req)
+ }
+ return nil, nil
+}
diff --git a/net/tshttpproxy/tshttpproxy_synology_test.go b/net/tshttpproxy/tshttpproxy_synology_test.go index 3061740f3..e11c9d059 100644 --- a/net/tshttpproxy/tshttpproxy_synology_test.go +++ b/net/tshttpproxy/tshttpproxy_synology_test.go @@ -1,376 +1,376 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -//go:build linux - -package tshttpproxy - -import ( - "errors" - "fmt" - "io" - "net/http" - "net/url" - "os" - "path/filepath" - "strings" - "testing" - "time" - - "tailscale.com/tstest" -) - -func TestSynologyProxyFromConfigCached(t *testing.T) { - req, err := http.NewRequest("GET", "http://example.org/", nil) - if err != nil { - t.Fatal(err) - } - - tstest.Replace(t, &synologyProxyConfigPath, filepath.Join(t.TempDir(), "proxy.conf")) - - t.Run("no config file", func(t *testing.T) { - if _, err := os.Stat(synologyProxyConfigPath); err == nil { - t.Fatalf("%s must not exist for this test", synologyProxyConfigPath) - } - - cache.updated = time.Time{} - cache.httpProxy = nil - cache.httpsProxy = nil - - if val, err := synologyProxyFromConfigCached(req); val != nil || err != nil { - t.Fatalf("got %s, %v; want nil, nil", val, err) - } - - if got, want := cache.updated, time.Unix(0, 0); got != want { - t.Fatalf("got %s, want %s", got, want) - } - if cache.httpProxy != nil { - t.Fatalf("got %s, want nil", cache.httpProxy) - } - if cache.httpsProxy != nil { - t.Fatalf("got %s, want nil", cache.httpsProxy) - } - }) - - t.Run("config file updated", func(t *testing.T) { - cache.updated = time.Now() - cache.httpProxy = nil - cache.httpsProxy = nil - - if err := os.WriteFile(synologyProxyConfigPath, []byte(` -proxy_enabled=yes -http_host=10.0.0.55 -http_port=80 -https_host=10.0.0.66 -https_port=443 - `), 0600); err != nil { - t.Fatal(err) - } - - val, err := synologyProxyFromConfigCached(req) - if err != nil { - t.Fatal(err) - } - - if cache.httpProxy == nil { - t.Fatal("http proxy was not cached") - } - if cache.httpsProxy == nil { - t.Fatal("https proxy was not cached") - } - - if want := urlMustParse("http://10.0.0.55:80"); val.String() != want.String() { - t.Fatalf("got %s; want %s", val, want) - } - }) - - t.Run("config file removed", func(t *testing.T) { - cache.updated = time.Now() - cache.httpProxy = urlMustParse("http://127.0.0.1/") - cache.httpsProxy = urlMustParse("http://127.0.0.1/") - - if err := os.Remove(synologyProxyConfigPath); err != nil && !os.IsNotExist(err) { - t.Fatal(err) - } - - val, err := synologyProxyFromConfigCached(req) - if err != nil { - t.Fatal(err) - } - if val != nil { - t.Fatalf("got %s; want nil", val) - } - if cache.httpProxy != nil { - t.Fatalf("got %s, want nil", cache.httpProxy) - } - if cache.httpsProxy != nil { - t.Fatalf("got %s, want nil", cache.httpsProxy) - } - }) - - t.Run("picks proxy from request scheme", func(t *testing.T) { - cache.updated = time.Now() - cache.httpProxy = nil - cache.httpsProxy = nil - - if err := os.WriteFile(synologyProxyConfigPath, []byte(` -proxy_enabled=yes -http_host=10.0.0.55 -http_port=80 -https_host=10.0.0.66 -https_port=443 - `), 0600); err != nil { - t.Fatal(err) - } - - httpReq, err := http.NewRequest("GET", "http://example.com", nil) - if err != nil { - t.Fatal(err) - } - val, err := synologyProxyFromConfigCached(httpReq) - if err != nil { - t.Fatal(err) - } - if val == nil { - t.Fatalf("got nil, want an http URL") - } - if got, want := val.String(), "http://10.0.0.55:80"; got != want { - t.Fatalf("got %q, want %q", got, want) - } - - httpsReq, err := http.NewRequest("GET", "https://example.com", nil) - if err != nil { - t.Fatal(err) - } - val, err = synologyProxyFromConfigCached(httpsReq) - if err != nil { - t.Fatal(err) - } - if val == nil { - t.Fatalf("got nil, want an http URL") - } - if got, want := val.String(), "http://10.0.0.66:443"; got != want { - t.Fatalf("got %q, want %q", got, want) - } - }) -} - -func TestSynologyProxiesFromConfig(t *testing.T) { - var ( - openReader io.ReadCloser - openErr error - ) - tstest.Replace(t, &openSynologyProxyConf, func() (io.ReadCloser, error) { - return openReader, openErr - }) - - t.Run("with config", func(t *testing.T) { - mc := &mustCloser{Reader: strings.NewReader(` -proxy_user=foo -proxy_pwd=bar -proxy_enabled=yes -adv_enabled=yes -bypass_enabled=yes -auth_enabled=yes -https_host=10.0.0.66 -https_port=8443 -http_host=10.0.0.55 -http_port=80 - `)} - defer mc.check(t) - openReader = mc - - httpProxy, httpsProxy, err := synologyProxiesFromConfig() - - if got, want := err, openErr; got != want { - t.Fatalf("got %s, want %s", got, want) - } - - if got, want := httpsProxy, urlMustParse("http://foo:bar@10.0.0.66:8443"); got.String() != want.String() { - t.Fatalf("got %s, want %s", got, want) - } - - if got, want := err, openErr; got != want { - t.Fatalf("got %s, want %s", got, want) - } - - if got, want := httpProxy, urlMustParse("http://foo:bar@10.0.0.55:80"); got.String() != want.String() { - t.Fatalf("got %s, want %s", got, want) - } - - }) - - t.Run("nonexistent config", func(t *testing.T) { - openReader = nil - openErr = os.ErrNotExist - - httpProxy, httpsProxy, err := synologyProxiesFromConfig() - if err != nil { - t.Fatalf("expected no error, got %s", err) - } - if httpProxy != nil { - t.Fatalf("expected no url, got %s", httpProxy) - } - if httpsProxy != nil { - t.Fatalf("expected no url, got %s", httpsProxy) - } - }) - - t.Run("error opening config", func(t *testing.T) { - openReader = nil - openErr = errors.New("example error") - - httpProxy, httpsProxy, err := synologyProxiesFromConfig() - if err != openErr { - t.Fatalf("expected %s, got %s", openErr, err) - } - if httpProxy != nil { - t.Fatalf("expected no url, got %s", httpProxy) - } - if httpsProxy != nil { - t.Fatalf("expected no url, got %s", httpsProxy) - } - }) - -} - -func TestParseSynologyConfig(t *testing.T) { - cases := map[string]struct { - input string - httpProxy *url.URL - httpsProxy *url.URL - err error - }{ - "populated": { - input: ` -proxy_user=foo -proxy_pwd=bar -proxy_enabled=yes -adv_enabled=yes -bypass_enabled=yes -auth_enabled=yes -https_host=10.0.0.66 -https_port=8443 -http_host=10.0.0.55 -http_port=80 -`, - httpProxy: urlMustParse("http://foo:bar@10.0.0.55:80"), - httpsProxy: urlMustParse("http://foo:bar@10.0.0.66:8443"), - err: nil, - }, - "no-auth": { - input: ` -proxy_user=foo -proxy_pwd=bar -proxy_enabled=yes -adv_enabled=yes -bypass_enabled=yes -auth_enabled=no -https_host=10.0.0.66 -https_port=8443 -http_host=10.0.0.55 -http_port=80 -`, - httpProxy: urlMustParse("http://10.0.0.55:80"), - httpsProxy: urlMustParse("http://10.0.0.66:8443"), - err: nil, - }, - "http-only": { - input: ` -proxy_user=foo -proxy_pwd=bar -proxy_enabled=yes -adv_enabled=yes -bypass_enabled=yes -auth_enabled=yes -https_host= -https_port=8443 -http_host=10.0.0.55 -http_port=80 -`, - httpProxy: urlMustParse("http://foo:bar@10.0.0.55:80"), - httpsProxy: nil, - err: nil, - }, - "empty": { - input: ` -proxy_user= -proxy_pwd= -proxy_enabled= -adv_enabled= -bypass_enabled= -auth_enabled= -https_host= -https_port= -http_host= -http_port= -`, - httpProxy: nil, - httpsProxy: nil, - err: nil, - }, - } - - for name, example := range cases { - t.Run(name, func(t *testing.T) { - httpProxy, httpsProxy, err := parseSynologyConfig(strings.NewReader(example.input)) - if err != example.err { - t.Fatal(err) - } - if example.err != nil { - return - } - - if example.httpProxy == nil && httpProxy != nil { - t.Fatalf("got %s, want nil", httpProxy) - } - - if example.httpProxy != nil { - if httpProxy == nil { - t.Fatalf("got nil, want %s", example.httpProxy) - } - - if got, want := example.httpProxy.String(), httpProxy.String(); got != want { - t.Fatalf("got %s, want %s", got, want) - } - } - - if example.httpsProxy == nil && httpsProxy != nil { - t.Fatalf("got %s, want nil", httpProxy) - } - - if example.httpsProxy != nil { - if httpsProxy == nil { - t.Fatalf("got nil, want %s", example.httpsProxy) - } - - if got, want := example.httpsProxy.String(), httpsProxy.String(); got != want { - t.Fatalf("got %s, want %s", got, want) - } - } - }) - } -} -func urlMustParse(u string) *url.URL { - r, err := url.Parse(u) - if err != nil { - panic(fmt.Sprintf("urlMustParse: %s", err)) - } - return r -} - -type mustCloser struct { - io.Reader - closed bool -} - -func (m *mustCloser) Close() error { - m.closed = true - return nil -} - -func (m *mustCloser) check(t *testing.T) { - if !m.closed { - t.Errorf("mustCloser wrapping %#v was not closed at time of check", m.Reader) - } -} +// Copyright (c) Tailscale Inc & AUTHORS
+// SPDX-License-Identifier: BSD-3-Clause
+
+//go:build linux
+
+package tshttpproxy
+
+import (
+ "errors"
+ "fmt"
+ "io"
+ "net/http"
+ "net/url"
+ "os"
+ "path/filepath"
+ "strings"
+ "testing"
+ "time"
+
+ "tailscale.com/tstest"
+)
+
+func TestSynologyProxyFromConfigCached(t *testing.T) {
+ req, err := http.NewRequest("GET", "http://example.org/", nil)
+ if err != nil {
+ t.Fatal(err)
+ }
+
+ tstest.Replace(t, &synologyProxyConfigPath, filepath.Join(t.TempDir(), "proxy.conf"))
+
+ t.Run("no config file", func(t *testing.T) {
+ if _, err := os.Stat(synologyProxyConfigPath); err == nil {
+ t.Fatalf("%s must not exist for this test", synologyProxyConfigPath)
+ }
+
+ cache.updated = time.Time{}
+ cache.httpProxy = nil
+ cache.httpsProxy = nil
+
+ if val, err := synologyProxyFromConfigCached(req); val != nil || err != nil {
+ t.Fatalf("got %s, %v; want nil, nil", val, err)
+ }
+
+ if got, want := cache.updated, time.Unix(0, 0); got != want {
+ t.Fatalf("got %s, want %s", got, want)
+ }
+ if cache.httpProxy != nil {
+ t.Fatalf("got %s, want nil", cache.httpProxy)
+ }
+ if cache.httpsProxy != nil {
+ t.Fatalf("got %s, want nil", cache.httpsProxy)
+ }
+ })
+
+ t.Run("config file updated", func(t *testing.T) {
+ cache.updated = time.Now()
+ cache.httpProxy = nil
+ cache.httpsProxy = nil
+
+ if err := os.WriteFile(synologyProxyConfigPath, []byte(`
+proxy_enabled=yes
+http_host=10.0.0.55
+http_port=80
+https_host=10.0.0.66
+https_port=443
+ `), 0600); err != nil {
+ t.Fatal(err)
+ }
+
+ val, err := synologyProxyFromConfigCached(req)
+ if err != nil {
+ t.Fatal(err)
+ }
+
+ if cache.httpProxy == nil {
+ t.Fatal("http proxy was not cached")
+ }
+ if cache.httpsProxy == nil {
+ t.Fatal("https proxy was not cached")
+ }
+
+ if want := urlMustParse("http://10.0.0.55:80"); val.String() != want.String() {
+ t.Fatalf("got %s; want %s", val, want)
+ }
+ })
+
+ t.Run("config file removed", func(t *testing.T) {
+ cache.updated = time.Now()
+ cache.httpProxy = urlMustParse("http://127.0.0.1/")
+ cache.httpsProxy = urlMustParse("http://127.0.0.1/")
+
+ if err := os.Remove(synologyProxyConfigPath); err != nil && !os.IsNotExist(err) {
+ t.Fatal(err)
+ }
+
+ val, err := synologyProxyFromConfigCached(req)
+ if err != nil {
+ t.Fatal(err)
+ }
+ if val != nil {
+ t.Fatalf("got %s; want nil", val)
+ }
+ if cache.httpProxy != nil {
+ t.Fatalf("got %s, want nil", cache.httpProxy)
+ }
+ if cache.httpsProxy != nil {
+ t.Fatalf("got %s, want nil", cache.httpsProxy)
+ }
+ })
+
+ t.Run("picks proxy from request scheme", func(t *testing.T) {
+ cache.updated = time.Now()
+ cache.httpProxy = nil
+ cache.httpsProxy = nil
+
+ if err := os.WriteFile(synologyProxyConfigPath, []byte(`
+proxy_enabled=yes
+http_host=10.0.0.55
+http_port=80
+https_host=10.0.0.66
+https_port=443
+ `), 0600); err != nil {
+ t.Fatal(err)
+ }
+
+ httpReq, err := http.NewRequest("GET", "http://example.com", nil)
+ if err != nil {
+ t.Fatal(err)
+ }
+ val, err := synologyProxyFromConfigCached(httpReq)
+ if err != nil {
+ t.Fatal(err)
+ }
+ if val == nil {
+ t.Fatalf("got nil, want an http URL")
+ }
+ if got, want := val.String(), "http://10.0.0.55:80"; got != want {
+ t.Fatalf("got %q, want %q", got, want)
+ }
+
+ httpsReq, err := http.NewRequest("GET", "https://example.com", nil)
+ if err != nil {
+ t.Fatal(err)
+ }
+ val, err = synologyProxyFromConfigCached(httpsReq)
+ if err != nil {
+ t.Fatal(err)
+ }
+ if val == nil {
+ t.Fatalf("got nil, want an http URL")
+ }
+ if got, want := val.String(), "http://10.0.0.66:443"; got != want {
+ t.Fatalf("got %q, want %q", got, want)
+ }
+ })
+}
+
+func TestSynologyProxiesFromConfig(t *testing.T) {
+ var (
+ openReader io.ReadCloser
+ openErr error
+ )
+ tstest.Replace(t, &openSynologyProxyConf, func() (io.ReadCloser, error) {
+ return openReader, openErr
+ })
+
+ t.Run("with config", func(t *testing.T) {
+ mc := &mustCloser{Reader: strings.NewReader(`
+proxy_user=foo
+proxy_pwd=bar
+proxy_enabled=yes
+adv_enabled=yes
+bypass_enabled=yes
+auth_enabled=yes
+https_host=10.0.0.66
+https_port=8443
+http_host=10.0.0.55
+http_port=80
+ `)}
+ defer mc.check(t)
+ openReader = mc
+
+ httpProxy, httpsProxy, err := synologyProxiesFromConfig()
+
+ if got, want := err, openErr; got != want {
+ t.Fatalf("got %s, want %s", got, want)
+ }
+
+ if got, want := httpsProxy, urlMustParse("http://foo:bar@10.0.0.66:8443"); got.String() != want.String() {
+ t.Fatalf("got %s, want %s", got, want)
+ }
+
+ if got, want := err, openErr; got != want {
+ t.Fatalf("got %s, want %s", got, want)
+ }
+
+ if got, want := httpProxy, urlMustParse("http://foo:bar@10.0.0.55:80"); got.String() != want.String() {
+ t.Fatalf("got %s, want %s", got, want)
+ }
+
+ })
+
+ t.Run("nonexistent config", func(t *testing.T) {
+ openReader = nil
+ openErr = os.ErrNotExist
+
+ httpProxy, httpsProxy, err := synologyProxiesFromConfig()
+ if err != nil {
+ t.Fatalf("expected no error, got %s", err)
+ }
+ if httpProxy != nil {
+ t.Fatalf("expected no url, got %s", httpProxy)
+ }
+ if httpsProxy != nil {
+ t.Fatalf("expected no url, got %s", httpsProxy)
+ }
+ })
+
+ t.Run("error opening config", func(t *testing.T) {
+ openReader = nil
+ openErr = errors.New("example error")
+
+ httpProxy, httpsProxy, err := synologyProxiesFromConfig()
+ if err != openErr {
+ t.Fatalf("expected %s, got %s", openErr, err)
+ }
+ if httpProxy != nil {
+ t.Fatalf("expected no url, got %s", httpProxy)
+ }
+ if httpsProxy != nil {
+ t.Fatalf("expected no url, got %s", httpsProxy)
+ }
+ })
+
+}
+
+func TestParseSynologyConfig(t *testing.T) {
+ cases := map[string]struct {
+ input string
+ httpProxy *url.URL
+ httpsProxy *url.URL
+ err error
+ }{
+ "populated": {
+ input: `
+proxy_user=foo
+proxy_pwd=bar
+proxy_enabled=yes
+adv_enabled=yes
+bypass_enabled=yes
+auth_enabled=yes
+https_host=10.0.0.66
+https_port=8443
+http_host=10.0.0.55
+http_port=80
+`,
+ httpProxy: urlMustParse("http://foo:bar@10.0.0.55:80"),
+ httpsProxy: urlMustParse("http://foo:bar@10.0.0.66:8443"),
+ err: nil,
+ },
+ "no-auth": {
+ input: `
+proxy_user=foo
+proxy_pwd=bar
+proxy_enabled=yes
+adv_enabled=yes
+bypass_enabled=yes
+auth_enabled=no
+https_host=10.0.0.66
+https_port=8443
+http_host=10.0.0.55
+http_port=80
+`,
+ httpProxy: urlMustParse("http://10.0.0.55:80"),
+ httpsProxy: urlMustParse("http://10.0.0.66:8443"),
+ err: nil,
+ },
+ "http-only": {
+ input: `
+proxy_user=foo
+proxy_pwd=bar
+proxy_enabled=yes
+adv_enabled=yes
+bypass_enabled=yes
+auth_enabled=yes
+https_host=
+https_port=8443
+http_host=10.0.0.55
+http_port=80
+`,
+ httpProxy: urlMustParse("http://foo:bar@10.0.0.55:80"),
+ httpsProxy: nil,
+ err: nil,
+ },
+ "empty": {
+ input: `
+proxy_user=
+proxy_pwd=
+proxy_enabled=
+adv_enabled=
+bypass_enabled=
+auth_enabled=
+https_host=
+https_port=
+http_host=
+http_port=
+`,
+ httpProxy: nil,
+ httpsProxy: nil,
+ err: nil,
+ },
+ }
+
+ for name, example := range cases {
+ t.Run(name, func(t *testing.T) {
+ httpProxy, httpsProxy, err := parseSynologyConfig(strings.NewReader(example.input))
+ if err != example.err {
+ t.Fatal(err)
+ }
+ if example.err != nil {
+ return
+ }
+
+ if example.httpProxy == nil && httpProxy != nil {
+ t.Fatalf("got %s, want nil", httpProxy)
+ }
+
+ if example.httpProxy != nil {
+ if httpProxy == nil {
+ t.Fatalf("got nil, want %s", example.httpProxy)
+ }
+
+ if got, want := example.httpProxy.String(), httpProxy.String(); got != want {
+ t.Fatalf("got %s, want %s", got, want)
+ }
+ }
+
+ if example.httpsProxy == nil && httpsProxy != nil {
+ t.Fatalf("got %s, want nil", httpProxy)
+ }
+
+ if example.httpsProxy != nil {
+ if httpsProxy == nil {
+ t.Fatalf("got nil, want %s", example.httpsProxy)
+ }
+
+ if got, want := example.httpsProxy.String(), httpsProxy.String(); got != want {
+ t.Fatalf("got %s, want %s", got, want)
+ }
+ }
+ })
+ }
+}
+func urlMustParse(u string) *url.URL {
+ r, err := url.Parse(u)
+ if err != nil {
+ panic(fmt.Sprintf("urlMustParse: %s", err))
+ }
+ return r
+}
+
+type mustCloser struct {
+ io.Reader
+ closed bool
+}
+
+func (m *mustCloser) Close() error {
+ m.closed = true
+ return nil
+}
+
+func (m *mustCloser) check(t *testing.T) {
+ if !m.closed {
+ t.Errorf("mustCloser wrapping %#v was not closed at time of check", m.Reader)
+ }
+}
diff --git a/net/tshttpproxy/tshttpproxy_windows.go b/net/tshttpproxy/tshttpproxy_windows.go index 06a1f5ae4..cb6b24c83 100644 --- a/net/tshttpproxy/tshttpproxy_windows.go +++ b/net/tshttpproxy/tshttpproxy_windows.go @@ -1,276 +1,276 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -package tshttpproxy - -import ( - "context" - "encoding/base64" - "fmt" - "log" - "net/http" - "net/url" - "runtime" - "strings" - "sync" - "syscall" - "time" - "unsafe" - - "github.com/alexbrainman/sspi/negotiate" - "golang.org/x/sys/windows" - "tailscale.com/hostinfo" - "tailscale.com/syncs" - "tailscale.com/types/logger" - "tailscale.com/util/clientmetric" - "tailscale.com/util/cmpver" -) - -func init() { - sysProxyFromEnv = proxyFromWinHTTPOrCache - sysAuthHeader = sysAuthHeaderWindows -} - -var cachedProxy struct { - sync.Mutex - val *url.URL -} - -// proxyErrorf is a rate-limited logger specifically for errors asking -// WinHTTP for the proxy information. We don't want to log about -// errors often, otherwise the log message itself will generate a new -// HTTP request which ultimately will call back into us to log again, -// forever. So for errors, we only log a bit. -var proxyErrorf = logger.RateLimitedFn(log.Printf, 10*time.Minute, 2 /* burst*/, 10 /* maxCache */) - -var ( - metricSuccess = clientmetric.NewCounter("winhttp_proxy_success") - metricErrDetectionFailed = clientmetric.NewCounter("winhttp_proxy_err_detection_failed") - metricErrInvalidParameters = clientmetric.NewCounter("winhttp_proxy_err_invalid_param") - metricErrDownloadScript = clientmetric.NewCounter("winhttp_proxy_err_download_script") - metricErrTimeout = clientmetric.NewCounter("winhttp_proxy_err_timeout") - metricErrOther = clientmetric.NewCounter("winhttp_proxy_err_other") -) - -func proxyFromWinHTTPOrCache(req *http.Request) (*url.URL, error) { - if req.URL == nil { - return nil, nil - } - urlStr := req.URL.String() - - ctx, cancel := context.WithTimeout(req.Context(), 5*time.Second) - defer cancel() - - type result struct { - proxy *url.URL - err error - } - resc := make(chan result, 1) - go func() { - proxy, err := proxyFromWinHTTP(ctx, urlStr) - resc <- result{proxy, err} - }() - - select { - case res := <-resc: - err := res.err - if err == nil { - metricSuccess.Add(1) - cachedProxy.Lock() - defer cachedProxy.Unlock() - if was, now := fmt.Sprint(cachedProxy.val), fmt.Sprint(res.proxy); was != now { - log.Printf("tshttpproxy: winhttp: updating cached proxy setting from %v to %v", was, now) - } - cachedProxy.val = res.proxy - return res.proxy, nil - } - - // See https://docs.microsoft.com/en-us/windows/win32/winhttp/error-messages - const ( - ERROR_WINHTTP_AUTODETECTION_FAILED = 12180 - ERROR_WINHTTP_UNABLE_TO_DOWNLOAD_SCRIPT = 12167 - ) - if err == syscall.Errno(ERROR_WINHTTP_AUTODETECTION_FAILED) { - metricErrDetectionFailed.Add(1) - setNoProxyUntil(10 * time.Second) - return nil, nil - } - if err == windows.ERROR_INVALID_PARAMETER { - metricErrInvalidParameters.Add(1) - // Seen on Windows 8.1. (https://github.com/tailscale/tailscale/issues/879) - // TODO(bradfitz): figure this out. - setNoProxyUntil(time.Hour) - proxyErrorf("tshttpproxy: winhttp: GetProxyForURL(%q): ERROR_INVALID_PARAMETER [unexpected]", urlStr) - return nil, nil - } - proxyErrorf("tshttpproxy: winhttp: GetProxyForURL(%q): %v/%#v", urlStr, err, err) - if err == syscall.Errno(ERROR_WINHTTP_UNABLE_TO_DOWNLOAD_SCRIPT) { - metricErrDownloadScript.Add(1) - setNoProxyUntil(10 * time.Second) - return nil, nil - } - metricErrOther.Add(1) - return nil, err - case <-ctx.Done(): - metricErrTimeout.Add(1) - cachedProxy.Lock() - defer cachedProxy.Unlock() - proxyErrorf("tshttpproxy: winhttp: GetProxyForURL(%q): timeout; using cached proxy %v", urlStr, cachedProxy.val) - return cachedProxy.val, nil - } -} - -func proxyFromWinHTTP(ctx context.Context, urlStr string) (proxy *url.URL, err error) { - runtime.LockOSThread() - defer runtime.UnlockOSThread() - - whi, err := httpOpen() - if err != nil { - proxyErrorf("winhttp: Open: %v", err) - return nil, err - } - defer whi.Close() - - t0 := time.Now() - v, err := whi.GetProxyForURL(urlStr) - td := time.Since(t0).Round(time.Millisecond) - if err := ctx.Err(); err != nil { - log.Printf("tshttpproxy: winhttp: context canceled, ignoring GetProxyForURL(%q) after %v", urlStr, td) - return nil, err - } - if err != nil { - return nil, err - } - if v == "" { - return nil, nil - } - // Discard all but first proxy value for now. - if i := strings.Index(v, ";"); i != -1 { - v = v[:i] - } - if !strings.HasPrefix(v, "https://") { - v = "http://" + v - } - return url.Parse(v) -} - -var userAgent = windows.StringToUTF16Ptr("Tailscale") - -const ( - winHTTP_ACCESS_TYPE_DEFAULT_PROXY = 0 - winHTTP_ACCESS_TYPE_AUTOMATIC_PROXY = 4 - winHTTP_AUTOPROXY_ALLOW_AUTOCONFIG = 0x00000100 - winHTTP_AUTOPROXY_AUTO_DETECT = 1 - winHTTP_AUTO_DETECT_TYPE_DHCP = 0x00000001 - winHTTP_AUTO_DETECT_TYPE_DNS_A = 0x00000002 -) - -// Windows 8.1 is actually Windows 6.3 under the hood. Yay, marketing! -const win8dot1Ver = "6.3" - -// accessType is the flag we must pass to WinHttpOpen for proxy resolution -// depending on whether or not we're running Windows < 8.1 -var accessType syncs.AtomicValue[uint32] - -func getAccessFlag() uint32 { - if flag, ok := accessType.LoadOk(); ok { - return flag - } - var flag uint32 - if cmpver.Compare(hostinfo.GetOSVersion(), win8dot1Ver) < 0 { - flag = winHTTP_ACCESS_TYPE_DEFAULT_PROXY - } else { - flag = winHTTP_ACCESS_TYPE_AUTOMATIC_PROXY - } - accessType.Store(flag) - return flag -} - -func httpOpen() (winHTTPInternet, error) { - return winHTTPOpen( - userAgent, - getAccessFlag(), - nil, /* WINHTTP_NO_PROXY_NAME */ - nil, /* WINHTTP_NO_PROXY_BYPASS */ - 0, - ) -} - -type winHTTPInternet windows.Handle - -func (hi winHTTPInternet) Close() error { - return winHTTPCloseHandle(hi) -} - -// WINHTTP_AUTOPROXY_OPTIONS -// https://docs.microsoft.com/en-us/windows/win32/api/winhttp/ns-winhttp-winhttp_autoproxy_options -type winHTTPAutoProxyOptions struct { - DwFlags uint32 - DwAutoDetectFlags uint32 - AutoConfigUrl *uint16 - _ uintptr - _ uint32 - FAutoLogonIfChallenged int32 // BOOL -} - -// WINHTTP_PROXY_INFO -// https://docs.microsoft.com/en-us/windows/win32/api/winhttp/ns-winhttp-winhttp_proxy_info -type winHTTPProxyInfo struct { - AccessType uint32 - Proxy *uint16 - ProxyBypass *uint16 -} - -type winHGlobal windows.Handle - -func globalFreeUTF16Ptr(p *uint16) error { - return globalFree((winHGlobal)(unsafe.Pointer(p))) -} - -func (pi *winHTTPProxyInfo) free() { - if pi.Proxy != nil { - globalFreeUTF16Ptr(pi.Proxy) - pi.Proxy = nil - } - if pi.ProxyBypass != nil { - globalFreeUTF16Ptr(pi.ProxyBypass) - pi.ProxyBypass = nil - } -} - -var proxyForURLOpts = &winHTTPAutoProxyOptions{ - DwFlags: winHTTP_AUTOPROXY_ALLOW_AUTOCONFIG | winHTTP_AUTOPROXY_AUTO_DETECT, - DwAutoDetectFlags: winHTTP_AUTO_DETECT_TYPE_DHCP, // | winHTTP_AUTO_DETECT_TYPE_DNS_A, -} - -func (hi winHTTPInternet) GetProxyForURL(urlStr string) (string, error) { - var out winHTTPProxyInfo - err := winHTTPGetProxyForURL( - hi, - windows.StringToUTF16Ptr(urlStr), - proxyForURLOpts, - &out, - ) - if err != nil { - return "", err - } - defer out.free() - return windows.UTF16PtrToString(out.Proxy), nil -} - -func sysAuthHeaderWindows(u *url.URL) (string, error) { - spn := "HTTP/" + u.Hostname() - creds, err := negotiate.AcquireCurrentUserCredentials() - if err != nil { - return "", fmt.Errorf("negotiate.AcquireCurrentUserCredentials: %w", err) - } - defer creds.Release() - - secCtx, token, err := negotiate.NewClientContext(creds, spn) - if err != nil { - return "", fmt.Errorf("negotiate.NewClientContext: %w", err) - } - defer secCtx.Release() - - return "Negotiate " + base64.StdEncoding.EncodeToString(token), nil -} +// Copyright (c) Tailscale Inc & AUTHORS
+// SPDX-License-Identifier: BSD-3-Clause
+
+package tshttpproxy
+
+import (
+ "context"
+ "encoding/base64"
+ "fmt"
+ "log"
+ "net/http"
+ "net/url"
+ "runtime"
+ "strings"
+ "sync"
+ "syscall"
+ "time"
+ "unsafe"
+
+ "github.com/alexbrainman/sspi/negotiate"
+ "golang.org/x/sys/windows"
+ "tailscale.com/hostinfo"
+ "tailscale.com/syncs"
+ "tailscale.com/types/logger"
+ "tailscale.com/util/clientmetric"
+ "tailscale.com/util/cmpver"
+)
+
+func init() {
+ sysProxyFromEnv = proxyFromWinHTTPOrCache
+ sysAuthHeader = sysAuthHeaderWindows
+}
+
+var cachedProxy struct {
+ sync.Mutex
+ val *url.URL
+}
+
+// proxyErrorf is a rate-limited logger specifically for errors asking
+// WinHTTP for the proxy information. We don't want to log about
+// errors often, otherwise the log message itself will generate a new
+// HTTP request which ultimately will call back into us to log again,
+// forever. So for errors, we only log a bit.
+var proxyErrorf = logger.RateLimitedFn(log.Printf, 10*time.Minute, 2 /* burst*/, 10 /* maxCache */)
+
+var (
+ metricSuccess = clientmetric.NewCounter("winhttp_proxy_success")
+ metricErrDetectionFailed = clientmetric.NewCounter("winhttp_proxy_err_detection_failed")
+ metricErrInvalidParameters = clientmetric.NewCounter("winhttp_proxy_err_invalid_param")
+ metricErrDownloadScript = clientmetric.NewCounter("winhttp_proxy_err_download_script")
+ metricErrTimeout = clientmetric.NewCounter("winhttp_proxy_err_timeout")
+ metricErrOther = clientmetric.NewCounter("winhttp_proxy_err_other")
+)
+
+func proxyFromWinHTTPOrCache(req *http.Request) (*url.URL, error) {
+ if req.URL == nil {
+ return nil, nil
+ }
+ urlStr := req.URL.String()
+
+ ctx, cancel := context.WithTimeout(req.Context(), 5*time.Second)
+ defer cancel()
+
+ type result struct {
+ proxy *url.URL
+ err error
+ }
+ resc := make(chan result, 1)
+ go func() {
+ proxy, err := proxyFromWinHTTP(ctx, urlStr)
+ resc <- result{proxy, err}
+ }()
+
+ select {
+ case res := <-resc:
+ err := res.err
+ if err == nil {
+ metricSuccess.Add(1)
+ cachedProxy.Lock()
+ defer cachedProxy.Unlock()
+ if was, now := fmt.Sprint(cachedProxy.val), fmt.Sprint(res.proxy); was != now {
+ log.Printf("tshttpproxy: winhttp: updating cached proxy setting from %v to %v", was, now)
+ }
+ cachedProxy.val = res.proxy
+ return res.proxy, nil
+ }
+
+ // See https://docs.microsoft.com/en-us/windows/win32/winhttp/error-messages
+ const (
+ ERROR_WINHTTP_AUTODETECTION_FAILED = 12180
+ ERROR_WINHTTP_UNABLE_TO_DOWNLOAD_SCRIPT = 12167
+ )
+ if err == syscall.Errno(ERROR_WINHTTP_AUTODETECTION_FAILED) {
+ metricErrDetectionFailed.Add(1)
+ setNoProxyUntil(10 * time.Second)
+ return nil, nil
+ }
+ if err == windows.ERROR_INVALID_PARAMETER {
+ metricErrInvalidParameters.Add(1)
+ // Seen on Windows 8.1. (https://github.com/tailscale/tailscale/issues/879)
+ // TODO(bradfitz): figure this out.
+ setNoProxyUntil(time.Hour)
+ proxyErrorf("tshttpproxy: winhttp: GetProxyForURL(%q): ERROR_INVALID_PARAMETER [unexpected]", urlStr)
+ return nil, nil
+ }
+ proxyErrorf("tshttpproxy: winhttp: GetProxyForURL(%q): %v/%#v", urlStr, err, err)
+ if err == syscall.Errno(ERROR_WINHTTP_UNABLE_TO_DOWNLOAD_SCRIPT) {
+ metricErrDownloadScript.Add(1)
+ setNoProxyUntil(10 * time.Second)
+ return nil, nil
+ }
+ metricErrOther.Add(1)
+ return nil, err
+ case <-ctx.Done():
+ metricErrTimeout.Add(1)
+ cachedProxy.Lock()
+ defer cachedProxy.Unlock()
+ proxyErrorf("tshttpproxy: winhttp: GetProxyForURL(%q): timeout; using cached proxy %v", urlStr, cachedProxy.val)
+ return cachedProxy.val, nil
+ }
+}
+
+func proxyFromWinHTTP(ctx context.Context, urlStr string) (proxy *url.URL, err error) {
+ runtime.LockOSThread()
+ defer runtime.UnlockOSThread()
+
+ whi, err := httpOpen()
+ if err != nil {
+ proxyErrorf("winhttp: Open: %v", err)
+ return nil, err
+ }
+ defer whi.Close()
+
+ t0 := time.Now()
+ v, err := whi.GetProxyForURL(urlStr)
+ td := time.Since(t0).Round(time.Millisecond)
+ if err := ctx.Err(); err != nil {
+ log.Printf("tshttpproxy: winhttp: context canceled, ignoring GetProxyForURL(%q) after %v", urlStr, td)
+ return nil, err
+ }
+ if err != nil {
+ return nil, err
+ }
+ if v == "" {
+ return nil, nil
+ }
+ // Discard all but first proxy value for now.
+ if i := strings.Index(v, ";"); i != -1 {
+ v = v[:i]
+ }
+ if !strings.HasPrefix(v, "https://") {
+ v = "http://" + v
+ }
+ return url.Parse(v)
+}
+
+var userAgent = windows.StringToUTF16Ptr("Tailscale")
+
+const (
+ winHTTP_ACCESS_TYPE_DEFAULT_PROXY = 0
+ winHTTP_ACCESS_TYPE_AUTOMATIC_PROXY = 4
+ winHTTP_AUTOPROXY_ALLOW_AUTOCONFIG = 0x00000100
+ winHTTP_AUTOPROXY_AUTO_DETECT = 1
+ winHTTP_AUTO_DETECT_TYPE_DHCP = 0x00000001
+ winHTTP_AUTO_DETECT_TYPE_DNS_A = 0x00000002
+)
+
+// Windows 8.1 is actually Windows 6.3 under the hood. Yay, marketing!
+const win8dot1Ver = "6.3"
+
+// accessType is the flag we must pass to WinHttpOpen for proxy resolution
+// depending on whether or not we're running Windows < 8.1
+var accessType syncs.AtomicValue[uint32]
+
+func getAccessFlag() uint32 {
+ if flag, ok := accessType.LoadOk(); ok {
+ return flag
+ }
+ var flag uint32
+ if cmpver.Compare(hostinfo.GetOSVersion(), win8dot1Ver) < 0 {
+ flag = winHTTP_ACCESS_TYPE_DEFAULT_PROXY
+ } else {
+ flag = winHTTP_ACCESS_TYPE_AUTOMATIC_PROXY
+ }
+ accessType.Store(flag)
+ return flag
+}
+
+func httpOpen() (winHTTPInternet, error) {
+ return winHTTPOpen(
+ userAgent,
+ getAccessFlag(),
+ nil, /* WINHTTP_NO_PROXY_NAME */
+ nil, /* WINHTTP_NO_PROXY_BYPASS */
+ 0,
+ )
+}
+
+type winHTTPInternet windows.Handle
+
+func (hi winHTTPInternet) Close() error {
+ return winHTTPCloseHandle(hi)
+}
+
+// WINHTTP_AUTOPROXY_OPTIONS
+// https://docs.microsoft.com/en-us/windows/win32/api/winhttp/ns-winhttp-winhttp_autoproxy_options
+type winHTTPAutoProxyOptions struct {
+ DwFlags uint32
+ DwAutoDetectFlags uint32
+ AutoConfigUrl *uint16
+ _ uintptr
+ _ uint32
+ FAutoLogonIfChallenged int32 // BOOL
+}
+
+// WINHTTP_PROXY_INFO
+// https://docs.microsoft.com/en-us/windows/win32/api/winhttp/ns-winhttp-winhttp_proxy_info
+type winHTTPProxyInfo struct {
+ AccessType uint32
+ Proxy *uint16
+ ProxyBypass *uint16
+}
+
+type winHGlobal windows.Handle
+
+func globalFreeUTF16Ptr(p *uint16) error {
+ return globalFree((winHGlobal)(unsafe.Pointer(p)))
+}
+
+func (pi *winHTTPProxyInfo) free() {
+ if pi.Proxy != nil {
+ globalFreeUTF16Ptr(pi.Proxy)
+ pi.Proxy = nil
+ }
+ if pi.ProxyBypass != nil {
+ globalFreeUTF16Ptr(pi.ProxyBypass)
+ pi.ProxyBypass = nil
+ }
+}
+
+var proxyForURLOpts = &winHTTPAutoProxyOptions{
+ DwFlags: winHTTP_AUTOPROXY_ALLOW_AUTOCONFIG | winHTTP_AUTOPROXY_AUTO_DETECT,
+ DwAutoDetectFlags: winHTTP_AUTO_DETECT_TYPE_DHCP, // | winHTTP_AUTO_DETECT_TYPE_DNS_A,
+}
+
+func (hi winHTTPInternet) GetProxyForURL(urlStr string) (string, error) {
+ var out winHTTPProxyInfo
+ err := winHTTPGetProxyForURL(
+ hi,
+ windows.StringToUTF16Ptr(urlStr),
+ proxyForURLOpts,
+ &out,
+ )
+ if err != nil {
+ return "", err
+ }
+ defer out.free()
+ return windows.UTF16PtrToString(out.Proxy), nil
+}
+
+func sysAuthHeaderWindows(u *url.URL) (string, error) {
+ spn := "HTTP/" + u.Hostname()
+ creds, err := negotiate.AcquireCurrentUserCredentials()
+ if err != nil {
+ return "", fmt.Errorf("negotiate.AcquireCurrentUserCredentials: %w", err)
+ }
+ defer creds.Release()
+
+ secCtx, token, err := negotiate.NewClientContext(creds, spn)
+ if err != nil {
+ return "", fmt.Errorf("negotiate.NewClientContext: %w", err)
+ }
+ defer secCtx.Release()
+
+ return "Negotiate " + base64.StdEncoding.EncodeToString(token), nil
+}
diff --git a/net/tstun/fake.go b/net/tstun/fake.go index 3d86bb3df..a002952a3 100644 --- a/net/tstun/fake.go +++ b/net/tstun/fake.go @@ -1,58 +1,58 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -package tstun - -import ( - "io" - "os" - - "github.com/tailscale/wireguard-go/tun" -) - -type fakeTUN struct { - evchan chan tun.Event - closechan chan struct{} -} - -// NewFake returns a tun.Device that does nothing. -func NewFake() tun.Device { - return &fakeTUN{ - evchan: make(chan tun.Event), - closechan: make(chan struct{}), - } -} - -func (t *fakeTUN) File() *os.File { - panic("fakeTUN.File() called, which makes no sense") -} - -func (t *fakeTUN) Close() error { - close(t.closechan) - close(t.evchan) - return nil -} - -func (t *fakeTUN) Read(out [][]byte, sizes []int, offset int) (int, error) { - <-t.closechan - return 0, io.EOF -} - -func (t *fakeTUN) Write(b [][]byte, n int) (int, error) { - select { - case <-t.closechan: - return 0, ErrClosed - default: - } - return 1, nil -} - -// FakeTUNName is the name of the fake TUN device. -const FakeTUNName = "FakeTUN" - -func (t *fakeTUN) Flush() error { return nil } -func (t *fakeTUN) MTU() (int, error) { return 1500, nil } -func (t *fakeTUN) Name() (string, error) { return FakeTUNName, nil } -func (t *fakeTUN) Events() <-chan tun.Event { return t.evchan } -func (t *fakeTUN) BatchSize() int { return 1 } -func (t *fakeTUN) IsFakeTun() bool { return true } +// Copyright (c) Tailscale Inc & AUTHORS
+// SPDX-License-Identifier: BSD-3-Clause
+
+package tstun
+
+import (
+ "io"
+ "os"
+
+ "github.com/tailscale/wireguard-go/tun"
+)
+
+type fakeTUN struct {
+ evchan chan tun.Event
+ closechan chan struct{}
+}
+
+// NewFake returns a tun.Device that does nothing.
+func NewFake() tun.Device {
+ return &fakeTUN{
+ evchan: make(chan tun.Event),
+ closechan: make(chan struct{}),
+ }
+}
+
+func (t *fakeTUN) File() *os.File {
+ panic("fakeTUN.File() called, which makes no sense")
+}
+
+func (t *fakeTUN) Close() error {
+ close(t.closechan)
+ close(t.evchan)
+ return nil
+}
+
+func (t *fakeTUN) Read(out [][]byte, sizes []int, offset int) (int, error) {
+ <-t.closechan
+ return 0, io.EOF
+}
+
+func (t *fakeTUN) Write(b [][]byte, n int) (int, error) {
+ select {
+ case <-t.closechan:
+ return 0, ErrClosed
+ default:
+ }
+ return 1, nil
+}
+
+// FakeTUNName is the name of the fake TUN device.
+const FakeTUNName = "FakeTUN"
+
+func (t *fakeTUN) Flush() error { return nil }
+func (t *fakeTUN) MTU() (int, error) { return 1500, nil }
+func (t *fakeTUN) Name() (string, error) { return FakeTUNName, nil }
+func (t *fakeTUN) Events() <-chan tun.Event { return t.evchan }
+func (t *fakeTUN) BatchSize() int { return 1 }
+func (t *fakeTUN) IsFakeTun() bool { return true }
diff --git a/net/tstun/ifstatus_noop.go b/net/tstun/ifstatus_noop.go index 8cf569f98..4d453b72c 100644 --- a/net/tstun/ifstatus_noop.go +++ b/net/tstun/ifstatus_noop.go @@ -1,18 +1,18 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -//go:build !windows - -package tstun - -import ( - "time" - - "github.com/tailscale/wireguard-go/tun" - "tailscale.com/types/logger" -) - -// Dummy implementation that does nothing. -func waitInterfaceUp(iface tun.Device, timeout time.Duration, logf logger.Logf) error { - return nil -} +// Copyright (c) Tailscale Inc & AUTHORS
+// SPDX-License-Identifier: BSD-3-Clause
+
+//go:build !windows
+
+package tstun
+
+import (
+ "time"
+
+ "github.com/tailscale/wireguard-go/tun"
+ "tailscale.com/types/logger"
+)
+
+// Dummy implementation that does nothing.
+func waitInterfaceUp(iface tun.Device, timeout time.Duration, logf logger.Logf) error {
+ return nil
+}
diff --git a/net/tstun/ifstatus_windows.go b/net/tstun/ifstatus_windows.go index fd9fc2112..6c6377bb4 100644 --- a/net/tstun/ifstatus_windows.go +++ b/net/tstun/ifstatus_windows.go @@ -1,109 +1,109 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -package tstun - -import ( - "fmt" - "sync" - "time" - - "github.com/tailscale/wireguard-go/tun" - "golang.zx2c4.com/wireguard/windows/tunnel/winipcfg" - "tailscale.com/types/logger" -) - -// ifaceWatcher waits for an interface to be up. -type ifaceWatcher struct { - logf logger.Logf - luid winipcfg.LUID - - mu sync.Mutex // guards following - done bool - sig chan bool -} - -// callback is the callback we register with Windows to call when IP interface changes. -func (iw *ifaceWatcher) callback(notificationType winipcfg.MibNotificationType, iface *winipcfg.MibIPInterfaceRow) { - // Probably should check only when MibParameterNotification, but just in case included MibAddInstance also. - if notificationType == winipcfg.MibParameterNotification || notificationType == winipcfg.MibAddInstance { - // Out of paranoia, start a goroutine to finish our work, to return to Windows out of this callback. - go iw.isUp() - } -} - -func (iw *ifaceWatcher) isUp() bool { - iw.mu.Lock() - defer iw.mu.Unlock() - - if iw.done { - // We already know that it's up - return true - } - - if iw.getOperStatus() != winipcfg.IfOperStatusUp { - return false - } - - iw.done = true - iw.sig <- true - return true -} - -func (iw *ifaceWatcher) getOperStatus() winipcfg.IfOperStatus { - ifc, err := iw.luid.Interface() - if err != nil { - iw.logf("iw.luid.Interface error: %v", err) - return 0 - } - return ifc.OperStatus -} - -func waitInterfaceUp(iface tun.Device, timeout time.Duration, logf logger.Logf) error { - iw := &ifaceWatcher{ - luid: winipcfg.LUID(iface.(*tun.NativeTun).LUID()), - logf: logger.WithPrefix(logf, "waitInterfaceUp: "), - } - - // Just in case check the status first - if iw.getOperStatus() == winipcfg.IfOperStatusUp { - iw.logf("TUN interface already up; no need to wait") - return nil - } - - iw.sig = make(chan bool, 1) - cb, err := winipcfg.RegisterInterfaceChangeCallback(iw.callback) - if err != nil { - iw.logf("RegisterInterfaceChangeCallback error: %v", err) - return err - } - defer cb.Unregister() - - t0 := time.Now() - expires := t0.Add(timeout) - ticker := time.NewTicker(10 * time.Second) - defer ticker.Stop() - - for { - iw.logf("waiting for TUN interface to come up...") - - select { - case <-iw.sig: - iw.logf("TUN interface is up after %v", time.Since(t0)) - return nil - case <-ticker.C: - } - - if iw.isUp() { - // Very unlikely to happen - either NotifyIpInterfaceChange doesn't work - // or it came up in the same moment as tick. Indicate this in the log message. - iw.logf("TUN interface is up after %v (on poll, without notification)", time.Since(t0)) - return nil - } - - if expires.Before(time.Now()) { - iw.logf("timeout waiting %v for TUN interface to come up", timeout) - return fmt.Errorf("timeout waiting for TUN interface to come up") - } - } -} +// Copyright (c) Tailscale Inc & AUTHORS
+// SPDX-License-Identifier: BSD-3-Clause
+
+package tstun
+
+import (
+ "fmt"
+ "sync"
+ "time"
+
+ "github.com/tailscale/wireguard-go/tun"
+ "golang.zx2c4.com/wireguard/windows/tunnel/winipcfg"
+ "tailscale.com/types/logger"
+)
+
+// ifaceWatcher waits for an interface to be up.
+type ifaceWatcher struct {
+ logf logger.Logf
+ luid winipcfg.LUID
+
+ mu sync.Mutex // guards following
+ done bool
+ sig chan bool
+}
+
+// callback is the callback we register with Windows to call when IP interface changes.
+func (iw *ifaceWatcher) callback(notificationType winipcfg.MibNotificationType, iface *winipcfg.MibIPInterfaceRow) {
+ // Probably should check only when MibParameterNotification, but just in case included MibAddInstance also.
+ if notificationType == winipcfg.MibParameterNotification || notificationType == winipcfg.MibAddInstance {
+ // Out of paranoia, start a goroutine to finish our work, to return to Windows out of this callback.
+ go iw.isUp()
+ }
+}
+
+func (iw *ifaceWatcher) isUp() bool {
+ iw.mu.Lock()
+ defer iw.mu.Unlock()
+
+ if iw.done {
+ // We already know that it's up
+ return true
+ }
+
+ if iw.getOperStatus() != winipcfg.IfOperStatusUp {
+ return false
+ }
+
+ iw.done = true
+ iw.sig <- true
+ return true
+}
+
+func (iw *ifaceWatcher) getOperStatus() winipcfg.IfOperStatus {
+ ifc, err := iw.luid.Interface()
+ if err != nil {
+ iw.logf("iw.luid.Interface error: %v", err)
+ return 0
+ }
+ return ifc.OperStatus
+}
+
+func waitInterfaceUp(iface tun.Device, timeout time.Duration, logf logger.Logf) error {
+ iw := &ifaceWatcher{
+ luid: winipcfg.LUID(iface.(*tun.NativeTun).LUID()),
+ logf: logger.WithPrefix(logf, "waitInterfaceUp: "),
+ }
+
+ // Just in case check the status first
+ if iw.getOperStatus() == winipcfg.IfOperStatusUp {
+ iw.logf("TUN interface already up; no need to wait")
+ return nil
+ }
+
+ iw.sig = make(chan bool, 1)
+ cb, err := winipcfg.RegisterInterfaceChangeCallback(iw.callback)
+ if err != nil {
+ iw.logf("RegisterInterfaceChangeCallback error: %v", err)
+ return err
+ }
+ defer cb.Unregister()
+
+ t0 := time.Now()
+ expires := t0.Add(timeout)
+ ticker := time.NewTicker(10 * time.Second)
+ defer ticker.Stop()
+
+ for {
+ iw.logf("waiting for TUN interface to come up...")
+
+ select {
+ case <-iw.sig:
+ iw.logf("TUN interface is up after %v", time.Since(t0))
+ return nil
+ case <-ticker.C:
+ }
+
+ if iw.isUp() {
+ // Very unlikely to happen - either NotifyIpInterfaceChange doesn't work
+ // or it came up in the same moment as tick. Indicate this in the log message.
+ iw.logf("TUN interface is up after %v (on poll, without notification)", time.Since(t0))
+ return nil
+ }
+
+ if expires.Before(time.Now()) {
+ iw.logf("timeout waiting %v for TUN interface to come up", timeout)
+ return fmt.Errorf("timeout waiting for TUN interface to come up")
+ }
+ }
+}
diff --git a/net/tstun/linkattrs_linux.go b/net/tstun/linkattrs_linux.go index 681e79269..7f5461109 100644 --- a/net/tstun/linkattrs_linux.go +++ b/net/tstun/linkattrs_linux.go @@ -1,63 +1,63 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -package tstun - -import ( - "github.com/mdlayher/genetlink" - "github.com/mdlayher/netlink" - "github.com/tailscale/wireguard-go/tun" - "golang.org/x/sys/unix" -) - -// setLinkSpeed sets the advertised link speed of the TUN interface. -func setLinkSpeed(iface tun.Device, mbps int) error { - name, err := iface.Name() - if err != nil { - return err - } - - conn, err := genetlink.Dial(&netlink.Config{Strict: true}) - if err != nil { - return err - } - - defer conn.Close() - - f, err := conn.GetFamily(unix.ETHTOOL_GENL_NAME) - if err != nil { - return err - } - - ae := netlink.NewAttributeEncoder() - ae.Nested(unix.ETHTOOL_A_LINKMODES_HEADER, func(nae *netlink.AttributeEncoder) error { - nae.String(unix.ETHTOOL_A_HEADER_DEV_NAME, name) - return nil - }) - ae.Uint32(unix.ETHTOOL_A_LINKMODES_SPEED, uint32(mbps)) - - b, err := ae.Encode() - if err != nil { - return err - } - - _, err = conn.Execute( - genetlink.Message{ - Header: genetlink.Header{ - Command: unix.ETHTOOL_MSG_LINKMODES_SET, - Version: unix.ETHTOOL_GENL_VERSION, - }, - Data: b, - }, - f.ID, - netlink.Request|netlink.Acknowledge, - ) - return err -} - -// setLinkAttrs sets up link attributes that can be queried by external tools. -// Its failure is non-fatal to interface bringup. -func setLinkAttrs(iface tun.Device) error { - // By default the link speed is 10Mbps, which is easily exceeded and causes monitoring tools to complain (#3933). - return setLinkSpeed(iface, unix.SPEED_UNKNOWN) -} +// Copyright (c) Tailscale Inc & AUTHORS
+// SPDX-License-Identifier: BSD-3-Clause
+
+package tstun
+
+import (
+ "github.com/mdlayher/genetlink"
+ "github.com/mdlayher/netlink"
+ "github.com/tailscale/wireguard-go/tun"
+ "golang.org/x/sys/unix"
+)
+
+// setLinkSpeed sets the advertised link speed of the TUN interface.
+func setLinkSpeed(iface tun.Device, mbps int) error {
+ name, err := iface.Name()
+ if err != nil {
+ return err
+ }
+
+ conn, err := genetlink.Dial(&netlink.Config{Strict: true})
+ if err != nil {
+ return err
+ }
+
+ defer conn.Close()
+
+ f, err := conn.GetFamily(unix.ETHTOOL_GENL_NAME)
+ if err != nil {
+ return err
+ }
+
+ ae := netlink.NewAttributeEncoder()
+ ae.Nested(unix.ETHTOOL_A_LINKMODES_HEADER, func(nae *netlink.AttributeEncoder) error {
+ nae.String(unix.ETHTOOL_A_HEADER_DEV_NAME, name)
+ return nil
+ })
+ ae.Uint32(unix.ETHTOOL_A_LINKMODES_SPEED, uint32(mbps))
+
+ b, err := ae.Encode()
+ if err != nil {
+ return err
+ }
+
+ _, err = conn.Execute(
+ genetlink.Message{
+ Header: genetlink.Header{
+ Command: unix.ETHTOOL_MSG_LINKMODES_SET,
+ Version: unix.ETHTOOL_GENL_VERSION,
+ },
+ Data: b,
+ },
+ f.ID,
+ netlink.Request|netlink.Acknowledge,
+ )
+ return err
+}
+
+// setLinkAttrs sets up link attributes that can be queried by external tools.
+// Its failure is non-fatal to interface bringup.
+func setLinkAttrs(iface tun.Device) error {
+ // By default the link speed is 10Mbps, which is easily exceeded and causes monitoring tools to complain (#3933).
+ return setLinkSpeed(iface, unix.SPEED_UNKNOWN)
+}
diff --git a/net/tstun/linkattrs_notlinux.go b/net/tstun/linkattrs_notlinux.go index 7a7b40fc2..45dd000b3 100644 --- a/net/tstun/linkattrs_notlinux.go +++ b/net/tstun/linkattrs_notlinux.go @@ -1,12 +1,12 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -//go:build !linux - -package tstun - -import "github.com/tailscale/wireguard-go/tun" - -func setLinkAttrs(iface tun.Device) error { - return nil -} +// Copyright (c) Tailscale Inc & AUTHORS
+// SPDX-License-Identifier: BSD-3-Clause
+
+//go:build !linux
+
+package tstun
+
+import "github.com/tailscale/wireguard-go/tun"
+
+func setLinkAttrs(iface tun.Device) error {
+ return nil
+}
diff --git a/net/tstun/mtu.go b/net/tstun/mtu.go index 004529c20..b72a19bde 100644 --- a/net/tstun/mtu.go +++ b/net/tstun/mtu.go @@ -1,161 +1,161 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -package tstun - -import ( - "tailscale.com/envknob" -) - -// The MTU (Maximum Transmission Unit) of a network interface is the largest -// packet that can be sent or received through that interface, including all -// headers above the link layer (e.g. IP headers, UDP headers, Wireguard -// headers, etc.). We have to think about several different values of MTU: -// -// Wire MTU: The MTU of an interface underneath the tailscale TUN, e.g. an -// Ethernet network card will default to a 1500 byte MTU. The user may change -// this MTU at any time. -// -// TUN MTU: The current MTU of the tailscale TUN. This MTU is adjusted downward -// to make room for the wireguard/tailscale headers. For example, if the -// underlying network interface's MTU is 1500 bytes, the maximum size of a -// packet entering the tailscale TUN is 1420 bytes. The user may change this MTU -// at any time via the OS's tools (ifconfig, ip, etc.). -// -// User configured initial MTU: The MTU the tailscale TUN should be created -// with, set by the user via TS_DEBUG_MTU. It should be adjusted down from the -// underlying interface MTU by 80 bytes to make room for the wireguard -// headers. This envknob is mostly for debugging. This value is used once at TUN -// creation and ignored thereafter. -// -// User configured current MTU: The MTU set via the OS's tools (ifconfig, ip, -// etc.). This MTU can change at any time. Setting the MTU this way goes through -// the MTU() method of tailscale's TUN wrapper. -// -// Maximum probed MTU: This is the largest MTU size that we send probe packets -// for. -// -// Safe MTU: If the tailscale TUN MTU is set to this value, almost all packets -// will get to their destination. Tailscale defaults to this MTU in the absence -// of path MTU probe information or user MTU configuration. We may occasionally -// find a path that needs a smaller MTU but it is very rare. -// -// Peer MTU: This is the path MTU to a peer's current best endpoint. It defaults -// to the Safe MTU unless we have path MTU probe results that tell us otherwise. -// -// Initial MTU: This is the MTU tailscaled creates the TUN with. In order of -// priority, it is: -// -// 1. If set, the value of TS_DEBUG_MTU clamped to a maximum of 65536 -// 2. If TS_DEBUG_ENABLE_PMTUD is set, the maximum size MTU we probe, minus wg -// overhead -// 3. If TS_DEBUG_ENABLE_PMTUD is not set, the Safe MTU -// -// Current MTU: This the MTU of the tailscale TUN at any given moment -// after TUN creation. In order of priority, it is: -// -// 1. The MTU set by the user via the OS, if it has ever been set -// 2. If TS_DEBUG_ENABLE_PMTUD is set, the maximum size MTU we probe, minus wg -// overhead -// 4. If TS_DEBUG_ENABLE_PMTUD is not set, the Safe MTU - -// TUNMTU is the MTU for the tailscale TUN. -type TUNMTU uint32 - -// WireMTU is the MTU for the underlying network devices. -type WireMTU uint32 - -const ( - // maxTUNMTU is the largest MTU we will consider for the Tailscale - // TUN. This is inherited from wireguard-go and can be surprisingly - // small; on Windows it is currently 2048 - 32 bytes and iOS it is 1700 - // - 32 bytes. - // TODO(val,raggi): On Windows this seems to derive from RIO driver - // constraints in Wireguard but we don't use RIO so could probably make - // this bigger. - maxTUNMTU TUNMTU = TUNMTU(MaxPacketSize) - // safeTUNMTU is the default "safe" MTU for the Tailscale TUN that we - // use in the absence of other information such as path MTU probes. - safeTUNMTU TUNMTU = 1280 -) - -// WireMTUsToProbe is a list of the on-the-wire MTUs we want to probe. Each time -// magicsock discovery begins, it will send a set of pings, one of each size -// listed below. -var WireMTUsToProbe = []WireMTU{ - WireMTU(safeTUNMTU), // Tailscale over Tailscale :) - TUNToWireMTU(safeTUNMTU), // Smallest MTU allowed for IPv6, current default - 1400, // Most common MTU minus a few bytes for tunnels - 1500, // Most common MTU - 8000, // Should fit inside all jumbo frame sizes - 9000, // Most jumbo frames are this size or larger -} - -// wgHeaderLen is the length of all the headers Wireguard adds to a packet -// in the worst case (IPv6). This constant is for use when we can't or -// shouldn't use information about the IP version of a specific packet -// (e.g., calculating the MTU for the Tailscale interface. -// -// A Wireguard header includes: -// -// - 20-byte IPv4 header or 40-byte IPv6 header -// - 8-byte UDP header -// - 4-byte type -// - 4-byte key index -// - 8-byte nonce -// - 16-byte authentication tag -const wgHeaderLen = 40 + 8 + 4 + 4 + 8 + 16 - -// TUNToWireMTU takes the MTU that the Tailscale TUN presents to the user and -// returns the on-the-wire MTU necessary to transmit the largest packet that -// will fit through the TUN, given that we have to add wireguard headers. -func TUNToWireMTU(t TUNMTU) WireMTU { - return WireMTU(t + wgHeaderLen) -} - -// WireToTUNMTU takes the MTU of an underlying network device and returns the -// largest possible MTU for a Tailscale TUN operating on top of that device, -// given that we have to add wireguard headers. -func WireToTUNMTU(w WireMTU) TUNMTU { - if w < wgHeaderLen { - return 0 - } - return TUNMTU(w - wgHeaderLen) -} - -// DefaultTUNMTU returns the MTU we use to set the Tailscale TUN -// MTU. It is also the path MTU that we default to if we have no -// information about the path to a peer. -// -// 1. If set, the value of TS_DEBUG_MTU clamped to a maximum of MaxTUNMTU -// 2. If TS_DEBUG_ENABLE_PMTUD is set, the maximum size MTU we probe, minus wg overhead -// 3. If TS_DEBUG_ENABLE_PMTUD is not set, the Safe MTU -func DefaultTUNMTU() TUNMTU { - if m, ok := envknob.LookupUintSized("TS_DEBUG_MTU", 10, 32); ok { - return min(TUNMTU(m), maxTUNMTU) - } - - debugPMTUD, _ := envknob.LookupBool("TS_DEBUG_ENABLE_PMTUD") - if debugPMTUD { - // TODO: While we are just probing MTU but not generating PTB, - // this has to continue to return the safe MTU. When we add the - // code to generate PTB, this will be: - // - // return WireToTUNMTU(maxProbedWireMTU) - return safeTUNMTU - } - - return safeTUNMTU -} - -// SafeWireMTU returns the wire MTU that is safe to use if we have no -// information about the path MTU to this peer. -func SafeWireMTU() WireMTU { - return TUNToWireMTU(safeTUNMTU) -} - -// DefaultWireMTU returns the default TUN MTU, adjusted for wireguard -// overhead. -func DefaultWireMTU() WireMTU { - return TUNToWireMTU(DefaultTUNMTU()) -} +// Copyright (c) Tailscale Inc & AUTHORS
+// SPDX-License-Identifier: BSD-3-Clause
+
+package tstun
+
+import (
+ "tailscale.com/envknob"
+)
+
+// The MTU (Maximum Transmission Unit) of a network interface is the largest
+// packet that can be sent or received through that interface, including all
+// headers above the link layer (e.g. IP headers, UDP headers, Wireguard
+// headers, etc.). We have to think about several different values of MTU:
+//
+// Wire MTU: The MTU of an interface underneath the tailscale TUN, e.g. an
+// Ethernet network card will default to a 1500 byte MTU. The user may change
+// this MTU at any time.
+//
+// TUN MTU: The current MTU of the tailscale TUN. This MTU is adjusted downward
+// to make room for the wireguard/tailscale headers. For example, if the
+// underlying network interface's MTU is 1500 bytes, the maximum size of a
+// packet entering the tailscale TUN is 1420 bytes. The user may change this MTU
+// at any time via the OS's tools (ifconfig, ip, etc.).
+//
+// User configured initial MTU: The MTU the tailscale TUN should be created
+// with, set by the user via TS_DEBUG_MTU. It should be adjusted down from the
+// underlying interface MTU by 80 bytes to make room for the wireguard
+// headers. This envknob is mostly for debugging. This value is used once at TUN
+// creation and ignored thereafter.
+//
+// User configured current MTU: The MTU set via the OS's tools (ifconfig, ip,
+// etc.). This MTU can change at any time. Setting the MTU this way goes through
+// the MTU() method of tailscale's TUN wrapper.
+//
+// Maximum probed MTU: This is the largest MTU size that we send probe packets
+// for.
+//
+// Safe MTU: If the tailscale TUN MTU is set to this value, almost all packets
+// will get to their destination. Tailscale defaults to this MTU in the absence
+// of path MTU probe information or user MTU configuration. We may occasionally
+// find a path that needs a smaller MTU but it is very rare.
+//
+// Peer MTU: This is the path MTU to a peer's current best endpoint. It defaults
+// to the Safe MTU unless we have path MTU probe results that tell us otherwise.
+//
+// Initial MTU: This is the MTU tailscaled creates the TUN with. In order of
+// priority, it is:
+//
+// 1. If set, the value of TS_DEBUG_MTU clamped to a maximum of 65536
+// 2. If TS_DEBUG_ENABLE_PMTUD is set, the maximum size MTU we probe, minus wg
+// overhead
+// 3. If TS_DEBUG_ENABLE_PMTUD is not set, the Safe MTU
+//
+// Current MTU: This the MTU of the tailscale TUN at any given moment
+// after TUN creation. In order of priority, it is:
+//
+// 1. The MTU set by the user via the OS, if it has ever been set
+// 2. If TS_DEBUG_ENABLE_PMTUD is set, the maximum size MTU we probe, minus wg
+// overhead
+// 4. If TS_DEBUG_ENABLE_PMTUD is not set, the Safe MTU
+
+// TUNMTU is the MTU for the tailscale TUN.
+type TUNMTU uint32
+
+// WireMTU is the MTU for the underlying network devices.
+type WireMTU uint32
+
+const (
+ // maxTUNMTU is the largest MTU we will consider for the Tailscale
+ // TUN. This is inherited from wireguard-go and can be surprisingly
+ // small; on Windows it is currently 2048 - 32 bytes and iOS it is 1700
+ // - 32 bytes.
+ // TODO(val,raggi): On Windows this seems to derive from RIO driver
+ // constraints in Wireguard but we don't use RIO so could probably make
+ // this bigger.
+ maxTUNMTU TUNMTU = TUNMTU(MaxPacketSize)
+ // safeTUNMTU is the default "safe" MTU for the Tailscale TUN that we
+ // use in the absence of other information such as path MTU probes.
+ safeTUNMTU TUNMTU = 1280
+)
+
+// WireMTUsToProbe is a list of the on-the-wire MTUs we want to probe. Each time
+// magicsock discovery begins, it will send a set of pings, one of each size
+// listed below.
+var WireMTUsToProbe = []WireMTU{
+ WireMTU(safeTUNMTU), // Tailscale over Tailscale :)
+ TUNToWireMTU(safeTUNMTU), // Smallest MTU allowed for IPv6, current default
+ 1400, // Most common MTU minus a few bytes for tunnels
+ 1500, // Most common MTU
+ 8000, // Should fit inside all jumbo frame sizes
+ 9000, // Most jumbo frames are this size or larger
+}
+
+// wgHeaderLen is the length of all the headers Wireguard adds to a packet
+// in the worst case (IPv6). This constant is for use when we can't or
+// shouldn't use information about the IP version of a specific packet
+// (e.g., calculating the MTU for the Tailscale interface.
+//
+// A Wireguard header includes:
+//
+// - 20-byte IPv4 header or 40-byte IPv6 header
+// - 8-byte UDP header
+// - 4-byte type
+// - 4-byte key index
+// - 8-byte nonce
+// - 16-byte authentication tag
+const wgHeaderLen = 40 + 8 + 4 + 4 + 8 + 16
+
+// TUNToWireMTU takes the MTU that the Tailscale TUN presents to the user and
+// returns the on-the-wire MTU necessary to transmit the largest packet that
+// will fit through the TUN, given that we have to add wireguard headers.
+func TUNToWireMTU(t TUNMTU) WireMTU {
+ return WireMTU(t + wgHeaderLen)
+}
+
+// WireToTUNMTU takes the MTU of an underlying network device and returns the
+// largest possible MTU for a Tailscale TUN operating on top of that device,
+// given that we have to add wireguard headers.
+func WireToTUNMTU(w WireMTU) TUNMTU {
+ if w < wgHeaderLen {
+ return 0
+ }
+ return TUNMTU(w - wgHeaderLen)
+}
+
+// DefaultTUNMTU returns the MTU we use to set the Tailscale TUN
+// MTU. It is also the path MTU that we default to if we have no
+// information about the path to a peer.
+//
+// 1. If set, the value of TS_DEBUG_MTU clamped to a maximum of MaxTUNMTU
+// 2. If TS_DEBUG_ENABLE_PMTUD is set, the maximum size MTU we probe, minus wg overhead
+// 3. If TS_DEBUG_ENABLE_PMTUD is not set, the Safe MTU
+func DefaultTUNMTU() TUNMTU {
+ if m, ok := envknob.LookupUintSized("TS_DEBUG_MTU", 10, 32); ok {
+ return min(TUNMTU(m), maxTUNMTU)
+ }
+
+ debugPMTUD, _ := envknob.LookupBool("TS_DEBUG_ENABLE_PMTUD")
+ if debugPMTUD {
+ // TODO: While we are just probing MTU but not generating PTB,
+ // this has to continue to return the safe MTU. When we add the
+ // code to generate PTB, this will be:
+ //
+ // return WireToTUNMTU(maxProbedWireMTU)
+ return safeTUNMTU
+ }
+
+ return safeTUNMTU
+}
+
+// SafeWireMTU returns the wire MTU that is safe to use if we have no
+// information about the path MTU to this peer.
+func SafeWireMTU() WireMTU {
+ return TUNToWireMTU(safeTUNMTU)
+}
+
+// DefaultWireMTU returns the default TUN MTU, adjusted for wireguard
+// overhead.
+func DefaultWireMTU() WireMTU {
+ return TUNToWireMTU(DefaultTUNMTU())
+}
diff --git a/net/tstun/mtu_test.go b/net/tstun/mtu_test.go index 8d165bfd3..fc5274ae1 100644 --- a/net/tstun/mtu_test.go +++ b/net/tstun/mtu_test.go @@ -1,99 +1,99 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause -package tstun - -import ( - "os" - "strconv" - "testing" -) - -// Test the default MTU in the presence of various envknobs. -func TestDefaultTunMTU(t *testing.T) { - // Save and restore the envknobs we will be changing. - - // TS_DEBUG_MTU sets the MTU to a specific value. - defer os.Setenv("TS_DEBUG_MTU", os.Getenv("TS_DEBUG_MTU")) - os.Setenv("TS_DEBUG_MTU", "") - - // TS_DEBUG_ENABLE_PMTUD enables path MTU discovery. - defer os.Setenv("TS_DEBUG_ENABLE_PMTUD", os.Getenv("TS_DEBUG_ENABLE_PMTUD")) - os.Setenv("TS_DEBUG_ENABLE_PMTUD", "") - - // With no MTU envknobs set, we should get the conservative MTU. - if DefaultTUNMTU() != safeTUNMTU { - t.Errorf("default TUN MTU = %d, want %d", DefaultTUNMTU(), safeTUNMTU) - } - - // If set, TS_DEBUG_MTU should set the MTU. - mtu := maxTUNMTU - 1 - os.Setenv("TS_DEBUG_MTU", strconv.Itoa(int(mtu))) - if DefaultTUNMTU() != mtu { - t.Errorf("default TUN MTU = %d, want %d, TS_DEBUG_MTU ignored", DefaultTUNMTU(), mtu) - } - - // MTU should be clamped to maxTunMTU. - mtu = maxTUNMTU + 1 - os.Setenv("TS_DEBUG_MTU", strconv.Itoa(int(mtu))) - if DefaultTUNMTU() != maxTUNMTU { - t.Errorf("default TUN MTU = %d, want %d, clamping failed", DefaultTUNMTU(), maxTUNMTU) - } - - // If PMTUD is enabled, the MTU should default to the safe MTU, but only - // if the user hasn't requested a specific MTU. - // - // TODO: When PMTUD is generating PTB responses, this will become the - // largest MTU we probe. - os.Setenv("TS_DEBUG_MTU", "") - os.Setenv("TS_DEBUG_ENABLE_PMTUD", "true") - if DefaultTUNMTU() != safeTUNMTU { - t.Errorf("default TUN MTU = %d, want %d", DefaultTUNMTU(), safeTUNMTU) - } - // TS_DEBUG_MTU should take precedence over TS_DEBUG_ENABLE_PMTUD. - mtu = WireToTUNMTU(MaxPacketSize - 1) - os.Setenv("TS_DEBUG_MTU", strconv.Itoa(int(mtu))) - if DefaultTUNMTU() != mtu { - t.Errorf("default TUN MTU = %d, want %d", DefaultTUNMTU(), mtu) - } -} - -// Test the conversion of wire MTU to/from Tailscale TUN MTU corner cases. -func TestMTUConversion(t *testing.T) { - tests := []struct { - w WireMTU - t TUNMTU - }{ - {w: 0, t: 0}, - {w: wgHeaderLen - 1, t: 0}, - {w: wgHeaderLen, t: 0}, - {w: wgHeaderLen + 1, t: 1}, - {w: 1360, t: 1280}, - {w: 1500, t: 1420}, - {w: 9000, t: 8920}, - } - - for _, tt := range tests { - m := WireToTUNMTU(tt.w) - if m != tt.t { - t.Errorf("conversion of wire MTU %v to TUN MTU = %v, want %v", tt.w, m, tt.t) - } - } - - tests2 := []struct { - t TUNMTU - w WireMTU - }{ - {t: 0, w: wgHeaderLen}, - {t: 1, w: wgHeaderLen + 1}, - {t: 1280, w: 1360}, - {t: 1420, w: 1500}, - {t: 8920, w: 9000}, - } - - for _, tt := range tests2 { - m := TUNToWireMTU(tt.t) - if m != tt.w { - t.Errorf("conversion of TUN MTU %v to wire MTU = %v, want %v", tt.t, m, tt.w) - } - } -} +// Copyright (c) Tailscale Inc & AUTHORS
+// SPDX-License-Identifier: BSD-3-Clause
+package tstun
+
+import (
+ "os"
+ "strconv"
+ "testing"
+)
+
+// Test the default MTU in the presence of various envknobs.
+func TestDefaultTunMTU(t *testing.T) {
+ // Save and restore the envknobs we will be changing.
+
+ // TS_DEBUG_MTU sets the MTU to a specific value.
+ defer os.Setenv("TS_DEBUG_MTU", os.Getenv("TS_DEBUG_MTU"))
+ os.Setenv("TS_DEBUG_MTU", "")
+
+ // TS_DEBUG_ENABLE_PMTUD enables path MTU discovery.
+ defer os.Setenv("TS_DEBUG_ENABLE_PMTUD", os.Getenv("TS_DEBUG_ENABLE_PMTUD"))
+ os.Setenv("TS_DEBUG_ENABLE_PMTUD", "")
+
+ // With no MTU envknobs set, we should get the conservative MTU.
+ if DefaultTUNMTU() != safeTUNMTU {
+ t.Errorf("default TUN MTU = %d, want %d", DefaultTUNMTU(), safeTUNMTU)
+ }
+
+ // If set, TS_DEBUG_MTU should set the MTU.
+ mtu := maxTUNMTU - 1
+ os.Setenv("TS_DEBUG_MTU", strconv.Itoa(int(mtu)))
+ if DefaultTUNMTU() != mtu {
+ t.Errorf("default TUN MTU = %d, want %d, TS_DEBUG_MTU ignored", DefaultTUNMTU(), mtu)
+ }
+
+ // MTU should be clamped to maxTunMTU.
+ mtu = maxTUNMTU + 1
+ os.Setenv("TS_DEBUG_MTU", strconv.Itoa(int(mtu)))
+ if DefaultTUNMTU() != maxTUNMTU {
+ t.Errorf("default TUN MTU = %d, want %d, clamping failed", DefaultTUNMTU(), maxTUNMTU)
+ }
+
+ // If PMTUD is enabled, the MTU should default to the safe MTU, but only
+ // if the user hasn't requested a specific MTU.
+ //
+ // TODO: When PMTUD is generating PTB responses, this will become the
+ // largest MTU we probe.
+ os.Setenv("TS_DEBUG_MTU", "")
+ os.Setenv("TS_DEBUG_ENABLE_PMTUD", "true")
+ if DefaultTUNMTU() != safeTUNMTU {
+ t.Errorf("default TUN MTU = %d, want %d", DefaultTUNMTU(), safeTUNMTU)
+ }
+ // TS_DEBUG_MTU should take precedence over TS_DEBUG_ENABLE_PMTUD.
+ mtu = WireToTUNMTU(MaxPacketSize - 1)
+ os.Setenv("TS_DEBUG_MTU", strconv.Itoa(int(mtu)))
+ if DefaultTUNMTU() != mtu {
+ t.Errorf("default TUN MTU = %d, want %d", DefaultTUNMTU(), mtu)
+ }
+}
+
+// Test the conversion of wire MTU to/from Tailscale TUN MTU corner cases.
+func TestMTUConversion(t *testing.T) {
+ tests := []struct {
+ w WireMTU
+ t TUNMTU
+ }{
+ {w: 0, t: 0},
+ {w: wgHeaderLen - 1, t: 0},
+ {w: wgHeaderLen, t: 0},
+ {w: wgHeaderLen + 1, t: 1},
+ {w: 1360, t: 1280},
+ {w: 1500, t: 1420},
+ {w: 9000, t: 8920},
+ }
+
+ for _, tt := range tests {
+ m := WireToTUNMTU(tt.w)
+ if m != tt.t {
+ t.Errorf("conversion of wire MTU %v to TUN MTU = %v, want %v", tt.w, m, tt.t)
+ }
+ }
+
+ tests2 := []struct {
+ t TUNMTU
+ w WireMTU
+ }{
+ {t: 0, w: wgHeaderLen},
+ {t: 1, w: wgHeaderLen + 1},
+ {t: 1280, w: 1360},
+ {t: 1420, w: 1500},
+ {t: 8920, w: 9000},
+ }
+
+ for _, tt := range tests2 {
+ m := TUNToWireMTU(tt.t)
+ if m != tt.w {
+ t.Errorf("conversion of TUN MTU %v to wire MTU = %v, want %v", tt.t, m, tt.w)
+ }
+ }
+}
diff --git a/net/tstun/tun_linux.go b/net/tstun/tun_linux.go index 9600ceb77..e08f12bc1 100644 --- a/net/tstun/tun_linux.go +++ b/net/tstun/tun_linux.go @@ -1,103 +1,103 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -package tstun - -import ( - "bytes" - "errors" - "os" - "os/exec" - "strings" - "syscall" - - "tailscale.com/types/logger" - "tailscale.com/version/distro" -) - -func init() { - tunDiagnoseFailure = diagnoseLinuxTUNFailure -} - -func diagnoseLinuxTUNFailure(tunName string, logf logger.Logf, createErr error) { - if errors.Is(createErr, syscall.EBUSY) { - logf("TUN device %s is busy; another process probably still has it open (from old version of Tailscale that had a bug)", tunName) - logf("To fix, kill the process that has it open. Find with:\n\n$ sudo lsof -n /dev/net/tun\n\n") - logf("... and then kill those PID(s)") - return - } - - var un syscall.Utsname - err := syscall.Uname(&un) - if err != nil { - logf("no TUN, and failed to look up kernel version: %v", err) - return - } - kernel := utsReleaseField(&un) - logf("Linux kernel version: %s", kernel) - - modprobeOut, err := exec.Command("/sbin/modprobe", "tun").CombinedOutput() - if err == nil { - logf("'modprobe tun' successful") - // Either tun is currently loaded, or it's statically - // compiled into the kernel (which modprobe checks - // with /lib/modules/$(uname -r)/modules.builtin) - // - // So if there's a problem at this point, it's - // probably because /dev/net/tun doesn't exist. - const dev = "/dev/net/tun" - if fi, err := os.Stat(dev); err != nil { - logf("tun module loaded in kernel, but %s does not exist", dev) - } else { - logf("%s: %v", dev, fi.Mode()) - } - - // We failed to find why it failed. Just let our - // caller report the error it got from wireguard-go. - return - } - logf("is CONFIG_TUN enabled in your kernel? `modprobe tun` failed with: %s", modprobeOut) - - switch distro.Get() { - case distro.Debian: - dpkgOut, err := exec.Command("dpkg", "-S", "kernel/drivers/net/tun.ko").CombinedOutput() - if len(bytes.TrimSpace(dpkgOut)) == 0 || err != nil { - logf("tun module not loaded nor found on disk") - return - } - if !bytes.Contains(dpkgOut, []byte(kernel)) { - logf("kernel/drivers/net/tun.ko found on disk, but not for current kernel; are you in middle of a system update and haven't rebooted? found: %s", dpkgOut) - } - case distro.Arch: - findOut, err := exec.Command("find", "/lib/modules/", "-path", "*/net/tun.ko*").CombinedOutput() - if len(bytes.TrimSpace(findOut)) == 0 || err != nil { - logf("tun module not loaded nor found on disk") - return - } - if !bytes.Contains(findOut, []byte(kernel)) { - logf("kernel/drivers/net/tun.ko found on disk, but not for current kernel; are you in middle of a system update and haven't rebooted? found: %s", findOut) - } - case distro.OpenWrt: - out, err := exec.Command("opkg", "list-installed").CombinedOutput() - if err != nil { - logf("error querying OpenWrt installed packages: %s", out) - return - } - for _, pkg := range []string{"kmod-tun", "ca-bundle"} { - if !bytes.Contains(out, []byte(pkg+" - ")) { - logf("Missing required package %s; run: opkg install %s", pkg, pkg) - } - } - } -} - -func utsReleaseField(u *syscall.Utsname) string { - var sb strings.Builder - for _, v := range u.Release { - if v == 0 { - break - } - sb.WriteByte(byte(v)) - } - return strings.TrimSpace(sb.String()) -} +// Copyright (c) Tailscale Inc & AUTHORS
+// SPDX-License-Identifier: BSD-3-Clause
+
+package tstun
+
+import (
+ "bytes"
+ "errors"
+ "os"
+ "os/exec"
+ "strings"
+ "syscall"
+
+ "tailscale.com/types/logger"
+ "tailscale.com/version/distro"
+)
+
+func init() {
+ tunDiagnoseFailure = diagnoseLinuxTUNFailure
+}
+
+func diagnoseLinuxTUNFailure(tunName string, logf logger.Logf, createErr error) {
+ if errors.Is(createErr, syscall.EBUSY) {
+ logf("TUN device %s is busy; another process probably still has it open (from old version of Tailscale that had a bug)", tunName)
+ logf("To fix, kill the process that has it open. Find with:\n\n$ sudo lsof -n /dev/net/tun\n\n")
+ logf("... and then kill those PID(s)")
+ return
+ }
+
+ var un syscall.Utsname
+ err := syscall.Uname(&un)
+ if err != nil {
+ logf("no TUN, and failed to look up kernel version: %v", err)
+ return
+ }
+ kernel := utsReleaseField(&un)
+ logf("Linux kernel version: %s", kernel)
+
+ modprobeOut, err := exec.Command("/sbin/modprobe", "tun").CombinedOutput()
+ if err == nil {
+ logf("'modprobe tun' successful")
+ // Either tun is currently loaded, or it's statically
+ // compiled into the kernel (which modprobe checks
+ // with /lib/modules/$(uname -r)/modules.builtin)
+ //
+ // So if there's a problem at this point, it's
+ // probably because /dev/net/tun doesn't exist.
+ const dev = "/dev/net/tun"
+ if fi, err := os.Stat(dev); err != nil {
+ logf("tun module loaded in kernel, but %s does not exist", dev)
+ } else {
+ logf("%s: %v", dev, fi.Mode())
+ }
+
+ // We failed to find why it failed. Just let our
+ // caller report the error it got from wireguard-go.
+ return
+ }
+ logf("is CONFIG_TUN enabled in your kernel? `modprobe tun` failed with: %s", modprobeOut)
+
+ switch distro.Get() {
+ case distro.Debian:
+ dpkgOut, err := exec.Command("dpkg", "-S", "kernel/drivers/net/tun.ko").CombinedOutput()
+ if len(bytes.TrimSpace(dpkgOut)) == 0 || err != nil {
+ logf("tun module not loaded nor found on disk")
+ return
+ }
+ if !bytes.Contains(dpkgOut, []byte(kernel)) {
+ logf("kernel/drivers/net/tun.ko found on disk, but not for current kernel; are you in middle of a system update and haven't rebooted? found: %s", dpkgOut)
+ }
+ case distro.Arch:
+ findOut, err := exec.Command("find", "/lib/modules/", "-path", "*/net/tun.ko*").CombinedOutput()
+ if len(bytes.TrimSpace(findOut)) == 0 || err != nil {
+ logf("tun module not loaded nor found on disk")
+ return
+ }
+ if !bytes.Contains(findOut, []byte(kernel)) {
+ logf("kernel/drivers/net/tun.ko found on disk, but not for current kernel; are you in middle of a system update and haven't rebooted? found: %s", findOut)
+ }
+ case distro.OpenWrt:
+ out, err := exec.Command("opkg", "list-installed").CombinedOutput()
+ if err != nil {
+ logf("error querying OpenWrt installed packages: %s", out)
+ return
+ }
+ for _, pkg := range []string{"kmod-tun", "ca-bundle"} {
+ if !bytes.Contains(out, []byte(pkg+" - ")) {
+ logf("Missing required package %s; run: opkg install %s", pkg, pkg)
+ }
+ }
+ }
+}
+
+func utsReleaseField(u *syscall.Utsname) string {
+ var sb strings.Builder
+ for _, v := range u.Release {
+ if v == 0 {
+ break
+ }
+ sb.WriteByte(byte(v))
+ }
+ return strings.TrimSpace(sb.String())
+}
diff --git a/net/tstun/tun_macos.go b/net/tstun/tun_macos.go index 3506f05b1..f71494f0b 100644 --- a/net/tstun/tun_macos.go +++ b/net/tstun/tun_macos.go @@ -1,25 +1,25 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -//go:build darwin && !ios - -package tstun - -import ( - "os" - - "tailscale.com/types/logger" -) - -func init() { - tunDiagnoseFailure = diagnoseDarwinTUNFailure -} - -func diagnoseDarwinTUNFailure(tunName string, logf logger.Logf, err error) { - if os.Getuid() != 0 { - logf("failed to create TUN device as non-root user; use 'sudo tailscaled', or run under launchd with 'sudo tailscaled install-system-daemon'") - } - if tunName != "utun" { - logf("failed to create TUN device %q; try using tun device \"utun\" instead for automatic selection", tunName) - } -} +// Copyright (c) Tailscale Inc & AUTHORS
+// SPDX-License-Identifier: BSD-3-Clause
+
+//go:build darwin && !ios
+
+package tstun
+
+import (
+ "os"
+
+ "tailscale.com/types/logger"
+)
+
+func init() {
+ tunDiagnoseFailure = diagnoseDarwinTUNFailure
+}
+
+func diagnoseDarwinTUNFailure(tunName string, logf logger.Logf, err error) {
+ if os.Getuid() != 0 {
+ logf("failed to create TUN device as non-root user; use 'sudo tailscaled', or run under launchd with 'sudo tailscaled install-system-daemon'")
+ }
+ if tunName != "utun" {
+ logf("failed to create TUN device %q; try using tun device \"utun\" instead for automatic selection", tunName)
+ }
+}
diff --git a/net/tstun/tun_notwindows.go b/net/tstun/tun_notwindows.go index 087fcd4ee..60f1c62ba 100644 --- a/net/tstun/tun_notwindows.go +++ b/net/tstun/tun_notwindows.go @@ -1,12 +1,12 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -//go:build !windows - -package tstun - -import "github.com/tailscale/wireguard-go/tun" - -func interfaceName(dev tun.Device) (string, error) { - return dev.Name() -} +// Copyright (c) Tailscale Inc & AUTHORS
+// SPDX-License-Identifier: BSD-3-Clause
+
+//go:build !windows
+
+package tstun
+
+import "github.com/tailscale/wireguard-go/tun"
+
+func interfaceName(dev tun.Device) (string, error) {
+ return dev.Name()
+}
|
