diff options
| author | Naman Sood <mail@nsood.in> | 2021-03-29 14:28:08 -0400 |
|---|---|---|
| committer | Naman Sood <mail@nsood.in> | 2021-03-29 14:28:08 -0400 |
| commit | c0a88a0129ebf0f9886b93b1f4e4f04a7c3bb86f (patch) | |
| tree | 57d5aef2985e3424e5bb6f4c810628aa3ccbf5d0 /wgengine | |
| parent | 47bd3c4cf5543fd7ecb049302c37c1001fa9f2d6 (diff) | |
| parent | a4c679e64691a3f0ba41ad9078312ca67e5e67fd (diff) | |
| download | tailscale-naman/netstack-subnet-routing.tar.xz tailscale-naman/netstack-subnet-routing.zip | |
merge with mainnaman/netstack-subnet-routing
Signed-off-by: Naman Sood <mail@nsood.in>
Diffstat (limited to 'wgengine')
55 files changed, 966 insertions, 5196 deletions
diff --git a/wgengine/filter/filter.go b/wgengine/filter/filter.go index cbb985114..3c4964c34 100644 --- a/wgengine/filter/filter.go +++ b/wgengine/filter/filter.go @@ -14,6 +14,7 @@ import ( "inet.af/netaddr" "tailscale.com/net/flowtrack" "tailscale.com/net/packet" + "tailscale.com/types/ipproto" "tailscale.com/types/logger" ) @@ -182,6 +183,7 @@ func matchesFamily(ms matches, keep func(netaddr.IP) bool) matches { var ret matches for _, m := range ms { var retm Match + retm.IPProto = m.IPProto for _, src := range m.Srcs { if keep(src.IP) { retm.Srcs = append(retm.Srcs, src) @@ -266,7 +268,7 @@ func (f *Filter) CheckTCP(srcIP, dstIP netaddr.IP, dstPort uint16) Response { } pkt.Src.IP = srcIP pkt.Dst.IP = dstIP - pkt.IPProto = packet.TCP + pkt.IPProto = ipproto.TCP pkt.TCPFlags = packet.TCPSyn pkt.Src.Port = 0 pkt.Dst.Port = dstPort @@ -324,7 +326,7 @@ func (f *Filter) runIn4(q *packet.Parsed) (r Response, why string) { } switch q.IPProto { - case packet.ICMPv4: + case ipproto.ICMPv4: if q.IsEchoResponse() || q.IsError() { // ICMP responses are allowed. // TODO(apenwarr): consider using conntrack state. @@ -336,7 +338,7 @@ func (f *Filter) runIn4(q *packet.Parsed) (r Response, why string) { // If any port is open to an IP, allow ICMP to it. return Accept, "icmp ok" } - case packet.TCP: + case ipproto.TCP: // For TCP, we want to allow *outgoing* connections, // which means we want to allow return packets on those // connections. To make this restriction work, we need to @@ -351,20 +353,20 @@ func (f *Filter) runIn4(q *packet.Parsed) (r Response, why string) { if f.matches4.match(q) { return Accept, "tcp ok" } - case packet.UDP: - t := flowtrack.Tuple{Src: q.Src, Dst: q.Dst} + case ipproto.UDP, ipproto.SCTP: + t := flowtrack.Tuple{Proto: q.IPProto, Src: q.Src, Dst: q.Dst} f.state.mu.Lock() _, ok := f.state.lru.Get(t) f.state.mu.Unlock() if ok { - return Accept, "udp cached" + return Accept, "cached" } if f.matches4.match(q) { - return Accept, "udp ok" + return Accept, "ok" } - case packet.TSMP: + case ipproto.TSMP: return Accept, "tsmp ok" default: return Drop, "Unknown proto" @@ -381,7 +383,7 @@ func (f *Filter) runIn6(q *packet.Parsed) (r Response, why string) { } switch q.IPProto { - case packet.ICMPv6: + case ipproto.ICMPv6: if q.IsEchoResponse() || q.IsError() { // ICMP responses are allowed. // TODO(apenwarr): consider using conntrack state. @@ -393,7 +395,7 @@ func (f *Filter) runIn6(q *packet.Parsed) (r Response, why string) { // If any port is open to an IP, allow ICMP to it. return Accept, "icmp ok" } - case packet.TCP: + case ipproto.TCP: // For TCP, we want to allow *outgoing* connections, // which means we want to allow return packets on those // connections. To make this restriction work, we need to @@ -402,25 +404,27 @@ func (f *Filter) runIn6(q *packet.Parsed) (r Response, why string) { // can't be initiated without first sending a SYN. // It happens to also be much faster. // TODO(apenwarr): Skip the rest of decoding in this path? - if q.IPProto == packet.TCP && !q.IsTCPSyn() { + if q.IPProto == ipproto.TCP && !q.IsTCPSyn() { return Accept, "tcp non-syn" } if f.matches6.match(q) { return Accept, "tcp ok" } - case packet.UDP: - t := flowtrack.Tuple{Src: q.Src, Dst: q.Dst} + case ipproto.UDP, ipproto.SCTP: + t := flowtrack.Tuple{Proto: q.IPProto, Src: q.Src, Dst: q.Dst} f.state.mu.Lock() _, ok := f.state.lru.Get(t) f.state.mu.Unlock() if ok { - return Accept, "udp cached" + return Accept, "cached" } if f.matches6.match(q) { - return Accept, "udp ok" + return Accept, "ok" } + case ipproto.TSMP: + return Accept, "tsmp ok" default: return Drop, "Unknown proto" } @@ -429,15 +433,16 @@ func (f *Filter) runIn6(q *packet.Parsed) (r Response, why string) { // runIn runs the output-specific part of the filter logic. func (f *Filter) runOut(q *packet.Parsed) (r Response, why string) { - if q.IPProto != packet.UDP { - return Accept, "ok out" + switch q.IPProto { + case ipproto.UDP, ipproto.SCTP: + tuple := flowtrack.Tuple{ + Proto: q.IPProto, + Src: q.Dst, Dst: q.Src, // src/dst reversed + } + f.state.mu.Lock() + f.state.lru.Add(tuple, nil) + f.state.mu.Unlock() } - - tuple := flowtrack.Tuple{Src: q.Dst, Dst: q.Src} // src/dst reversed - - f.state.mu.Lock() - f.state.lru.Add(tuple, nil) - f.state.mu.Unlock() return Accept, "ok out" } @@ -485,11 +490,11 @@ func (f *Filter) pre(q *packet.Parsed, rf RunFlags, dir direction) Response { } switch q.IPProto { - case packet.Unknown: + case ipproto.Unknown: // Unknown packets are dangerous; always drop them. f.logRateLimit(rf, q, dir, Drop, "unknown") return Drop - case packet.Fragment: + case ipproto.Fragment: // Fragments after the first always need to be passed through. // Very small fragments are considered Junk by Parsed. f.logRateLimit(rf, q, dir, Accept, "fragment") @@ -513,5 +518,5 @@ func omitDropLogging(p *packet.Parsed, dir direction) bool { return false } - return p.Dst.IP.IsMulticast() || (p.Dst.IP.IsLinkLocalUnicast() && p.Dst.IP != gcpDNSAddr) || p.IPProto == packet.IGMP + return p.Dst.IP.IsMulticast() || (p.Dst.IP.IsLinkLocalUnicast() && p.Dst.IP != gcpDNSAddr) || p.IPProto == ipproto.IGMP } diff --git a/wgengine/filter/filter_test.go b/wgengine/filter/filter_test.go index 9f98761f6..ca807a5fd 100644 --- a/wgengine/filter/filter_test.go +++ b/wgengine/filter/filter_test.go @@ -7,6 +7,7 @@ package filter import ( "encoding/hex" "fmt" + "reflect" "strconv" "strings" "testing" @@ -16,19 +17,32 @@ import ( "inet.af/netaddr" "tailscale.com/net/packet" "tailscale.com/net/tsaddr" + "tailscale.com/tailcfg" + "tailscale.com/types/ipproto" "tailscale.com/types/logger" ) func newFilter(logf logger.Logf) *Filter { + m := func(srcs []netaddr.IPPrefix, dsts []NetPortRange, protos ...ipproto.Proto) Match { + if protos == nil { + protos = defaultProtos + } + return Match{ + IPProto: protos, + Srcs: srcs, + Dsts: dsts, + } + } matches := []Match{ - {Srcs: nets("8.1.1.1", "8.2.2.2"), Dsts: netports("1.2.3.4:22", "5.6.7.8:23-24")}, - {Srcs: nets("8.1.1.1", "8.2.2.2"), Dsts: netports("5.6.7.8:27-28")}, - {Srcs: nets("2.2.2.2"), Dsts: netports("8.1.1.1:22")}, - {Srcs: nets("0.0.0.0/0"), Dsts: netports("100.122.98.50:*")}, - {Srcs: nets("0.0.0.0/0"), Dsts: netports("0.0.0.0/0:443")}, - {Srcs: nets("153.1.1.1", "153.1.1.2", "153.3.3.3"), Dsts: netports("1.2.3.4:999")}, - {Srcs: nets("::1", "::2"), Dsts: netports("2001::1:22", "2001::2:22")}, - {Srcs: nets("::/0"), Dsts: netports("::/0:443")}, + m(nets("8.1.1.1", "8.2.2.2"), netports("1.2.3.4:22", "5.6.7.8:23-24")), + m(nets("9.1.1.1", "9.2.2.2"), netports("1.2.3.4:22", "5.6.7.8:23-24"), ipproto.SCTP), + m(nets("8.1.1.1", "8.2.2.2"), netports("5.6.7.8:27-28")), + m(nets("2.2.2.2"), netports("8.1.1.1:22")), + m(nets("0.0.0.0/0"), netports("100.122.98.50:*")), + m(nets("0.0.0.0/0"), netports("0.0.0.0/0:443")), + m(nets("153.1.1.1", "153.1.1.2", "153.3.3.3"), netports("1.2.3.4:999")), + m(nets("::1", "::2"), netports("2001::1:22", "2001::2:22")), + m(nets("::/0"), netports("::/0:443")), } // Expects traffic to 100.122.98.50, 1.2.3.4, 5.6.7.8, @@ -52,43 +66,48 @@ func TestFilter(t *testing.T) { } tests := []InOut{ // allow 8.1.1.1 => 1.2.3.4:22 - {Accept, parsed(packet.TCP, "8.1.1.1", "1.2.3.4", 999, 22)}, - {Accept, parsed(packet.ICMPv4, "8.1.1.1", "1.2.3.4", 0, 0)}, - {Drop, parsed(packet.TCP, "8.1.1.1", "1.2.3.4", 0, 0)}, - {Accept, parsed(packet.TCP, "8.1.1.1", "1.2.3.4", 0, 22)}, - {Drop, parsed(packet.TCP, "8.1.1.1", "1.2.3.4", 0, 21)}, + {Accept, parsed(ipproto.TCP, "8.1.1.1", "1.2.3.4", 999, 22)}, + {Accept, parsed(ipproto.ICMPv4, "8.1.1.1", "1.2.3.4", 0, 0)}, + {Drop, parsed(ipproto.TCP, "8.1.1.1", "1.2.3.4", 0, 0)}, + {Accept, parsed(ipproto.TCP, "8.1.1.1", "1.2.3.4", 0, 22)}, + {Drop, parsed(ipproto.TCP, "8.1.1.1", "1.2.3.4", 0, 21)}, // allow 8.2.2.2. => 1.2.3.4:22 - {Accept, parsed(packet.TCP, "8.2.2.2", "1.2.3.4", 0, 22)}, - {Drop, parsed(packet.TCP, "8.2.2.2", "1.2.3.4", 0, 23)}, - {Drop, parsed(packet.TCP, "8.3.3.3", "1.2.3.4", 0, 22)}, + {Accept, parsed(ipproto.TCP, "8.2.2.2", "1.2.3.4", 0, 22)}, + {Drop, parsed(ipproto.TCP, "8.2.2.2", "1.2.3.4", 0, 23)}, + {Drop, parsed(ipproto.TCP, "8.3.3.3", "1.2.3.4", 0, 22)}, // allow 8.1.1.1 => 5.6.7.8:23-24 - {Accept, parsed(packet.TCP, "8.1.1.1", "5.6.7.8", 0, 23)}, - {Accept, parsed(packet.TCP, "8.1.1.1", "5.6.7.8", 0, 24)}, - {Drop, parsed(packet.TCP, "8.1.1.3", "5.6.7.8", 0, 24)}, - {Drop, parsed(packet.TCP, "8.1.1.1", "5.6.7.8", 0, 22)}, + {Accept, parsed(ipproto.TCP, "8.1.1.1", "5.6.7.8", 0, 23)}, + {Accept, parsed(ipproto.TCP, "8.1.1.1", "5.6.7.8", 0, 24)}, + {Drop, parsed(ipproto.TCP, "8.1.1.3", "5.6.7.8", 0, 24)}, + {Drop, parsed(ipproto.TCP, "8.1.1.1", "5.6.7.8", 0, 22)}, // allow * => *:443 - {Accept, parsed(packet.TCP, "17.34.51.68", "8.1.34.51", 0, 443)}, - {Drop, parsed(packet.TCP, "17.34.51.68", "8.1.34.51", 0, 444)}, + {Accept, parsed(ipproto.TCP, "17.34.51.68", "8.1.34.51", 0, 443)}, + {Drop, parsed(ipproto.TCP, "17.34.51.68", "8.1.34.51", 0, 444)}, // allow * => 100.122.98.50:* - {Accept, parsed(packet.TCP, "17.34.51.68", "100.122.98.50", 0, 999)}, - {Accept, parsed(packet.TCP, "17.34.51.68", "100.122.98.50", 0, 0)}, + {Accept, parsed(ipproto.TCP, "17.34.51.68", "100.122.98.50", 0, 999)}, + {Accept, parsed(ipproto.TCP, "17.34.51.68", "100.122.98.50", 0, 0)}, // allow ::1, ::2 => [2001::1]:22 - {Accept, parsed(packet.TCP, "::1", "2001::1", 0, 22)}, - {Accept, parsed(packet.ICMPv6, "::1", "2001::1", 0, 0)}, - {Accept, parsed(packet.TCP, "::2", "2001::1", 0, 22)}, - {Accept, parsed(packet.TCP, "::2", "2001::2", 0, 22)}, - {Drop, parsed(packet.TCP, "::1", "2001::1", 0, 23)}, - {Drop, parsed(packet.TCP, "::1", "2001::3", 0, 22)}, - {Drop, parsed(packet.TCP, "::3", "2001::1", 0, 22)}, + {Accept, parsed(ipproto.TCP, "::1", "2001::1", 0, 22)}, + {Accept, parsed(ipproto.ICMPv6, "::1", "2001::1", 0, 0)}, + {Accept, parsed(ipproto.TCP, "::2", "2001::1", 0, 22)}, + {Accept, parsed(ipproto.TCP, "::2", "2001::2", 0, 22)}, + {Drop, parsed(ipproto.TCP, "::1", "2001::1", 0, 23)}, + {Drop, parsed(ipproto.TCP, "::1", "2001::3", 0, 22)}, + {Drop, parsed(ipproto.TCP, "::3", "2001::1", 0, 22)}, // allow * => *:443 - {Accept, parsed(packet.TCP, "::1", "2001::1", 0, 443)}, - {Drop, parsed(packet.TCP, "::1", "2001::1", 0, 444)}, + {Accept, parsed(ipproto.TCP, "::1", "2001::1", 0, 443)}, + {Drop, parsed(ipproto.TCP, "::1", "2001::1", 0, 444)}, // localNets prefilter - accepted by policy filter, but // unexpected dst IP. - {Drop, parsed(packet.TCP, "8.1.1.1", "16.32.48.64", 0, 443)}, - {Drop, parsed(packet.TCP, "1::", "2602::1", 0, 443)}, + {Drop, parsed(ipproto.TCP, "8.1.1.1", "16.32.48.64", 0, 443)}, + {Drop, parsed(ipproto.TCP, "1::", "2602::1", 0, 443)}, + + // Don't allow protocols not specified by filter + {Drop, parsed(ipproto.SCTP, "8.1.1.1", "1.2.3.4", 999, 22)}, + // But SCTP is allowed for 9.1.1.1 + {Accept, parsed(ipproto.SCTP, "9.1.1.1", "1.2.3.4", 999, 22)}, } for i, test := range tests { aclFunc := acl.runIn4 @@ -98,7 +117,7 @@ func TestFilter(t *testing.T) { if got, why := aclFunc(&test.p); test.want != got { t.Errorf("#%d runIn got=%v want=%v why=%q packet:%v", i, got, test.want, why, test.p) } - if test.p.IPProto == packet.TCP { + if test.p.IPProto == ipproto.TCP { var got Response if test.p.IPVersion == 4 { got = acl.CheckTCP(test.p.Src.IP, test.p.Dst.IP, test.p.Dst.Port) @@ -109,7 +128,7 @@ func TestFilter(t *testing.T) { t.Errorf("#%d CheckTCP got=%v want=%v packet:%v", i, got, test.want, test.p) } // TCP and UDP are treated equivalently in the filter - verify that. - test.p.IPProto = packet.UDP + test.p.IPProto = ipproto.UDP if got, why := aclFunc(&test.p); test.want != got { t.Errorf("#%d runIn (UDP) got=%v want=%v why=%q packet:%v", i, got, test.want, why, test.p) } @@ -123,8 +142,8 @@ func TestUDPState(t *testing.T) { acl := newFilter(t.Logf) flags := LogDrops | LogAccepts - a4 := parsed(packet.UDP, "119.119.119.119", "102.102.102.102", 4242, 4343) - b4 := parsed(packet.UDP, "102.102.102.102", "119.119.119.119", 4343, 4242) + a4 := parsed(ipproto.UDP, "119.119.119.119", "102.102.102.102", 4242, 4343) + b4 := parsed(ipproto.UDP, "102.102.102.102", "119.119.119.119", 4343, 4242) // Unsollicited UDP traffic gets dropped if got := acl.RunIn(&a4, flags); got != Drop { @@ -139,8 +158,8 @@ func TestUDPState(t *testing.T) { t.Fatalf("incoming response packet not accepted, got=%v: %v", got, a4) } - a6 := parsed(packet.UDP, "2001::2", "2001::1", 4242, 4343) - b6 := parsed(packet.UDP, "2001::1", "2001::2", 4343, 4242) + a6 := parsed(ipproto.UDP, "2001::2", "2001::1", 4242, 4343) + b6 := parsed(ipproto.UDP, "2001::1", "2001::2", 4343, 4242) // Unsollicited UDP traffic gets dropped if got := acl.RunIn(&a6, flags); got != Drop { @@ -159,10 +178,10 @@ func TestUDPState(t *testing.T) { func TestNoAllocs(t *testing.T) { acl := newFilter(t.Logf) - tcp4Packet := raw4(packet.TCP, "8.1.1.1", "1.2.3.4", 999, 22, 0) - udp4Packet := raw4(packet.UDP, "8.1.1.1", "1.2.3.4", 999, 22, 0) - tcp6Packet := raw6(packet.TCP, "2001::1", "2001::2", 999, 22, 0) - udp6Packet := raw6(packet.UDP, "2001::1", "2001::2", 999, 22, 0) + tcp4Packet := raw4(ipproto.TCP, "8.1.1.1", "1.2.3.4", 999, 22, 0) + udp4Packet := raw4(ipproto.UDP, "8.1.1.1", "1.2.3.4", 999, 22, 0) + tcp6Packet := raw6(ipproto.TCP, "2001::1", "2001::2", 999, 22, 0) + udp6Packet := raw6(ipproto.UDP, "2001::1", "2001::2", 999, 22, 0) tests := []struct { name string @@ -243,13 +262,13 @@ func TestParseIPSet(t *testing.T) { } func BenchmarkFilter(b *testing.B) { - tcp4Packet := raw4(packet.TCP, "8.1.1.1", "1.2.3.4", 999, 22, 0) - udp4Packet := raw4(packet.UDP, "8.1.1.1", "1.2.3.4", 999, 22, 0) - icmp4Packet := raw4(packet.ICMPv4, "8.1.1.1", "1.2.3.4", 0, 0, 0) + tcp4Packet := raw4(ipproto.TCP, "8.1.1.1", "1.2.3.4", 999, 22, 0) + udp4Packet := raw4(ipproto.UDP, "8.1.1.1", "1.2.3.4", 999, 22, 0) + icmp4Packet := raw4(ipproto.ICMPv4, "8.1.1.1", "1.2.3.4", 0, 0, 0) - tcp6Packet := raw6(packet.TCP, "::1", "2001::1", 999, 22, 0) - udp6Packet := raw6(packet.UDP, "::1", "2001::1", 999, 22, 0) - icmp6Packet := raw6(packet.ICMPv6, "::1", "2001::1", 0, 0, 0) + tcp6Packet := raw6(ipproto.TCP, "::1", "2001::1", 999, 22, 0) + udp6Packet := raw6(ipproto.UDP, "::1", "2001::1", 999, 22, 0) + icmp6Packet := raw6(ipproto.ICMPv6, "::1", "2001::1", 0, 0, 0) benches := []struct { name string @@ -296,11 +315,11 @@ func TestPreFilter(t *testing.T) { }{ {"empty", Accept, []byte{}}, {"short", Drop, []byte("short")}, - {"junk", Drop, raw4default(packet.Unknown, 10)}, - {"fragment", Accept, raw4default(packet.Fragment, 40)}, - {"tcp", noVerdict, raw4default(packet.TCP, 0)}, - {"udp", noVerdict, raw4default(packet.UDP, 0)}, - {"icmp", noVerdict, raw4default(packet.ICMPv4, 0)}, + {"junk", Drop, raw4default(ipproto.Unknown, 10)}, + {"fragment", Accept, raw4default(ipproto.Fragment, 40)}, + {"tcp", noVerdict, raw4default(ipproto.TCP, 0)}, + {"udp", noVerdict, raw4default(ipproto.UDP, 0)}, + {"icmp", noVerdict, raw4default(ipproto.ICMPv4, 0)}, } f := NewAllowNone(t.Logf, &netaddr.IPSet{}) for _, testPacket := range packets { @@ -322,7 +341,7 @@ func TestOmitDropLogging(t *testing.T) { }{ { name: "v4_tcp_out", - pkt: &packet.Parsed{IPVersion: 4, IPProto: packet.TCP}, + pkt: &packet.Parsed{IPVersion: 4, IPProto: ipproto.TCP}, dir: out, want: false, }, @@ -420,73 +439,73 @@ func TestLoggingPrivacy(t *testing.T) { }{ { name: "ts_to_ts_v4_out", - pkt: &packet.Parsed{IPVersion: 4, IPProto: packet.TCP, Src: ts4, Dst: ts4}, + pkt: &packet.Parsed{IPVersion: 4, IPProto: ipproto.TCP, Src: ts4, Dst: ts4}, dir: out, logged: true, }, { name: "ts_to_internet_v4_out", - pkt: &packet.Parsed{IPVersion: 4, IPProto: packet.TCP, Src: ts4, Dst: internet4}, + pkt: &packet.Parsed{IPVersion: 4, IPProto: ipproto.TCP, Src: ts4, Dst: internet4}, dir: out, logged: false, }, { name: "internet_to_ts_v4_out", - pkt: &packet.Parsed{IPVersion: 4, IPProto: packet.TCP, Src: internet4, Dst: ts4}, + pkt: &packet.Parsed{IPVersion: 4, IPProto: ipproto.TCP, Src: internet4, Dst: ts4}, dir: out, logged: false, }, { name: "ts_to_ts_v4_in", - pkt: &packet.Parsed{IPVersion: 4, IPProto: packet.TCP, Src: ts4, Dst: ts4}, + pkt: &packet.Parsed{IPVersion: 4, IPProto: ipproto.TCP, Src: ts4, Dst: ts4}, dir: in, logged: true, }, { name: "ts_to_internet_v4_in", - pkt: &packet.Parsed{IPVersion: 4, IPProto: packet.TCP, Src: ts4, Dst: internet4}, + pkt: &packet.Parsed{IPVersion: 4, IPProto: ipproto.TCP, Src: ts4, Dst: internet4}, dir: in, logged: false, }, { name: "internet_to_ts_v4_in", - pkt: &packet.Parsed{IPVersion: 4, IPProto: packet.TCP, Src: internet4, Dst: ts4}, + pkt: &packet.Parsed{IPVersion: 4, IPProto: ipproto.TCP, Src: internet4, Dst: ts4}, dir: in, logged: false, }, { name: "ts_to_ts_v6_out", - pkt: &packet.Parsed{IPVersion: 6, IPProto: packet.TCP, Src: ts6, Dst: ts6}, + pkt: &packet.Parsed{IPVersion: 6, IPProto: ipproto.TCP, Src: ts6, Dst: ts6}, dir: out, logged: true, }, { name: "ts_to_internet_v6_out", - pkt: &packet.Parsed{IPVersion: 6, IPProto: packet.TCP, Src: ts6, Dst: internet6}, + pkt: &packet.Parsed{IPVersion: 6, IPProto: ipproto.TCP, Src: ts6, Dst: internet6}, dir: out, logged: false, }, { name: "internet_to_ts_v6_out", - pkt: &packet.Parsed{IPVersion: 6, IPProto: packet.TCP, Src: internet6, Dst: ts6}, + pkt: &packet.Parsed{IPVersion: 6, IPProto: ipproto.TCP, Src: internet6, Dst: ts6}, dir: out, logged: false, }, { name: "ts_to_ts_v6_in", - pkt: &packet.Parsed{IPVersion: 6, IPProto: packet.TCP, Src: ts6, Dst: ts6}, + pkt: &packet.Parsed{IPVersion: 6, IPProto: ipproto.TCP, Src: ts6, Dst: ts6}, dir: in, logged: true, }, { name: "ts_to_internet_v6_in", - pkt: &packet.Parsed{IPVersion: 6, IPProto: packet.TCP, Src: ts6, Dst: internet6}, + pkt: &packet.Parsed{IPVersion: 6, IPProto: ipproto.TCP, Src: ts6, Dst: internet6}, dir: in, logged: false, }, { name: "internet_to_ts_v6_in", - pkt: &packet.Parsed{IPVersion: 6, IPProto: packet.TCP, Src: internet6, Dst: ts6}, + pkt: &packet.Parsed{IPVersion: 6, IPProto: ipproto.TCP, Src: internet6, Dst: ts6}, dir: in, logged: false, }, @@ -520,7 +539,7 @@ func mustIP(s string) netaddr.IP { return ip } -func parsed(proto packet.IPProto, src, dst string, sport, dport uint16) packet.Parsed { +func parsed(proto ipproto.Proto, src, dst string, sport, dport uint16) packet.Parsed { sip, dip := mustIP(src), mustIP(dst) var ret packet.Parsed @@ -541,7 +560,7 @@ func parsed(proto packet.IPProto, src, dst string, sport, dport uint16) packet.P return ret } -func raw6(proto packet.IPProto, src, dst string, sport, dport uint16, trimLen int) []byte { +func raw6(proto ipproto.Proto, src, dst string, sport, dport uint16, trimLen int) []byte { u := packet.UDP6Header{ IP6Header: packet.IP6Header{ Src: mustIP(src), @@ -570,7 +589,7 @@ func raw6(proto packet.IPProto, src, dst string, sport, dport uint16, trimLen in } } -func raw4(proto packet.IPProto, src, dst string, sport, dport uint16, trimLength int) []byte { +func raw4(proto ipproto.Proto, src, dst string, sport, dport uint16, trimLength int) []byte { u := packet.UDP4Header{ IP4Header: packet.IP4Header{ Src: mustIP(src), @@ -588,7 +607,7 @@ func raw4(proto packet.IPProto, src, dst string, sport, dport uint16, trimLength // UDP marshaling clobbers IPProto, so override it here. switch proto { - case packet.Unknown, packet.Fragment: + case ipproto.Unknown, ipproto.Fragment: default: u.IP4Header.IPProto = proto } @@ -596,7 +615,7 @@ func raw4(proto packet.IPProto, src, dst string, sport, dport uint16, trimLength panic(err) } - if proto == packet.Fragment { + if proto == ipproto.Fragment { // Set some fragment offset. This makes the IP // checksum wrong, but we don't validate the checksum // when parsing. @@ -610,7 +629,7 @@ func raw4(proto packet.IPProto, src, dst string, sport, dport uint16, trimLength } } -func raw4default(proto packet.IPProto, trimLength int) []byte { +func raw4default(proto ipproto.Proto, trimLength int) []byte { return raw4(proto, "8.8.8.8", "8.8.8.8", 53, 53, trimLength) } @@ -707,3 +726,91 @@ func netports(netPorts ...string) (ret []NetPortRange) { } return ret } + +func TestMatchesFromFilterRules(t *testing.T) { + tests := []struct { + name string + in []tailcfg.FilterRule + want []Match + }{ + { + name: "empty", + want: []Match{}, + }, + { + name: "implicit_protos", + in: []tailcfg.FilterRule{ + { + SrcIPs: []string{"100.64.1.1"}, + DstPorts: []tailcfg.NetPortRange{{ + IP: "*", + Ports: tailcfg.PortRange{First: 22, Last: 22}, + }}, + }, + }, + want: []Match{ + { + IPProto: []ipproto.Proto{ + ipproto.TCP, + ipproto.UDP, + ipproto.ICMPv4, + ipproto.ICMPv6, + }, + Dsts: []NetPortRange{ + { + Net: netaddr.MustParseIPPrefix("0.0.0.0/0"), + Ports: PortRange{22, 22}, + }, + { + Net: netaddr.MustParseIPPrefix("::0/0"), + Ports: PortRange{22, 22}, + }, + }, + Srcs: []netaddr.IPPrefix{ + netaddr.MustParseIPPrefix("100.64.1.1/32"), + }, + }, + }, + }, + { + name: "explicit_protos", + in: []tailcfg.FilterRule{ + { + IPProto: []int{int(ipproto.TCP)}, + SrcIPs: []string{"100.64.1.1"}, + DstPorts: []tailcfg.NetPortRange{{ + IP: "1.2.0.0/16", + Ports: tailcfg.PortRange{First: 22, Last: 22}, + }}, + }, + }, + want: []Match{ + { + IPProto: []ipproto.Proto{ + ipproto.TCP, + }, + Dsts: []NetPortRange{ + { + Net: netaddr.MustParseIPPrefix("1.2.0.0/16"), + Ports: PortRange{22, 22}, + }, + }, + Srcs: []netaddr.IPPrefix{ + netaddr.MustParseIPPrefix("100.64.1.1/32"), + }, + }, + }, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got, err := MatchesFromFilterRules(tt.in) + if err != nil { + t.Fatal(err) + } + if !reflect.DeepEqual(got, tt.want) { + t.Errorf("wrong\n got: %v\nwant: %v\n", got, tt.want) + } + }) + } +} diff --git a/wgengine/filter/match.go b/wgengine/filter/match.go index c30c37552..a1b356113 100644 --- a/wgengine/filter/match.go +++ b/wgengine/filter/match.go @@ -10,6 +10,7 @@ import ( "inet.af/netaddr" "tailscale.com/net/packet" + "tailscale.com/types/ipproto" ) //go:generate go run tailscale.com/cmd/cloner --type=Match --output=match_clone.go @@ -47,11 +48,13 @@ func (npr NetPortRange) String() string { // Match matches packets from any IP address in Srcs to any ip:port in // Dsts. type Match struct { - Dsts []NetPortRange - Srcs []netaddr.IPPrefix + IPProto []ipproto.Proto // required set (no default value at this layer) + Dsts []NetPortRange + Srcs []netaddr.IPPrefix } func (m Match) String() string { + // TODO(bradfitz): use strings.Builder, add String tests srcs := []string{} for _, src := range m.Srcs { srcs = append(srcs, src.String()) @@ -72,13 +75,16 @@ func (m Match) String() string { } else { ds = "[" + strings.Join(dsts, ",") + "]" } - return fmt.Sprintf("%v=>%v", ss, ds) + return fmt.Sprintf("%v%v=>%v", m.IPProto, ss, ds) } type matches []Match func (ms matches) match(q *packet.Parsed) bool { for _, m := range ms { + if !protoInList(q.IPProto, m.IPProto) { + continue + } if !ipInList(q.Src.IP, m.Srcs) { continue } @@ -117,3 +123,12 @@ func ipInList(ip netaddr.IP, netlist []netaddr.IPPrefix) bool { } return false } + +func protoInList(proto ipproto.Proto, valid []ipproto.Proto) bool { + for _, v := range valid { + if proto == v { + return true + } + } + return false +} diff --git a/wgengine/filter/match_clone.go b/wgengine/filter/match_clone.go index 571664bd5..04874ddec 100644 --- a/wgengine/filter/match_clone.go +++ b/wgengine/filter/match_clone.go @@ -8,6 +8,7 @@ package filter import ( "inet.af/netaddr" + "tailscale.com/types/ipproto" ) // Clone makes a deep copy of Match. @@ -18,6 +19,7 @@ func (src *Match) Clone() *Match { } dst := new(Match) *dst = *src + dst.IPProto = append(src.IPProto[:0:0], src.IPProto...) dst.Dsts = append(src.Dsts[:0:0], src.Dsts...) dst.Srcs = append(src.Srcs[:0:0], src.Srcs...) return dst @@ -26,6 +28,7 @@ func (src *Match) Clone() *Match { // A compilation failure here means this code must be regenerated, with command: // tailscale.com/cmd/cloner -type Match var _MatchNeedsRegeneration = Match(struct { - Dsts []NetPortRange - Srcs []netaddr.IPPrefix + IPProto []ipproto.Proto + Dsts []NetPortRange + Srcs []netaddr.IPPrefix }{}) diff --git a/wgengine/filter/tailcfg.go b/wgengine/filter/tailcfg.go index 2f20cdb61..1338a75b4 100644 --- a/wgengine/filter/tailcfg.go +++ b/wgengine/filter/tailcfg.go @@ -10,8 +10,16 @@ import ( "inet.af/netaddr" "tailscale.com/tailcfg" + "tailscale.com/types/ipproto" ) +var defaultProtos = []ipproto.Proto{ + ipproto.TCP, + ipproto.UDP, + ipproto.ICMPv4, + ipproto.ICMPv6, +} + // MatchesFromFilterRules converts tailcfg FilterRules into Matches. // If an error is returned, the Matches result is still valid, // containing the rules that were successfully converted. @@ -22,6 +30,17 @@ func MatchesFromFilterRules(pf []tailcfg.FilterRule) ([]Match, error) { for _, r := range pf { m := Match{} + if len(r.IPProto) == 0 { + m.IPProto = append([]ipproto.Proto(nil), defaultProtos...) + } else { + m.IPProto = make([]ipproto.Proto, 0, len(r.IPProto)) + for _, n := range r.IPProto { + if n >= 0 && n <= 0xff { + m.IPProto = append(m.IPProto, ipproto.Proto(n)) + } + } + } + for i, s := range r.SrcIPs { var bits *int if len(r.SrcBits) > i { diff --git a/wgengine/ifstatus_noop.go b/wgengine/ifstatus_noop.go deleted file mode 100644 index 7564d67ec..000000000 --- a/wgengine/ifstatus_noop.go +++ /dev/null @@ -1,19 +0,0 @@ -// Copyright (c) 2021 Tailscale Inc & AUTHORS All rights reserved. -// Use of this source code is governed by a BSD-style -// license that can be found in the LICENSE file. - -// +build !windows - -package wgengine - -import ( - "time" - - "github.com/tailscale/wireguard-go/tun" - "tailscale.com/types/logger" -) - -// Dummy implementation that does nothing. -func waitInterfaceUp(iface tun.Device, timeout time.Duration, logf logger.Logf) error { - return nil -} diff --git a/wgengine/ifstatus_windows.go b/wgengine/ifstatus_windows.go deleted file mode 100644 index 840b6cf39..000000000 --- a/wgengine/ifstatus_windows.go +++ /dev/null @@ -1,111 +0,0 @@ -// Copyright (c) 2021 Tailscale Inc & AUTHORS All rights reserved. -// Use of this source code is governed by a BSD-style -// license that can be found in the LICENSE file. - -package wgengine - -import ( - "fmt" - "sync" - "time" - - "github.com/tailscale/wireguard-go/tun" - "golang.zx2c4.com/wireguard/windows/tunnel/winipcfg" - "tailscale.com/types/logger" -) - -// ifaceWatcher waits for an interface to be up. -type ifaceWatcher struct { - logf logger.Logf - luid winipcfg.LUID - - mu sync.Mutex // guards following - done bool - sig chan bool -} - -// callback is the callback we register with Windows to call when IP interface changes. -func (iw *ifaceWatcher) callback(notificationType winipcfg.MibNotificationType, iface *winipcfg.MibIPInterfaceRow) { - // Probably should check only when MibParameterNotification, but just in case included MibAddInstance also. - if notificationType == winipcfg.MibParameterNotification || notificationType == winipcfg.MibAddInstance { - // Out of paranoia, start a goroutine to finish our work, to return to Windows out of this callback. - go iw.isUp() - } -} - -func (iw *ifaceWatcher) isUp() bool { - iw.mu.Lock() - defer iw.mu.Unlock() - - if iw.done { - // We already know that it's up - return true - } - - if iw.getOperStatus() != winipcfg.IfOperStatusUp { - return false - } - - iw.done = true - iw.sig <- true - return true -} - -func (iw *ifaceWatcher) getOperStatus() winipcfg.IfOperStatus { - ifc, err := iw.luid.Interface() - if err != nil { - iw.logf("iw.luid.Interface error: %v", err) - return 0 - } - return ifc.OperStatus -} - -func waitInterfaceUp(iface tun.Device, timeout time.Duration, logf logger.Logf) error { - iw := &ifaceWatcher{ - luid: winipcfg.LUID(iface.(*tun.NativeTun).LUID()), - logf: logger.WithPrefix(logf, "waitInterfaceUp: "), - } - - // Just in case check the status first - if iw.getOperStatus() == winipcfg.IfOperStatusUp { - iw.logf("TUN interface already up; no need to wait") - return nil - } - - iw.sig = make(chan bool, 1) - cb, err := winipcfg.RegisterInterfaceChangeCallback(iw.callback) - if err != nil { - iw.logf("RegisterInterfaceChangeCallback error: %v", err) - return err - } - defer cb.Unregister() - - t0 := time.Now() - expires := t0.Add(timeout) - ticker := time.NewTicker(10 * time.Second) - defer ticker.Stop() - - for { - iw.logf("waiting for TUN interface to come up...") - - select { - case <-iw.sig: - iw.logf("TUN interface is up after %v", time.Since(t0)) - return nil - case <-ticker.C: - break - } - - if iw.isUp() { - // Very unlikely to happen - either NotifyIpInterfaceChange doesn't work - // or it came up in the same moment as tick. Indicate this in the log message. - iw.logf("TUN interface is up after %v (on poll, without notification)", time.Since(t0)) - return nil - } - - if expires.Before(time.Now()) { - iw.logf("timeout waiting %v for TUN interface to come up", timeout) - return fmt.Errorf("timeout waiting for TUN interface to come up") - } - } -} diff --git a/wgengine/magicsock/magicsock.go b/wgengine/magicsock/magicsock.go index 917e55828..7a437f9e8 100644 --- a/wgengine/magicsock/magicsock.go +++ b/wgengine/magicsock/magicsock.go @@ -630,7 +630,7 @@ func (c *Conn) setEndpoints(endpoints []string, reasons map[string]string) (chan delete(c.onEndpointRefreshed, de) } - if stringsEqual(endpoints, c.lastEndpoints) { + if stringSetsEqual(endpoints, c.lastEndpoints) { return false } c.lastEndpoints = endpoints @@ -814,46 +814,6 @@ func (c *Conn) SetNetInfoCallback(fn func(*tailcfg.NetInfo)) { } } -// peerForIP returns the Node in nm that's responsible for -// handling the given IP address. -func peerForIP(nm *netmap.NetworkMap, ip netaddr.IP) (n *tailcfg.Node, ok bool) { - if nm == nil { - return nil, false - } - // Check for exact matches before looking for subnet matches. - for _, p := range nm.Peers { - for _, a := range p.Addresses { - if a.IP == ip { - return p, true - } - } - } - - // TODO(bradfitz): this is O(n peers). Add ART to netaddr? - var best netaddr.IPPrefix - for _, p := range nm.Peers { - for _, cidr := range p.AllowedIPs { - if cidr.Contains(ip) { - if best.IsZero() || cidr.Bits > best.Bits { - n = p - best = cidr - } - } - } - } - return n, n != nil -} - -// PeerForIP returns the node that ip should route to. -func (c *Conn) PeerForIP(ip netaddr.IP) (n *tailcfg.Node, ok bool) { - c.mu.Lock() - defer c.mu.Unlock() - if c.netMap == nil { - return - } - return peerForIP(c.netMap, ip) -} - // LastRecvActivityOfDisco returns the time we last got traffic from // this endpoint (updated every ~10 seconds). func (c *Conn) LastRecvActivityOfDisco(dk tailcfg.DiscoKey) time.Time { @@ -871,21 +831,14 @@ func (c *Conn) LastRecvActivityOfDisco(dk tailcfg.DiscoKey) time.Time { } // Ping handles a "tailscale ping" CLI query. -func (c *Conn) Ping(ip netaddr.IP, cb func(*ipnstate.PingResult)) { +func (c *Conn) Ping(peer *tailcfg.Node, res *ipnstate.PingResult, cb func(*ipnstate.PingResult)) { c.mu.Lock() defer c.mu.Unlock() - res := &ipnstate.PingResult{IP: ip.String()} if c.privateKey.IsZero() { res.Err = "local tailscaled stopped" cb(res) return } - peer, ok := peerForIP(c.netMap, ip) - if !ok { - res.Err = "no matching peer" - cb(res) - return - } if len(peer.Addresses) > 0 { res.NodeIP = peer.Addresses[0].IP.String() } @@ -1111,12 +1064,32 @@ func (c *Conn) determineEndpoints(ctx context.Context) (ipPorts []string, reason return eps, already, nil } -func stringsEqual(x, y []string) bool { - if len(x) != len(y) { - return false +// stringSetsEqual reports whether x and y represent the same set of +// strings. The order doesn't matter. +// +// It does not mutate the slices. +func stringSetsEqual(x, y []string) bool { + if len(x) == len(y) { + orderMatches := true + for i := range x { + if x[i] != y[i] { + orderMatches = false + break + } + } + if orderMatches { + return true + } } - for i := range x { - if x[i] != y[i] { + m := map[string]int{} + for _, v := range x { + m[v] |= 1 + } + for _, v := range y { + m[v] |= 2 + } + for _, n := range m { + if n != 3 { return false } } @@ -2034,7 +2007,7 @@ func (c *Conn) handleDiscoMessage(msg []byte, src netaddr.IPPort) (isDiscoMsg bo return } if de != nil { - c.logf("magicsock: disco: %v<-%v (%v, %v) got call-me-maybe, %d endpoints", + c.logf("[v1] magicsock: disco: %v<-%v (%v, %v) got call-me-maybe, %d endpoints", c.discoShort, de.discoShort, de.publicKey.ShortString(), derpStr(src.String()), len(dm.MyNumber)) @@ -3012,25 +2985,7 @@ func (c *Conn) UpdateStatus(sb *ipnstate.StatusBuilder) { c.mu.Lock() defer c.mu.Unlock() - ss := &ipnstate.PeerStatus{ - PublicKey: c.privateKey.Public(), - Addrs: c.lastEndpoints, - OS: version.OS(), - } - if c.netMap != nil { - ss.HostName = c.netMap.Hostinfo.Hostname - ss.DNSName = c.netMap.Name - ss.UserID = c.netMap.User - } else { - ss.HostName, _ = os.Hostname() - } - if c.derpMap != nil { - derpRegion, ok := c.derpMap.Regions[c.myDerp] - if ok { - ss.Relay = derpRegion.RegionCode - } - } - + var tailAddr string if c.netMap != nil { for _, addr := range c.netMap.Addresses { if !addr.IsSingleIP() { @@ -3041,11 +2996,30 @@ func (c *Conn) UpdateStatus(sb *ipnstate.StatusBuilder) { // readability of `tailscale status`, make it the IPv4 // address. if addr.IP.Is4() { - ss.TailAddr = addr.IP.String() + tailAddr = addr.IP.String() } } } - sb.SetSelfStatus(ss) + + sb.MutateSelfStatus(func(ss *ipnstate.PeerStatus) { + ss.PublicKey = c.privateKey.Public() + ss.Addrs = c.lastEndpoints + ss.OS = version.OS() + if c.netMap != nil { + ss.HostName = c.netMap.Hostinfo.Hostname + ss.DNSName = c.netMap.Name + ss.UserID = c.netMap.User + } else { + ss.HostName, _ = os.Hostname() + } + if c.derpMap != nil { + derpRegion, ok := c.derpMap.Regions[c.myDerp] + if ok { + ss.Relay = derpRegion.RegionCode + } + } + ss.TailAddr = tailAddr + }) for dk, n := range c.nodeOfDisco { ps := &ipnstate.PeerStatus{InMagicSock: true} @@ -3106,10 +3080,9 @@ type discoEndpoint struct { lastFullPing time.Time // last time we pinged all endpoints derpAddr netaddr.IPPort // fallback/bootstrap path, if non-zero (non-zero for well-behaved clients) - bestAddr netaddr.IPPort // best non-DERP path; zero if none - bestAddrLatency time.Duration - bestAddrAt time.Time // time best address re-confirmed - trustBestAddrUntil time.Time // time when bestAddr expires + bestAddr addrLatency // best non-DERP path; zero if none + bestAddrAt time.Time // time best address re-confirmed + trustBestAddrUntil time.Time // time when bestAddr expires sentPing map[stun.TxID]sentPing endpointState map[netaddr.IPPort]*endpointState isCallMeMaybeEP map[netaddr.IPPort]bool @@ -3214,8 +3187,8 @@ func (st *endpointState) shouldDeleteLocked() bool { func (de *discoEndpoint) deleteEndpointLocked(ep netaddr.IPPort) { delete(de.endpointState, ep) - if de.bestAddr == ep { - de.bestAddr = netaddr.IPPort{} + if de.bestAddr.IPPort == ep { + de.bestAddr = addrLatency{} } } @@ -3283,7 +3256,7 @@ func (de *discoEndpoint) DstToBytes() []byte { return packIPPort(de.fakeWGAddr) // // de.mu must be held. func (de *discoEndpoint) addrForSendLocked(now time.Time) (udpAddr, derpAddr netaddr.IPPort) { - udpAddr = de.bestAddr + udpAddr = de.bestAddr.IPPort if udpAddr.IsZero() || now.After(de.trustBestAddrUntil) { // We had a bestAddr but it expired so send both to it // and DERP. @@ -3336,7 +3309,7 @@ func (de *discoEndpoint) wantFullPingLocked(now time.Time) bool { if now.After(de.trustBestAddrUntil) { return true } - if de.bestAddrLatency <= goodEnoughLatency { + if de.bestAddr.latency <= goodEnoughLatency { return false } if now.Sub(de.lastFullPing) >= upgradeInterval { @@ -3589,7 +3562,7 @@ func (de *discoEndpoint) addCandidateEndpoint(ep netaddr.IPPort) { } // Newly discovered endpoint. Exciting! - de.c.logf("magicsock: disco: adding %v as candidate endpoint for %v (%s)", ep, de.discoShort, de.publicKey.ShortString()) + de.c.logf("[v1] magicsock: disco: adding %v as candidate endpoint for %v (%s)", ep, de.discoShort, de.publicKey.ShortString()) de.endpointState[ep] = &endpointState{ lastGotPing: time.Now(), } @@ -3602,7 +3575,7 @@ func (de *discoEndpoint) addCandidateEndpoint(ep netaddr.IPPort) { } } size2 := len(de.endpointState) - de.c.logf("magicsock: disco: addCandidateEndpoint pruned %v candidate set from %v to %v entries", size, size2) + de.c.logf("[v1] magicsock: disco: addCandidateEndpoint pruned %v candidate set from %v to %v entries", size, size2) } } @@ -3668,20 +3641,50 @@ func (de *discoEndpoint) handlePongConnLocked(m *disco.Pong, src netaddr.IPPort) // Promote this pong response to our current best address if it's lower latency. // TODO(bradfitz): decide how latency vs. preference order affects decision if !isDerp { - if de.bestAddr.IsZero() || latency < de.bestAddrLatency { - if de.bestAddr != sp.to { - de.c.logf("magicsock: disco: node %v %v now using %v", de.publicKey.ShortString(), de.discoShort, sp.to) - de.bestAddr = sp.to - } + thisPong := addrLatency{sp.to, latency} + if betterAddr(thisPong, de.bestAddr) { + de.c.logf("magicsock: disco: node %v %v now using %v", de.publicKey.ShortString(), de.discoShort, sp.to) + de.bestAddr = thisPong } - if de.bestAddr == sp.to { - de.bestAddrLatency = latency + if de.bestAddr.IPPort == thisPong.IPPort { + de.bestAddr.latency = latency de.bestAddrAt = now de.trustBestAddrUntil = now.Add(trustUDPAddrDuration) } } } +// addrLatency is an IPPort with an associated latency. +type addrLatency struct { + netaddr.IPPort + latency time.Duration +} + +// betterAddr reports whether a is a better addr to use than b. +func betterAddr(a, b addrLatency) bool { + if a.IPPort == b.IPPort { + return false + } + if b.IsZero() { + return true + } + if a.IsZero() { + return false + } + if a.IP.Is6() && b.IP.Is4() { + // Prefer IPv6 for being a bit more robust, as long as + // the latencies are roughly equivalent. + if a.latency/10*9 < b.latency { + return true + } + } else if a.IP.Is4() && b.IP.Is6() { + if betterAddr(b, a) { + return false + } + } + return a.latency < b.latency +} + // discoEndpoint.mu must be held. func (st *endpointState) addPongReplyLocked(r pongReply) { if n := len(st.recentPongs); n < pongHistoryCount { @@ -3729,7 +3732,7 @@ func (de *discoEndpoint) handleCallMeMaybe(m *disco.CallMeMaybe) { } } if len(newEPs) > 0 { - de.c.logf("magicsock: disco: call-me-maybe from %v %v added new endpoints: %v", + de.c.logf("[v1] magicsock: disco: call-me-maybe from %v %v added new endpoints: %v", de.publicKey.ShortString(), de.discoShort, logger.ArgWriter(func(w *bufio.Writer) { for i, ep := range newEPs { @@ -3788,8 +3791,7 @@ func (de *discoEndpoint) stopAndReset() { // state isn't a mix of before & after two sessions. de.lastSend = time.Time{} de.lastFullPing = time.Time{} - de.bestAddr = netaddr.IPPort{} - de.bestAddrLatency = 0 + de.bestAddr = addrLatency{} de.bestAddrAt = time.Time{} de.trustBestAddrUntil = time.Time{} for _, es := range de.endpointState { diff --git a/wgengine/magicsock/magicsock_test.go b/wgengine/magicsock/magicsock_test.go index 8e64a2696..2e7e29635 100644 --- a/wgengine/magicsock/magicsock_test.go +++ b/wgengine/magicsock/magicsock_test.go @@ -37,6 +37,7 @@ import ( "tailscale.com/derp/derpmap" "tailscale.com/ipn/ipnstate" "tailscale.com/net/stun/stuntest" + "tailscale.com/net/tstun" "tailscale.com/tailcfg" "tailscale.com/tstest" "tailscale.com/tstest/natlab" @@ -47,7 +48,6 @@ import ( "tailscale.com/types/wgkey" "tailscale.com/util/cibuild" "tailscale.com/wgengine/filter" - "tailscale.com/wgengine/tstun" "tailscale.com/wgengine/wgcfg" "tailscale.com/wgengine/wgcfg/nmcfg" "tailscale.com/wgengine/wglog" @@ -130,7 +130,7 @@ type magicStack struct { epCh chan []string // endpoint updates produced by this peer conn *Conn // the magicsock itself tun *tuntest.ChannelTUN // TUN device to send/receive packets - tsTun *tstun.TUN // wrapped tun that implements filtering and wgengine hooks + tsTun *tstun.Wrapper // wrapped tun that implements filtering and wgengine hooks dev *device.Device // the wireguard-go Device that connects the previous things wgLogger *wglog.Logger // wireguard-go log wrapper } @@ -166,16 +166,16 @@ func newMagicStack(t testing.TB, logf logger.Logf, l nettype.PacketListener, der } tun := tuntest.NewChannelTUN() - tsTun := tstun.WrapTUN(logf, tun.TUN()) + tsTun := tstun.Wrap(logf, tun.TUN()) tsTun.SetFilter(filter.NewAllowAllForTest(logf)) wgLogger := wglog.NewLogger(logf) - dev := device.NewDevice(tsTun, &device.DeviceOptions{ - Logger: wgLogger.DeviceLogger, + opts := &device.DeviceOptions{ CreateEndpoint: conn.CreateEndpoint, CreateBind: conn.CreateBind, SkipBindUpdate: true, - }) + } + dev := device.NewDevice(tsTun, wgLogger.DeviceLogger, opts) dev.Up() // Wait for magicsock to connect up to DERP. @@ -522,12 +522,13 @@ func TestDeviceStartStop(t *testing.T) { defer conn.Close() tun := tuntest.NewChannelTUN() - dev := device.NewDevice(tun.TUN(), &device.DeviceOptions{ - Logger: wglog.NewLogger(t.Logf).DeviceLogger, + wgLogger := wglog.NewLogger(t.Logf) + opts := &device.DeviceOptions{ CreateEndpoint: conn.CreateEndpoint, CreateBind: conn.CreateBind, SkipBindUpdate: true, - }) + } + dev := device.NewDevice(tun.TUN(), wgLogger.DeviceLogger, opts) dev.Up() dev.Close() } @@ -1431,7 +1432,7 @@ func TestDerpReceiveFromIPv4(t *testing.T) { t.Fatal(err) } defer sendConn.Close() - nodeKey, _ := addTestEndpoint(conn, sendConn) + nodeKey, _ := addTestEndpoint(t, conn, sendConn) var sends int = 250e3 // takes about a second if testing.Short() { @@ -1509,7 +1510,7 @@ func TestDerpReceiveFromIPv4(t *testing.T) { // addTestEndpoint sets conn's network map to a single peer expected // to receive packets from sendConn (or DERP), and returns that peer's // nodekey and discokey. -func addTestEndpoint(conn *Conn, sendConn net.PacketConn) (tailcfg.NodeKey, tailcfg.DiscoKey) { +func addTestEndpoint(tb testing.TB, conn *Conn, sendConn net.PacketConn) (tailcfg.NodeKey, tailcfg.DiscoKey) { // Give conn just enough state that it'll recognize sendConn as a // valid peer and not fall through to the legacy magicsock // codepath. @@ -1525,7 +1526,10 @@ func addTestEndpoint(conn *Conn, sendConn net.PacketConn) (tailcfg.NodeKey, tail }, }) conn.SetPrivateKey(wgkey.Private{0: 1}) - conn.CreateEndpoint([32]byte(nodeKey), "0000000000000000000000000000000000000000000000000000000000000001.disco.tailscale:12345") + _, err := conn.CreateEndpoint([32]byte(nodeKey), "0000000000000000000000000000000000000000000000000000000000000001.disco.tailscale:12345") + if err != nil { + tb.Fatal(err) + } conn.addValidDiscoPathForTest(discoKey, netaddr.MustParseIPPort(sendConn.LocalAddr().String())) return nodeKey, discoKey } @@ -1541,7 +1545,7 @@ func setUpReceiveFrom(tb testing.TB) (roundTrip func()) { } tb.Cleanup(func() { sendConn.Close() }) - addTestEndpoint(conn, sendConn) + addTestEndpoint(tb, conn, sendConn) var dstAddr net.Addr = conn.pconn4.LocalAddr() sendBuf := make([]byte, 1<<10) @@ -1793,3 +1797,114 @@ func TestRebindStress(t *testing.T) { t.Fatalf("Got ReceiveIPv4 error: %v (is closed = %v). Log:\n%s", err, errors.Is(err, net.ErrClosed), logBuf.Bytes()) } } + +func TestStringSetsEqual(t *testing.T) { + s := func(nn ...int) (ret []string) { + for _, n := range nn { + ret = append(ret, strconv.Itoa(n)) + } + return + } + tests := []struct { + a, b []string + want bool + }{ + { + want: true, + }, + { + a: s(1, 2, 3), + b: s(1, 2, 3), + want: true, + }, + { + a: s(1, 2), + b: s(2, 1), + want: true, + }, + { + a: s(1, 2), + b: s(2, 1, 1), + want: true, + }, + { + a: s(1, 2, 2), + b: s(2, 1), + want: true, + }, + { + a: s(1, 2, 2), + b: s(2, 1, 1), + want: true, + }, + { + a: s(1, 2, 2, 3), + b: s(2, 1, 1), + want: false, + }, + { + a: s(1, 2, 2), + b: s(2, 1, 1, 3), + want: false, + }, + } + for _, tt := range tests { + if got := stringSetsEqual(tt.a, tt.b); got != tt.want { + t.Errorf("%q vs %q = %v; want %v", tt.a, tt.b, got, tt.want) + } + } + +} + +func TestBetterAddr(t *testing.T) { + const ms = time.Millisecond + al := func(ipps string, d time.Duration) addrLatency { + return addrLatency{netaddr.MustParseIPPort(ipps), d} + } + zero := addrLatency{} + tests := []struct { + a, b addrLatency + want bool + }{ + {a: zero, b: zero, want: false}, + {a: al("10.0.0.2:123", 5*ms), b: zero, want: true}, + {a: zero, b: al("10.0.0.2:123", 5*ms), want: false}, + {a: al("10.0.0.2:123", 5*ms), b: al("1.2.3.4:555", 6*ms), want: true}, + {a: al("10.0.0.2:123", 5*ms), b: al("10.0.0.2:123", 10*ms), want: false}, // same IPPort + + // Prefer IPv6 if roughly equivalent: + { + a: al("[2001::5]:123", 100*ms), + b: al("1.2.3.4:555", 91*ms), + want: true, + }, + { + a: al("1.2.3.4:555", 91*ms), + b: al("[2001::5]:123", 100*ms), + want: false, + }, + // But not if IPv4 is much faster: + { + a: al("[2001::5]:123", 100*ms), + b: al("1.2.3.4:555", 30*ms), + want: false, + }, + { + a: al("1.2.3.4:555", 30*ms), + b: al("[2001::5]:123", 100*ms), + want: true, + }, + } + for _, tt := range tests { + got := betterAddr(tt.a, tt.b) + if got != tt.want { + t.Errorf("betterAddr(%+v, %+v) = %v; want %v", tt.a, tt.b, got, tt.want) + continue + } + gotBack := betterAddr(tt.b, tt.a) + if got && gotBack { + t.Errorf("betterAddr(%+v, %+v) and betterAddr(%+v, %+v) both unexpectedly true", tt.a, tt.b, tt.b, tt.a) + } + } + +} diff --git a/wgengine/monitor/monitor.go b/wgengine/monitor/monitor.go index 254df7cb6..8ee7087ce 100644 --- a/wgengine/monitor/monitor.go +++ b/wgengine/monitor/monitor.go @@ -10,6 +10,7 @@ package monitor import ( "encoding/json" "errors" + "runtime" "sync" "time" @@ -18,6 +19,13 @@ import ( "tailscale.com/types/logger" ) +// pollWallTimeInterval is how often we check the time to check +// for big jumps in wall (non-monotonic) time as a backup mechanism +// to get notified of a sleeping device waking back up. +// Usually there are also minor network change events on wake that let +// us check the wall time sooner than this. +const pollWallTimeInterval = 15 * time.Second + // message represents a message returned from an osMon. type message interface { // Ignore is whether we should ignore this message. @@ -50,18 +58,20 @@ type Mon struct { logf logger.Logf om osMon // nil means not supported on this platform change chan struct{} - stop chan struct{} - - mu sync.Mutex // guards cbs - cbs map[*callbackHandle]ChangeFunc - ifState *interfaces.State - gwValid bool // whether gw and gwSelfIP are valid (cached)x - gw netaddr.IP - gwSelfIP netaddr.IP + stop chan struct{} // closed on Stop - onceStart sync.Once + mu sync.Mutex // guards all following fields + cbs map[*callbackHandle]ChangeFunc + ifState *interfaces.State + gwValid bool // whether gw and gwSelfIP are valid + gw netaddr.IP // our gateway's IP + gwSelfIP netaddr.IP // our own IP address (that corresponds to gw) started bool + closed bool goroutines sync.WaitGroup + wallTimer *time.Timer // nil until Started; re-armed AfterFunc per tick + lastWall time.Time + timeJumped bool // whether we need to send a changed=true after a big time jump } // New instantiates and starts a monitoring instance. @@ -70,10 +80,11 @@ type Mon struct { func New(logf logger.Logf) (*Mon, error) { logf = logger.WithPrefix(logf, "monitor: ") m := &Mon{ - logf: logf, - cbs: map[*callbackHandle]ChangeFunc{}, - change: make(chan struct{}, 1), - stop: make(chan struct{}), + logf: logf, + cbs: map[*callbackHandle]ChangeFunc{}, + change: make(chan struct{}, 1), + stop: make(chan struct{}), + lastWall: wallTime(), } st, err := m.interfaceStateUncached() if err != nil { @@ -101,12 +112,7 @@ func (m *Mon) InterfaceState() *interfaces.State { } func (m *Mon) interfaceStateUncached() (*interfaces.State, error) { - s, err := interfaces.GetState() - if s != nil { - s.RemoveTailscaleInterfaces() - s.RemoveUninterestingInterfacesAndAddresses() - } - return s, err + return interfaces.GetState() } // GatewayAndSelfIP returns the current network's default gateway, and @@ -145,28 +151,54 @@ func (m *Mon) RegisterChangeCallback(callback ChangeFunc) (unregister func()) { // Start starts the monitor. // A monitor can only be started & closed once. func (m *Mon) Start() { - m.onceStart.Do(func() { - if m.om == nil { - return - } - m.started = true - m.goroutines.Add(2) - go m.pump() - go m.debounce() - }) + m.mu.Lock() + defer m.mu.Unlock() + if m.started || m.closed { + return + } + m.started = true + + switch runtime.GOOS { + case "ios", "android": + // For battery reasons, and because these platforms + // don't really sleep in the same way, don't poll + // for the wall time to detect for wake-for-sleep + // walltime jumps. + default: + m.wallTimer = time.AfterFunc(pollWallTimeInterval, m.pollWallTime) + } + + if m.om == nil { + return + } + m.goroutines.Add(2) + go m.pump() + go m.debounce() } // Close closes the monitor. -// It may only be called once. func (m *Mon) Close() error { + m.mu.Lock() + if m.closed { + m.mu.Unlock() + return nil + } + m.closed = true close(m.stop) + + if m.wallTimer != nil { + m.wallTimer.Stop() + } + var err error if m.om != nil { err = m.om.Close() } - // If it was previously started, wait for those goroutines to finish. - m.onceStart.Do(func() {}) - if m.started { + + started := m.started + m.mu.Unlock() + + if started { m.goroutines.Wait() } return err @@ -232,9 +264,17 @@ func (m *Mon) debounce() { m.logf("interfaces.State: %v", err) } else { m.mu.Lock() + + // See if we have a queued or new time jump signal. + m.checkWallTimeAdvanceLocked() + timeJumped := m.timeJumped + if timeJumped { + m.logf("time jumped (probably wake from sleep); synthesizing major change event") + } + oldState := m.ifState - changed := !curState.Equal(oldState) - if changed { + ifChanged := !curState.EqualFiltered(oldState, interfaces.FilterInteresting) + if ifChanged { m.gwValid = false m.ifState = curState @@ -243,6 +283,10 @@ func (m *Mon) debounce() { jsonSummary(oldState), jsonSummary(curState)) } } + changed := ifChanged || timeJumped + if changed { + m.timeJumped = false + } for _, cb := range m.cbs { go cb(changed, m.ifState) } @@ -264,3 +308,33 @@ func jsonSummary(x interface{}) interface{} { } return j } + +func wallTime() time.Time { + // From time package's docs: "The canonical way to strip a + // monotonic clock reading is to use t = t.Round(0)." + return time.Now().Round(0) +} + +func (m *Mon) pollWallTime() { + m.mu.Lock() + defer m.mu.Unlock() + if m.closed { + return + } + m.checkWallTimeAdvanceLocked() + if m.timeJumped { + m.InjectEvent() + } + m.wallTimer.Reset(pollWallTimeInterval) +} + +// checkWallTimeAdvanceLocked updates m.timeJumped, if wall time jumped +// more than 150% of pollWallTimeInterval, indicating we probably just +// came out of sleep. +func (m *Mon) checkWallTimeAdvanceLocked() { + now := wallTime() + if now.Sub(m.lastWall) > pollWallTimeInterval*3/2 { + m.timeJumped = true + } + m.lastWall = now +} diff --git a/wgengine/monitor/monitor_polling.go b/wgengine/monitor/monitor_polling.go index 079c956bb..cdc995ca4 100644 --- a/wgengine/monitor/monitor_polling.go +++ b/wgengine/monitor/monitor_polling.go @@ -12,6 +12,7 @@ import ( "sync" "time" + "tailscale.com/net/interfaces" "tailscale.com/types/logger" ) @@ -53,7 +54,7 @@ func (pm *pollingMon) Receive() (message, error) { defer ticker.Stop() base := pm.m.InterfaceState() for { - if cur, err := pm.m.interfaceStateUncached(); err == nil && !cur.Equal(base) { + if cur, err := pm.m.interfaceStateUncached(); err == nil && !cur.EqualFiltered(base, interfaces.FilterInteresting) { return unspecifiedMessage{}, nil } select { diff --git a/wgengine/netstack/netstack.go b/wgengine/netstack/netstack.go index ad647fa6f..92881decd 100644 --- a/wgengine/netstack/netstack.go +++ b/wgengine/netstack/netstack.go @@ -32,13 +32,13 @@ import ( "inet.af/netstack/waiter" "tailscale.com/net/packet" "tailscale.com/net/tsaddr" + "tailscale.com/net/tstun" "tailscale.com/types/logger" "tailscale.com/types/netmap" "tailscale.com/util/dnsname" "tailscale.com/wgengine" "tailscale.com/wgengine/filter" "tailscale.com/wgengine/magicsock" - "tailscale.com/wgengine/tstun" ) const debugNetstack = false @@ -49,7 +49,7 @@ const debugNetstack = false type Impl struct { ipstack *stack.Stack linkEP *channel.Endpoint - tundev *tstun.TUN + tundev *tstun.Wrapper e wgengine.Engine mc *magicsock.Conn logf logger.Logf @@ -67,7 +67,7 @@ const nicID = 1 const mtu = 1500 // Create creates and populates a new Impl. -func Create(logf logger.Logf, tundev *tstun.TUN, e wgengine.Engine, mc *magicsock.Conn) (*Impl, error) { +func Create(logf logger.Logf, tundev *tstun.Wrapper, e wgengine.Engine, mc *magicsock.Conn) (*Impl, error) { if mc == nil { return nil, errors.New("nil magicsock.Conn") } @@ -300,7 +300,7 @@ func (m DNSMap) Resolve(ctx context.Context, addr string) (netaddr.IPPort, error return netaddr.IPPort{IP: ip, Port: uint16(port16)}, nil } - // No Magic DNS name so try real DNS. + // No MagicDNS name so try real DNS. var r net.Resolver ips, err := r.LookupIP(ctx, "ip", host) if err != nil { @@ -363,7 +363,7 @@ func (ns *Impl) injectOutbound() { } } -func (ns *Impl) injectInbound(p *packet.Parsed, t *tstun.TUN) filter.Response { +func (ns *Impl) injectInbound(p *packet.Parsed, t *tstun.Wrapper) filter.Response { var pn tcpip.NetworkProtocolNumber switch p.IPVersion { case 4: diff --git a/wgengine/pendopen.go b/wgengine/pendopen.go index be1fa1468..2951a0c7e 100644 --- a/wgengine/pendopen.go +++ b/wgengine/pendopen.go @@ -14,8 +14,9 @@ import ( "tailscale.com/net/flowtrack" "tailscale.com/net/packet" "tailscale.com/net/tsaddr" + "tailscale.com/net/tstun" + "tailscale.com/types/ipproto" "tailscale.com/wgengine/filter" - "tailscale.com/wgengine/tstun" ) const tcpTimeoutBeforeDebug = 5 * time.Second @@ -65,10 +66,10 @@ func (e *userspaceEngine) noteFlowProblemFromPeer(f flowtrack.Tuple, problem pac of.problem = problem } -func (e *userspaceEngine) trackOpenPreFilterIn(pp *packet.Parsed, t *tstun.TUN) (res filter.Response) { +func (e *userspaceEngine) trackOpenPreFilterIn(pp *packet.Parsed, t *tstun.Wrapper) (res filter.Response) { res = filter.Accept // always - if pp.IPProto == packet.TSMP { + if pp.IPProto == ipproto.TSMP { res = filter.DropSilently rh, ok := pp.AsTailscaleRejectedHeader() if !ok { @@ -83,14 +84,14 @@ func (e *userspaceEngine) trackOpenPreFilterIn(pp *packet.Parsed, t *tstun.TUN) } if pp.IPVersion == 0 || - pp.IPProto != packet.TCP || + pp.IPProto != ipproto.TCP || pp.TCPFlags&(packet.TCPSyn|packet.TCPRst) == 0 { return } // Either a SYN or a RST came back. Remove it in either case. - f := flowtrack.Tuple{Dst: pp.Src, Src: pp.Dst} // src/dst reversed + f := flowtrack.Tuple{Proto: pp.IPProto, Dst: pp.Src, Src: pp.Dst} // src/dst reversed removed := e.removeFlow(f) if removed && pp.TCPFlags&packet.TCPRst != 0 { e.logf("open-conn-track: flow TCP %v got RST by peer", f) @@ -98,23 +99,23 @@ func (e *userspaceEngine) trackOpenPreFilterIn(pp *packet.Parsed, t *tstun.TUN) return } -func (e *userspaceEngine) trackOpenPostFilterOut(pp *packet.Parsed, t *tstun.TUN) (res filter.Response) { +func (e *userspaceEngine) trackOpenPostFilterOut(pp *packet.Parsed, t *tstun.Wrapper) (res filter.Response) { res = filter.Accept // always if pp.IPVersion == 0 || - pp.IPProto != packet.TCP || + pp.IPProto != ipproto.TCP || pp.TCPFlags&packet.TCPSyn == 0 { return } - flow := flowtrack.Tuple{Src: pp.Src, Dst: pp.Dst} + flow := flowtrack.Tuple{Proto: pp.IPProto, Src: pp.Src, Dst: pp.Dst} // iOS likes to probe Apple IPs on all interfaces to check for connectivity. // Don't start timers tracking those. They won't succeed anyway. Avoids log spam // like: // open-conn-track: timeout opening (100.115.73.60:52501 => 17.125.252.5:443); no associated peer node if runtime.GOOS == "ios" && flow.Dst.Port == 443 && !tsaddr.IsTailscaleIP(flow.Dst.IP) { - if _, ok := e.magicConn.PeerForIP(flow.Dst.IP); !ok { + if _, err := e.peerForIP(flow.Dst.IP); err != nil { return } } @@ -154,8 +155,12 @@ func (e *userspaceEngine) onOpenTimeout(flow flowtrack.Tuple) { } // Diagnose why it might've timed out. - n, ok := e.magicConn.PeerForIP(flow.Dst.IP) - if !ok { + n, err := e.peerForIP(flow.Dst.IP) + if err != nil { + e.logf("open-conn-track: timeout opening %v; peerForIP: %v", flow, err) + return + } + if n == nil { e.logf("open-conn-track: timeout opening %v; no associated peer node", flow) return } diff --git a/wgengine/router/dns/config.go b/wgengine/router/dns/config.go deleted file mode 100644 index 2b6ff615a..000000000 --- a/wgengine/router/dns/config.go +++ /dev/null @@ -1,77 +0,0 @@ -// Copyright (c) 2020 Tailscale Inc & AUTHORS All rights reserved. -// Use of this source code is governed by a BSD-style -// license that can be found in the LICENSE file. - -package dns - -import ( - "inet.af/netaddr" - - "tailscale.com/types/logger" -) - -// Config is the set of parameters that uniquely determine -// the state to which a manager should bring system DNS settings. -type Config struct { - // Nameservers are the IP addresses of the nameservers to use. - Nameservers []netaddr.IP - // Domains are the search domains to use. - Domains []string - // PerDomain indicates whether it is preferred to use Nameservers - // only for DNS queries for subdomains of Domains. - // Note that Nameservers may still be applied to all queries - // if the manager does not support per-domain settings. - PerDomain bool - // Proxied indicates whether DNS requests are proxied through a tsdns.Resolver. - // This enables Magic DNS. - Proxied bool -} - -// Equal determines whether its argument and receiver -// represent equivalent DNS configurations (then DNS reconfig is a no-op). -func (lhs Config) Equal(rhs Config) bool { - if lhs.Proxied != rhs.Proxied || lhs.PerDomain != rhs.PerDomain { - return false - } - - if len(lhs.Nameservers) != len(rhs.Nameservers) { - return false - } - - if len(lhs.Domains) != len(rhs.Domains) { - return false - } - - // With how we perform resolution order shouldn't matter, - // but it is unlikely that we will encounter different orders. - for i, server := range lhs.Nameservers { - if rhs.Nameservers[i] != server { - return false - } - } - - // The order of domains, on the other hand, is significant. - for i, domain := range lhs.Domains { - if rhs.Domains[i] != domain { - return false - } - } - - return true -} - -// ManagerConfig is the set of parameters from which -// a manager implementation is chosen and initialized. -type ManagerConfig struct { - // Logf is the logger for the manager to use. - // It is wrapped with a "dns: " prefix. - Logf logger.Logf - // InterfaceName is the name of the interface with which DNS settings should be associated. - InterfaceName string - // Cleanup indicates that the manager is created for cleanup only. - // A no-op manager will be instantiated if the system needs no cleanup. - Cleanup bool - // PerDomain indicates that a manager capable of per-domain configuration is preferred. - // Certain managers are per-domain only; they will not be considered if this is false. - PerDomain bool -} diff --git a/wgengine/router/dns/direct.go b/wgengine/router/dns/direct.go deleted file mode 100644 index bd1c03b9d..000000000 --- a/wgengine/router/dns/direct.go +++ /dev/null @@ -1,188 +0,0 @@ -// Copyright (c) 2020 Tailscale Inc & AUTHORS All rights reserved. -// Use of this source code is governed by a BSD-style -// license that can be found in the LICENSE file. - -// +build linux freebsd openbsd - -package dns - -import ( - "bufio" - "bytes" - "errors" - "fmt" - "io" - "io/ioutil" - "os" - "os/exec" - "runtime" - "strings" - - "inet.af/netaddr" - "tailscale.com/atomicfile" -) - -const ( - tsConf = "/etc/resolv.tailscale.conf" - backupConf = "/etc/resolv.pre-tailscale-backup.conf" - resolvConf = "/etc/resolv.conf" -) - -// writeResolvConf writes DNS configuration in resolv.conf format to the given writer. -func writeResolvConf(w io.Writer, servers []netaddr.IP, domains []string) { - io.WriteString(w, "# resolv.conf(5) file generated by tailscale\n") - io.WriteString(w, "# DO NOT EDIT THIS FILE BY HAND -- CHANGES WILL BE OVERWRITTEN\n\n") - for _, ns := range servers { - io.WriteString(w, "nameserver ") - io.WriteString(w, ns.String()) - io.WriteString(w, "\n") - } - if len(domains) > 0 { - io.WriteString(w, "search") - for _, domain := range domains { - io.WriteString(w, " ") - io.WriteString(w, domain) - } - io.WriteString(w, "\n") - } -} - -// readResolvConf reads DNS configuration from /etc/resolv.conf. -func readResolvConf() (Config, error) { - var config Config - - f, err := os.Open("/etc/resolv.conf") - if err != nil { - return config, err - } - - scanner := bufio.NewScanner(f) - for scanner.Scan() { - line := strings.TrimSpace(scanner.Text()) - - if strings.HasPrefix(line, "nameserver") { - nameserver := strings.TrimPrefix(line, "nameserver") - nameserver = strings.TrimSpace(nameserver) - ip, err := netaddr.ParseIP(nameserver) - if err != nil { - return config, err - } - config.Nameservers = append(config.Nameservers, ip) - continue - } - - if strings.HasPrefix(line, "search") { - domain := strings.TrimPrefix(line, "search") - domain = strings.TrimSpace(domain) - config.Domains = append(config.Domains, domain) - continue - } - } - - return config, nil -} - -// isResolvedRunning reports whether systemd-resolved is running on the system, -// even if it is not managing the system DNS settings. -func isResolvedRunning() bool { - if runtime.GOOS != "linux" { - return false - } - - // systemd-resolved is never installed without systemd. - _, err := exec.LookPath("systemctl") - if err != nil { - return false - } - - // is-active exits with code 3 if the service is not active. - err = exec.Command("systemctl", "is-active", "systemd-resolved.service").Run() - - return err == nil -} - -// directManager is a managerImpl which replaces /etc/resolv.conf with a file -// generated from the given configuration, creating a backup of its old state. -// -// This way of configuring DNS is precarious, since it does not react -// to the disappearance of the Tailscale interface. -// The caller must call Down before program shutdown -// or as cleanup if the program terminates unexpectedly. -type directManager struct{} - -func newDirectManager(mconfig ManagerConfig) managerImpl { - return directManager{} -} - -// Up implements managerImpl. -func (m directManager) Up(config Config) error { - // Write the tsConf file. - buf := new(bytes.Buffer) - writeResolvConf(buf, config.Nameservers, config.Domains) - if err := atomicfile.WriteFile(tsConf, buf.Bytes(), 0644); err != nil { - return err - } - - if linkPath, err := os.Readlink(resolvConf); err != nil { - // Remove any old backup that may exist. - os.Remove(backupConf) - - // Backup the existing /etc/resolv.conf file. - contents, err := ioutil.ReadFile(resolvConf) - // If the original did not exist, still back up an empty file. - // The presence of a backup file is the way we know that Up ran. - if err != nil && !errors.Is(err, os.ErrNotExist) { - return err - } - if err := atomicfile.WriteFile(backupConf, contents, 0644); err != nil { - return err - } - } else if linkPath != tsConf { - // Backup the existing symlink. - os.Remove(backupConf) - if err := os.Symlink(linkPath, backupConf); err != nil { - return err - } - } else { - // Nothing to do, resolvConf already points to tsConf. - return nil - } - - os.Remove(resolvConf) - if err := os.Symlink(tsConf, resolvConf); err != nil { - return err - } - - if isResolvedRunning() { - exec.Command("systemctl", "restart", "systemd-resolved.service").Run() // Best-effort. - } - - return nil -} - -// Down implements managerImpl. -func (m directManager) Down() error { - if _, err := os.Stat(backupConf); err != nil { - // If the backup file does not exist, then Up never ran successfully. - if os.IsNotExist(err) { - return nil - } - return err - } - - if ln, err := os.Readlink(resolvConf); err != nil { - return err - } else if ln != tsConf { - return fmt.Errorf("resolv.conf is not a symlink to %s", tsConf) - } - if err := os.Rename(backupConf, resolvConf); err != nil { - return err - } - os.Remove(tsConf) - - if isResolvedRunning() { - exec.Command("systemctl", "restart", "systemd-resolved.service").Run() // Best-effort. - } - - return nil -} diff --git a/wgengine/router/dns/manager.go b/wgengine/router/dns/manager.go deleted file mode 100644 index 8e2fc9d71..000000000 --- a/wgengine/router/dns/manager.go +++ /dev/null @@ -1,100 +0,0 @@ -// Copyright (c) 2020 Tailscale Inc & AUTHORS All rights reserved. -// Use of this source code is governed by a BSD-style -// license that can be found in the LICENSE file. - -package dns - -import ( - "time" - - "tailscale.com/types/logger" -) - -// We use file-ignore below instead of ignore because on some platforms, -// the lint exception is necessary and on others it is not, -// and plain ignore complains if the exception is unnecessary. - -//lint:file-ignore U1000 reconfigTimeout is used on some platforms but not others - -// reconfigTimeout is the time interval within which Manager.{Up,Down} should complete. -// -// This is particularly useful because certain conditions can cause indefinite hangs -// (such as improper dbus auth followed by contextless dbus.Object.Call). -// Such operations should be wrapped in a timeout context. -const reconfigTimeout = time.Second - -type managerImpl interface { - // Up updates system DNS settings to match the given configuration. - Up(Config) error - // Down undoes the effects of Up. - // It is idempotent and performs no action if Up has never been called. - Down() error -} - -// Manager manages system DNS settings. -type Manager struct { - logf logger.Logf - - impl managerImpl - - config Config - mconfig ManagerConfig -} - -// NewManagers created a new manager from the given config. -func NewManager(mconfig ManagerConfig) *Manager { - mconfig.Logf = logger.WithPrefix(mconfig.Logf, "dns: ") - m := &Manager{ - logf: mconfig.Logf, - impl: newManager(mconfig), - - config: Config{PerDomain: mconfig.PerDomain}, - mconfig: mconfig, - } - - m.logf("using %T", m.impl) - return m -} - -func (m *Manager) Set(config Config) error { - if config.Equal(m.config) { - return nil - } - - m.logf("Set: %+v", config) - - if len(config.Nameservers) == 0 { - err := m.impl.Down() - // If we save the config, we will not retry next time. Only do this on success. - if err == nil { - m.config = config - } - return err - } - - // Switching to and from per-domain mode may require a change of manager. - if config.PerDomain != m.config.PerDomain { - if err := m.impl.Down(); err != nil { - return err - } - m.mconfig.PerDomain = config.PerDomain - m.impl = newManager(m.mconfig) - m.logf("switched to %T", m.impl) - } - - err := m.impl.Up(config) - // If we save the config, we will not retry next time. Only do this on success. - if err == nil { - m.config = config - } - - return err -} - -func (m *Manager) Up() error { - return m.impl.Up(m.config) -} - -func (m *Manager) Down() error { - return m.impl.Down() -} diff --git a/wgengine/router/dns/manager_default.go b/wgengine/router/dns/manager_default.go deleted file mode 100644 index 04c8bb811..000000000 --- a/wgengine/router/dns/manager_default.go +++ /dev/null @@ -1,14 +0,0 @@ -// Copyright (c) 2020 Tailscale Inc & AUTHORS All rights reserved. -// Use of this source code is governed by a BSD-style -// license that can be found in the LICENSE file. - -// +build !linux,!freebsd,!openbsd,!windows - -package dns - -func newManager(mconfig ManagerConfig) managerImpl { - // TODO(dmytro): on darwin, we should use a macOS-specific method such as scutil. - // This is currently not implemented. Editing /etc/resolv.conf does not work, - // as most applications use the system resolver, which disregards it. - return newNoopManager(mconfig) -} diff --git a/wgengine/router/dns/manager_freebsd.go b/wgengine/router/dns/manager_freebsd.go deleted file mode 100644 index 232635f7e..000000000 --- a/wgengine/router/dns/manager_freebsd.go +++ /dev/null @@ -1,14 +0,0 @@ -// Copyright (c) 2020 Tailscale Inc & AUTHORS All rights reserved. -// Use of this source code is governed by a BSD-style -// license that can be found in the LICENSE file. - -package dns - -func newManager(mconfig ManagerConfig) managerImpl { - switch { - case isResolvconfActive(): - return newResolvconfManager(mconfig) - default: - return newDirectManager(mconfig) - } -} diff --git a/wgengine/router/dns/manager_linux.go b/wgengine/router/dns/manager_linux.go deleted file mode 100644 index f53aed7d3..000000000 --- a/wgengine/router/dns/manager_linux.go +++ /dev/null @@ -1,27 +0,0 @@ -// Copyright (c) 2020 Tailscale Inc & AUTHORS All rights reserved. -// Use of this source code is governed by a BSD-style -// license that can be found in the LICENSE file. - -package dns - -func newManager(mconfig ManagerConfig) managerImpl { - switch { - // systemd-resolved should only activate per-domain. - case isResolvedActive() && mconfig.PerDomain: - if mconfig.Cleanup { - return newNoopManager(mconfig) - } else { - return newResolvedManager(mconfig) - } - case isNMActive(): - if mconfig.Cleanup { - return newNoopManager(mconfig) - } else { - return newNMManager(mconfig) - } - case isResolvconfActive(): - return newResolvconfManager(mconfig) - default: - return newDirectManager(mconfig) - } -} diff --git a/wgengine/router/dns/manager_openbsd.go b/wgengine/router/dns/manager_openbsd.go deleted file mode 100644 index 228e3cca5..000000000 --- a/wgengine/router/dns/manager_openbsd.go +++ /dev/null @@ -1,9 +0,0 @@ -// Copyright (c) 2020 Tailscale Inc & AUTHORS All rights reserved. -// Use of this source code is governed by a BSD-style -// license that can be found in the LICENSE file. - -package dns - -func newManager(mconfig ManagerConfig) managerImpl { - return newDirectManager(mconfig) -} diff --git a/wgengine/router/dns/manager_windows.go b/wgengine/router/dns/manager_windows.go deleted file mode 100644 index 5940404e7..000000000 --- a/wgengine/router/dns/manager_windows.go +++ /dev/null @@ -1,118 +0,0 @@ -// Copyright (c) 2020 Tailscale Inc & AUTHORS All rights reserved. -// Use of this source code is governed by a BSD-style -// license that can be found in the LICENSE file. - -package dns - -import ( - "fmt" - "os/exec" - "strings" - "syscall" - "time" - - "golang.org/x/sys/windows/registry" - "tailscale.com/types/logger" -) - -const ( - ipv4RegBase = `SYSTEM\CurrentControlSet\Services\Tcpip\Parameters` - ipv6RegBase = `SYSTEM\CurrentControlSet\Services\Tcpip6\Parameters` -) - -type windowsManager struct { - logf logger.Logf - guid string -} - -func newManager(mconfig ManagerConfig) managerImpl { - return windowsManager{ - logf: mconfig.Logf, - guid: mconfig.InterfaceName, - } -} - -// keyOpenTimeout is how long we wait for a registry key to -// appear. For some reason, registry keys tied to ephemeral interfaces -// can take a long while to appear after interface creation, and we -// can end up racing with that. -const keyOpenTimeout = 20 * time.Second - -func setRegistryString(path, name, value string) error { - key, err := openKeyWait(registry.LOCAL_MACHINE, path, registry.SET_VALUE, keyOpenTimeout) - if err != nil { - return fmt.Errorf("opening %s: %w", path, err) - } - defer key.Close() - - err = key.SetStringValue(name, value) - if err != nil { - return fmt.Errorf("setting %s[%s]: %w", path, name, err) - } - return nil -} - -func (m windowsManager) setNameservers(basePath string, nameservers []string) error { - path := fmt.Sprintf(`%s\Interfaces\%s`, basePath, m.guid) - value := strings.Join(nameservers, ",") - return setRegistryString(path, "NameServer", value) -} - -func (m windowsManager) setDomains(basePath string, domains []string) error { - path := fmt.Sprintf(`%s\Interfaces\%s`, basePath, m.guid) - value := strings.Join(domains, ",") - return setRegistryString(path, "SearchList", value) -} - -func (m windowsManager) Up(config Config) error { - var ipsv4 []string - var ipsv6 []string - - for _, ip := range config.Nameservers { - if ip.Is4() { - ipsv4 = append(ipsv4, ip.String()) - } else { - ipsv6 = append(ipsv6, ip.String()) - } - } - - if err := m.setNameservers(ipv4RegBase, ipsv4); err != nil { - return err - } - if err := m.setDomains(ipv4RegBase, config.Domains); err != nil { - return err - } - - if err := m.setNameservers(ipv6RegBase, ipsv6); err != nil { - return err - } - if err := m.setDomains(ipv6RegBase, config.Domains); err != nil { - return err - } - - // Force DNS re-registration in Active Directory. What we actually - // care about is that this command invokes the undocumented hidden - // function that forces Windows to notice that adapter settings - // have changed, which makes the DNS settings actually take - // effect. - // - // This command can take a few seconds to run, so run it async, best effort. - go func() { - t0 := time.Now() - m.logf("running ipconfig /registerdns ...") - cmd := exec.Command("ipconfig", "/registerdns") - cmd.SysProcAttr = &syscall.SysProcAttr{HideWindow: true} - d := time.Since(t0).Round(time.Millisecond) - if err := cmd.Run(); err != nil { - m.logf("error running ipconfig /registerdns after %v: %v", d, err) - } else { - m.logf("ran ipconfig /registerdns in %v", d) - } - }() - - return nil -} - -func (m windowsManager) Down() error { - return m.Up(Config{Nameservers: nil, Domains: nil}) -} diff --git a/wgengine/router/dns/nm.go b/wgengine/router/dns/nm.go deleted file mode 100644 index a597fa60d..000000000 --- a/wgengine/router/dns/nm.go +++ /dev/null @@ -1,205 +0,0 @@ -// Copyright (c) 2020 Tailscale Inc & AUTHORS All rights reserved. -// Use of this source code is governed by a BSD-style -// license that can be found in the LICENSE file. - -// +build linux - -package dns - -import ( - "bufio" - "bytes" - "context" - "fmt" - "os" - "os/exec" - - "github.com/godbus/dbus/v5" - "tailscale.com/util/endian" -) - -// isNMActive determines if NetworkManager is currently managing system DNS settings. -func isNMActive() bool { - // This is somewhat tricky because NetworkManager supports a number - // of DNS configuration modes. In all cases, we expect it to be installed - // and /etc/resolv.conf to contain a mention of NetworkManager in the comments. - _, err := exec.LookPath("NetworkManager") - if err != nil { - return false - } - - f, err := os.Open("/etc/resolv.conf") - if err != nil { - return false - } - defer f.Close() - - scanner := bufio.NewScanner(f) - for scanner.Scan() { - line := scanner.Bytes() - // Look for the word "NetworkManager" until comments end. - if len(line) > 0 && line[0] != '#' { - return false - } - if bytes.Contains(line, []byte("NetworkManager")) { - return true - } - } - return false -} - -// nmManager uses the NetworkManager DBus API. -type nmManager struct { - interfaceName string -} - -func newNMManager(mconfig ManagerConfig) managerImpl { - return nmManager{ - interfaceName: mconfig.InterfaceName, - } -} - -type nmConnectionSettings map[string]map[string]dbus.Variant - -// Up implements managerImpl. -func (m nmManager) Up(config Config) error { - ctx, cancel := context.WithTimeout(context.Background(), reconfigTimeout) - defer cancel() - - // conn is a shared connection whose lifecycle is managed by the dbus package. - // We should not interfere with that by closing it. - conn, err := dbus.SystemBus() - if err != nil { - return fmt.Errorf("connecting to system bus: %w", err) - } - - // This is how we get at the DNS settings: - // - // org.freedesktop.NetworkManager - // | - // [GetDeviceByIpIface] - // | - // v - // org.freedesktop.NetworkManager.Device <--------\ - // (describes a network interface) | - // | | - // [GetAppliedConnection] [Reapply] - // | | - // v | - // org.freedesktop.NetworkManager.Connection | - // (connection settings) ------/ - // contains {dns, dns-priority, dns-search} - // - // Ref: https://developer.gnome.org/NetworkManager/stable/settings-ipv4.html. - - nm := conn.Object( - "org.freedesktop.NetworkManager", - dbus.ObjectPath("/org/freedesktop/NetworkManager"), - ) - - var devicePath dbus.ObjectPath - err = nm.CallWithContext( - ctx, "org.freedesktop.NetworkManager.GetDeviceByIpIface", 0, - m.interfaceName, - ).Store(&devicePath) - if err != nil { - return fmt.Errorf("getDeviceByIpIface: %w", err) - } - device := conn.Object("org.freedesktop.NetworkManager", devicePath) - - var ( - settings nmConnectionSettings - version uint64 - ) - err = device.CallWithContext( - ctx, "org.freedesktop.NetworkManager.Device.GetAppliedConnection", 0, - uint32(0), - ).Store(&settings, &version) - if err != nil { - return fmt.Errorf("getAppliedConnection: %w", err) - } - - // Frustratingly, NetworkManager represents IPv4 addresses as uint32s, - // although IPv6 addresses are represented as byte arrays. - // Perform the conversion here. - var ( - dnsv4 []uint32 - dnsv6 [][]byte - ) - for _, ip := range config.Nameservers { - b := ip.As16() - if ip.Is4() { - dnsv4 = append(dnsv4, endian.Native.Uint32(b[12:])) - } else { - dnsv6 = append(dnsv6, b[:]) - } - } - - ipv4Map := settings["ipv4"] - ipv4Map["dns"] = dbus.MakeVariant(dnsv4) - ipv4Map["dns-search"] = dbus.MakeVariant(config.Domains) - // We should only request priority if we have nameservers to set. - if len(dnsv4) == 0 { - ipv4Map["dns-priority"] = dbus.MakeVariant(100) - } else { - // dns-priority = -1 ensures that we have priority - // over other interfaces, except those exploiting this same trick. - // Ref: https://bugs.launchpad.net/ubuntu/+source/network-manager/+bug/1211110/comments/92. - ipv4Map["dns-priority"] = dbus.MakeVariant(-1) - } - // In principle, we should not need set this to true, - // as our interface does not configure any automatic DNS settings (presumably via DHCP). - // All the same, better to be safe. - ipv4Map["ignore-auto-dns"] = dbus.MakeVariant(true) - - ipv6Map := settings["ipv6"] - // This is a hack. - // Methods "disabled", "ignore", "link-local" (IPv6 default) prevent us from setting DNS. - // It seems that our only recourse is "manual" or "auto". - // "manual" requires addresses, so we use "auto", which will assign us a random IPv6 /64. - ipv6Map["method"] = dbus.MakeVariant("auto") - // Our IPv6 config is a fake, so it should never become the default route. - ipv6Map["never-default"] = dbus.MakeVariant(true) - // Moreover, we should ignore all autoconfigured routes (hopefully none), as they are bogus. - ipv6Map["ignore-auto-routes"] = dbus.MakeVariant(true) - - // Finally, set the actual DNS config. - ipv6Map["dns"] = dbus.MakeVariant(dnsv6) - ipv6Map["dns-search"] = dbus.MakeVariant(config.Domains) - if len(dnsv6) == 0 { - ipv6Map["dns-priority"] = dbus.MakeVariant(100) - } else { - ipv6Map["dns-priority"] = dbus.MakeVariant(-1) - } - ipv6Map["ignore-auto-dns"] = dbus.MakeVariant(true) - - // deprecatedProperties are the properties in interface settings - // that are deprecated by NetworkManager. - // - // In practice, this means that they are returned for reading, - // but submitting a settings object with them present fails - // with hard-to-diagnose errors. They must be removed. - deprecatedProperties := []string{ - "addresses", "routes", - } - - for _, property := range deprecatedProperties { - delete(ipv4Map, property) - delete(ipv6Map, property) - } - - err = device.CallWithContext( - ctx, "org.freedesktop.NetworkManager.Device.Reapply", 0, - settings, version, uint32(0), - ).Store() - if err != nil { - return fmt.Errorf("reapply: %w", err) - } - - return nil -} - -// Down implements managerImpl. -func (m nmManager) Down() error { - return m.Up(Config{Nameservers: nil, Domains: nil}) -} diff --git a/wgengine/router/dns/noop.go b/wgengine/router/dns/noop.go deleted file mode 100644 index 35c07a232..000000000 --- a/wgengine/router/dns/noop.go +++ /dev/null @@ -1,17 +0,0 @@ -// Copyright (c) 2020 Tailscale Inc & AUTHORS All rights reserved. -// Use of this source code is governed by a BSD-style -// license that can be found in the LICENSE file. - -package dns - -type noopManager struct{} - -// Up implements managerImpl. -func (m noopManager) Up(Config) error { return nil } - -// Down implements managerImpl. -func (m noopManager) Down() error { return nil } - -func newNoopManager(mconfig ManagerConfig) managerImpl { - return noopManager{} -} diff --git a/wgengine/router/dns/registry_windows.go b/wgengine/router/dns/registry_windows.go deleted file mode 100644 index f8e1f514a..000000000 --- a/wgengine/router/dns/registry_windows.go +++ /dev/null @@ -1,76 +0,0 @@ -// Copyright (c) 2020 Tailscale Inc & AUTHORS All rights reserved. -// Use of this source code is governed by a BSD-style -// license that can be found in the LICENSE file. -// -// The code in this file originates from https://git.zx2c4.com/wireguard-go: -// Copyright (C) 2017-2020 WireGuard LLC. All Rights Reserved. -// Copying license: https://git.zx2c4.com/wireguard-go/tree/COPYING - -package dns - -import ( - "fmt" - "runtime" - "strings" - "time" - - "golang.org/x/sys/windows" - "golang.org/x/sys/windows/registry" -) - -func openKeyWait(k registry.Key, path string, access uint32, timeout time.Duration) (registry.Key, error) { - runtime.LockOSThread() - defer runtime.UnlockOSThread() - - deadline := time.Now().Add(timeout) - pathSpl := strings.Split(path, "\\") - for i := 0; ; i++ { - keyName := pathSpl[i] - isLast := i+1 == len(pathSpl) - - event, err := windows.CreateEvent(nil, 0, 0, nil) - if err != nil { - return 0, fmt.Errorf("windows.CreateEvent: %v", err) - } - defer windows.CloseHandle(event) - - var key registry.Key - for { - err = windows.RegNotifyChangeKeyValue(windows.Handle(k), false, windows.REG_NOTIFY_CHANGE_NAME, event, true) - if err != nil { - return 0, fmt.Errorf("windows.RegNotifyChangeKeyValue: %v", err) - } - - var accessFlags uint32 - if isLast { - accessFlags = access - } else { - accessFlags = registry.NOTIFY - } - key, err = registry.OpenKey(k, keyName, accessFlags) - if err == windows.ERROR_FILE_NOT_FOUND || err == windows.ERROR_PATH_NOT_FOUND { - timeout := time.Until(deadline) / time.Millisecond - if timeout < 0 { - timeout = 0 - } - s, err := windows.WaitForSingleObject(event, uint32(timeout)) - if err != nil { - return 0, fmt.Errorf("windows.WaitForSingleObject: %v", err) - } - if s == uint32(windows.WAIT_TIMEOUT) { // windows.WAIT_TIMEOUT status const is misclassified as error in golang.org/x/sys/windows - return 0, fmt.Errorf("timeout waiting for registry key") - } - } else if err != nil { - return 0, fmt.Errorf("registry.OpenKey(%v): %v", path, err) - } else { - if isLast { - return key, nil - } - defer key.Close() - break - } - } - - k = key - } -} diff --git a/wgengine/router/dns/resolvconf.go b/wgengine/router/dns/resolvconf.go deleted file mode 100644 index 8bf97ee88..000000000 --- a/wgengine/router/dns/resolvconf.go +++ /dev/null @@ -1,157 +0,0 @@ -// Copyright (c) 2020 Tailscale Inc & AUTHORS All rights reserved. -// Use of this source code is governed by a BSD-style -// license that can be found in the LICENSE file. - -// +build linux freebsd - -package dns - -import ( - "bufio" - "bytes" - "fmt" - "os" - "os/exec" -) - -// isResolvconfActive indicates whether the system appears to be using resolvconf. -// If this is true, then directManager should be avoided: -// resolvconf has exclusive ownership of /etc/resolv.conf. -func isResolvconfActive() bool { - // Sanity-check first: if there is no resolvconf binary, then this is fruitless. - // - // However, this binary may be a shim like the one systemd-resolved provides. - // Such a shim may not behave as expected: in particular, systemd-resolved - // does not seem to respect the exclusive mode -x, saying: - // -x Send DNS traffic preferably over this interface - // whereas e.g. openresolv sends DNS traffix _exclusively_ over that interface, - // or not at all (in case of another exclusive-mode request later in time). - // - // Moreover, resolvconf may be installed but unused, in which case we should - // not use it either, lest we clobber existing configuration. - // - // To handle all the above correctly, we scan the comments in /etc/resolv.conf - // to ensure that it was generated by a resolvconf implementation. - _, err := exec.LookPath("resolvconf") - if err != nil { - return false - } - - f, err := os.Open("/etc/resolv.conf") - if err != nil { - return false - } - defer f.Close() - - scanner := bufio.NewScanner(f) - for scanner.Scan() { - line := scanner.Bytes() - // Look for the word "resolvconf" until comments end. - if len(line) > 0 && line[0] != '#' { - return false - } - if bytes.Contains(line, []byte("resolvconf")) { - return true - } - } - return false -} - -// resolvconfImpl enumerates supported implementations of the resolvconf CLI. -type resolvconfImpl uint8 - -const ( - // resolvconfOpenresolv is the implementation packaged as "openresolv" on Ubuntu. - // It supports exclusive mode and interface metrics. - resolvconfOpenresolv resolvconfImpl = iota - // resolvconfLegacy is the implementation by Thomas Hood packaged as "resolvconf" on Ubuntu. - // It does not support exclusive mode or interface metrics. - resolvconfLegacy -) - -func (impl resolvconfImpl) String() string { - switch impl { - case resolvconfOpenresolv: - return "openresolv" - case resolvconfLegacy: - return "legacy" - default: - return "unknown" - } -} - -// getResolvconfImpl returns the implementation of resolvconf that appears to be in use. -func getResolvconfImpl() resolvconfImpl { - err := exec.Command("resolvconf", "-v").Run() - if err != nil { - if exitErr, ok := err.(*exec.ExitError); ok { - // Thomas Hood's resolvconf has a minimal flag set - // and exits with code 99 when passed an unknown flag. - if exitErr.ExitCode() == 99 { - return resolvconfLegacy - } - } - } - return resolvconfOpenresolv -} - -type resolvconfManager struct { - impl resolvconfImpl -} - -func newResolvconfManager(mconfig ManagerConfig) managerImpl { - impl := getResolvconfImpl() - mconfig.Logf("resolvconf implementation is %s", impl) - - return resolvconfManager{ - impl: impl, - } -} - -// resolvconfConfigName is the name of the config submitted to resolvconf. -// It has this form to match the "tun*" rule in interface-order -// when running resolvconfLegacy, hopefully placing our config first. -const resolvconfConfigName = "tun-tailscale.inet" - -// Up implements managerImpl. -func (m resolvconfManager) Up(config Config) error { - stdin := new(bytes.Buffer) - writeResolvConf(stdin, config.Nameservers, config.Domains) // dns_direct.go - - var cmd *exec.Cmd - switch m.impl { - case resolvconfOpenresolv: - // Request maximal priority (metric 0) and exclusive mode. - cmd = exec.Command("resolvconf", "-m", "0", "-x", "-a", resolvconfConfigName) - case resolvconfLegacy: - // This does not quite give us the desired behavior (queries leak), - // but there is nothing else we can do without messing with other interfaces' settings. - cmd = exec.Command("resolvconf", "-a", resolvconfConfigName) - } - cmd.Stdin = stdin - out, err := cmd.CombinedOutput() - if err != nil { - return fmt.Errorf("running %s: %s", cmd, out) - } - - return nil -} - -// Down implements managerImpl. -func (m resolvconfManager) Down() error { - var cmd *exec.Cmd - switch m.impl { - case resolvconfOpenresolv: - cmd = exec.Command("resolvconf", "-f", "-d", resolvconfConfigName) - case resolvconfLegacy: - // resolvconfLegacy lacks the -f flag. - // Instead, it succeeds even when the config does not exist. - cmd = exec.Command("resolvconf", "-d", resolvconfConfigName) - } - out, err := cmd.CombinedOutput() - if err != nil { - return fmt.Errorf("running %s: %s", cmd, out) - } - - return nil -} diff --git a/wgengine/router/dns/resolved.go b/wgengine/router/dns/resolved.go deleted file mode 100644 index 9d8c40d90..000000000 --- a/wgengine/router/dns/resolved.go +++ /dev/null @@ -1,188 +0,0 @@ -// Copyright (c) 2020 Tailscale Inc & AUTHORS All rights reserved. -// Use of this source code is governed by a BSD-style -// license that can be found in the LICENSE file. - -// +build linux - -package dns - -import ( - "context" - "errors" - "fmt" - "os/exec" - - "github.com/godbus/dbus/v5" - "golang.org/x/sys/unix" - "inet.af/netaddr" - "tailscale.com/net/interfaces" -) - -// resolvedListenAddr is the listen address of the resolved stub resolver. -// -// We only consider resolved to be the system resolver if the stub resolver is; -// that is, if this address is the sole nameserver in /etc/resolved.conf. -// In other cases, resolved may be managing the system DNS configuration directly. -// Then the nameserver list will be a concatenation of those for all -// the interfaces that register their interest in being a default resolver with -// SetLinkDomains([]{{"~.", true}, ...}) -// which includes at least the interface with the default route, i.e. not us. -// This does not work for us: there is a possibility of getting NXDOMAIN -// from the other nameservers before we are asked or get a chance to respond. -// We consider this case as lacking resolved support and fall through to dnsDirect. -// -// While it may seem that we need to read a config option to get at this, -// this address is, in fact, hard-coded into resolved. -var resolvedListenAddr = netaddr.IPv4(127, 0, 0, 53) - -var errNotReady = errors.New("interface not ready") - -type resolvedLinkNameserver struct { - Family int32 - Address []byte -} - -type resolvedLinkDomain struct { - Domain string - RoutingOnly bool -} - -// isResolvedActive determines if resolved is currently managing system DNS settings. -func isResolvedActive() bool { - // systemd-resolved is never installed without systemd. - _, err := exec.LookPath("systemctl") - if err != nil { - return false - } - - // is-active exits with code 3 if the service is not active. - err = exec.Command("systemctl", "is-active", "systemd-resolved").Run() - if err != nil { - return false - } - - config, err := readResolvConf() - if err != nil { - return false - } - - // The sole nameserver must be the systemd-resolved stub. - if len(config.Nameservers) == 1 && config.Nameservers[0] == resolvedListenAddr { - return true - } - - return false -} - -// resolvedManager uses the systemd-resolved DBus API. -type resolvedManager struct{} - -func newResolvedManager(mconfig ManagerConfig) managerImpl { - return resolvedManager{} -} - -// Up implements managerImpl. -func (m resolvedManager) Up(config Config) error { - ctx, cancel := context.WithTimeout(context.Background(), reconfigTimeout) - defer cancel() - - // conn is a shared connection whose lifecycle is managed by the dbus package. - // We should not interfere with that by closing it. - conn, err := dbus.SystemBus() - if err != nil { - return fmt.Errorf("connecting to system bus: %w", err) - } - - resolved := conn.Object( - "org.freedesktop.resolve1", - dbus.ObjectPath("/org/freedesktop/resolve1"), - ) - - // In principle, we could persist this in the manager struct - // if we knew that interface indices are persistent. This does not seem to be the case. - _, iface, err := interfaces.Tailscale() - if err != nil { - return fmt.Errorf("getting interface index: %w", err) - } - if iface == nil { - return errNotReady - } - - var linkNameservers = make([]resolvedLinkNameserver, len(config.Nameservers)) - for i, server := range config.Nameservers { - ip := server.As16() - if server.Is4() { - linkNameservers[i] = resolvedLinkNameserver{ - Family: unix.AF_INET, - Address: ip[12:], - } - } else { - linkNameservers[i] = resolvedLinkNameserver{ - Family: unix.AF_INET6, - Address: ip[:], - } - } - } - - err = resolved.CallWithContext( - ctx, "org.freedesktop.resolve1.Manager.SetLinkDNS", 0, - iface.Index, linkNameservers, - ).Store() - if err != nil { - return fmt.Errorf("setLinkDNS: %w", err) - } - - var linkDomains = make([]resolvedLinkDomain, len(config.Domains)) - for i, domain := range config.Domains { - linkDomains[i] = resolvedLinkDomain{ - Domain: domain, - RoutingOnly: false, - } - } - - err = resolved.CallWithContext( - ctx, "org.freedesktop.resolve1.Manager.SetLinkDomains", 0, - iface.Index, linkDomains, - ).Store() - if err != nil { - return fmt.Errorf("setLinkDomains: %w", err) - } - - return nil -} - -// Down implements managerImpl. -func (m resolvedManager) Down() error { - ctx, cancel := context.WithTimeout(context.Background(), reconfigTimeout) - defer cancel() - - // conn is a shared connection whose lifecycle is managed by the dbus package. - // We should not interfere with that by closing it. - conn, err := dbus.SystemBus() - if err != nil { - return fmt.Errorf("connecting to system bus: %w", err) - } - - resolved := conn.Object( - "org.freedesktop.resolve1", - dbus.ObjectPath("/org/freedesktop/resolve1"), - ) - - _, iface, err := interfaces.Tailscale() - if err != nil { - return fmt.Errorf("getting interface index: %w", err) - } - if iface == nil { - return errNotReady - } - - err = resolved.CallWithContext( - ctx, "org.freedesktop.resolve1.Manager.RevertLink", 0, - iface.Index, - ).Store() - if err != nil { - return fmt.Errorf("RevertLink: %w", err) - } - - return nil -} diff --git a/wgengine/router/router.go b/wgengine/router/router.go index 9c3f1003f..2e53363cf 100644 --- a/wgengine/router/router.go +++ b/wgengine/router/router.go @@ -7,12 +7,11 @@ package router import ( - "github.com/tailscale/wireguard-go/device" "github.com/tailscale/wireguard-go/tun" "inet.af/netaddr" + "tailscale.com/net/dns" "tailscale.com/types/logger" "tailscale.com/types/preftype" - "tailscale.com/wgengine/router/dns" ) // Router is responsible for managing the system network stack. @@ -33,9 +32,9 @@ type Router interface { // New returns a new Router for the current platform, using the // provided tun device. -func New(logf logger.Logf, wgdev *device.Device, tundev tun.Device) (Router, error) { +func New(logf logger.Logf, tundev tun.Device) (Router, error) { logf = logger.WithPrefix(logf, "router: ") - return newUserspaceRouter(logf, wgdev, tundev) + return newUserspaceRouter(logf, tundev) } // Cleanup restores the system network configuration to its original state diff --git a/wgengine/router/router_darwin.go b/wgengine/router/router_darwin.go index 26b689355..58ba8e6d3 100644 --- a/wgengine/router/router_darwin.go +++ b/wgengine/router/router_darwin.go @@ -5,13 +5,12 @@ package router import ( - "github.com/tailscale/wireguard-go/device" "github.com/tailscale/wireguard-go/tun" "tailscale.com/types/logger" ) -func newUserspaceRouter(logf logger.Logf, wgdev *device.Device, tundev tun.Device) (Router, error) { - return newUserspaceBSDRouter(logf, wgdev, tundev) +func newUserspaceRouter(logf logger.Logf, tundev tun.Device) (Router, error) { + return newUserspaceBSDRouter(logf, tundev) } func cleanup(logger.Logf, string) { diff --git a/wgengine/router/router_default.go b/wgengine/router/router_default.go index 4d7365e04..7f05da42f 100644 --- a/wgengine/router/router_default.go +++ b/wgengine/router/router_default.go @@ -7,13 +7,12 @@ package router import ( - "github.com/tailscale/wireguard-go/device" "github.com/tailscale/wireguard-go/tun" "tailscale.com/types/logger" ) -func newUserspaceRouter(logf logger.Logf, tunname string, dev *device.Device, tunDev tun.Device, netChanged func()) Router { - return NewFakeRouter(logf, tunname, dev, tunDev, netChanged) +func newUserspaceRouter(logf logger.Logf, tunname string, tunDev tun.Device, netChanged func()) Router { + return NewFakeRouter(logf, tunname, tunDev, netChanged) } func cleanup(logf logger.Logf, interfaceName string) { diff --git a/wgengine/router/router_fake.go b/wgengine/router/router_fake.go index 0d14e5000..add4b576b 100644 --- a/wgengine/router/router_fake.go +++ b/wgengine/router/router_fake.go @@ -5,15 +5,13 @@ package router import ( - "github.com/tailscale/wireguard-go/device" - "github.com/tailscale/wireguard-go/tun" "tailscale.com/types/logger" ) -// NewFakeRouter returns a Router that does nothing when called and -// always returns nil errors. -func NewFake(logf logger.Logf, _ *device.Device, _ tun.Device) (Router, error) { - return fakeRouter{logf: logf}, nil +// NewFake returns a Router that does nothing when called and always +// returns nil errors. +func NewFake(logf logger.Logf) Router { + return fakeRouter{logf: logf} } type fakeRouter struct { diff --git a/wgengine/router/router_freebsd.go b/wgengine/router/router_freebsd.go index e56e3f82d..6e9380299 100644 --- a/wgengine/router/router_freebsd.go +++ b/wgengine/router/router_freebsd.go @@ -5,7 +5,6 @@ package router import ( - "github.com/tailscale/wireguard-go/device" "github.com/tailscale/wireguard-go/tun" "tailscale.com/types/logger" ) @@ -15,8 +14,8 @@ import ( // Work is currently underway for an in-kernel FreeBSD implementation of wireguard // https://svnweb.freebsd.org/base?view=revision&revision=357986 -func newUserspaceRouter(logf logger.Logf, _ *device.Device, tundev tun.Device) (Router, error) { - return newUserspaceBSDRouter(logf, nil, tundev) +func newUserspaceRouter(logf logger.Logf, tundev tun.Device) (Router, error) { + return newUserspaceBSDRouter(logf, tundev) } func cleanup(logf logger.Logf, interfaceName string) { diff --git a/wgengine/router/router_linux.go b/wgengine/router/router_linux.go index b700efccc..311681ab0 100644 --- a/wgengine/router/router_linux.go +++ b/wgengine/router/router_linux.go @@ -16,14 +16,13 @@ import ( "github.com/coreos/go-iptables/iptables" "github.com/go-multierror/multierror" - "github.com/tailscale/wireguard-go/device" "github.com/tailscale/wireguard-go/tun" "inet.af/netaddr" + "tailscale.com/net/dns" "tailscale.com/net/tsaddr" "tailscale.com/types/logger" "tailscale.com/types/preftype" "tailscale.com/version/distro" - "tailscale.com/wgengine/router/dns" ) const ( @@ -110,7 +109,7 @@ type linuxRouter struct { cmd commandRunner } -func newUserspaceRouter(logf logger.Logf, _ *device.Device, tunDev tun.Device) (Router, error) { +func newUserspaceRouter(logf logger.Logf, tunDev tun.Device) (Router, error) { tunname, err := tunDev.Name() if err != nil { return nil, err diff --git a/wgengine/router/router_linux_test.go b/wgengine/router/router_linux_test.go index 7bacb2e2f..45109d35c 100644 --- a/wgengine/router/router_linux_test.go +++ b/wgengine/router/router_linux_test.go @@ -627,7 +627,7 @@ func TestDelRouteIdempotent(t *testing.T) { } } - r, err := newUserspaceRouter(logf, nil, tun) + r, err := newUserspaceRouter(logf, tun) if err != nil { t.Fatal(err) } diff --git a/wgengine/router/router_openbsd.go b/wgengine/router/router_openbsd.go index 8c7269658..a6dbf9282 100644 --- a/wgengine/router/router_openbsd.go +++ b/wgengine/router/router_openbsd.go @@ -10,11 +10,10 @@ import ( "log" "os/exec" - "github.com/tailscale/wireguard-go/device" "github.com/tailscale/wireguard-go/tun" "inet.af/netaddr" + "tailscale.com/net/dns" "tailscale.com/types/logger" - "tailscale.com/wgengine/router/dns" ) // For now this router only supports the WireGuard userspace implementation. @@ -31,7 +30,7 @@ type openbsdRouter struct { dns *dns.Manager } -func newUserspaceRouter(logf logger.Logf, _ *device.Device, tundev tun.Device) (Router, error) { +func newUserspaceRouter(logf logger.Logf, tundev tun.Device) (Router, error) { tunname, err := tundev.Name() if err != nil { return nil, err diff --git a/wgengine/router/router_userspace_bsd.go b/wgengine/router/router_userspace_bsd.go index 71ccd1706..79a81de03 100644 --- a/wgengine/router/router_userspace_bsd.go +++ b/wgengine/router/router_userspace_bsd.go @@ -12,12 +12,11 @@ import ( "os/exec" "runtime" - "github.com/tailscale/wireguard-go/device" "github.com/tailscale/wireguard-go/tun" "inet.af/netaddr" + "tailscale.com/net/dns" "tailscale.com/types/logger" "tailscale.com/version" - "tailscale.com/wgengine/router/dns" ) type userspaceBSDRouter struct { @@ -29,7 +28,7 @@ type userspaceBSDRouter struct { dns *dns.Manager } -func newUserspaceBSDRouter(logf logger.Logf, _ *device.Device, tundev tun.Device) (Router, error) { +func newUserspaceBSDRouter(logf logger.Logf, tundev tun.Device) (Router, error) { tunname, err := tundev.Name() if err != nil { return nil, err diff --git a/wgengine/router/router_windows.go b/wgengine/router/router_windows.go index 89c686b95..2efdce7ed 100644 --- a/wgengine/router/router_windows.go +++ b/wgengine/router/router_windows.go @@ -16,21 +16,19 @@ import ( "syscall" "time" - "github.com/tailscale/wireguard-go/device" "github.com/tailscale/wireguard-go/tun" "golang.org/x/sys/windows" "golang.zx2c4.com/wireguard/windows/tunnel/winipcfg" "inet.af/netaddr" "tailscale.com/logtail/backoff" + "tailscale.com/net/dns" "tailscale.com/types/logger" - "tailscale.com/wgengine/router/dns" ) type winRouter struct { logf func(fmt string, args ...interface{}) tunname string nativeTun *tun.NativeTun - wgdev *device.Device routeChangeCallback *winipcfg.RouteChangeCallback dns *dns.Manager firewall *firewallTweaker @@ -45,7 +43,7 @@ type winRouter struct { firewallSubproc *exec.Cmd } -func newUserspaceRouter(logf logger.Logf, wgdev *device.Device, tundev tun.Device) (Router, error) { +func newUserspaceRouter(logf logger.Logf, tundev tun.Device) (Router, error) { tunname, err := tundev.Name() if err != nil { return nil, err @@ -65,7 +63,6 @@ func newUserspaceRouter(logf logger.Logf, wgdev *device.Device, tundev tun.Devic return &winRouter{ logf: logf, - wgdev: wgdev, tunname: tunname, nativeTun: nativeTun, dns: dns.NewManager(mconfig), @@ -112,11 +109,8 @@ func (r *winRouter) Set(cfg *Config) error { } // Flush DNS on router config change to clear cached DNS entries (solves #1430) - out, err := exec.Command("ipconfig", "/flushdns").CombinedOutput() - if err != nil { - r.logf("flushdns error: %v; output: %s", err, out) - } else { - r.logf("flushdns successful") + if err := dns.Flush(); err != nil { + r.logf("flushdns error: %v", err) } return nil diff --git a/wgengine/tsdns/forwarder.go b/wgengine/tsdns/forwarder.go deleted file mode 100644 index 470748c4a..000000000 --- a/wgengine/tsdns/forwarder.go +++ /dev/null @@ -1,474 +0,0 @@ -// Copyright (c) 2020 Tailscale Inc & AUTHORS All rights reserved. -// Use of this source code is governed by a BSD-style -// license that can be found in the LICENSE file. - -package tsdns - -import ( - "bytes" - "context" - "encoding/binary" - "errors" - "fmt" - "hash/crc32" - "math/rand" - "net" - "os" - "sync" - "time" - - "inet.af/netaddr" - "tailscale.com/logtail/backoff" - "tailscale.com/net/netns" - "tailscale.com/types/logger" -) - -// headerBytes is the number of bytes in a DNS message header. -const headerBytes = 12 - -// connCount is the number of UDP connections to use for forwarding. -const connCount = 32 - -const ( - // cleanupInterval is the interval between purged of timed-out entries from txMap. - cleanupInterval = 30 * time.Second - // responseTimeout is the maximal amount of time to wait for a DNS response. - responseTimeout = 5 * time.Second -) - -var errNoUpstreams = errors.New("upstream nameservers not set") - -var aLongTimeAgo = time.Unix(0, 1) - -type forwardingRecord struct { - src netaddr.IPPort - createdAt time.Time -} - -// txid identifies a DNS transaction. -// -// As the standard DNS Request ID is only 16 bits, we extend it: -// the lower 32 bits are the zero-extended bits of the DNS Request ID; -// the upper 32 bits are the CRC32 checksum of the first question in the request. -// This makes probability of txid collision negligible. -type txid uint64 - -// getTxID computes the txid of the given DNS packet. -func getTxID(packet []byte) txid { - if len(packet) < headerBytes { - return 0 - } - - dnsid := binary.BigEndian.Uint16(packet[0:2]) - qcount := binary.BigEndian.Uint16(packet[4:6]) - if qcount == 0 { - return txid(dnsid) - } - - offset := headerBytes - for i := uint16(0); i < qcount; i++ { - // Note: this relies on the fact that names are not compressed in questions, - // so they are guaranteed to end with a NUL byte. - // - // Justification: - // RFC 1035 doesn't seem to explicitly prohibit compressing names in questions, - // but this is exceedingly unlikely to be done in practice. A DNS request - // with multiple questions is ill-defined (which questions do the header flags apply to?) - // and a single question would have to contain a pointer to an *answer*, - // which would be excessively smart, pointless (an answer can just as well refer to the question) - // and perhaps even prohibited: a draft RFC (draft-ietf-dnsind-local-compression-05) states: - // - // > It is important that these pointers always point backwards. - // - // This is said in summarizing RFC 1035, although that phrase does not appear in the original RFC. - // Additionally, (https://cr.yp.to/djbdns/notes.html) states: - // - // > The precise rule is that a name can be compressed if it is a response owner name, - // > the name in NS data, the name in CNAME data, the name in PTR data, the name in MX data, - // > or one of the names in SOA data. - namebytes := bytes.IndexByte(packet[offset:], 0) - // ... | name | NUL | type | class - // ?? 1 2 2 - offset = offset + namebytes + 5 - if len(packet) < offset { - // Corrupt packet; don't crash. - return txid(dnsid) - } - } - - hash := crc32.ChecksumIEEE(packet[headerBytes:offset]) - return (txid(hash) << 32) | txid(dnsid) -} - -// forwarder forwards DNS packets to a number of upstream nameservers. -type forwarder struct { - logf logger.Logf - - // responses is a channel by which responses are returned. - responses chan Packet - // closed signals all goroutines to stop. - closed chan struct{} - // wg signals when all goroutines have stopped. - wg sync.WaitGroup - - // conns are the UDP connections used for forwarding. - // A random one is selected for each request, regardless of the target upstream. - conns []*fwdConn - - mu sync.Mutex - // upstreams are the nameserver addresses that should be used for forwarding. - upstreams []net.Addr - // txMap maps DNS txids to active forwarding records. - txMap map[txid]forwardingRecord -} - -func init() { - rand.Seed(time.Now().UnixNano()) -} - -func newForwarder(logf logger.Logf, responses chan Packet) *forwarder { - return &forwarder{ - logf: logger.WithPrefix(logf, "forward: "), - responses: responses, - closed: make(chan struct{}), - conns: make([]*fwdConn, connCount), - txMap: make(map[txid]forwardingRecord), - } -} - -func (f *forwarder) Start() error { - f.wg.Add(connCount + 1) - for idx := range f.conns { - f.conns[idx] = newFwdConn(f.logf, idx) - go f.recv(f.conns[idx]) - } - go f.cleanMap() - - return nil -} - -func (f *forwarder) Close() { - select { - case <-f.closed: - return - default: - // continue - } - close(f.closed) - - for _, conn := range f.conns { - conn.close() - } - - f.wg.Wait() -} - -func (f *forwarder) rebindFromNetworkChange() { - for _, c := range f.conns { - c.mu.Lock() - c.reconnectLocked() - c.mu.Unlock() - } -} - -func (f *forwarder) setUpstreams(upstreams []net.Addr) { - f.mu.Lock() - f.upstreams = upstreams - f.mu.Unlock() -} - -// send sends packet to dst. It is best effort. -func (f *forwarder) send(packet []byte, dst net.Addr) { - connIdx := rand.Intn(connCount) - conn := f.conns[connIdx] - conn.send(packet, dst) -} - -func (f *forwarder) recv(conn *fwdConn) { - defer f.wg.Done() - - for { - select { - case <-f.closed: - return - default: - } - out := make([]byte, maxResponseBytes) - n := conn.read(out) - if n == 0 { - continue - } - if n < headerBytes { - f.logf("recv: packet too small (%d bytes)", n) - } - - out = out[:n] - txid := getTxID(out) - - f.mu.Lock() - - record, found := f.txMap[txid] - // At most one nameserver will return a response: - // the first one to do so will delete txid from the map. - if !found { - f.mu.Unlock() - continue - } - delete(f.txMap, txid) - - f.mu.Unlock() - - packet := Packet{ - Payload: out, - Addr: record.src, - } - select { - case <-f.closed: - return - case f.responses <- packet: - // continue - } - } -} - -// cleanMap periodically deletes timed-out forwarding records from f.txMap to bound growth. -func (f *forwarder) cleanMap() { - defer f.wg.Done() - - t := time.NewTicker(cleanupInterval) - defer t.Stop() - - var now time.Time - for { - select { - case <-f.closed: - return - case now = <-t.C: - // continue - } - - f.mu.Lock() - for k, v := range f.txMap { - if now.Sub(v.createdAt) > responseTimeout { - delete(f.txMap, k) - } - } - f.mu.Unlock() - } -} - -// forward forwards the query to all upstream nameservers and returns the first response. -func (f *forwarder) forward(query Packet) error { - txid := getTxID(query.Payload) - - f.mu.Lock() - - upstreams := f.upstreams - if len(upstreams) == 0 { - f.mu.Unlock() - return errNoUpstreams - } - f.txMap[txid] = forwardingRecord{ - src: query.Addr, - createdAt: time.Now(), - } - - f.mu.Unlock() - - for _, upstream := range upstreams { - f.send(query.Payload, upstream) - } - - return nil -} - -// A fwdConn manages a single connection used to forward DNS requests. -// Net link changes can cause a *net.UDPConn to become permanently unusable, particularly on macOS. -// fwdConn detects such situations and transparently creates new connections. -type fwdConn struct { - // logf allows a fwdConn to log. - logf logger.Logf - - // wg tracks the number of outstanding conn.Read and conn.Write calls. - wg sync.WaitGroup - // change allows calls to read to block until a the network connection has been replaced. - change *sync.Cond - - // mu protects fields that follow it; it is also change's Locker. - mu sync.Mutex - // closed tracks whether fwdConn has been permanently closed. - closed bool - // conn is the current active connection. - conn net.PacketConn -} - -func newFwdConn(logf logger.Logf, idx int) *fwdConn { - c := new(fwdConn) - c.logf = logger.WithPrefix(logf, fmt.Sprintf("fwdConn %d: ", idx)) - c.change = sync.NewCond(&c.mu) - // c.conn is created lazily in send - return c -} - -// send sends packet to dst using c's connection. -// It is best effort. It is UDP, after all. Failures are logged. -func (c *fwdConn) send(packet []byte, dst net.Addr) { - var b *backoff.Backoff // lazily initialized, since it is not needed in the common case - backOff := func(err error) { - if b == nil { - b = backoff.NewBackoff("tsdns-fwdConn-send", c.logf, 30*time.Second) - } - b.BackOff(context.Background(), err) - } - - for { - // Gather the current connection. - // We can't hold the lock while we call WriteTo. - c.mu.Lock() - conn := c.conn - closed := c.closed - if closed { - c.mu.Unlock() - return - } - if conn == nil { - c.reconnectLocked() - c.mu.Unlock() - continue - } - c.mu.Unlock() - - c.wg.Add(1) - _, err := conn.WriteTo(packet, dst) - c.wg.Done() - if err == nil { - // Success - return - } - if errors.Is(err, os.ErrDeadlineExceeded) { - // We intentionally closed this connection. - // It has been replaced by a new connection. Try again. - continue - } - // Something else went wrong. - // We have three choices here: try again, give up, or create a new connection. - var opErr *net.OpError - if !errors.As(err, &opErr) { - // Weird. All errors from the net package should be *net.OpError. Bail. - c.logf("send: non-*net.OpErr %v (%T)", err, err) - return - } - if opErr.Temporary() || opErr.Timeout() { - // I doubt that either of these can happen (this is UDP), - // but go ahead and try again. - backOff(err) - continue - } - if networkIsDown(err) { - // Fail. - c.logf("send: network is down") - return - } - if networkIsUnreachable(err) { - // This can be caused by a link change. - // Replace the existing connection with a new one. - c.mu.Lock() - // It's possible that multiple senders discovered simultaneously - // that the network is unreachable. Avoid reconnecting multiple times: - // Only reconnect if the current connection is the one that we - // discovered to be problematic. - if c.conn == conn { - backOff(err) - c.reconnectLocked() - } - c.mu.Unlock() - // Try again with our new network connection. - continue - } - // Unrecognized error. Fail. - c.logf("send: unrecognized error: %v", err) - return - } -} - -// read waits for a response from c's connection. -// It returns the number of bytes read, which may be 0 -// in case of an error or a closed connection. -func (c *fwdConn) read(out []byte) int { - for { - // Gather the current connection. - // We can't hold the lock while we call ReadFrom. - c.mu.Lock() - conn := c.conn - closed := c.closed - if closed { - c.mu.Unlock() - return 0 - } - if conn == nil { - // There is no current connection. - // Wait for the connection to change, then try again. - c.change.Wait() - c.mu.Unlock() - continue - } - c.mu.Unlock() - - c.wg.Add(1) - n, _, err := conn.ReadFrom(out) - c.wg.Done() - if err == nil { - // Success. - return n - } - if errors.Is(err, os.ErrDeadlineExceeded) { - // We intentionally closed this connection. - // It has been replaced by a new connection. Try again. - continue - } - - c.logf("read: unrecognized error: %v", err) - return 0 - } -} - -// reconnectLocked replaces the current connection with a new one. -// c.mu must be locked. -func (c *fwdConn) reconnectLocked() { - c.closeConnLocked() - // Make a new connection. - conn, err := netns.Listener().ListenPacket(context.Background(), "udp", "") - if err != nil { - c.logf("ListenPacket failed: %v", err) - } else { - c.conn = conn - } - // Broadcast that a new connection is available. - c.change.Broadcast() -} - -// closeCurrentConn closes the current connection. -// c.mu must be locked. -func (c *fwdConn) closeConnLocked() { - if c.conn == nil { - return - } - // Unblock all readers/writers, wait for them, close the connection. - c.conn.SetDeadline(aLongTimeAgo) - c.wg.Wait() - c.conn.Close() - c.conn = nil -} - -// close permanently closes c. -func (c *fwdConn) close() { - c.mu.Lock() - defer c.mu.Unlock() - if c.closed { - return - } - c.closed = true - c.closeConnLocked() - // Unblock any remaining readers. - c.change.Broadcast() -} diff --git a/wgengine/tsdns/map.go b/wgengine/tsdns/map.go deleted file mode 100644 index c51dbf59b..000000000 --- a/wgengine/tsdns/map.go +++ /dev/null @@ -1,160 +0,0 @@ -// Copyright (c) 2020 Tailscale Inc & AUTHORS All rights reserved. -// Use of this source code is governed by a BSD-style -// license that can be found in the LICENSE file. - -package tsdns - -import ( - "sort" - "strings" - - "inet.af/netaddr" -) - -// Map is all the data Resolver needs to resolve DNS queries within the Tailscale network. -type Map struct { - // nameToIP is a mapping of Tailscale domain names to their IP addresses. - // For example, monitoring.tailscale.us -> 100.64.0.1. - nameToIP map[string]netaddr.IP - // ipToName is the inverse of nameToIP. - ipToName map[netaddr.IP]string - // names are the keys of nameToIP in sorted order. - names []string - // rootDomains are the domains whose subdomains should always - // be resolved locally to prevent leakage of sensitive names. - rootDomains []string // e.g. "user.provider.beta.tailscale.net." -} - -// NewMap returns a new Map with name to address mapping given by nameToIP. -// -// rootDomains are the domains whose subdomains should always be -// resolved locally to prevent leakage of sensitive names. They should -// end in a period ("user-foo.tailscale.net."). -func NewMap(initNameToIP map[string]netaddr.IP, rootDomains []string) *Map { - // TODO(dmytro): we have to allocate names and ipToName, but nameToIP can be avoided. - // It is here because control sends us names not in canonical form. Change this. - names := make([]string, 0, len(initNameToIP)) - nameToIP := make(map[string]netaddr.IP, len(initNameToIP)) - ipToName := make(map[netaddr.IP]string, len(initNameToIP)) - - for name, ip := range initNameToIP { - if len(name) == 0 { - // Nothing useful can be done with empty names. - continue - } - if name[len(name)-1] != '.' { - name += "." - } - names = append(names, name) - nameToIP[name] = ip - ipToName[ip] = name - } - sort.Strings(names) - - return &Map{ - nameToIP: nameToIP, - ipToName: ipToName, - names: names, - - rootDomains: rootDomains, - } -} - -func printSingleNameIP(buf *strings.Builder, name string, ip netaddr.IP) { - buf.WriteString(name) - buf.WriteByte('\t') - buf.WriteString(ip.String()) - buf.WriteByte('\n') -} - -func (m *Map) Pretty() string { - buf := new(strings.Builder) - for _, name := range m.names { - printSingleNameIP(buf, name, m.nameToIP[name]) - } - return buf.String() -} - -func (m *Map) PrettyDiffFrom(old *Map) string { - var ( - oldNameToIP map[string]netaddr.IP - newNameToIP map[string]netaddr.IP - oldNames []string - newNames []string - ) - if old != nil { - oldNameToIP = old.nameToIP - oldNames = old.names - } - if m != nil { - newNameToIP = m.nameToIP - newNames = m.names - } - - buf := new(strings.Builder) - space := func() bool { - return buf.Len() < (1 << 10) - } - - for len(oldNames) > 0 && len(newNames) > 0 { - var name string - - newName, oldName := newNames[0], oldNames[0] - switch { - case oldName < newName: - name = oldName - oldNames = oldNames[1:] - case oldName > newName: - name = newName - newNames = newNames[1:] - case oldNames[0] == newNames[0]: - name = oldNames[0] - oldNames = oldNames[1:] - newNames = newNames[1:] - } - if !space() { - continue - } - - ipOld, inOld := oldNameToIP[name] - ipNew, inNew := newNameToIP[name] - switch { - case !inOld: - buf.WriteByte('+') - printSingleNameIP(buf, name, ipNew) - case !inNew: - buf.WriteByte('-') - printSingleNameIP(buf, name, ipOld) - case ipOld != ipNew: - buf.WriteByte('-') - printSingleNameIP(buf, name, ipOld) - buf.WriteByte('+') - printSingleNameIP(buf, name, ipNew) - } - } - - for _, name := range oldNames { - if !space() { - break - } - if _, ok := newNameToIP[name]; !ok { - buf.WriteByte('-') - printSingleNameIP(buf, name, oldNameToIP[name]) - } - } - - for _, name := range newNames { - if !space() { - break - } - if _, ok := oldNameToIP[name]; !ok { - buf.WriteByte('+') - printSingleNameIP(buf, name, newNameToIP[name]) - } - } - if !space() { - buf.WriteString("... [truncated]\n") - } - - return buf.String() -} diff --git a/wgengine/tsdns/map_test.go b/wgengine/tsdns/map_test.go deleted file mode 100644 index dba9bb586..000000000 --- a/wgengine/tsdns/map_test.go +++ /dev/null @@ -1,156 +0,0 @@ -// Copyright (c) 2020 Tailscale Inc & AUTHORS All rights reserved. -// Use of this source code is governed by a BSD-style -// license that can be found in the LICENSE file. - -package tsdns - -import ( - "fmt" - "strings" - "testing" - - "inet.af/netaddr" -) - -func TestPretty(t *testing.T) { - tests := []struct { - name string - dmap *Map - want string - }{ - {"empty", NewMap(nil, nil), ""}, - { - "single", - NewMap(map[string]netaddr.IP{ - "hello.ipn.dev.": netaddr.IPv4(100, 101, 102, 103), - }, nil), - "hello.ipn.dev.\t100.101.102.103\n", - }, - { - "multiple", - NewMap(map[string]netaddr.IP{ - "test1.domain.": netaddr.IPv4(100, 101, 102, 103), - "test2.sub.domain.": netaddr.IPv4(100, 99, 9, 1), - }, nil), - "test1.domain.\t100.101.102.103\ntest2.sub.domain.\t100.99.9.1\n", - }, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - got := tt.dmap.Pretty() - if tt.want != got { - t.Errorf("want %v; got %v", tt.want, got) - } - }) - } -} - -func TestPrettyDiffFrom(t *testing.T) { - tests := []struct { - name string - map1 *Map - map2 *Map - want string - }{ - { - "from_empty", - nil, - NewMap(map[string]netaddr.IP{ - "test1.ipn.dev.": netaddr.IPv4(100, 101, 102, 103), - "test2.ipn.dev.": netaddr.IPv4(100, 103, 102, 101), - }, nil), - "+test1.ipn.dev.\t100.101.102.103\n+test2.ipn.dev.\t100.103.102.101\n", - }, - { - "equal", - NewMap(map[string]netaddr.IP{ - "test1.ipn.dev.": netaddr.IPv4(100, 101, 102, 103), - "test2.ipn.dev.": netaddr.IPv4(100, 103, 102, 101), - }, nil), - NewMap(map[string]netaddr.IP{ - "test2.ipn.dev.": netaddr.IPv4(100, 103, 102, 101), - "test1.ipn.dev.": netaddr.IPv4(100, 101, 102, 103), - }, nil), - "", - }, - { - "changed_ip", - NewMap(map[string]netaddr.IP{ - "test1.ipn.dev.": netaddr.IPv4(100, 101, 102, 103), - "test2.ipn.dev.": netaddr.IPv4(100, 103, 102, 101), - }, nil), - NewMap(map[string]netaddr.IP{ - "test2.ipn.dev.": netaddr.IPv4(100, 104, 102, 101), - "test1.ipn.dev.": netaddr.IPv4(100, 101, 102, 103), - }, nil), - "-test2.ipn.dev.\t100.103.102.101\n+test2.ipn.dev.\t100.104.102.101\n", - }, - { - "new_domain", - NewMap(map[string]netaddr.IP{ - "test1.ipn.dev.": netaddr.IPv4(100, 101, 102, 103), - "test2.ipn.dev.": netaddr.IPv4(100, 103, 102, 101), - }, nil), - NewMap(map[string]netaddr.IP{ - "test3.ipn.dev.": netaddr.IPv4(100, 105, 106, 107), - "test2.ipn.dev.": netaddr.IPv4(100, 103, 102, 101), - "test1.ipn.dev.": netaddr.IPv4(100, 101, 102, 103), - }, nil), - "+test3.ipn.dev.\t100.105.106.107\n", - }, - { - "gone_domain", - NewMap(map[string]netaddr.IP{ - "test1.ipn.dev.": netaddr.IPv4(100, 101, 102, 103), - "test2.ipn.dev.": netaddr.IPv4(100, 103, 102, 101), - }, nil), - NewMap(map[string]netaddr.IP{ - "test1.ipn.dev.": netaddr.IPv4(100, 101, 102, 103), - }, nil), - "-test2.ipn.dev.\t100.103.102.101\n", - }, - { - "mixed", - NewMap(map[string]netaddr.IP{ - "test1.ipn.dev.": netaddr.IPv4(100, 101, 102, 103), - "test4.ipn.dev.": netaddr.IPv4(100, 107, 106, 105), - "test5.ipn.dev.": netaddr.IPv4(100, 64, 1, 1), - "test2.ipn.dev.": netaddr.IPv4(100, 103, 102, 101), - }, nil), - NewMap(map[string]netaddr.IP{ - "test2.ipn.dev.": netaddr.IPv4(100, 104, 102, 101), - "test1.ipn.dev.": netaddr.IPv4(100, 100, 101, 102), - "test3.ipn.dev.": netaddr.IPv4(100, 64, 1, 1), - }, nil), - "-test1.ipn.dev.\t100.101.102.103\n+test1.ipn.dev.\t100.100.101.102\n" + - "-test2.ipn.dev.\t100.103.102.101\n+test2.ipn.dev.\t100.104.102.101\n" + - "+test3.ipn.dev.\t100.64.1.1\n-test4.ipn.dev.\t100.107.106.105\n-test5.ipn.dev.\t100.64.1.1\n", - }, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - got := tt.map2.PrettyDiffFrom(tt.map1) - if tt.want != got { - t.Errorf("want %v; got %v", tt.want, got) - } - }) - } - - t.Run("truncated", func(t *testing.T) { - small := NewMap(nil, nil) - m := map[string]netaddr.IP{} - for i := 0; i < 5000; i++ { - m[fmt.Sprintf("host%d.ipn.dev.", i)] = netaddr.IPv4(100, 64, 1, 1) - } - veryBig := NewMap(m, nil) - diff := veryBig.PrettyDiffFrom(small) - if len(diff) > 3<<10 { - t.Errorf("pretty diff too large: %d bytes", len(diff)) - } - if !strings.Contains(diff, "truncated") { - t.Errorf("big diff not truncated") - } - }) -} diff --git a/wgengine/tsdns/neterr_darwin.go b/wgengine/tsdns/neterr_darwin.go deleted file mode 100644 index 62bab6488..000000000 --- a/wgengine/tsdns/neterr_darwin.go +++ /dev/null @@ -1,25 +0,0 @@ -// Copyright (c) 2020 Tailscale Inc & AUTHORS All rights reserved. -// Use of this source code is governed by a BSD-style -// license that can be found in the LICENSE file. - -package tsdns - -import ( - "errors" - "syscall" -) - -// Avoid allocation when calling errors.Is below -// by converting syscall.Errno to error here. -var ( - networkDown error = syscall.ENETDOWN - networkUnreachable error = syscall.ENETUNREACH -) - -func networkIsDown(err error) bool { - return errors.Is(err, networkDown) -} - -func networkIsUnreachable(err error) bool { - return errors.Is(err, networkUnreachable) -} diff --git a/wgengine/tsdns/neterr_other.go b/wgengine/tsdns/neterr_other.go deleted file mode 100644 index d245d0c38..000000000 --- a/wgengine/tsdns/neterr_other.go +++ /dev/null @@ -1,10 +0,0 @@ -// Copyright (c) 2020 Tailscale Inc & AUTHORS All rights reserved. -// Use of this source code is governed by a BSD-style -// license that can be found in the LICENSE file. - -// +build !darwin,!windows - -package tsdns - -func networkIsDown(err error) bool { return false } -func networkIsUnreachable(err error) bool { return false } diff --git a/wgengine/tsdns/neterr_windows.go b/wgengine/tsdns/neterr_windows.go deleted file mode 100644 index 90f0db2ab..000000000 --- a/wgengine/tsdns/neterr_windows.go +++ /dev/null @@ -1,29 +0,0 @@ -// Copyright (c) 2020 Tailscale Inc & AUTHORS All rights reserved. -// Use of this source code is governed by a BSD-style -// license that can be found in the LICENSE file. - -package tsdns - -import ( - "net" - "os" - - "golang.org/x/sys/windows" -) - -func networkIsDown(err error) bool { - if oe, ok := err.(*net.OpError); ok && oe.Op == "write" { - if se, ok := oe.Err.(*os.SyscallError); ok { - if se.Syscall == "wsasendto" && se.Err == windows.WSAENETUNREACH { - return true - } - } - } - return false -} - -func networkIsUnreachable(err error) bool { - // TODO(bradfitz,josharian): something here? what is the - // difference between down and unreachable? Add comments. - return false -} diff --git a/wgengine/tsdns/tsdns.go b/wgengine/tsdns/tsdns.go deleted file mode 100644 index b68b8c04e..000000000 --- a/wgengine/tsdns/tsdns.go +++ /dev/null @@ -1,662 +0,0 @@ -// Copyright (c) 2020 Tailscale Inc & AUTHORS All rights reserved. -// Use of this source code is governed by a BSD-style -// license that can be found in the LICENSE file. - -// Package tsdns provides a Resolver capable of resolving -// domains on a Tailscale network. -package tsdns - -import ( - "encoding/hex" - "errors" - "net" - "strings" - "sync" - "time" - - dns "golang.org/x/net/dns/dnsmessage" - "inet.af/netaddr" - "tailscale.com/net/interfaces" - "tailscale.com/types/logger" - "tailscale.com/util/dnsname" - "tailscale.com/wgengine/monitor" -) - -// maxResponseBytes is the maximum size of a response from a Resolver. -const maxResponseBytes = 512 - -// queueSize is the maximal number of DNS requests that can await polling. -// If EnqueueRequest is called when this many requests are already pending, -// the request will be dropped to avoid blocking the caller. -const queueSize = 64 - -// defaultTTL is the TTL of all responses from Resolver. -const defaultTTL = 600 * time.Second - -// ErrClosed indicates that the resolver has been closed and readers should exit. -var ErrClosed = errors.New("closed") - -var ( - errFullQueue = errors.New("request queue full") - errMapNotSet = errors.New("domain map not set") - errNotForwarding = errors.New("forwarding disabled") - errNotImplemented = errors.New("query type not implemented") - errNotQuery = errors.New("not a DNS query") - errNotOurName = errors.New("not a Tailscale DNS name") -) - -// Packet represents a DNS payload together with the address of its origin. -type Packet struct { - // Payload is the application layer DNS payload. - // Resolver assumes ownership of the request payload when it is enqueued - // and cedes ownership of the response payload when it is returned from NextResponse. - Payload []byte - // Addr is the source address for a request and the destination address for a response. - Addr netaddr.IPPort -} - -// Resolver is a DNS resolver for nodes on the Tailscale network, -// associating them with domain names of the form <mynode>.<mydomain>.<root>. -// If it is asked to resolve a domain that is not of that form, -// it delegates to upstream nameservers if any are set. -type Resolver struct { - logf logger.Logf - linkMon *monitor.Mon // or nil - unregLinkMon func() // or nil - // forwarder forwards requests to upstream nameservers. - forwarder *forwarder - - // queue is a buffered channel holding DNS requests queued for resolution. - queue chan Packet - // responses is an unbuffered channel to which responses are returned. - responses chan Packet - // errors is an unbuffered channel to which errors are returned. - errors chan error - // closed signals all goroutines to stop. - closed chan struct{} - // wg signals when all goroutines have stopped. - wg sync.WaitGroup - - // mu guards the following fields from being updated while used. - mu sync.Mutex - // dnsMap is the map most recently received from the control server. - dnsMap *Map -} - -// ResolverConfig is the set of configuration options for a Resolver. -type ResolverConfig struct { - // Logf is the logger to use throughout the Resolver. - Logf logger.Logf - // Forward determines whether the resolver will forward packets to - // nameservers set with SetUpstreams if the domain name is not of a Tailscale node. - Forward bool - // LinkMonitor optionally provides a link monitor to use to rebind - // connections on link changes. - // If nil, rebinds are not performend. - LinkMonitor *monitor.Mon -} - -// NewResolver constructs a resolver associated with the given root domain. -// The root domain must be in canonical form (with a trailing period). -func NewResolver(config ResolverConfig) *Resolver { - r := &Resolver{ - logf: logger.WithPrefix(config.Logf, "tsdns: "), - linkMon: config.LinkMonitor, - queue: make(chan Packet, queueSize), - responses: make(chan Packet), - errors: make(chan error), - closed: make(chan struct{}), - } - - if config.Forward { - r.forwarder = newForwarder(r.logf, r.responses) - } - if r.linkMon != nil { - r.unregLinkMon = r.linkMon.RegisterChangeCallback(r.onLinkMonitorChange) - } - - return r -} - -func (r *Resolver) Start() error { - if r.forwarder != nil { - if err := r.forwarder.Start(); err != nil { - return err - } - } - - r.wg.Add(1) - go r.poll() - - return nil -} - -// Close shuts down the resolver and ensures poll goroutines have exited. -// The Resolver cannot be used again after Close is called. -func (r *Resolver) Close() { - select { - case <-r.closed: - return - default: - // continue - } - close(r.closed) - - if r.unregLinkMon != nil { - r.unregLinkMon() - } - - if r.forwarder != nil { - r.forwarder.Close() - } - - r.wg.Wait() -} - -func (r *Resolver) onLinkMonitorChange(changed bool, state *interfaces.State) { - if !changed { - return - } - if r.forwarder != nil { - r.forwarder.rebindFromNetworkChange() - } -} - -// SetMap sets the resolver's DNS map, taking ownership of it. -func (r *Resolver) SetMap(m *Map) { - r.mu.Lock() - oldMap := r.dnsMap - r.dnsMap = m - r.mu.Unlock() - r.logf("map diff:\n%s", m.PrettyDiffFrom(oldMap)) -} - -// SetUpstreams sets the addresses of the resolver's -// upstream nameservers, taking ownership of the argument. -func (r *Resolver) SetUpstreams(upstreams []net.Addr) { - if r.forwarder != nil { - r.forwarder.setUpstreams(upstreams) - } - r.logf("set upstreams: %v", upstreams) -} - -// EnqueueRequest places the given DNS request in the resolver's queue. -// It takes ownership of the payload and does not block. -// If the queue is full, the request will be dropped and an error will be returned. -func (r *Resolver) EnqueueRequest(request Packet) error { - select { - case <-r.closed: - return ErrClosed - case r.queue <- request: - return nil - default: - return errFullQueue - } -} - -// NextResponse returns a DNS response to a previously enqueued request. -// It blocks until a response is available and gives up ownership of the response payload. -func (r *Resolver) NextResponse() (Packet, error) { - select { - case <-r.closed: - return Packet{}, ErrClosed - case resp := <-r.responses: - return resp, nil - case err := <-r.errors: - return Packet{}, err - } -} - -// Resolve maps a given domain name to the IP address of the host that owns it, -// if the IP address conforms to the DNS resource type given by tp (one of A, AAAA, ALL). -// The domain name must be in canonical form (with a trailing period). -func (r *Resolver) Resolve(domain string, tp dns.Type) (netaddr.IP, dns.RCode, error) { - r.mu.Lock() - dnsMap := r.dnsMap - r.mu.Unlock() - - if dnsMap == nil { - return netaddr.IP{}, dns.RCodeServerFailure, errMapNotSet - } - - // Reject .onion domains per RFC 7686. - if dnsname.HasSuffix(domain, ".onion") { - return netaddr.IP{}, dns.RCodeNameError, nil - } - - anyHasSuffix := false - for _, suffix := range dnsMap.rootDomains { - if dnsname.HasSuffix(domain, suffix) { - anyHasSuffix = true - break - } - } - addr, found := dnsMap.nameToIP[domain] - if !found { - if !anyHasSuffix { - return netaddr.IP{}, dns.RCodeRefused, nil - } - return netaddr.IP{}, dns.RCodeNameError, nil - } - - // Refactoring note: this must happen after we check suffixes, - // otherwise we will respond with NOTIMP to requests that should be forwarded. - switch tp { - case dns.TypeA: - if !addr.Is4() { - return netaddr.IP{}, dns.RCodeSuccess, nil - } - return addr, dns.RCodeSuccess, nil - case dns.TypeAAAA: - if !addr.Is6() { - return netaddr.IP{}, dns.RCodeSuccess, nil - } - return addr, dns.RCodeSuccess, nil - case dns.TypeALL: - // Answer with whatever we've got. - // It could be IPv4, IPv6, or a zero addr. - // TODO: Return all available resolutions (A and AAAA, if we have them). - return addr, dns.RCodeSuccess, nil - - // Leave some some record types explicitly unimplemented. - // These types relate to recursive resolution or special - // DNS sematics and might be implemented in the future. - case dns.TypeNS, dns.TypeSOA, dns.TypeAXFR, dns.TypeHINFO: - return netaddr.IP{}, dns.RCodeNotImplemented, errNotImplemented - - // For everything except for the few types above that are explictly not implemented, return no records. - // This is what other DNS systems do: always return NOERROR - // without any records whenever the requested record type is unknown. - // You can try this with: - // dig -t TYPE9824 example.com - // and note that NOERROR is returned, despite that record type being made up. - default: - // no records exist of this type - return netaddr.IP{}, dns.RCodeSuccess, nil - } -} - -// ResolveReverse returns the unique domain name that maps to the given address. -// The returned domain name is in canonical form (with a trailing period). -func (r *Resolver) ResolveReverse(ip netaddr.IP) (string, dns.RCode, error) { - r.mu.Lock() - dnsMap := r.dnsMap - r.mu.Unlock() - - if dnsMap == nil { - return "", dns.RCodeServerFailure, errMapNotSet - } - name, found := dnsMap.ipToName[ip] - if !found { - return "", dns.RCodeNameError, nil - } - return name, dns.RCodeSuccess, nil -} - -func (r *Resolver) poll() { - defer r.wg.Done() - - var packet Packet - for { - select { - case <-r.closed: - return - case packet = <-r.queue: - // continue - } - - out, err := r.respond(packet.Payload) - - if err == errNotOurName { - if r.forwarder != nil { - err = r.forwarder.forward(packet) - if err == nil { - // forward will send response into r.responses, nothing to do. - continue - } - } else { - err = errNotForwarding - } - } - - if err != nil { - select { - case <-r.closed: - return - case r.errors <- err: - // continue - } - } else { - packet.Payload = out - select { - case <-r.closed: - return - case r.responses <- packet: - // continue - } - } - } -} - -type response struct { - Header dns.Header - Question dns.Question - // Name is the response to a PTR query. - Name string - // IP is the response to an A, AAAA, or ALL query. - IP netaddr.IP -} - -// parseQuery parses the query in given packet into a response struct. -func parseQuery(query []byte, resp *response) error { - var parser dns.Parser - var err error - - resp.Header, err = parser.Start(query) - if err != nil { - return err - } - - if resp.Header.Response { - return errNotQuery - } - - resp.Question, err = parser.Question() - if err != nil { - return err - } - - return nil -} - -// marshalARecord serializes an A record into an active builder. -// The caller may continue using the builder following the call. -func marshalARecord(name dns.Name, ip netaddr.IP, builder *dns.Builder) error { - var answer dns.AResource - - answerHeader := dns.ResourceHeader{ - Name: name, - Type: dns.TypeA, - Class: dns.ClassINET, - TTL: uint32(defaultTTL / time.Second), - } - ipbytes := ip.As4() - copy(answer.A[:], ipbytes[:]) - return builder.AResource(answerHeader, answer) -} - -// marshalAAAARecord serializes an AAAA record into an active builder. -// The caller may continue using the builder following the call. -func marshalAAAARecord(name dns.Name, ip netaddr.IP, builder *dns.Builder) error { - var answer dns.AAAAResource - - answerHeader := dns.ResourceHeader{ - Name: name, - Type: dns.TypeAAAA, - Class: dns.ClassINET, - TTL: uint32(defaultTTL / time.Second), - } - ipbytes := ip.As16() - copy(answer.AAAA[:], ipbytes[:]) - return builder.AAAAResource(answerHeader, answer) -} - -// marshalPTRRecord serializes a PTR record into an active builder. -// The caller may continue using the builder following the call. -func marshalPTRRecord(queryName dns.Name, name string, builder *dns.Builder) error { - var answer dns.PTRResource - var err error - - answerHeader := dns.ResourceHeader{ - Name: queryName, - Type: dns.TypePTR, - Class: dns.ClassINET, - TTL: uint32(defaultTTL / time.Second), - } - answer.PTR, err = dns.NewName(name) - if err != nil { - return err - } - return builder.PTRResource(answerHeader, answer) -} - -// marshalResponse serializes the DNS response into a new buffer. -func marshalResponse(resp *response) ([]byte, error) { - resp.Header.Response = true - resp.Header.Authoritative = true - if resp.Header.RecursionDesired { - resp.Header.RecursionAvailable = true - } - - builder := dns.NewBuilder(nil, resp.Header) - - isSuccess := resp.Header.RCode == dns.RCodeSuccess - - if resp.Question.Type != 0 || isSuccess { - err := builder.StartQuestions() - if err != nil { - return nil, err - } - - err = builder.Question(resp.Question) - if err != nil { - return nil, err - } - } - - // Only successful responses contain answers. - if !isSuccess { - return builder.Finish() - } - - err := builder.StartAnswers() - if err != nil { - return nil, err - } - - switch resp.Question.Type { - case dns.TypeA, dns.TypeAAAA, dns.TypeALL: - if resp.IP.Is4() { - err = marshalARecord(resp.Question.Name, resp.IP, &builder) - } else if resp.IP.Is6() { - err = marshalAAAARecord(resp.Question.Name, resp.IP, &builder) - } - case dns.TypePTR: - err = marshalPTRRecord(resp.Question.Name, resp.Name, &builder) - } - if err != nil { - return nil, err - } - - return builder.Finish() -} - -const ( - rdnsv4Suffix = ".in-addr.arpa." - rdnsv6Suffix = ".ip6.arpa." -) - -// hasRDNSBonjourPrefix reports whether name has a Bonjour Service Prefix.. -// -// https://tools.ietf.org/html/rfc6763 lists -// "five special RR names" for Bonjour service discovery: -// -// b._dns-sd._udp.<domain>. -// db._dns-sd._udp.<domain>. -// r._dns-sd._udp.<domain>. -// dr._dns-sd._udp.<domain>. -// lb._dns-sd._udp.<domain>. -func hasRDNSBonjourPrefix(s string) bool { - // Even the shortest name containing a Bonjour prefix is long, - // so check length (cheap) and bail early if possible. - if len(s) < len("*._dns-sd._udp.0.0.0.0.in-addr.arpa.") { - return false - } - dot := strings.IndexByte(s, '.') - if dot == -1 { - return false // shouldn't happen - } - switch s[:dot] { - case "b", "db", "r", "dr", "lb": - default: - return false - } - - return strings.HasPrefix(s[dot:], "._dns-sd._udp.") -} - -// rawNameToLower converts a raw DNS name to a string, lowercasing it. -func rawNameToLower(name []byte) string { - var sb strings.Builder - sb.Grow(len(name)) - - for _, b := range name { - if 'A' <= b && b <= 'Z' { - b = b - 'A' + 'a' - } - sb.WriteByte(b) - } - - return sb.String() -} - -// ptrNameToIPv4 transforms a PTR name representing an IPv4 address to said address. -// Such names are IPv4 labels in reverse order followed by .in-addr.arpa. -// For example, -// 4.3.2.1.in-addr.arpa -// is transformed to -// 1.2.3.4 -func rdnsNameToIPv4(name string) (ip netaddr.IP, ok bool) { - name = strings.TrimSuffix(name, rdnsv4Suffix) - ip, err := netaddr.ParseIP(string(name)) - if err != nil { - return netaddr.IP{}, false - } - if !ip.Is4() { - return netaddr.IP{}, false - } - b := ip.As4() - return netaddr.IPv4(b[3], b[2], b[1], b[0]), true -} - -// ptrNameToIPv6 transforms a PTR name representing an IPv6 address to said address. -// Such names are dot-separated nibbles in reverse order followed by .ip6.arpa. -// For example, -// b.a.9.8.7.6.5.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.8.b.d.0.1.0.0.2.ip6.arpa. -// is transformed to -// 2001:db8::567:89ab -func rdnsNameToIPv6(name string) (ip netaddr.IP, ok bool) { - var b [32]byte - var ipb [16]byte - - name = strings.TrimSuffix(name, rdnsv6Suffix) - // 32 nibbles and 31 dots between them. - if len(name) != 63 { - return netaddr.IP{}, false - } - - // Dots and hex digits alternate. - prevDot := true - // i ranges over name backward; j ranges over b forward. - for i, j := len(name)-1, 0; i >= 0; i-- { - thisDot := (name[i] == '.') - if prevDot == thisDot { - return netaddr.IP{}, false - } - prevDot = thisDot - - if !thisDot { - // This is safe assuming alternation. - // We do not check that non-dots are hex digits: hex.Decode below will do that. - b[j] = name[i] - j++ - } - } - - _, err := hex.Decode(ipb[:], b[:]) - if err != nil { - return netaddr.IP{}, false - } - - return netaddr.IPFrom16(ipb), true -} - -// respondReverse returns a DNS response to a PTR query. -// It is assumed that resp.Question is populated by respond before this is called. -func (r *Resolver) respondReverse(query []byte, name string, resp *response) ([]byte, error) { - if hasRDNSBonjourPrefix(name) { - return nil, errNotOurName - } - - var ip netaddr.IP - var ok bool - switch { - case strings.HasSuffix(name, rdnsv4Suffix): - ip, ok = rdnsNameToIPv4(name) - case strings.HasSuffix(name, rdnsv6Suffix): - ip, ok = rdnsNameToIPv6(name) - default: - return nil, errNotOurName - } - - // It is more likely that we failed in parsing the name than that it is actually malformed. - // To avoid frustrating users, just log and delegate. - if !ok { - r.logf("parsing rdns: malformed name: %s", name) - return nil, errNotOurName - } - - var err error - resp.Name, resp.Header.RCode, err = r.ResolveReverse(ip) - if err != nil { - r.logf("resolving rdns: %v", ip, err) - } - if resp.Header.RCode == dns.RCodeNameError { - return nil, errNotOurName - } - - return marshalResponse(resp) -} - -// respond returns a DNS response to query if it can be resolved locally. -// Otherwise, it returns errNotOurName. -func (r *Resolver) respond(query []byte) ([]byte, error) { - resp := new(response) - - // ParseQuery is sufficiently fast to run on every DNS packet. - // This is considerably simpler than extracting the name by hand - // to shave off microseconds in case of delegation. - err := parseQuery(query, resp) - // We will not return this error: it is the sender's fault. - if err != nil { - if errors.Is(err, dns.ErrSectionDone) { - r.logf("parseQuery(%02x): no DNS questions", query) - } else { - r.logf("parseQuery(%02x): %v", query, err) - } - resp.Header.RCode = dns.RCodeFormatError - return marshalResponse(resp) - } - rawName := resp.Question.Name.Data[:resp.Question.Name.Length] - name := rawNameToLower(rawName) - - // Always try to handle reverse lookups; delegate inside when not found. - // This way, queries for existent nodes do not leak, - // but we behave gracefully if non-Tailscale nodes exist in CGNATRange. - if resp.Question.Type == dns.TypePTR { - return r.respondReverse(query, name, resp) - } - - resp.IP, resp.Header.RCode, err = r.Resolve(name, resp.Question.Type) - // This return code is special: it requests forwarding. - if resp.Header.RCode == dns.RCodeRefused { - return nil, errNotOurName - } - - // We will not return this error: it is the sender's fault. - if err != nil { - r.logf("resolving: %v", err) - } - - return marshalResponse(resp) -} diff --git a/wgengine/tsdns/tsdns_server_test.go b/wgengine/tsdns/tsdns_server_test.go deleted file mode 100644 index df9047fc6..000000000 --- a/wgengine/tsdns/tsdns_server_test.go +++ /dev/null @@ -1,95 +0,0 @@ -// Copyright (c) 2020 Tailscale Inc & AUTHORS All rights reserved. -// Use of this source code is governed by a BSD-style -// license that can be found in the LICENSE file. - -package tsdns - -import ( - "log" - "testing" - - "github.com/miekg/dns" - "inet.af/netaddr" -) - -// This file exists to isolate the test infrastructure -// that depends on github.com/miekg/dns -// from the rest, which only depends on dnsmessage. - -var dnsHandleFunc = dns.HandleFunc - -// resolveToIP returns a handler function which responds -// to queries of type A it receives with an A record containing ipv4, -// to queries of type AAAA with an AAAA record containing ipv6, -// to queries of type NS with an NS record containg name. -func resolveToIP(ipv4, ipv6 netaddr.IP, ns string) dns.HandlerFunc { - return func(w dns.ResponseWriter, req *dns.Msg) { - m := new(dns.Msg) - m.SetReply(req) - - if len(req.Question) != 1 { - panic("not a single-question request") - } - question := req.Question[0] - - var ans dns.RR - switch question.Qtype { - case dns.TypeA: - ans = &dns.A{ - Hdr: dns.RR_Header{ - Name: question.Name, - Rrtype: dns.TypeA, - Class: dns.ClassINET, - }, - A: ipv4.IPAddr().IP, - } - case dns.TypeAAAA: - ans = &dns.AAAA{ - Hdr: dns.RR_Header{ - Name: question.Name, - Rrtype: dns.TypeAAAA, - Class: dns.ClassINET, - }, - AAAA: ipv6.IPAddr().IP, - } - case dns.TypeNS: - ans = &dns.NS{ - Hdr: dns.RR_Header{ - Name: question.Name, - Rrtype: dns.TypeNS, - Class: dns.ClassINET, - }, - Ns: ns, - } - } - - m.Answer = append(m.Answer, ans) - w.WriteMsg(m) - } -} - -func resolveToNXDOMAIN(w dns.ResponseWriter, req *dns.Msg) { - m := new(dns.Msg) - m.SetRcode(req, dns.RcodeNameError) - w.WriteMsg(m) -} - -func serveDNS(tb testing.TB, addr string) (*dns.Server, chan error) { - server := &dns.Server{Addr: addr, Net: "udp"} - - waitch := make(chan struct{}) - server.NotifyStartedFunc = func() { close(waitch) } - - errch := make(chan error, 1) - go func() { - err := server.ListenAndServe() - if err != nil { - log.Printf("ListenAndServe(%q): %v", addr, err) - } - errch <- err - close(errch) - }() - - <-waitch - return server, errch -} diff --git a/wgengine/tsdns/tsdns_test.go b/wgengine/tsdns/tsdns_test.go deleted file mode 100644 index 66a62d107..000000000 --- a/wgengine/tsdns/tsdns_test.go +++ /dev/null @@ -1,816 +0,0 @@ -// Copyright (c) 2020 Tailscale Inc & AUTHORS All rights reserved. -// Use of this source code is governed by a BSD-style -// license that can be found in the LICENSE file. - -package tsdns - -import ( - "bytes" - "errors" - "net" - "sync" - "testing" - - dns "golang.org/x/net/dns/dnsmessage" - "inet.af/netaddr" - "tailscale.com/tstest" -) - -var testipv4 = netaddr.IPv4(1, 2, 3, 4) -var testipv6 = netaddr.IPv6Raw([16]byte{ - 0x00, 0x01, 0x02, 0x03, - 0x04, 0x05, 0x06, 0x07, - 0x08, 0x09, 0x0a, 0x0b, - 0x0c, 0x0d, 0x0e, 0x0f, -}) - -var dnsMap = NewMap( - map[string]netaddr.IP{ - "test1.ipn.dev.": testipv4, - "test2.ipn.dev.": testipv6, - }, - []string{"ipn.dev."}, -) - -func dnspacket(domain string, tp dns.Type) []byte { - var dnsHeader dns.Header - question := dns.Question{ - Name: dns.MustNewName(domain), - Type: tp, - Class: dns.ClassINET, - } - - builder := dns.NewBuilder(nil, dnsHeader) - builder.StartQuestions() - builder.Question(question) - payload, _ := builder.Finish() - - return payload -} - -type dnsResponse struct { - ip netaddr.IP - name string - rcode dns.RCode -} - -func unpackResponse(payload []byte) (dnsResponse, error) { - var response dnsResponse - var parser dns.Parser - - h, err := parser.Start(payload) - if err != nil { - return response, err - } - - if !h.Response { - return response, errors.New("not a response") - } - - response.rcode = h.RCode - if response.rcode != dns.RCodeSuccess { - return response, nil - } - - err = parser.SkipAllQuestions() - if err != nil { - return response, err - } - - ah, err := parser.AnswerHeader() - if err != nil { - return response, err - } - - switch ah.Type { - case dns.TypeA: - res, err := parser.AResource() - if err != nil { - return response, err - } - response.ip = netaddr.IPv4(res.A[0], res.A[1], res.A[2], res.A[3]) - case dns.TypeAAAA: - res, err := parser.AAAAResource() - if err != nil { - return response, err - } - response.ip = netaddr.IPv6Raw(res.AAAA) - case dns.TypeNS: - res, err := parser.NSResource() - if err != nil { - return response, err - } - response.name = res.NS.String() - default: - return response, errors.New("type not in {A, AAAA, NS}") - } - - return response, nil -} - -func syncRespond(r *Resolver, query []byte) ([]byte, error) { - request := Packet{Payload: query} - r.EnqueueRequest(request) - resp, err := r.NextResponse() - return resp.Payload, err -} - -func mustIP(str string) netaddr.IP { - ip, err := netaddr.ParseIP(str) - if err != nil { - panic(err) - } - return ip -} - -func TestRDNSNameToIPv4(t *testing.T) { - tests := []struct { - name string - input string - wantIP netaddr.IP - wantOK bool - }{ - {"valid", "4.123.24.1.in-addr.arpa.", netaddr.IPv4(1, 24, 123, 4), true}, - {"double_dot", "1..2.3.in-addr.arpa.", netaddr.IP{}, false}, - {"overflow", "1.256.3.4.in-addr.arpa.", netaddr.IP{}, false}, - {"not_ip", "sub.do.ma.in.in-addr.arpa.", netaddr.IP{}, false}, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - ip, ok := rdnsNameToIPv4(tt.input) - if ok != tt.wantOK { - t.Errorf("ok = %v; want %v", ok, tt.wantOK) - } else if ok && ip != tt.wantIP { - t.Errorf("ip = %v; want %v", ip, tt.wantIP) - } - }) - } -} - -func TestRDNSNameToIPv6(t *testing.T) { - tests := []struct { - name string - input string - wantIP netaddr.IP - wantOK bool - }{ - { - "valid", - "b.a.9.8.7.6.5.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.8.b.d.0.1.0.0.2.ip6.arpa.", - mustIP("2001:db8::567:89ab"), - true, - }, - { - "double_dot", - "b..9.8.7.6.5.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.8.b.d.0.1.0.0.2.ip6.arpa.", - netaddr.IP{}, - false, - }, - { - "double_hex", - "b.a.98.0.7.6.5.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.8.b.d.0.1.0.0.2.ip6.arpa.", - netaddr.IP{}, - false, - }, - { - "not_hex", - "b.a.g.0.7.6.5.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.8.b.d.0.1.0.0.2.ip6.arpa.", - netaddr.IP{}, - false, - }, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - ip, ok := rdnsNameToIPv6(tt.input) - if ok != tt.wantOK { - t.Errorf("ok = %v; want %v", ok, tt.wantOK) - } else if ok && ip != tt.wantIP { - t.Errorf("ip = %v; want %v", ip, tt.wantIP) - } - }) - } -} - -func TestResolve(t *testing.T) { - r := NewResolver(ResolverConfig{Logf: t.Logf, Forward: false}) - r.SetMap(dnsMap) - - if err := r.Start(); err != nil { - t.Fatalf("start: %v", err) - } - defer r.Close() - - tests := []struct { - name string - qname string - qtype dns.Type - ip netaddr.IP - code dns.RCode - }{ - {"ipv4", "test1.ipn.dev.", dns.TypeA, testipv4, dns.RCodeSuccess}, - {"ipv6", "test2.ipn.dev.", dns.TypeAAAA, testipv6, dns.RCodeSuccess}, - {"no-ipv6", "test1.ipn.dev.", dns.TypeAAAA, netaddr.IP{}, dns.RCodeSuccess}, - {"nxdomain", "test3.ipn.dev.", dns.TypeA, netaddr.IP{}, dns.RCodeNameError}, - {"foreign domain", "google.com.", dns.TypeA, netaddr.IP{}, dns.RCodeRefused}, - {"all", "test1.ipn.dev.", dns.TypeA, testipv4, dns.RCodeSuccess}, - {"mx-ipv4", "test1.ipn.dev.", dns.TypeMX, netaddr.IP{}, dns.RCodeSuccess}, - {"mx-ipv6", "test2.ipn.dev.", dns.TypeMX, netaddr.IP{}, dns.RCodeSuccess}, - {"mx-nxdomain", "test3.ipn.dev.", dns.TypeMX, netaddr.IP{}, dns.RCodeNameError}, - {"ns-nxdomain", "test3.ipn.dev.", dns.TypeNS, netaddr.IP{}, dns.RCodeNameError}, - {"onion-domain", "footest.onion.", dns.TypeA, netaddr.IP{}, dns.RCodeNameError}, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - ip, code, err := r.Resolve(tt.qname, tt.qtype) - if err != nil { - t.Errorf("err = %v; want nil", err) - } - if code != tt.code { - t.Errorf("code = %v; want %v", code, tt.code) - } - // Only check ip for non-err - if ip != tt.ip { - t.Errorf("ip = %v; want %v", ip, tt.ip) - } - }) - } -} - -func TestResolveReverse(t *testing.T) { - r := NewResolver(ResolverConfig{Logf: t.Logf, Forward: false}) - r.SetMap(dnsMap) - - if err := r.Start(); err != nil { - t.Fatalf("start: %v", err) - } - defer r.Close() - - tests := []struct { - name string - ip netaddr.IP - want string - code dns.RCode - }{ - {"ipv4", testipv4, "test1.ipn.dev.", dns.RCodeSuccess}, - {"ipv6", testipv6, "test2.ipn.dev.", dns.RCodeSuccess}, - {"nxdomain", netaddr.IPv4(4, 3, 2, 1), "", dns.RCodeNameError}, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - name, code, err := r.ResolveReverse(tt.ip) - if err != nil { - t.Errorf("err = %v; want nil", err) - } - if code != tt.code { - t.Errorf("code = %v; want %v", code, tt.code) - } - if name != tt.want { - t.Errorf("ip = %v; want %v", name, tt.want) - } - }) - } -} - -func ipv6Works() bool { - c, err := net.Listen("tcp", "[::1]:0") - if err != nil { - return false - } - c.Close() - return true -} - -func TestDelegate(t *testing.T) { - tstest.ResourceCheck(t) - - if !ipv6Works() { - t.Skip("skipping test that requires localhost IPv6") - } - - dnsHandleFunc("test.site.", resolveToIP(testipv4, testipv6, "dns.test.site.")) - dnsHandleFunc("nxdomain.site.", resolveToNXDOMAIN) - - v4server, v4errch := serveDNS(t, "127.0.0.1:0") - v6server, v6errch := serveDNS(t, "[::1]:0") - - defer func() { - if err := <-v4errch; err != nil { - t.Errorf("v4 server error: %v", err) - } - if err := <-v6errch; err != nil { - t.Errorf("v6 server error: %v", err) - } - }() - if v4server != nil { - defer v4server.Shutdown() - } - if v6server != nil { - defer v6server.Shutdown() - } - - if v4server == nil || v6server == nil { - // There is an error in at least one of the channels - // and we cannot proceed; return to see it. - return - } - - r := NewResolver(ResolverConfig{Logf: t.Logf, Forward: true}) - r.SetMap(dnsMap) - r.SetUpstreams([]net.Addr{ - v4server.PacketConn.LocalAddr(), - v6server.PacketConn.LocalAddr(), - }) - - if err := r.Start(); err != nil { - t.Fatalf("start: %v", err) - } - defer r.Close() - - tests := []struct { - title string - query []byte - response dnsResponse - }{ - { - "ipv4", - dnspacket("test.site.", dns.TypeA), - dnsResponse{ip: testipv4, rcode: dns.RCodeSuccess}, - }, - { - "ipv6", - dnspacket("test.site.", dns.TypeAAAA), - dnsResponse{ip: testipv6, rcode: dns.RCodeSuccess}, - }, - { - "ns", - dnspacket("test.site.", dns.TypeNS), - dnsResponse{name: "dns.test.site.", rcode: dns.RCodeSuccess}, - }, - { - "nxdomain", - dnspacket("nxdomain.site.", dns.TypeA), - dnsResponse{rcode: dns.RCodeNameError}, - }, - } - - for _, tt := range tests { - t.Run(tt.title, func(t *testing.T) { - payload, err := syncRespond(r, tt.query) - if err != nil { - t.Errorf("err = %v; want nil", err) - return - } - response, err := unpackResponse(payload) - if err != nil { - t.Errorf("extract: err = %v; want nil (in %x)", err, payload) - return - } - if response.rcode != tt.response.rcode { - t.Errorf("rcode = %v; want %v", response.rcode, tt.response.rcode) - } - if response.ip != tt.response.ip { - t.Errorf("ip = %v; want %v", response.ip, tt.response.ip) - } - if response.name != tt.response.name { - t.Errorf("name = %v; want %v", response.name, tt.response.name) - } - }) - } -} - -func TestDelegateCollision(t *testing.T) { - dnsHandleFunc("test.site.", resolveToIP(testipv4, testipv6, "dns.test.site.")) - - server, errch := serveDNS(t, "127.0.0.1:0") - defer func() { - if err := <-errch; err != nil { - t.Errorf("server error: %v", err) - } - }() - - if server == nil { - return - } - defer server.Shutdown() - - r := NewResolver(ResolverConfig{Logf: t.Logf, Forward: true}) - r.SetMap(dnsMap) - r.SetUpstreams([]net.Addr{server.PacketConn.LocalAddr()}) - - if err := r.Start(); err != nil { - t.Fatalf("start: %v", err) - } - defer r.Close() - - packets := []struct { - qname string - qtype dns.Type - addr netaddr.IPPort - }{ - {"test.site.", dns.TypeA, netaddr.IPPort{IP: netaddr.IPv4(1, 1, 1, 1), Port: 1001}}, - {"test.site.", dns.TypeAAAA, netaddr.IPPort{IP: netaddr.IPv4(1, 1, 1, 1), Port: 1002}}, - } - - // packets will have the same dns txid. - for _, p := range packets { - payload := dnspacket(p.qname, p.qtype) - req := Packet{Payload: payload, Addr: p.addr} - err := r.EnqueueRequest(req) - if err != nil { - t.Error(err) - } - } - - // Despite the txid collision, the answer(s) should still match the query. - resp, err := r.NextResponse() - if err != nil { - t.Error(err) - } - - var p dns.Parser - _, err = p.Start(resp.Payload) - if err != nil { - t.Error(err) - } - err = p.SkipAllQuestions() - if err != nil { - t.Error(err) - } - ans, err := p.AllAnswers() - if err != nil { - t.Error(err) - } - - var wantType dns.Type - switch ans[0].Body.(type) { - case *dns.AResource: - wantType = dns.TypeA - case *dns.AAAAResource: - wantType = dns.TypeAAAA - default: - t.Errorf("unexpected answer type: %T", ans[0].Body) - } - - for _, p := range packets { - if p.qtype == wantType && p.addr != resp.Addr { - t.Errorf("addr = %v; want %v", resp.Addr, p.addr) - } - } -} - -func TestConcurrentSetMap(t *testing.T) { - r := NewResolver(ResolverConfig{Logf: t.Logf, Forward: false}) - - if err := r.Start(); err != nil { - t.Fatalf("start: %v", err) - } - defer r.Close() - - // This is purely to ensure that Resolve does not race with SetMap. - var wg sync.WaitGroup - wg.Add(2) - go func() { - defer wg.Done() - r.SetMap(dnsMap) - }() - go func() { - defer wg.Done() - r.Resolve("test1.ipn.dev", dns.TypeA) - }() - wg.Wait() -} - -func TestConcurrentSetUpstreams(t *testing.T) { - dnsHandleFunc("test.site.", resolveToIP(testipv4, testipv6, "dns.test.site.")) - - server, errch := serveDNS(t, "127.0.0.1:0") - defer func() { - if err := <-errch; err != nil { - t.Errorf("server error: %v", err) - } - }() - - if server == nil { - return - } - defer server.Shutdown() - - r := NewResolver(ResolverConfig{Logf: t.Logf, Forward: true}) - r.SetMap(dnsMap) - - if err := r.Start(); err != nil { - t.Fatalf("start: %v", err) - } - defer r.Close() - - packet := dnspacket("test.site.", dns.TypeA) - // This is purely to ensure that delegation does not race with SetUpstreams. - var wg sync.WaitGroup - wg.Add(2) - go func() { - defer wg.Done() - r.SetUpstreams([]net.Addr{server.PacketConn.LocalAddr()}) - }() - go func() { - defer wg.Done() - syncRespond(r, packet) - }() - wg.Wait() -} - -var allResponse = []byte{ - 0x00, 0x00, // transaction id: 0 - 0x84, 0x00, // flags: response, authoritative, no error - 0x00, 0x01, // one question - 0x00, 0x01, // one answer - 0x00, 0x00, 0x00, 0x00, // no authority or additional RRs - // Question: - 0x05, 0x74, 0x65, 0x73, 0x74, 0x31, 0x03, 0x69, 0x70, 0x6e, 0x03, 0x64, 0x65, 0x76, 0x00, // name - 0x00, 0xff, 0x00, 0x01, // type ALL, class IN - // Answer: - 0x05, 0x74, 0x65, 0x73, 0x74, 0x31, 0x03, 0x69, 0x70, 0x6e, 0x03, 0x64, 0x65, 0x76, 0x00, // name - 0x00, 0x01, 0x00, 0x01, // type A, class IN - 0x00, 0x00, 0x02, 0x58, // TTL: 600 - 0x00, 0x04, // length: 4 bytes - 0x01, 0x02, 0x03, 0x04, // A: 1.2.3.4 -} - -var ipv4Response = []byte{ - 0x00, 0x00, // transaction id: 0 - 0x84, 0x00, // flags: response, authoritative, no error - 0x00, 0x01, // one question - 0x00, 0x01, // one answer - 0x00, 0x00, 0x00, 0x00, // no authority or additional RRs - // Question: - 0x05, 0x74, 0x65, 0x73, 0x74, 0x31, 0x03, 0x69, 0x70, 0x6e, 0x03, 0x64, 0x65, 0x76, 0x00, // name - 0x00, 0x01, 0x00, 0x01, // type A, class IN - // Answer: - 0x05, 0x74, 0x65, 0x73, 0x74, 0x31, 0x03, 0x69, 0x70, 0x6e, 0x03, 0x64, 0x65, 0x76, 0x00, // name - 0x00, 0x01, 0x00, 0x01, // type A, class IN - 0x00, 0x00, 0x02, 0x58, // TTL: 600 - 0x00, 0x04, // length: 4 bytes - 0x01, 0x02, 0x03, 0x04, // A: 1.2.3.4 -} - -var ipv6Response = []byte{ - 0x00, 0x00, // transaction id: 0 - 0x84, 0x00, // flags: response, authoritative, no error - 0x00, 0x01, // one question - 0x00, 0x01, // one answer - 0x00, 0x00, 0x00, 0x00, // no authority or additional RRs - // Question: - 0x05, 0x74, 0x65, 0x73, 0x74, 0x32, 0x03, 0x69, 0x70, 0x6e, 0x03, 0x64, 0x65, 0x76, 0x00, // name - 0x00, 0x1c, 0x00, 0x01, // type AAAA, class IN - // Answer: - 0x05, 0x74, 0x65, 0x73, 0x74, 0x32, 0x03, 0x69, 0x70, 0x6e, 0x03, 0x64, 0x65, 0x76, 0x00, // name - 0x00, 0x1c, 0x00, 0x01, // type AAAA, class IN - 0x00, 0x00, 0x02, 0x58, // TTL: 600 - 0x00, 0x10, // length: 16 bytes - // AAAA: 0001:0203:0405:0607:0809:0A0B:0C0D:0E0F - 0x00, 0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07, 0x08, 0x09, 0x0a, 0xb, 0xc, 0xd, 0xe, 0xf, -} - -var ipv4UppercaseResponse = []byte{ - 0x00, 0x00, // transaction id: 0 - 0x84, 0x00, // flags: response, authoritative, no error - 0x00, 0x01, // one question - 0x00, 0x01, // one answer - 0x00, 0x00, 0x00, 0x00, // no authority or additional RRs - // Question: - 0x05, 0x54, 0x45, 0x53, 0x54, 0x31, 0x03, 0x49, 0x50, 0x4e, 0x03, 0x44, 0x45, 0x56, 0x00, // name - 0x00, 0x01, 0x00, 0x01, // type A, class IN - // Answer: - 0x05, 0x54, 0x45, 0x53, 0x54, 0x31, 0x03, 0x49, 0x50, 0x4e, 0x03, 0x44, 0x45, 0x56, 0x00, // name - 0x00, 0x01, 0x00, 0x01, // type A, class IN - 0x00, 0x00, 0x02, 0x58, // TTL: 600 - 0x00, 0x04, // length: 4 bytes - 0x01, 0x02, 0x03, 0x04, // A: 1.2.3.4 -} - -var ptrResponse = []byte{ - 0x00, 0x00, // transaction id: 0 - 0x84, 0x00, // flags: response, authoritative, no error - 0x00, 0x01, // one question - 0x00, 0x01, // one answer - 0x00, 0x00, 0x00, 0x00, // no authority or additional RRs - // Question: 4.3.2.1.in-addr.arpa - 0x01, 0x34, 0x01, 0x33, 0x01, 0x32, 0x01, 0x31, 0x07, - 0x69, 0x6e, 0x2d, 0x61, 0x64, 0x64, 0x72, 0x04, 0x61, 0x72, 0x70, 0x61, 0x00, - 0x00, 0x0c, 0x00, 0x01, // type PTR, class IN - // Answer: 4.3.2.1.in-addr.arpa - 0x01, 0x34, 0x01, 0x33, 0x01, 0x32, 0x01, 0x31, 0x07, - 0x69, 0x6e, 0x2d, 0x61, 0x64, 0x64, 0x72, 0x04, 0x61, 0x72, 0x70, 0x61, 0x00, - 0x00, 0x0c, 0x00, 0x01, // type PTR, class IN - 0x00, 0x00, 0x02, 0x58, // TTL: 600 - 0x00, 0x0f, // length: 15 bytes - // PTR: test1.ipn.dev - 0x05, 0x74, 0x65, 0x73, 0x74, 0x31, 0x03, 0x69, 0x70, 0x6e, 0x03, 0x64, 0x65, 0x76, 0x00, -} - -var ptrResponse6 = []byte{ - 0x00, 0x00, // transaction id: 0 - 0x84, 0x00, // flags: response, authoritative, no error - 0x00, 0x01, // one question - 0x00, 0x01, // one answer - 0x00, 0x00, 0x00, 0x00, // no authority or additional RRs - // Question: f.0.e.0.d.0.c.0.b.0.a.0.9.0.8.0.7.0.6.0.5.0.4.0.3.0.2.0.1.0.0.0.ip6.arpa - 0x01, 0x66, 0x01, 0x30, 0x01, 0x65, 0x01, 0x30, - 0x01, 0x64, 0x01, 0x30, 0x01, 0x63, 0x01, 0x30, - 0x01, 0x62, 0x01, 0x30, 0x01, 0x61, 0x01, 0x30, - 0x01, 0x39, 0x01, 0x30, 0x01, 0x38, 0x01, 0x30, - 0x01, 0x37, 0x01, 0x30, 0x01, 0x36, 0x01, 0x30, - 0x01, 0x35, 0x01, 0x30, 0x01, 0x34, 0x01, 0x30, - 0x01, 0x33, 0x01, 0x30, 0x01, 0x32, 0x01, 0x30, - 0x01, 0x31, 0x01, 0x30, 0x01, 0x30, 0x01, 0x30, - 0x03, 0x69, 0x70, 0x36, - 0x04, 0x61, 0x72, 0x70, 0x61, 0x00, - 0x00, 0x0c, 0x00, 0x01, // type PTR, class IN6 - // Answer: f.0.e.0.d.0.c.0.b.0.a.0.9.0.8.0.7.0.6.0.5.0.4.0.3.0.2.0.1.0.0.0.ip6.arpa - 0x01, 0x66, 0x01, 0x30, 0x01, 0x65, 0x01, 0x30, - 0x01, 0x64, 0x01, 0x30, 0x01, 0x63, 0x01, 0x30, - 0x01, 0x62, 0x01, 0x30, 0x01, 0x61, 0x01, 0x30, - 0x01, 0x39, 0x01, 0x30, 0x01, 0x38, 0x01, 0x30, - 0x01, 0x37, 0x01, 0x30, 0x01, 0x36, 0x01, 0x30, - 0x01, 0x35, 0x01, 0x30, 0x01, 0x34, 0x01, 0x30, - 0x01, 0x33, 0x01, 0x30, 0x01, 0x32, 0x01, 0x30, - 0x01, 0x31, 0x01, 0x30, 0x01, 0x30, 0x01, 0x30, - 0x03, 0x69, 0x70, 0x36, - 0x04, 0x61, 0x72, 0x70, 0x61, 0x00, - 0x00, 0x0c, 0x00, 0x01, // type PTR, class IN - 0x00, 0x00, 0x02, 0x58, // TTL: 600 - 0x00, 0x0f, // length: 15 bytes - // PTR: test2.ipn.dev - 0x05, 0x74, 0x65, 0x73, 0x74, 0x32, 0x03, 0x69, 0x70, 0x6e, 0x03, 0x64, 0x65, 0x76, 0x00, -} - -var nxdomainResponse = []byte{ - 0x00, 0x00, // transaction id: 0 - 0x84, 0x03, // flags: response, authoritative, error: nxdomain - 0x00, 0x01, // one question - 0x00, 0x00, // no answers - 0x00, 0x00, 0x00, 0x00, // no authority or additional RRs - // Question: - 0x05, 0x74, 0x65, 0x73, 0x74, 0x33, 0x03, 0x69, 0x70, 0x6e, 0x03, 0x64, 0x65, 0x76, 0x00, // name - 0x00, 0x01, 0x00, 0x01, // type A, class IN -} - -var emptyResponse = []byte{ - 0x00, 0x00, // transaction id: 0 - 0x84, 0x00, // flags: response, authoritative, no error - 0x00, 0x01, // one question - 0x00, 0x00, // no answers - 0x00, 0x00, 0x00, 0x00, // no authority or additional RRs - // Question: - 0x05, 0x74, 0x65, 0x73, 0x74, 0x31, 0x03, 0x69, 0x70, 0x6e, 0x03, 0x64, 0x65, 0x76, 0x00, // name - 0x00, 0x1c, 0x00, 0x01, // type AAAA, class IN -} - -func TestFull(t *testing.T) { - r := NewResolver(ResolverConfig{Logf: t.Logf, Forward: false}) - r.SetMap(dnsMap) - - if err := r.Start(); err != nil { - t.Fatalf("start: %v", err) - } - defer r.Close() - - // One full packet and one error packet - tests := []struct { - name string - request []byte - response []byte - }{ - {"all", dnspacket("test1.ipn.dev.", dns.TypeALL), allResponse}, - {"ipv4", dnspacket("test1.ipn.dev.", dns.TypeA), ipv4Response}, - {"ipv6", dnspacket("test2.ipn.dev.", dns.TypeAAAA), ipv6Response}, - {"no-ipv6", dnspacket("test1.ipn.dev.", dns.TypeAAAA), emptyResponse}, - {"upper", dnspacket("TEST1.IPN.DEV.", dns.TypeA), ipv4UppercaseResponse}, - {"ptr", dnspacket("4.3.2.1.in-addr.arpa.", dns.TypePTR), ptrResponse}, - {"ptr", dnspacket("f.0.e.0.d.0.c.0.b.0.a.0.9.0.8.0.7.0.6.0.5.0.4.0.3.0.2.0.1.0.0.0.ip6.arpa.", - dns.TypePTR), ptrResponse6}, - {"nxdomain", dnspacket("test3.ipn.dev.", dns.TypeA), nxdomainResponse}, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - response, err := syncRespond(r, tt.request) - if err != nil { - t.Errorf("err = %v; want nil", err) - } - if !bytes.Equal(response, tt.response) { - t.Errorf("response = %x; want %x", response, tt.response) - } - }) - } -} - -func TestAllocs(t *testing.T) { - r := NewResolver(ResolverConfig{Logf: t.Logf, Forward: false}) - r.SetMap(dnsMap) - - if err := r.Start(); err != nil { - t.Fatalf("start: %v", err) - } - defer r.Close() - - // It is seemingly pointless to test allocs in the delegate path, - // as dialer.Dial -> Read -> Write alone comprise 12 allocs. - tests := []struct { - name string - query []byte - want int - }{ - // Name lowercasing and response slice created by dns.NewBuilder. - {"forward", dnspacket("test1.ipn.dev.", dns.TypeA), 2}, - // 3 extra allocs in rdnsNameToIPv4 and one in marshalPTRRecord (dns.NewName). - {"reverse", dnspacket("4.3.2.1.in-addr.arpa.", dns.TypePTR), 5}, - } - - for _, tt := range tests { - allocs := testing.AllocsPerRun(100, func() { - syncRespond(r, tt.query) - }) - if int(allocs) > tt.want { - t.Errorf("%s: allocs = %v; want %v", tt.name, allocs, tt.want) - } - } -} - -func TestTrimRDNSBonjourPrefix(t *testing.T) { - tests := []struct { - in string - want bool - }{ - {"b._dns-sd._udp.0.10.20.172.in-addr.arpa.", true}, - {"db._dns-sd._udp.0.10.20.172.in-addr.arpa.", true}, - {"r._dns-sd._udp.0.10.20.172.in-addr.arpa.", true}, - {"dr._dns-sd._udp.0.10.20.172.in-addr.arpa.", true}, - {"lb._dns-sd._udp.0.10.20.172.in-addr.arpa.", true}, - {"qq._dns-sd._udp.0.10.20.172.in-addr.arpa.", false}, - {"0.10.20.172.in-addr.arpa.", false}, - {"i-have-no-dot", false}, - } - - for _, test := range tests { - got := hasRDNSBonjourPrefix(test.in) - if got != test.want { - t.Errorf("trimRDNSBonjourPrefix(%q) = %v, want %v", test.in, got, test.want) - } - } -} - -func BenchmarkFull(b *testing.B) { - dnsHandleFunc("test.site.", resolveToIP(testipv4, testipv6, "dns.test.site.")) - - server, errch := serveDNS(b, "127.0.0.1:0") - defer func() { - if err := <-errch; err != nil { - b.Errorf("server error: %v", err) - } - }() - - if server == nil { - return - } - defer server.Shutdown() - - r := NewResolver(ResolverConfig{Logf: b.Logf, Forward: true}) - r.SetMap(dnsMap) - r.SetUpstreams([]net.Addr{server.PacketConn.LocalAddr()}) - - if err := r.Start(); err != nil { - b.Fatalf("start: %v", err) - } - defer r.Close() - - tests := []struct { - name string - request []byte - }{ - {"forward", dnspacket("test1.ipn.dev.", dns.TypeA)}, - {"reverse", dnspacket("4.3.2.1.in-addr.arpa.", dns.TypePTR)}, - {"delegated", dnspacket("test.site.", dns.TypeA)}, - } - - for _, tt := range tests { - b.Run(tt.name, func(b *testing.B) { - for i := 0; i < b.N; i++ { - syncRespond(r, tt.request) - } - }) - } -} - -func TestMarshalResponseFormatError(t *testing.T) { - resp := new(response) - resp.Header.RCode = dns.RCodeFormatError - v, err := marshalResponse(resp) - if err != nil { - t.Errorf("marshal error: %v", err) - } - t.Logf("response: %q", v) -} diff --git a/wgengine/tstun/faketun.go b/wgengine/tstun/faketun.go deleted file mode 100644 index 50880131a..000000000 --- a/wgengine/tstun/faketun.go +++ /dev/null @@ -1,56 +0,0 @@ -// Copyright (c) 2020 Tailscale Inc & AUTHORS All rights reserved. -// Use of this source code is governed by a BSD-style -// license that can be found in the LICENSE file. - -package tstun - -import ( - "io" - "os" - - "github.com/tailscale/wireguard-go/tun" -) - -type fakeTUN struct { - evchan chan tun.Event - closechan chan struct{} -} - -// NewFakeTUN returns a fake TUN device that does not depend on the -// operating system or any special permissions. -// It primarily exists for testing. -func NewFakeTUN() tun.Device { - return &fakeTUN{ - evchan: make(chan tun.Event), - closechan: make(chan struct{}), - } -} - -func (t *fakeTUN) File() *os.File { - panic("fakeTUN.File() called, which makes no sense") -} - -func (t *fakeTUN) Close() error { - close(t.closechan) - close(t.evchan) - return nil -} - -func (t *fakeTUN) Read(out []byte, offset int) (int, error) { - <-t.closechan - return 0, io.EOF -} - -func (t *fakeTUN) Write(b []byte, n int) (int, error) { - select { - case <-t.closechan: - return 0, ErrClosed - default: - } - return len(b), nil -} - -func (t *fakeTUN) Flush() error { return nil } -func (t *fakeTUN) MTU() (int, error) { return 1500, nil } -func (t *fakeTUN) Name() (string, error) { return "FakeTUN", nil } -func (t *fakeTUN) Events() chan tun.Event { return t.evchan } diff --git a/wgengine/tstun/tun.go b/wgengine/tstun/tun.go deleted file mode 100644 index 92af1b8b0..000000000 --- a/wgengine/tstun/tun.go +++ /dev/null @@ -1,465 +0,0 @@ -// Copyright (c) 2020 Tailscale Inc & AUTHORS All rights reserved. -// Use of this source code is governed by a BSD-style -// license that can be found in the LICENSE file. - -// Package tstun provides a TUN struct implementing the tun.Device interface -// with additional features as required by wgengine. -package tstun - -import ( - "errors" - "io" - "os" - "sync" - "sync/atomic" - "time" - - "github.com/tailscale/wireguard-go/device" - "github.com/tailscale/wireguard-go/tun" - "inet.af/netaddr" - "tailscale.com/net/packet" - "tailscale.com/types/logger" - "tailscale.com/wgengine/filter" -) - -const maxBufferSize = device.MaxMessageSize - -// PacketStartOffset is the minimal amount of leading space that must exist -// before &packet[offset] in a packet passed to Read, Write, or InjectInboundDirect. -// This is necessary to avoid reallocation in wireguard-go internals. -const PacketStartOffset = device.MessageTransportHeaderSize - -// MaxPacketSize is the maximum size (in bytes) -// of a packet that can be injected into a tstun.TUN. -const MaxPacketSize = device.MaxContentSize - -var ( - // ErrClosed is returned when attempting an operation on a closed TUN. - ErrClosed = errors.New("device closed") - // ErrFiltered is returned when the acted-on packet is rejected by a filter. - ErrFiltered = errors.New("packet dropped by filter") -) - -var ( - errPacketTooBig = errors.New("packet too big") - errOffsetTooBig = errors.New("offset larger than buffer length") - errOffsetTooSmall = errors.New("offset smaller than PacketStartOffset") -) - -// parsedPacketPool holds a pool of Parsed structs for use in filtering. -// This is needed because escape analysis cannot see that parsed packets -// do not escape through {Pre,Post}Filter{In,Out}. -var parsedPacketPool = sync.Pool{New: func() interface{} { return new(packet.Parsed) }} - -// FilterFunc is a packet-filtering function with access to the TUN device. -// It must not hold onto the packet struct, as its backing storage will be reused. -type FilterFunc func(*packet.Parsed, *TUN) filter.Response - -// TUN wraps a tun.Device from wireguard-go, -// augmenting it with filtering and packet injection. -// All the added work happens in Read and Write: -// the other methods delegate to the underlying tdev. -type TUN struct { - logf logger.Logf - // tdev is the underlying TUN device. - tdev tun.Device - - closeOnce sync.Once - - lastActivityAtomic int64 // unix seconds of last send or receive - - destIPActivity atomic.Value // of map[netaddr.IP]func() - - // buffer stores the oldest unconsumed packet from tdev. - // It is made a static buffer in order to avoid allocations. - buffer [maxBufferSize]byte - // bufferConsumed synchronizes access to buffer (shared by Read and poll). - bufferConsumed chan struct{} - - // closed signals poll (by closing) when the device is closed. - closed chan struct{} - // errors is the error queue populated by poll. - errors chan error - // outbound is the queue by which packets leave the TUN device. - // - // The directions are relative to the network, not the device: - // inbound packets arrive via UDP and are written into the TUN device; - // outbound packets are read from the TUN device and sent out via UDP. - // This queue is needed because although inbound writes are synchronous, - // the other direction must wait on a Wireguard goroutine to poll it. - // - // Empty reads are skipped by Wireguard, so it is always legal - // to discard an empty packet instead of sending it through t.outbound. - outbound chan []byte - - // fitler stores the currently active package filter - filter atomic.Value // of *filter.Filter - // filterFlags control the verbosity of logging packet drops/accepts. - filterFlags filter.RunFlags - - // PreFilterIn is the inbound filter function that runs before the main filter - // and therefore sees the packets that may be later dropped by it. - PreFilterIn FilterFunc - // PostFilterIn is the inbound filter function that runs after the main filter. - PostFilterIn FilterFunc - // PreFilterOut is the outbound filter function that runs before the main filter - // and therefore sees the packets that may be later dropped by it. - PreFilterOut FilterFunc - // PostFilterOut is the outbound filter function that runs after the main filter. - PostFilterOut FilterFunc - - // disableFilter disables all filtering when set. This should only be used in tests. - disableFilter bool -} - -func WrapTUN(logf logger.Logf, tdev tun.Device) *TUN { - tun := &TUN{ - logf: logger.WithPrefix(logf, "tstun: "), - tdev: tdev, - // bufferConsumed is conceptually a condition variable: - // a goroutine should not block when setting it, even with no listeners. - bufferConsumed: make(chan struct{}, 1), - closed: make(chan struct{}), - errors: make(chan error), - outbound: make(chan []byte), - // TODO(dmytro): (highly rate-limited) hexdumps should happen on unknown packets. - filterFlags: filter.LogAccepts | filter.LogDrops, - } - - go tun.poll() - // The buffer starts out consumed. - tun.bufferConsumed <- struct{}{} - - return tun -} - -// SetDestIPActivityFuncs sets a map of funcs to run per packet -// destination (the map keys). -// -// The map ownership passes to the TUN. It must be non-nil. -func (t *TUN) SetDestIPActivityFuncs(m map[netaddr.IP]func()) { - t.destIPActivity.Store(m) -} - -func (t *TUN) Close() error { - var err error - t.closeOnce.Do(func() { - // Other channels need not be closed: poll will exit gracefully after this. - close(t.closed) - - err = t.tdev.Close() - }) - return err -} - -func (t *TUN) Events() chan tun.Event { - return t.tdev.Events() -} - -func (t *TUN) File() *os.File { - return t.tdev.File() -} - -func (t *TUN) Flush() error { - return t.tdev.Flush() -} - -func (t *TUN) MTU() (int, error) { - return t.tdev.MTU() -} - -func (t *TUN) Name() (string, error) { - return t.tdev.Name() -} - -// poll polls t.tdev.Read, placing the oldest unconsumed packet into t.buffer. -// This is needed because t.tdev.Read in general may block (it does on Windows), -// so packets may be stuck in t.outbound if t.Read called t.tdev.Read directly. -func (t *TUN) poll() { - for { - select { - case <-t.closed: - return - case <-t.bufferConsumed: - // continue - } - - // Read may use memory in t.buffer before PacketStartOffset for mandatory headers. - // This is the rationale behind the tun.TUN.{Read,Write} interfaces - // and the reason t.buffer has size MaxMessageSize and not MaxContentSize. - n, err := t.tdev.Read(t.buffer[:], PacketStartOffset) - if err != nil { - select { - case <-t.closed: - return - case t.errors <- err: - // In principle, read errors are not fatal (but wireguard-go disagrees). - t.bufferConsumed <- struct{}{} - } - continue - } - - // Wireguard will skip an empty read, - // so we might as well do it here to avoid the send through t.outbound. - if n == 0 { - t.bufferConsumed <- struct{}{} - continue - } - - select { - case <-t.closed: - return - case t.outbound <- t.buffer[PacketStartOffset : PacketStartOffset+n]: - // continue - } - } -} - -var magicDNSIPPort = netaddr.MustParseIPPort("100.100.100.100:0") - -func (t *TUN) filterOut(p *packet.Parsed) filter.Response { - // Fake ICMP echo responses to MagicDNS (100.100.100.100). - if p.IsEchoRequest() && p.Dst == magicDNSIPPort { - header := p.ICMP4Header() - header.ToResponse() - outp := packet.Generate(&header, p.Payload()) - t.InjectInboundCopy(outp) - return filter.DropSilently // don't pass on to OS; already handled - } - - if t.PreFilterOut != nil { - if res := t.PreFilterOut(p, t); res.IsDrop() { - return res - } - } - - filt, _ := t.filter.Load().(*filter.Filter) - - if filt == nil { - return filter.Drop - } - - if filt.RunOut(p, t.filterFlags) != filter.Accept { - return filter.Drop - } - - if t.PostFilterOut != nil { - if res := t.PostFilterOut(p, t); res.IsDrop() { - return res - } - } - - return filter.Accept -} - -// noteActivity records that there was a read or write at the current time. -func (t *TUN) noteActivity() { - atomic.StoreInt64(&t.lastActivityAtomic, time.Now().Unix()) -} - -// IdleDuration reports how long it's been since the last read or write to this device. -// -// Its value is only accurate to roughly second granularity. -// If there's never been activity, the duration is since 1970. -func (t *TUN) IdleDuration() time.Duration { - sec := atomic.LoadInt64(&t.lastActivityAtomic) - return time.Since(time.Unix(sec, 0)) -} - -func (t *TUN) Read(buf []byte, offset int) (int, error) { - var n int - - wasInjectedPacket := false - - select { - case <-t.closed: - return 0, io.EOF - case err := <-t.errors: - return 0, err - case pkt := <-t.outbound: - n = copy(buf[offset:], pkt) - // t.buffer has a fixed location in memory, - // so this is the easiest way to tell when it has been consumed. - // &pkt[0] can be used because empty packets do not reach t.outbound. - if &pkt[0] == &t.buffer[PacketStartOffset] { - t.bufferConsumed <- struct{}{} - } else { - // If the packet is not from t.buffer, then it is an injected packet. - wasInjectedPacket = true - } - } - - p := parsedPacketPool.Get().(*packet.Parsed) - defer parsedPacketPool.Put(p) - p.Decode(buf[offset : offset+n]) - - if m, ok := t.destIPActivity.Load().(map[netaddr.IP]func()); ok { - if fn := m[p.Dst.IP]; fn != nil { - fn() - } - } - - // For injected packets, we return early to bypass filtering. - if wasInjectedPacket { - t.noteActivity() - return n, nil - } - - if !t.disableFilter { - response := t.filterOut(p) - if response != filter.Accept { - // Wireguard considers read errors fatal; pretend nothing was read - return 0, nil - } - } - - t.noteActivity() - return n, nil -} - -func (t *TUN) filterIn(buf []byte) filter.Response { - p := parsedPacketPool.Get().(*packet.Parsed) - defer parsedPacketPool.Put(p) - p.Decode(buf) - - if t.PreFilterIn != nil { - if res := t.PreFilterIn(p, t); res.IsDrop() { - return res - } - } - - filt, _ := t.filter.Load().(*filter.Filter) - - if filt == nil { - return filter.Drop - } - - if filt.RunIn(p, t.filterFlags) != filter.Accept { - - // Tell them, via TSMP, we're dropping them due to the ACL. - // Their host networking stack can translate this into ICMP - // or whatnot as required. But notably, their GUI or tailscale CLI - // can show them a rejection history with reasons. - if p.IPVersion == 4 && p.IPProto == packet.TCP && p.TCPFlags&packet.TCPSyn != 0 { - rj := packet.TailscaleRejectedHeader{ - IPSrc: p.Dst.IP, - IPDst: p.Src.IP, - Src: p.Src, - Dst: p.Dst, - Proto: p.IPProto, - Reason: packet.RejectedDueToACLs, - } - if filt.ShieldsUp() { - rj.Reason = packet.RejectedDueToShieldsUp - } - pkt := packet.Generate(rj, nil) - t.InjectOutbound(pkt) - - // TODO(bradfitz): also send a TCP RST, after the TSMP message. - } - - return filter.Drop - } - - if t.PostFilterIn != nil { - if res := t.PostFilterIn(p, t); res.IsDrop() { - return res - } - } - - return filter.Accept -} - -// Write accepts an incoming packet. The packet begins at buf[offset:], -// like wireguard-go/tun.Device.Write. -func (t *TUN) Write(buf []byte, offset int) (int, error) { - if !t.disableFilter { - res := t.filterIn(buf[offset:]) - if res == filter.DropSilently { - return len(buf), nil - } - if res != filter.Accept { - return 0, ErrFiltered - } - } - - t.noteActivity() - return t.tdev.Write(buf, offset) -} - -func (t *TUN) GetFilter() *filter.Filter { - filt, _ := t.filter.Load().(*filter.Filter) - return filt -} - -func (t *TUN) SetFilter(filt *filter.Filter) { - t.filter.Store(filt) -} - -// InjectInboundDirect makes the TUN device behave as if a packet -// with the given contents was received from the network. -// It blocks and does not take ownership of the packet. -// The injected packet will not pass through inbound filters. -// -// The packet contents are to start at &buf[offset]. -// offset must be greater or equal to PacketStartOffset. -// The space before &buf[offset] will be used by Wireguard. -func (t *TUN) InjectInboundDirect(buf []byte, offset int) error { - if len(buf) > MaxPacketSize { - return errPacketTooBig - } - if len(buf) < offset { - return errOffsetTooBig - } - if offset < PacketStartOffset { - return errOffsetTooSmall - } - - // Write to the underlying device to skip filters. - _, err := t.tdev.Write(buf, offset) - return err -} - -// InjectInboundCopy takes a packet without leading space, -// reallocates it to conform to the InjectInboundDirect interface -// and calls InjectInboundDirect on it. Injecting a nil packet is a no-op. -func (t *TUN) InjectInboundCopy(packet []byte) error { - // We duplicate this check from InjectInboundDirect here - // to avoid wasting an allocation on an oversized packet. - if len(packet) > MaxPacketSize { - return errPacketTooBig - } - if len(packet) == 0 { - return nil - } - - buf := make([]byte, PacketStartOffset+len(packet)) - copy(buf[PacketStartOffset:], packet) - - return t.InjectInboundDirect(buf, PacketStartOffset) -} - -// InjectOutbound makes the TUN device behave as if a packet -// with the given contents was sent to the network. -// It does not block, but takes ownership of the packet. -// The injected packet will not pass through outbound filters. -// Injecting an empty packet is a no-op. -func (t *TUN) InjectOutbound(packet []byte) error { - if len(packet) > MaxPacketSize { - return errPacketTooBig - } - if len(packet) == 0 { - return nil - } - select { - case <-t.closed: - return ErrClosed - case t.outbound <- packet: - return nil - } -} - -// Unwrap returns the underlying TUN device. -func (t *TUN) Unwrap() tun.Device { - return t.tdev -} diff --git a/wgengine/tstun/tun_test.go b/wgengine/tstun/tun_test.go deleted file mode 100644 index 365d56b4d..000000000 --- a/wgengine/tstun/tun_test.go +++ /dev/null @@ -1,382 +0,0 @@ -// Copyright (c) 2020 Tailscale Inc & AUTHORS All rights reserved. -// Use of this source code is governed by a BSD-style -// license that can be found in the LICENSE file. - -package tstun - -import ( - "bytes" - "fmt" - "strconv" - "strings" - "sync/atomic" - "testing" - "unsafe" - - "github.com/tailscale/wireguard-go/tun/tuntest" - "inet.af/netaddr" - "tailscale.com/net/packet" - "tailscale.com/types/logger" - "tailscale.com/wgengine/filter" -) - -func udp4(src, dst string, sport, dport uint16) []byte { - sip, err := netaddr.ParseIP(src) - if err != nil { - panic(err) - } - dip, err := netaddr.ParseIP(dst) - if err != nil { - panic(err) - } - header := &packet.UDP4Header{ - IP4Header: packet.IP4Header{ - Src: sip, - Dst: dip, - IPID: 0, - }, - SrcPort: sport, - DstPort: dport, - } - return packet.Generate(header, []byte("udp_payload")) -} - -func nets(nets ...string) (ret []netaddr.IPPrefix) { - for _, s := range nets { - if i := strings.IndexByte(s, '/'); i == -1 { - ip, err := netaddr.ParseIP(s) - if err != nil { - panic(err) - } - bits := uint8(32) - if ip.Is6() { - bits = 128 - } - ret = append(ret, netaddr.IPPrefix{IP: ip, Bits: bits}) - } else { - pfx, err := netaddr.ParseIPPrefix(s) - if err != nil { - panic(err) - } - ret = append(ret, pfx) - } - } - return ret -} - -func ports(s string) filter.PortRange { - if s == "*" { - return filter.PortRange{First: 0, Last: 65535} - } - - var fs, ls string - i := strings.IndexByte(s, '-') - if i == -1 { - fs = s - ls = fs - } else { - fs = s[:i] - ls = s[i+1:] - } - first, err := strconv.ParseInt(fs, 10, 16) - if err != nil { - panic(fmt.Sprintf("invalid NetPortRange %q", s)) - } - last, err := strconv.ParseInt(ls, 10, 16) - if err != nil { - panic(fmt.Sprintf("invalid NetPortRange %q", s)) - } - return filter.PortRange{First: uint16(first), Last: uint16(last)} -} - -func netports(netPorts ...string) (ret []filter.NetPortRange) { - for _, s := range netPorts { - i := strings.LastIndexByte(s, ':') - if i == -1 { - panic(fmt.Sprintf("invalid NetPortRange %q", s)) - } - - npr := filter.NetPortRange{ - Net: nets(s[:i])[0], - Ports: ports(s[i+1:]), - } - ret = append(ret, npr) - } - return ret -} - -func setfilter(logf logger.Logf, tun *TUN) { - matches := []filter.Match{ - {Srcs: nets("5.6.7.8"), Dsts: netports("1.2.3.4:89-90")}, - {Srcs: nets("1.2.3.4"), Dsts: netports("5.6.7.8:98")}, - } - var sb netaddr.IPSetBuilder - sb.AddPrefix(netaddr.MustParseIPPrefix("1.2.0.0/16")) - tun.SetFilter(filter.New(matches, sb.IPSet(), sb.IPSet(), nil, logf)) -} - -func newChannelTUN(logf logger.Logf, secure bool) (*tuntest.ChannelTUN, *TUN) { - chtun := tuntest.NewChannelTUN() - tun := WrapTUN(logf, chtun.TUN()) - if secure { - setfilter(logf, tun) - } else { - tun.disableFilter = true - } - return chtun, tun -} - -func newFakeTUN(logf logger.Logf, secure bool) (*fakeTUN, *TUN) { - ftun := NewFakeTUN() - tun := WrapTUN(logf, ftun) - if secure { - setfilter(logf, tun) - } else { - tun.disableFilter = true - } - return ftun.(*fakeTUN), tun -} - -func TestReadAndInject(t *testing.T) { - chtun, tun := newChannelTUN(t.Logf, false) - defer tun.Close() - - const size = 2 // all payloads have this size - written := []string{"w0", "w1"} - injected := []string{"i0", "i1"} - - go func() { - for _, packet := range written { - payload := []byte(packet) - chtun.Outbound <- payload - } - }() - - for _, packet := range injected { - go func(packet string) { - payload := []byte(packet) - err := tun.InjectOutbound(payload) - if err != nil { - t.Errorf("%s: error: %v", packet, err) - } - }(packet) - } - - var buf [MaxPacketSize]byte - var seen = make(map[string]bool) - // We expect the same packets back, in no particular order. - for i := 0; i < len(written)+len(injected); i++ { - n, err := tun.Read(buf[:], 0) - if err != nil { - t.Errorf("read %d: error: %v", i, err) - } - if n != size { - t.Errorf("read %d: got size %d; want %d", i, n, size) - } - got := string(buf[:n]) - t.Logf("read %d: got %s", i, got) - seen[got] = true - } - - for _, packet := range written { - if !seen[packet] { - t.Errorf("%s not received", packet) - } - } - for _, packet := range injected { - if !seen[packet] { - t.Errorf("%s not received", packet) - } - } -} - -func TestWriteAndInject(t *testing.T) { - chtun, tun := newChannelTUN(t.Logf, false) - defer tun.Close() - - const size = 2 // all payloads have this size - written := []string{"w0", "w1"} - injected := []string{"i0", "i1"} - - go func() { - for _, packet := range written { - payload := []byte(packet) - n, err := tun.Write(payload, 0) - if err != nil { - t.Errorf("%s: error: %v", packet, err) - } - if n != size { - t.Errorf("%s: got size %d; want %d", packet, n, size) - } - } - }() - - for _, packet := range injected { - go func(packet string) { - payload := []byte(packet) - err := tun.InjectInboundCopy(payload) - if err != nil { - t.Errorf("%s: error: %v", packet, err) - } - }(packet) - } - - seen := make(map[string]bool) - // We expect the same packets back, in no particular order. - for i := 0; i < len(written)+len(injected); i++ { - packet := <-chtun.Inbound - got := string(packet) - t.Logf("read %d: got %s", i, got) - seen[got] = true - } - - for _, packet := range written { - if !seen[packet] { - t.Errorf("%s not received", packet) - } - } - for _, packet := range injected { - if !seen[packet] { - t.Errorf("%s not received", packet) - } - } -} - -func TestFilter(t *testing.T) { - chtun, tun := newChannelTUN(t.Logf, true) - defer tun.Close() - - type direction int - - const ( - in direction = iota - out - ) - - tests := []struct { - name string - dir direction - drop bool - data []byte - }{ - {"junk_in", in, true, []byte("\x45not a valid IPv4 packet")}, - {"junk_out", out, true, []byte("\x45not a valid IPv4 packet")}, - {"bad_port_in", in, true, udp4("5.6.7.8", "1.2.3.4", 22, 22)}, - {"bad_port_out", out, false, udp4("1.2.3.4", "5.6.7.8", 22, 22)}, - {"bad_ip_in", in, true, udp4("8.1.1.1", "1.2.3.4", 89, 89)}, - {"bad_ip_out", out, false, udp4("1.2.3.4", "8.1.1.1", 98, 98)}, - {"good_packet_in", in, false, udp4("5.6.7.8", "1.2.3.4", 89, 89)}, - {"good_packet_out", out, false, udp4("1.2.3.4", "5.6.7.8", 98, 98)}, - } - - // A reader on the other end of the TUN. - go func() { - var recvbuf []byte - for { - select { - case <-tun.closed: - return - case recvbuf = <-chtun.Inbound: - // continue - } - for _, tt := range tests { - if tt.drop && bytes.Equal(recvbuf, tt.data) { - t.Errorf("did not drop %s", tt.name) - } - } - } - }() - - var buf [MaxPacketSize]byte - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - var n int - var err error - var filtered bool - - if tt.dir == in { - _, err = tun.Write(tt.data, 0) - if err == ErrFiltered { - filtered = true - err = nil - } - } else { - chtun.Outbound <- tt.data - n, err = tun.Read(buf[:], 0) - // In the read direction, errors are fatal, so we return n = 0 instead. - filtered = (n == 0) - } - - if err != nil { - t.Errorf("got err %v; want nil", err) - } - - if filtered { - if !tt.drop { - t.Errorf("got drop; want accept") - } - } else { - if tt.drop { - t.Errorf("got accept; want drop") - } - } - }) - } -} - -func TestAllocs(t *testing.T) { - ftun, tun := newFakeTUN(t.Logf, false) - defer tun.Close() - - buf := []byte{0x00} - allocs := testing.AllocsPerRun(100, func() { - _, err := ftun.Write(buf, 0) - if err != nil { - t.Errorf("write: error: %v", err) - return - } - }) - - if allocs > 0 { - t.Errorf("read allocs = %v; want 0", allocs) - } -} - -func TestClose(t *testing.T) { - ftun, tun := newFakeTUN(t.Logf, false) - - data := udp4("1.2.3.4", "5.6.7.8", 98, 98) - _, err := ftun.Write(data, 0) - if err != nil { - t.Error(err) - } - - tun.Close() - _, err = ftun.Write(data, 0) - if err == nil { - t.Error("Expected error from ftun.Write() after Close()") - } -} - -func BenchmarkWrite(b *testing.B) { - ftun, tun := newFakeTUN(b.Logf, true) - defer tun.Close() - - packet := udp4("5.6.7.8", "1.2.3.4", 89, 89) - for i := 0; i < b.N; i++ { - _, err := ftun.Write(packet, 0) - if err != nil { - b.Errorf("err = %v; want nil", err) - } - } -} - -func TestAtomic64Alignment(t *testing.T) { - off := unsafe.Offsetof(TUN{}.lastActivityAtomic) - if off%8 != 0 { - t.Errorf("offset %v not 8-byte aligned", off) - } - - c := new(TUN) - atomic.StoreInt64(&c.lastActivityAtomic, 123) -} diff --git a/wgengine/tstun/tun_windows.go b/wgengine/tstun/tun_windows.go deleted file mode 100644 index dc5fc2d79..000000000 --- a/wgengine/tstun/tun_windows.go +++ /dev/null @@ -1,24 +0,0 @@ -// Copyright (c) 2020 Tailscale Inc & AUTHORS All rights reserved. -// Use of this source code is governed by a BSD-style -// license that can be found in the LICENSE file. - -package tstun - -import ( - "github.com/tailscale/wireguard-go/tun" - "github.com/tailscale/wireguard-go/tun/wintun" - "golang.org/x/sys/windows" -) - -func init() { - var err error - tun.WintunPool, err = wintun.MakePool("Tailscale") - if err != nil { - panic(err) - } - guid, err := windows.GUIDFromString("{37217669-42da-4657-a55b-0d995d328250}") - if err != nil { - panic(err) - } - tun.WintunStaticRequestedGUID = &guid -} diff --git a/wgengine/userspace.go b/wgengine/userspace.go index fccca8c8b..9202e4c7a 100644 --- a/wgengine/userspace.go +++ b/wgengine/userspace.go @@ -8,12 +8,12 @@ import ( "bufio" "bytes" "context" + crand "crypto/rand" "errors" "fmt" "io" "net" "os" - "os/exec" "runtime" "strconv" "strings" @@ -29,38 +29,28 @@ import ( "tailscale.com/health" "tailscale.com/internal/deepprint" "tailscale.com/ipn/ipnstate" + "tailscale.com/net/dns" "tailscale.com/net/flowtrack" "tailscale.com/net/interfaces" "tailscale.com/net/packet" "tailscale.com/net/tsaddr" "tailscale.com/net/tshttpproxy" + "tailscale.com/net/tstun" "tailscale.com/tailcfg" + "tailscale.com/types/ipproto" "tailscale.com/types/key" "tailscale.com/types/logger" "tailscale.com/types/netmap" "tailscale.com/types/wgkey" "tailscale.com/version" - "tailscale.com/version/distro" "tailscale.com/wgengine/filter" "tailscale.com/wgengine/magicsock" "tailscale.com/wgengine/monitor" "tailscale.com/wgengine/router" - "tailscale.com/wgengine/tsdns" - "tailscale.com/wgengine/tstun" "tailscale.com/wgengine/wgcfg" "tailscale.com/wgengine/wglog" ) -// minimalMTU is the MTU we set on tailscale's TUN -// interface. wireguard-go defaults to 1420 bytes, which only works if -// the "outer" MTU is 1500 bytes. This breaks on DSL connections -// (typically 1492 MTU) and on GCE (1460 MTU?!). -// -// 1280 is the smallest MTU allowed for IPv6, which is a sensible -// "probably works everywhere" setting until we develop proper PMTU -// discovery. -const minimalMTU = 1280 - const magicDNSPort = 53 var magicDNSIP = netaddr.IPv4(100, 100, 100, 100) @@ -90,10 +80,10 @@ type userspaceEngine struct { reqCh chan struct{} waitCh chan struct{} // chan is closed when first Close call completes; contrast with closing bool timeNow func() time.Time - tundev *tstun.TUN + tundev *tstun.Wrapper wgdev *device.Device router router.Router - resolver *tsdns.Resolver + resolver *dns.Resolver magicConn *magicsock.Conn linkMon *monitor.Mon linkMonOwned bool // whether we created linkMon (and thus need to close it) @@ -101,10 +91,10 @@ type userspaceEngine struct { testMaybeReconfigHook func() // for tests; if non-nil, fires if maybeReconfigWireguardLocked called - // localAddrs is the set of IP addresses assigned to the local + // isLocalAddr reports the whether an IP is assigned to the local // tunnel interface. It's used to reflect local packets // incorrectly sent to us. - localAddrs atomic.Value // of map[netaddr.IP]bool + isLocalAddr atomic.Value // of func(netaddr.IP)bool wgLock sync.Mutex // serializes all wgdev operations; see lock order comment below lastCfgFull wgcfg.Config @@ -117,8 +107,9 @@ type userspaceEngine struct { destIPActivityFuncs map[netaddr.IP]func() statusBufioReader *bufio.Reader // reusable for UAPI - mu sync.Mutex // guards following; see lock order comment below - closing bool // Close was called (even if we're still closing) + mu sync.Mutex // guards following; see lock order comment below + netMap *netmap.NetworkMap // or nil + closing bool // Close was called (even if we're still closing) statusCallback StatusCallback peerSequence []wgkey.Key endpoints []string @@ -126,36 +117,30 @@ type userspaceEngine struct { pendOpen map[flowtrack.Tuple]*pendingOpenFlow // see pendopen.go networkMapCallbacks map[*someHandle]NetworkMapCallback tsIPByIPPort map[netaddr.IPPort]netaddr.IP // allows registration of IP:ports as belonging to a certain Tailscale IP for whois lookups + pongCallback map[[8]byte]func() // for TSMP pong responses // Lock ordering: magicsock.Conn.mu, wgLock, then mu. } // InternalsGetter is implemented by Engines that can export their internals. type InternalsGetter interface { - GetInternals() (*tstun.TUN, *magicsock.Conn) + GetInternals() (*tstun.Wrapper, *magicsock.Conn) } -func (e *userspaceEngine) GetInternals() (*tstun.TUN, *magicsock.Conn) { +func (e *userspaceEngine) GetInternals() (*tstun.Wrapper, *magicsock.Conn) { return e.tundev, e.magicConn } -// RouterGen is the signature for a function that creates a -// router.Router. -type RouterGen func(logf logger.Logf, wgdev *device.Device, tundev tun.Device) (router.Router, error) - // Config is the engine configuration. type Config struct { - // TUN is the TUN device used by the engine. - // Exactly one of either TUN or TUNName must be specified. - TUN tun.Device - - // TUNName is the TUN device to create. - // Exactly one of either TUN or TUNName must be specified. - TUNName string + // Tun is the device used by the Engine to exchange packets with + // the OS. + // If nil, a fake Device that does nothing is used. + Tun tun.Device - // RouterGen is the function used to instantiate the router. - // If nil, wgengine/router.New is used. - RouterGen RouterGen + // Router interfaces the Engine to the OS network stack. + // If nil, a fake Router that does nothing is used. + Router router.Router // LinkMonitor optionally provides an existing link monitor to re-use. // If nil, a new link monitor is created. @@ -165,59 +150,36 @@ type Config struct { // If zero, a port is automatically selected. ListenPort uint16 - // Fake determines whether this engine should automatically - // reply to ICMP pings. - Fake bool + // RespondToPing determines whether this engine should internally + // reply to ICMP pings, without involving the OS. + // Used in "fake" mode for development. + RespondToPing bool } func NewFakeUserspaceEngine(logf logger.Logf, listenPort uint16) (Engine, error) { logf("Starting userspace wireguard engine (with fake TUN device)") return NewUserspaceEngine(logf, Config{ - TUN: tstun.NewFakeTUN(), - RouterGen: router.NewFake, - ListenPort: listenPort, - Fake: true, + ListenPort: listenPort, + RespondToPing: true, }) } // NewUserspaceEngine creates the named tun device and returns a // Tailscale Engine running on it. -func NewUserspaceEngine(logf logger.Logf, conf Config) (Engine, error) { - if conf.TUN != nil && conf.TUNName != "" { - return nil, errors.New("TUN and TUNName are mutually exclusive") - } - if conf.TUN == nil && conf.TUNName == "" { - return nil, errors.New("either TUN or TUNName are required") - } - tunDev := conf.TUN - var err error - if tunName := conf.TUNName; tunName != "" { - logf("Starting userspace wireguard engine with tun device %q", tunName) - tunDev, err = tun.CreateTUN(tunName, minimalMTU) - if err != nil { - diagnoseTUNFailure(tunName, logf) - logf("CreateTUN: %v", err) - return nil, err - } - logf("CreateTUN ok.") +func NewUserspaceEngine(logf logger.Logf, conf Config) (_ Engine, reterr error) { + var closePool closeOnErrorPool + defer closePool.closeAllIfError(&reterr) - if err := waitInterfaceUp(tunDev, 90*time.Second, logf); err != nil { - return nil, err - } + if conf.Tun == nil { + logf("[v1] using fake (no-op) tun device") + conf.Tun = tstun.NewFake() } - - if conf.RouterGen == nil { - conf.RouterGen = router.New + if conf.Router == nil { + logf("[v1] using fake (no-op) OS network configurator") + conf.Router = router.NewFake(logf) } - return newUserspaceEngine(logf, tunDev, conf) -} - -func newUserspaceEngine(logf logger.Logf, rawTUNDev tun.Device, conf Config) (_ Engine, reterr error) { - var closePool closeOnErrorPool - defer closePool.closeAllIfError(&reterr) - - tsTUNDev := tstun.WrapTUN(logf, rawTUNDev) + tsTUNDev := tstun.Wrap(logf, conf.Tun) closePool.add(tsTUNDev) e := &userspaceEngine{ @@ -226,9 +188,10 @@ func newUserspaceEngine(logf logger.Logf, rawTUNDev tun.Device, conf Config) (_ reqCh: make(chan struct{}, 1), waitCh: make(chan struct{}), tundev: tsTUNDev, + router: conf.Router, pingers: make(map[wgkey.Key]*pinger), } - e.localAddrs.Store(map[netaddr.IP]bool{}) + e.isLocalAddr.Store(genLocalAddrFunc(nil)) if conf.LinkMonitor != nil { e.linkMon = conf.LinkMonitor @@ -242,7 +205,7 @@ func newUserspaceEngine(logf logger.Logf, rawTUNDev tun.Device, conf Config) (_ e.linkMonOwned = true } - e.resolver = tsdns.NewResolver(tsdns.ResolverConfig{ + e.resolver = dns.NewResolver(dns.ResolverConfig{ Logf: logf, Forward: true, LinkMonitor: e.linkMon, @@ -281,8 +244,7 @@ func newUserspaceEngine(logf logger.Logf, rawTUNDev tun.Device, conf Config) (_ closePool.add(e.magicConn) e.magicConn.SetNetworkUp(e.linkMon.InterfaceState().AnyInterfaceUp()) - // Respond to all pings only in fake mode. - if conf.Fake { + if conf.RespondToPing { e.tundev.PostFilterIn = echoRespondToAll } e.tundev.PreFilterOut = e.handleLocalPackets @@ -300,7 +262,6 @@ func newUserspaceEngine(logf logger.Logf, rawTUNDev tun.Device, conf Config) (_ e.wgLogger = wglog.NewLogger(logf) opts := &device.DeviceOptions{ - Logger: e.wgLogger.DeviceLogger, HandshakeDone: func(peerKey device.NoisePublicKey, peer *device.Peer, deviceAllowedIPs *device.AllowedIPs) { // Send an unsolicited status event every time a // handshake completes. This makes sure our UI can @@ -349,20 +310,21 @@ func newUserspaceEngine(logf logger.Logf, rawTUNDev tun.Device, conf Config) (_ SkipBindUpdate: true, } + e.tundev.OnTSMPPongReceived = func(data [8]byte) { + e.mu.Lock() + defer e.mu.Unlock() + cb := e.pongCallback[data] + e.logf("wgengine: got TSMP pong %02x; cb=%v", data, cb != nil) + if cb != nil { + go cb() + } + } + // wgdev takes ownership of tundev, will close it when closed. e.logf("Creating wireguard device...") - e.wgdev = device.NewDevice(e.tundev, opts) + e.wgdev = device.NewDevice(e.tundev, e.wgLogger.DeviceLogger, opts) closePool.addFunc(e.wgdev.Close) - // Pass the underlying tun.(*NativeDevice) to the router: - // routers do not Read or Write, but do access native interfaces. - e.logf("Creating router...") - e.router, err = conf.RouterGen(logf, e.wgdev, e.tundev.Unwrap()) - if err != nil { - return nil, err - } - closePool.add(e.router) - go func() { up := false for event := range e.tundev.Events() { @@ -411,7 +373,7 @@ func newUserspaceEngine(logf logger.Logf, rawTUNDev tun.Device, conf Config) (_ } // echoRespondToAll is an inbound post-filter responding to all echo requests. -func echoRespondToAll(p *packet.Parsed, t *tstun.TUN) filter.Response { +func echoRespondToAll(p *packet.Parsed, t *tstun.Wrapper) filter.Response { if p.IsEchoRequest() { header := p.ICMP4Header() header.ToResponse() @@ -432,44 +394,40 @@ func echoRespondToAll(p *packet.Parsed, t *tstun.TUN) filter.Response { // stack, and intercepts any packets that should be handled by // tailscaled directly. Other packets are allowed to proceed into the // main ACL filter. -func (e *userspaceEngine) handleLocalPackets(p *packet.Parsed, t *tstun.TUN) filter.Response { +func (e *userspaceEngine) handleLocalPackets(p *packet.Parsed, t *tstun.Wrapper) filter.Response { if verdict := e.handleDNS(p, t); verdict == filter.Drop { // local DNS handled the packet. return filter.Drop } - if (runtime.GOOS == "darwin" || runtime.GOOS == "ios") && e.isLocalAddr(p.Dst.IP) { - // macOS NetworkExtension directs packets destined to the - // tunnel's local IP address into the tunnel, instead of - // looping back within the kernel network stack. We have to - // notice that an outbound packet is actually destined for - // ourselves, and loop it back into macOS. - t.InjectInboundCopy(p.Buffer()) - return filter.Drop + if runtime.GOOS == "darwin" || runtime.GOOS == "ios" { + isLocalAddr, ok := e.isLocalAddr.Load().(func(netaddr.IP) bool) + if !ok { + e.logf("[unexpected] e.isLocalAddr was nil, can't check for loopback packet") + } else if isLocalAddr(p.Dst.IP) { + // macOS NetworkExtension directs packets destined to the + // tunnel's local IP address into the tunnel, instead of + // looping back within the kernel network stack. We have to + // notice that an outbound packet is actually destined for + // ourselves, and loop it back into macOS. + t.InjectInboundCopy(p.Buffer()) + return filter.Drop + } } return filter.Accept } -func (e *userspaceEngine) isLocalAddr(ip netaddr.IP) bool { - localAddrs, ok := e.localAddrs.Load().(map[netaddr.IP]bool) - if !ok { - e.logf("[unexpected] e.localAddrs was nil, can't check for loopback packet") - return false - } - return localAddrs[ip] -} - // handleDNS is an outbound pre-filter resolving Tailscale domains. -func (e *userspaceEngine) handleDNS(p *packet.Parsed, t *tstun.TUN) filter.Response { - if p.Dst.IP == magicDNSIP && p.Dst.Port == magicDNSPort && p.IPProto == packet.UDP { - request := tsdns.Packet{ +func (e *userspaceEngine) handleDNS(p *packet.Parsed, t *tstun.Wrapper) filter.Response { + if p.Dst.IP == magicDNSIP && p.Dst.Port == magicDNSPort && p.IPProto == ipproto.UDP { + request := dns.Packet{ Payload: append([]byte(nil), p.Payload()...), Addr: netaddr.IPPort{IP: p.Src.IP, Port: p.Src.Port}, } err := e.resolver.EnqueueRequest(request) if err != nil { - e.logf("tsdns: enqueue: %v", err) + e.logf("dns: enqueue: %v", err) } return filter.Drop } @@ -480,11 +438,11 @@ func (e *userspaceEngine) handleDNS(p *packet.Parsed, t *tstun.TUN) filter.Respo func (e *userspaceEngine) pollResolver() { for { resp, err := e.resolver.NextResponse() - if err == tsdns.ErrClosed { + if err == dns.ErrClosed { return } if err != nil { - e.logf("tsdns: error: %v", err) + e.logf("dns: error: %v", err) continue } @@ -498,7 +456,7 @@ func (e *userspaceEngine) pollResolver() { } hlen := h.Len() - // TODO(dmytro): avoid this allocation without importing tstun quirks into tsdns. + // TODO(dmytro): avoid this allocation without importing tstun quirks into dns. const offset = tstun.PacketStartOffset buf := make([]byte, offset+hlen+len(resp.Payload)) copy(buf[offset+hlen:], resp.Payload) @@ -925,16 +883,34 @@ func (e *userspaceEngine) updateActivityMapsLocked(trackDisco []tailcfg.DiscoKey e.tundev.SetDestIPActivityFuncs(e.destIPActivityFuncs) } +// genLocalAddrFunc returns a func that reports whether an IP is in addrs. +// addrs is assumed to be all /32 or /128 entries. +func genLocalAddrFunc(addrs []netaddr.IPPrefix) func(netaddr.IP) bool { + // Specialize the three common cases: no address, just IPv4 + // (or just IPv6), and both IPv4 and IPv6. + if len(addrs) == 0 { + return func(netaddr.IP) bool { return false } + } + if len(addrs) == 1 { + return func(t netaddr.IP) bool { return t == addrs[0].IP } + } + if len(addrs) == 2 { + return func(t netaddr.IP) bool { return t == addrs[0].IP || t == addrs[1].IP } + } + // Otherwise, the general implementation: a map lookup. + m := map[netaddr.IP]bool{} + for _, a := range addrs { + m[a.IP] = true + } + return func(t netaddr.IP) bool { return m[t] } +} + func (e *userspaceEngine) Reconfig(cfg *wgcfg.Config, routerCfg *router.Config) error { if routerCfg == nil { panic("routerCfg must not be nil") } - localAddrs := map[netaddr.IP]bool{} - for _, addr := range routerCfg.LocalAddrs { - localAddrs[addr.IP] = true - } - e.localAddrs.Store(localAddrs) + e.isLocalAddr.Store(genLocalAddrFunc(routerCfg.LocalAddrs)) e.wgLock.Lock() defer e.wgLock.Unlock() @@ -1034,7 +1010,7 @@ func (e *userspaceEngine) SetFilter(filt *filter.Filter) { e.tundev.SetFilter(filt) } -func (e *userspaceEngine) SetDNSMap(dm *tsdns.Map) { +func (e *userspaceEngine) SetDNSMap(dm *dns.Map) { e.resolver.SetMap(dm) } @@ -1270,6 +1246,7 @@ func (e *userspaceEngine) linkChange(changed bool, cur *interfaces.State) { e.logf("[v1] LinkChange: minor") } + health.SetAnyInterfaceUp(up) e.magicConn.SetNetworkUp(up) why := "link-change-minor" @@ -1306,6 +1283,7 @@ func (e *userspaceEngine) SetDERPMap(dm *tailcfg.DERPMap) { func (e *userspaceEngine) SetNetworkMap(nm *netmap.NetworkMap) { e.magicConn.SetNetworkMap(nm) e.mu.Lock() + e.netMap = nm callbacks := make([]NetworkMapCallback, 0, 4) for _, fn := range e.networkMapCallbacks { callbacks = append(callbacks, fn) @@ -1338,8 +1316,107 @@ func (e *userspaceEngine) UpdateStatus(sb *ipnstate.StatusBuilder) { e.magicConn.UpdateStatus(sb) } -func (e *userspaceEngine) Ping(ip netaddr.IP, cb func(*ipnstate.PingResult)) { - e.magicConn.Ping(ip, cb) +func (e *userspaceEngine) Ping(ip netaddr.IP, useTSMP bool, cb func(*ipnstate.PingResult)) { + res := &ipnstate.PingResult{IP: ip.String()} + peer, err := e.peerForIP(ip) + if err != nil { + e.logf("ping(%v): %v", ip, err) + res.Err = err.Error() + cb(res) + return + } + if peer == nil { + e.logf("ping(%v): no matching peer", ip) + res.Err = "no matching peer" + cb(res) + return + } + pingType := "disco" + if useTSMP { + pingType = "TSMP" + } + e.logf("ping(%v): sending %v ping to %v %v ...", ip, pingType, peer.Key.ShortString(), peer.ComputedName) + if useTSMP { + e.sendTSMPPing(ip, peer, res, cb) + } else { + e.magicConn.Ping(peer, res, cb) + } +} + +func (e *userspaceEngine) mySelfIPMatchingFamily(dst netaddr.IP) (src netaddr.IP, err error) { + e.mu.Lock() + defer e.mu.Unlock() + if e.netMap == nil { + return netaddr.IP{}, errors.New("no netmap") + } + for _, a := range e.netMap.Addresses { + if a.IsSingleIP() && a.IP.BitLen() == dst.BitLen() { + return a.IP, nil + } + } + if len(e.netMap.Addresses) == 0 { + return netaddr.IP{}, errors.New("no self address in netmap") + } + return netaddr.IP{}, errors.New("no self address in netmap matching address family") +} + +func (e *userspaceEngine) sendTSMPPing(ip netaddr.IP, peer *tailcfg.Node, res *ipnstate.PingResult, cb func(*ipnstate.PingResult)) { + srcIP, err := e.mySelfIPMatchingFamily(ip) + if err != nil { + res.Err = err.Error() + cb(res) + return + } + var iph packet.Header + if srcIP.Is4() { + iph = packet.IP4Header{ + IPProto: ipproto.TSMP, + Src: srcIP, + Dst: ip, + } + } else { + iph = packet.IP6Header{ + IPProto: ipproto.TSMP, + Src: srcIP, + Dst: ip, + } + } + + var data [8]byte + crand.Read(data[:]) + + expireTimer := time.AfterFunc(10*time.Second, func() { + e.setTSMPPongCallback(data, nil) + }) + t0 := time.Now() + e.setTSMPPongCallback(data, func() { + expireTimer.Stop() + d := time.Since(t0) + res.LatencySeconds = d.Seconds() + res.NodeIP = ip.String() + res.NodeName = peer.ComputedName + cb(res) + }) + + var tsmpPayload [9]byte + tsmpPayload[0] = byte(packet.TSMPTypePing) + copy(tsmpPayload[1:], data[:]) + + tsmpPing := packet.Generate(iph, tsmpPayload[:]) + e.tundev.InjectOutbound(tsmpPing) +} + +func (e *userspaceEngine) setTSMPPongCallback(data [8]byte, cb func()) { + e.mu.Lock() + defer e.mu.Unlock() + if e.pongCallback == nil { + e.pongCallback = map[[8]byte]func(){} + } + if cb == nil { + delete(e.pongCallback, data) + } else { + e.pongCallback[data] = cb + } } func (e *userspaceEngine) RegisterIPPortIdentity(ipport netaddr.IPPort, tsIP netaddr.IP) { @@ -1367,92 +1444,77 @@ func (e *userspaceEngine) WhoIsIPPort(ipport netaddr.IPPort) (tsIP netaddr.IP, o return tsIP, ok } -// diagnoseTUNFailure is called if tun.CreateTUN fails, to poke around -// the system and log some diagnostic info that might help debug why -// TUN failed. Because TUN's already failed and things the program's -// about to end, we might as well log a lot. -func diagnoseTUNFailure(tunName string, logf logger.Logf) { - switch runtime.GOOS { - case "linux": - diagnoseLinuxTUNFailure(tunName, logf) - case "darwin": - diagnoseDarwinTUNFailure(tunName, logf) - default: - logf("no TUN failure diagnostics for OS %q", runtime.GOOS) +// peerForIP returns the Node in the wireguard config +// that's responsible for handling the given IP address. +// +// If none is found in the wireguard config but one is found in +// the netmap, it's described in an error. +// +// If none is found in either place, (nil, nil) is returned. +// +// peerForIP acquires both e.mu and e.wgLock, but neither at the same +// time. +func (e *userspaceEngine) peerForIP(ip netaddr.IP) (n *tailcfg.Node, err error) { + e.mu.Lock() + nm := e.netMap + e.mu.Unlock() + if nm == nil { + return nil, errors.New("no network map") } -} -func diagnoseDarwinTUNFailure(tunName string, logf logger.Logf) { - if os.Getuid() != 0 { - logf("failed to create TUN device as non-root user; use 'sudo tailscaled', or run under launchd with 'sudo tailscaled install-system-daemon'") - } - if tunName != "utun" { - logf("failed to create TUN device %q; try using tun device \"utun\" instead for automatic selection", tunName) + // Check for exact matches before looking for subnet matches. + var bestInNMPrefix netaddr.IPPrefix + var bestInNM *tailcfg.Node + for _, p := range nm.Peers { + for _, a := range p.Addresses { + if a.IP == ip && a.IsSingleIP() && tsaddr.IsTailscaleIP(ip) { + return p, nil + } + } + for _, cidr := range p.AllowedIPs { + if !cidr.Contains(ip) { + continue + } + if bestInNMPrefix.IsZero() || cidr.Bits > bestInNMPrefix.Bits { + bestInNMPrefix = cidr + bestInNM = p + } + } } -} -func diagnoseLinuxTUNFailure(tunName string, logf logger.Logf) { - kernel, err := exec.Command("uname", "-r").Output() - kernel = bytes.TrimSpace(kernel) - if err != nil { - logf("no TUN, and failed to look up kernel version: %v", err) - return - } - logf("Linux kernel version: %s", kernel) + e.wgLock.Lock() + defer e.wgLock.Unlock() - modprobeOut, err := exec.Command("/sbin/modprobe", "tun").CombinedOutput() - if err == nil { - logf("'modprobe tun' successful") - // Either tun is currently loaded, or it's statically - // compiled into the kernel (which modprobe checks - // with /lib/modules/$(uname -r)/modules.builtin) - // - // So if there's a problem at this point, it's - // probably because /dev/net/tun doesn't exist. - const dev = "/dev/net/tun" - if fi, err := os.Stat(dev); err != nil { - logf("tun module loaded in kernel, but %s does not exist", dev) - } else { - logf("%s: %v", dev, fi.Mode()) + // TODO(bradfitz): this is O(n peers). Add ART to netaddr? + var best netaddr.IPPrefix + var bestKey tailcfg.NodeKey + for _, p := range e.lastCfgFull.Peers { + for _, cidr := range p.AllowedIPs { + if !cidr.Contains(ip) { + continue + } + if best.IsZero() || cidr.Bits > best.Bits { + best = cidr + bestKey = tailcfg.NodeKey(p.PublicKey) + } } - - // We failed to find why it failed. Just let our - // caller report the error it got from wireguard-go. - return } - logf("is CONFIG_TUN enabled in your kernel? `modprobe tun` failed with: %s", modprobeOut) - - switch distro.Get() { - case distro.Debian: - dpkgOut, err := exec.Command("dpkg", "-S", "kernel/drivers/net/tun.ko").CombinedOutput() - if len(bytes.TrimSpace(dpkgOut)) == 0 || err != nil { - logf("tun module not loaded nor found on disk") - return - } - if !bytes.Contains(dpkgOut, kernel) { - logf("kernel/drivers/net/tun.ko found on disk, but not for current kernel; are you in middle of a system update and haven't rebooted? found: %s", dpkgOut) - } - case distro.Arch: - findOut, err := exec.Command("find", "/lib/modules/", "-path", "*/net/tun.ko*").CombinedOutput() - if len(bytes.TrimSpace(findOut)) == 0 || err != nil { - logf("tun module not loaded nor found on disk") - return - } - if !bytes.Contains(findOut, kernel) { - logf("kernel/drivers/net/tun.ko found on disk, but not for current kernel; are you in middle of a system update and haven't rebooted? found: %s", findOut) - } - case distro.OpenWrt: - out, err := exec.Command("opkg", "list-installed").CombinedOutput() - if err != nil { - logf("error querying OpenWrt installed packages: %s", out) - return - } - for _, pkg := range []string{"kmod-tun", "ca-bundle"} { - if !bytes.Contains(out, []byte(pkg+" - ")) { - logf("Missing required package %s; run: opkg install %s", pkg, pkg) + // And another pass. Probably better than allocating a map per peerForIP + // call. But TODO(bradfitz): add a lookup map to netmap.NetworkMap. + if !bestKey.IsZero() { + for _, p := range nm.Peers { + if p.Key == bestKey { + return p, nil } } } + if bestInNM == nil { + return nil, nil + } + if bestInNMPrefix.Bits == 0 { + return nil, errors.New("exit node found but not enabled") + } + return nil, fmt.Errorf("node %q found, but not using its %v route", bestInNM.ComputedNameWithHost, bestInNMPrefix) } type closeOnErrorPool []func() diff --git a/wgengine/userspace_test.go b/wgengine/userspace_test.go index e9b83389a..ac081d83e 100644 --- a/wgengine/userspace_test.go +++ b/wgengine/userspace_test.go @@ -13,10 +13,10 @@ import ( "go4.org/mem" "inet.af/netaddr" + "tailscale.com/net/tstun" "tailscale.com/tailcfg" "tailscale.com/types/key" "tailscale.com/wgengine/router" - "tailscale.com/wgengine/tstun" "tailscale.com/wgengine/wgcfg" ) @@ -39,7 +39,7 @@ func TestNoteReceiveActivity(t *testing.T) { logf: func(format string, a ...interface{}) { fmt.Fprintf(&logBuf, format, a...) }, - tundev: new(tstun.TUN), + tundev: new(tstun.Wrapper), testMaybeReconfigHook: func() { confc <- true }, trimmedDisco: map[tailcfg.DiscoKey]bool{}, } @@ -139,3 +139,50 @@ func dkFromHex(hex string) tailcfg.DiscoKey { } return tailcfg.DiscoKey(k) } + +// an experiment to see if genLocalAddrFunc was worth it. As of Go +// 1.16, it still very much is. (30-40x faster) +func BenchmarkGenLocalAddrFunc(b *testing.B) { + la1 := netaddr.MustParseIP("1.2.3.4") + la2 := netaddr.MustParseIP("::4") + lanot := netaddr.MustParseIP("5.5.5.5") + var x bool + b.Run("map1", func(b *testing.B) { + m := map[netaddr.IP]bool{ + la1: true, + } + for i := 0; i < b.N; i++ { + x = m[la1] + x = m[lanot] + } + }) + b.Run("map2", func(b *testing.B) { + m := map[netaddr.IP]bool{ + la1: true, + la2: true, + } + for i := 0; i < b.N; i++ { + x = m[la1] + x = m[lanot] + } + }) + b.Run("or1", func(b *testing.B) { + f := func(t netaddr.IP) bool { + return t == la1 + } + for i := 0; i < b.N; i++ { + x = f(la1) + x = f(lanot) + } + }) + b.Run("or2", func(b *testing.B) { + f := func(t netaddr.IP) bool { + return t == la1 || t == la2 + } + for i := 0; i < b.N; i++ { + x = f(la1) + x = f(lanot) + } + }) + b.Logf("x = %v", x) +} diff --git a/wgengine/watchdog.go b/wgengine/watchdog.go index f4f7d3085..f3248d6fc 100644 --- a/wgengine/watchdog.go +++ b/wgengine/watchdog.go @@ -14,12 +14,12 @@ import ( "inet.af/netaddr" "tailscale.com/ipn/ipnstate" + "tailscale.com/net/dns" "tailscale.com/tailcfg" "tailscale.com/types/netmap" "tailscale.com/wgengine/filter" "tailscale.com/wgengine/monitor" "tailscale.com/wgengine/router" - "tailscale.com/wgengine/tsdns" "tailscale.com/wgengine/wgcfg" ) @@ -84,7 +84,7 @@ func (e *watchdogEngine) GetFilter() *filter.Filter { func (e *watchdogEngine) SetFilter(filt *filter.Filter) { e.watchdog("SetFilter", func() { e.wrap.SetFilter(filt) }) } -func (e *watchdogEngine) SetDNSMap(dm *tsdns.Map) { +func (e *watchdogEngine) SetDNSMap(dm *dns.Map) { e.watchdog("SetDNSMap", func() { e.wrap.SetDNSMap(dm) }) } func (e *watchdogEngine) SetStatusCallback(cb StatusCallback) { @@ -117,8 +117,8 @@ func (e *watchdogEngine) DiscoPublicKey() (k tailcfg.DiscoKey) { e.watchdog("DiscoPublicKey", func() { k = e.wrap.DiscoPublicKey() }) return k } -func (e *watchdogEngine) Ping(ip netaddr.IP, cb func(*ipnstate.PingResult)) { - e.watchdog("Ping", func() { e.wrap.Ping(ip, cb) }) +func (e *watchdogEngine) Ping(ip netaddr.IP, useTSMP bool, cb func(*ipnstate.PingResult)) { + e.watchdog("Ping", func() { e.wrap.Ping(ip, useTSMP, cb) }) } func (e *watchdogEngine) RegisterIPPortIdentity(ipp netaddr.IPPort, tsIP netaddr.IP) { e.watchdog("RegisterIPPortIdentity", func() { e.wrap.RegisterIPPortIdentity(ipp, tsIP) }) diff --git a/wgengine/watchdog_test.go b/wgengine/watchdog_test.go index 7487d4827..0a699f848 100644 --- a/wgengine/watchdog_test.go +++ b/wgengine/watchdog_test.go @@ -7,6 +7,7 @@ package wgengine import ( "bytes" "fmt" + "runtime" "strings" "testing" "time" @@ -15,6 +16,13 @@ import ( func TestWatchdog(t *testing.T) { t.Parallel() + var maxWaitMultiple time.Duration = 1 + if runtime.GOOS == "darwin" { + // Work around slow close syscalls on Big Sur with content filter Network Extensions installed. + // See https://github.com/tailscale/tailscale/issues/1598. + maxWaitMultiple = 15 + } + t.Run("default watchdog does not fire", func(t *testing.T) { t.Parallel() e, err := NewFakeUserspaceEngine(t.Logf, 0) @@ -23,7 +31,7 @@ func TestWatchdog(t *testing.T) { } e = NewWatchdog(e) - e.(*watchdogEngine).maxWait = 150 * time.Millisecond + e.(*watchdogEngine).maxWait = maxWaitMultiple * 150 * time.Millisecond e.(*watchdogEngine).logf = t.Logf e.(*watchdogEngine).fatalf = t.Fatalf @@ -42,7 +50,7 @@ func TestWatchdog(t *testing.T) { usEngine := e.(*userspaceEngine) e = NewWatchdog(e) wdEngine := e.(*watchdogEngine) - wdEngine.maxWait = 100 * time.Millisecond + wdEngine.maxWait = maxWaitMultiple * 100 * time.Millisecond logBuf := new(bytes.Buffer) fatalCalled := make(chan struct{}) diff --git a/wgengine/wgcfg/device_test.go b/wgengine/wgcfg/device_test.go index d48da7c52..6bab065a5 100644 --- a/wgengine/wgcfg/device_test.go +++ b/wgengine/wgcfg/device_test.go @@ -55,12 +55,8 @@ func TestDeviceConfig(t *testing.T) { }}, } - device1 := device.NewDevice(newNilTun(), &device.DeviceOptions{ - Logger: device.NewLogger(device.LogLevelError, "device1"), - }) - device2 := device.NewDevice(newNilTun(), &device.DeviceOptions{ - Logger: device.NewLogger(device.LogLevelError, "device2"), - }) + device1 := device.NewDevice(newNilTun(), device.NewLogger(device.LogLevelError, "device1")) + device2 := device.NewDevice(newNilTun(), device.NewLogger(device.LogLevelError, "device2")) defer device1.Close() defer device2.Close() diff --git a/wgengine/wgengine.go b/wgengine/wgengine.go index c8e7963db..5946e1e77 100644 --- a/wgengine/wgengine.go +++ b/wgengine/wgengine.go @@ -9,12 +9,12 @@ import ( "inet.af/netaddr" "tailscale.com/ipn/ipnstate" + "tailscale.com/net/dns" "tailscale.com/tailcfg" "tailscale.com/types/netmap" "tailscale.com/wgengine/filter" "tailscale.com/wgengine/monitor" "tailscale.com/wgengine/router" - "tailscale.com/wgengine/tsdns" "tailscale.com/wgengine/wgcfg" ) @@ -66,7 +66,7 @@ type Engine interface { SetFilter(*filter.Filter) // SetDNSMap updates the DNS map. - SetDNSMap(*tsdns.Map) + SetDNSMap(*dns.Map) // SetStatusCallback sets the function to call when the // WireGuard status changes. @@ -136,7 +136,7 @@ type Engine interface { // Ping is a request to start a discovery ping with the peer handling // the given IP and then call cb with its ping latency & method. - Ping(ip netaddr.IP, cb func(*ipnstate.PingResult)) + Ping(ip netaddr.IP, useTSMP bool, cb func(*ipnstate.PingResult)) // RegisterIPPortIdentity registers a given node (identified by its // Tailscale IP) as temporarily having the given IP:port for whois lookups. |
