summaryrefslogtreecommitdiffhomepage
path: root/net
diff options
context:
space:
mode:
authorNaman Sood <mail@nsood.in>2021-03-29 14:28:08 -0400
committerNaman Sood <mail@nsood.in>2021-03-29 14:28:08 -0400
commitc0a88a0129ebf0f9886b93b1f4e4f04a7c3bb86f (patch)
tree57d5aef2985e3424e5bb6f4c810628aa3ccbf5d0 /net
parent47bd3c4cf5543fd7ecb049302c37c1001fa9f2d6 (diff)
parenta4c679e64691a3f0ba41ad9078312ca67e5e67fd (diff)
downloadtailscale-naman/netstack-subnet-routing.tar.xz
tailscale-naman/netstack-subnet-routing.zip
Signed-off-by: Naman Sood <mail@nsood.in>
Diffstat (limited to 'net')
-rw-r--r--net/dns/config.go77
-rw-r--r--net/dns/direct.go188
-rw-r--r--net/dns/flush_windows.go19
-rw-r--r--net/dns/forwarder.go474
-rw-r--r--net/dns/manager.go100
-rw-r--r--net/dns/manager_default.go14
-rw-r--r--net/dns/manager_freebsd.go14
-rw-r--r--net/dns/manager_linux.go27
-rw-r--r--net/dns/manager_openbsd.go9
-rw-r--r--net/dns/manager_windows.go118
-rw-r--r--net/dns/map.go160
-rw-r--r--net/dns/map_test.go156
-rw-r--r--net/dns/neterr_darwin.go25
-rw-r--r--net/dns/neterr_other.go10
-rw-r--r--net/dns/neterr_windows.go29
-rw-r--r--net/dns/nm.go205
-rw-r--r--net/dns/noop.go17
-rw-r--r--net/dns/registry_windows.go76
-rw-r--r--net/dns/resolvconf.go157
-rw-r--r--net/dns/resolved.go188
-rw-r--r--net/dns/tsdns.go662
-rw-r--r--net/dns/tsdns_server_test.go95
-rw-r--r--net/dns/tsdns_test.go816
-rw-r--r--net/flowtrack/flowtrack.go10
-rw-r--r--net/interfaces/interfaces.go117
-rw-r--r--net/interfaces/interfaces_test.go12
-rw-r--r--net/interfaces/interfaces_windows.go108
-rw-r--r--net/packet/header.go1
-rw-r--r--net/packet/icmp4.go8
-rw-r--r--net/packet/ip.go66
-rw-r--r--net/packet/ip4.go3
-rw-r--r--net/packet/ip6.go3
-rw-r--r--net/packet/packet.go104
-rw-r--r--net/packet/packet_test.go48
-rw-r--r--net/packet/tsmp.go72
-rw-r--r--net/packet/udp4.go8
-rw-r--r--net/packet/udp6.go8
-rw-r--r--net/tsaddr/tsaddr.go2
-rw-r--r--net/tstun/fake.go54
-rw-r--r--net/tstun/ifstatus_noop.go19
-rw-r--r--net/tstun/ifstatus_windows.go111
-rw-r--r--net/tstun/tun.go129
-rw-r--r--net/tstun/tun_windows.go24
-rw-r--r--net/tstun/wrap.go498
-rw-r--r--net/tstun/wrap_test.go387
45 files changed, 5216 insertions, 212 deletions
diff --git a/net/dns/config.go b/net/dns/config.go
new file mode 100644
index 000000000..db08df7aa
--- /dev/null
+++ b/net/dns/config.go
@@ -0,0 +1,77 @@
+// Copyright (c) 2020 Tailscale Inc & AUTHORS All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package dns
+
+import (
+ "inet.af/netaddr"
+
+ "tailscale.com/types/logger"
+)
+
+// Config is the set of parameters that uniquely determine
+// the state to which a manager should bring system DNS settings.
+type Config struct {
+ // Nameservers are the IP addresses of the nameservers to use.
+ Nameservers []netaddr.IP
+ // Domains are the search domains to use.
+ Domains []string
+ // PerDomain indicates whether it is preferred to use Nameservers
+ // only for DNS queries for subdomains of Domains.
+ // Note that Nameservers may still be applied to all queries
+ // if the manager does not support per-domain settings.
+ PerDomain bool
+ // Proxied indicates whether DNS requests are proxied through a dns.Resolver.
+ // This enables MagicDNS.
+ Proxied bool
+}
+
+// Equal determines whether its argument and receiver
+// represent equivalent DNS configurations (then DNS reconfig is a no-op).
+func (lhs Config) Equal(rhs Config) bool {
+ if lhs.Proxied != rhs.Proxied || lhs.PerDomain != rhs.PerDomain {
+ return false
+ }
+
+ if len(lhs.Nameservers) != len(rhs.Nameservers) {
+ return false
+ }
+
+ if len(lhs.Domains) != len(rhs.Domains) {
+ return false
+ }
+
+ // With how we perform resolution order shouldn't matter,
+ // but it is unlikely that we will encounter different orders.
+ for i, server := range lhs.Nameservers {
+ if rhs.Nameservers[i] != server {
+ return false
+ }
+ }
+
+ // The order of domains, on the other hand, is significant.
+ for i, domain := range lhs.Domains {
+ if rhs.Domains[i] != domain {
+ return false
+ }
+ }
+
+ return true
+}
+
+// ManagerConfig is the set of parameters from which
+// a manager implementation is chosen and initialized.
+type ManagerConfig struct {
+ // Logf is the logger for the manager to use.
+ // It is wrapped with a "dns: " prefix.
+ Logf logger.Logf
+ // InterfaceName is the name of the interface with which DNS settings should be associated.
+ InterfaceName string
+ // Cleanup indicates that the manager is created for cleanup only.
+ // A no-op manager will be instantiated if the system needs no cleanup.
+ Cleanup bool
+ // PerDomain indicates that a manager capable of per-domain configuration is preferred.
+ // Certain managers are per-domain only; they will not be considered if this is false.
+ PerDomain bool
+}
diff --git a/net/dns/direct.go b/net/dns/direct.go
new file mode 100644
index 000000000..bd1c03b9d
--- /dev/null
+++ b/net/dns/direct.go
@@ -0,0 +1,188 @@
+// Copyright (c) 2020 Tailscale Inc & AUTHORS All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+// +build linux freebsd openbsd
+
+package dns
+
+import (
+ "bufio"
+ "bytes"
+ "errors"
+ "fmt"
+ "io"
+ "io/ioutil"
+ "os"
+ "os/exec"
+ "runtime"
+ "strings"
+
+ "inet.af/netaddr"
+ "tailscale.com/atomicfile"
+)
+
+const (
+ tsConf = "/etc/resolv.tailscale.conf"
+ backupConf = "/etc/resolv.pre-tailscale-backup.conf"
+ resolvConf = "/etc/resolv.conf"
+)
+
+// writeResolvConf writes DNS configuration in resolv.conf format to the given writer.
+func writeResolvConf(w io.Writer, servers []netaddr.IP, domains []string) {
+ io.WriteString(w, "# resolv.conf(5) file generated by tailscale\n")
+ io.WriteString(w, "# DO NOT EDIT THIS FILE BY HAND -- CHANGES WILL BE OVERWRITTEN\n\n")
+ for _, ns := range servers {
+ io.WriteString(w, "nameserver ")
+ io.WriteString(w, ns.String())
+ io.WriteString(w, "\n")
+ }
+ if len(domains) > 0 {
+ io.WriteString(w, "search")
+ for _, domain := range domains {
+ io.WriteString(w, " ")
+ io.WriteString(w, domain)
+ }
+ io.WriteString(w, "\n")
+ }
+}
+
+// readResolvConf reads DNS configuration from /etc/resolv.conf.
+func readResolvConf() (Config, error) {
+ var config Config
+
+ f, err := os.Open("/etc/resolv.conf")
+ if err != nil {
+ return config, err
+ }
+
+ scanner := bufio.NewScanner(f)
+ for scanner.Scan() {
+ line := strings.TrimSpace(scanner.Text())
+
+ if strings.HasPrefix(line, "nameserver") {
+ nameserver := strings.TrimPrefix(line, "nameserver")
+ nameserver = strings.TrimSpace(nameserver)
+ ip, err := netaddr.ParseIP(nameserver)
+ if err != nil {
+ return config, err
+ }
+ config.Nameservers = append(config.Nameservers, ip)
+ continue
+ }
+
+ if strings.HasPrefix(line, "search") {
+ domain := strings.TrimPrefix(line, "search")
+ domain = strings.TrimSpace(domain)
+ config.Domains = append(config.Domains, domain)
+ continue
+ }
+ }
+
+ return config, nil
+}
+
+// isResolvedRunning reports whether systemd-resolved is running on the system,
+// even if it is not managing the system DNS settings.
+func isResolvedRunning() bool {
+ if runtime.GOOS != "linux" {
+ return false
+ }
+
+ // systemd-resolved is never installed without systemd.
+ _, err := exec.LookPath("systemctl")
+ if err != nil {
+ return false
+ }
+
+ // is-active exits with code 3 if the service is not active.
+ err = exec.Command("systemctl", "is-active", "systemd-resolved.service").Run()
+
+ return err == nil
+}
+
+// directManager is a managerImpl which replaces /etc/resolv.conf with a file
+// generated from the given configuration, creating a backup of its old state.
+//
+// This way of configuring DNS is precarious, since it does not react
+// to the disappearance of the Tailscale interface.
+// The caller must call Down before program shutdown
+// or as cleanup if the program terminates unexpectedly.
+type directManager struct{}
+
+func newDirectManager(mconfig ManagerConfig) managerImpl {
+ return directManager{}
+}
+
+// Up implements managerImpl.
+func (m directManager) Up(config Config) error {
+ // Write the tsConf file.
+ buf := new(bytes.Buffer)
+ writeResolvConf(buf, config.Nameservers, config.Domains)
+ if err := atomicfile.WriteFile(tsConf, buf.Bytes(), 0644); err != nil {
+ return err
+ }
+
+ if linkPath, err := os.Readlink(resolvConf); err != nil {
+ // Remove any old backup that may exist.
+ os.Remove(backupConf)
+
+ // Backup the existing /etc/resolv.conf file.
+ contents, err := ioutil.ReadFile(resolvConf)
+ // If the original did not exist, still back up an empty file.
+ // The presence of a backup file is the way we know that Up ran.
+ if err != nil && !errors.Is(err, os.ErrNotExist) {
+ return err
+ }
+ if err := atomicfile.WriteFile(backupConf, contents, 0644); err != nil {
+ return err
+ }
+ } else if linkPath != tsConf {
+ // Backup the existing symlink.
+ os.Remove(backupConf)
+ if err := os.Symlink(linkPath, backupConf); err != nil {
+ return err
+ }
+ } else {
+ // Nothing to do, resolvConf already points to tsConf.
+ return nil
+ }
+
+ os.Remove(resolvConf)
+ if err := os.Symlink(tsConf, resolvConf); err != nil {
+ return err
+ }
+
+ if isResolvedRunning() {
+ exec.Command("systemctl", "restart", "systemd-resolved.service").Run() // Best-effort.
+ }
+
+ return nil
+}
+
+// Down implements managerImpl.
+func (m directManager) Down() error {
+ if _, err := os.Stat(backupConf); err != nil {
+ // If the backup file does not exist, then Up never ran successfully.
+ if os.IsNotExist(err) {
+ return nil
+ }
+ return err
+ }
+
+ if ln, err := os.Readlink(resolvConf); err != nil {
+ return err
+ } else if ln != tsConf {
+ return fmt.Errorf("resolv.conf is not a symlink to %s", tsConf)
+ }
+ if err := os.Rename(backupConf, resolvConf); err != nil {
+ return err
+ }
+ os.Remove(tsConf)
+
+ if isResolvedRunning() {
+ exec.Command("systemctl", "restart", "systemd-resolved.service").Run() // Best-effort.
+ }
+
+ return nil
+}
diff --git a/net/dns/flush_windows.go b/net/dns/flush_windows.go
new file mode 100644
index 000000000..3c7e7d645
--- /dev/null
+++ b/net/dns/flush_windows.go
@@ -0,0 +1,19 @@
+// Copyright (c) 2021 Tailscale Inc & AUTHORS All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package dns
+
+import (
+ "fmt"
+ "os/exec"
+)
+
+// Flush clears the local resolver cache.
+func Flush() error {
+ out, err := exec.Command("ipconfig", "/flushdns").CombinedOutput()
+ if err != nil {
+ return fmt.Errorf("%v (output: %s)", err, out)
+ }
+ return nil
+}
diff --git a/net/dns/forwarder.go b/net/dns/forwarder.go
new file mode 100644
index 000000000..519c00027
--- /dev/null
+++ b/net/dns/forwarder.go
@@ -0,0 +1,474 @@
+// Copyright (c) 2020 Tailscale Inc & AUTHORS All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package dns
+
+import (
+ "bytes"
+ "context"
+ "encoding/binary"
+ "errors"
+ "fmt"
+ "hash/crc32"
+ "math/rand"
+ "net"
+ "os"
+ "sync"
+ "time"
+
+ "inet.af/netaddr"
+ "tailscale.com/logtail/backoff"
+ "tailscale.com/net/netns"
+ "tailscale.com/types/logger"
+)
+
+// headerBytes is the number of bytes in a DNS message header.
+const headerBytes = 12
+
+// connCount is the number of UDP connections to use for forwarding.
+const connCount = 32
+
+const (
+ // cleanupInterval is the interval between purged of timed-out entries from txMap.
+ cleanupInterval = 30 * time.Second
+ // responseTimeout is the maximal amount of time to wait for a DNS response.
+ responseTimeout = 5 * time.Second
+)
+
+var errNoUpstreams = errors.New("upstream nameservers not set")
+
+var aLongTimeAgo = time.Unix(0, 1)
+
+type forwardingRecord struct {
+ src netaddr.IPPort
+ createdAt time.Time
+}
+
+// txid identifies a DNS transaction.
+//
+// As the standard DNS Request ID is only 16 bits, we extend it:
+// the lower 32 bits are the zero-extended bits of the DNS Request ID;
+// the upper 32 bits are the CRC32 checksum of the first question in the request.
+// This makes probability of txid collision negligible.
+type txid uint64
+
+// getTxID computes the txid of the given DNS packet.
+func getTxID(packet []byte) txid {
+ if len(packet) < headerBytes {
+ return 0
+ }
+
+ dnsid := binary.BigEndian.Uint16(packet[0:2])
+ qcount := binary.BigEndian.Uint16(packet[4:6])
+ if qcount == 0 {
+ return txid(dnsid)
+ }
+
+ offset := headerBytes
+ for i := uint16(0); i < qcount; i++ {
+ // Note: this relies on the fact that names are not compressed in questions,
+ // so they are guaranteed to end with a NUL byte.
+ //
+ // Justification:
+ // RFC 1035 doesn't seem to explicitly prohibit compressing names in questions,
+ // but this is exceedingly unlikely to be done in practice. A DNS request
+ // with multiple questions is ill-defined (which questions do the header flags apply to?)
+ // and a single question would have to contain a pointer to an *answer*,
+ // which would be excessively smart, pointless (an answer can just as well refer to the question)
+ // and perhaps even prohibited: a draft RFC (draft-ietf-dnsind-local-compression-05) states:
+ //
+ // > It is important that these pointers always point backwards.
+ //
+ // This is said in summarizing RFC 1035, although that phrase does not appear in the original RFC.
+ // Additionally, (https://cr.yp.to/djbdns/notes.html) states:
+ //
+ // > The precise rule is that a name can be compressed if it is a response owner name,
+ // > the name in NS data, the name in CNAME data, the name in PTR data, the name in MX data,
+ // > or one of the names in SOA data.
+ namebytes := bytes.IndexByte(packet[offset:], 0)
+ // ... | name | NUL | type | class
+ // ?? 1 2 2
+ offset = offset + namebytes + 5
+ if len(packet) < offset {
+ // Corrupt packet; don't crash.
+ return txid(dnsid)
+ }
+ }
+
+ hash := crc32.ChecksumIEEE(packet[headerBytes:offset])
+ return (txid(hash) << 32) | txid(dnsid)
+}
+
+// forwarder forwards DNS packets to a number of upstream nameservers.
+type forwarder struct {
+ logf logger.Logf
+
+ // responses is a channel by which responses are returned.
+ responses chan Packet
+ // closed signals all goroutines to stop.
+ closed chan struct{}
+ // wg signals when all goroutines have stopped.
+ wg sync.WaitGroup
+
+ // conns are the UDP connections used for forwarding.
+ // A random one is selected for each request, regardless of the target upstream.
+ conns []*fwdConn
+
+ mu sync.Mutex
+ // upstreams are the nameserver addresses that should be used for forwarding.
+ upstreams []net.Addr
+ // txMap maps DNS txids to active forwarding records.
+ txMap map[txid]forwardingRecord
+}
+
+func init() {
+ rand.Seed(time.Now().UnixNano())
+}
+
+func newForwarder(logf logger.Logf, responses chan Packet) *forwarder {
+ return &forwarder{
+ logf: logger.WithPrefix(logf, "forward: "),
+ responses: responses,
+ closed: make(chan struct{}),
+ conns: make([]*fwdConn, connCount),
+ txMap: make(map[txid]forwardingRecord),
+ }
+}
+
+func (f *forwarder) Start() error {
+ f.wg.Add(connCount + 1)
+ for idx := range f.conns {
+ f.conns[idx] = newFwdConn(f.logf, idx)
+ go f.recv(f.conns[idx])
+ }
+ go f.cleanMap()
+
+ return nil
+}
+
+func (f *forwarder) Close() {
+ select {
+ case <-f.closed:
+ return
+ default:
+ // continue
+ }
+ close(f.closed)
+
+ for _, conn := range f.conns {
+ conn.close()
+ }
+
+ f.wg.Wait()
+}
+
+func (f *forwarder) rebindFromNetworkChange() {
+ for _, c := range f.conns {
+ c.mu.Lock()
+ c.reconnectLocked()
+ c.mu.Unlock()
+ }
+}
+
+func (f *forwarder) setUpstreams(upstreams []net.Addr) {
+ f.mu.Lock()
+ f.upstreams = upstreams
+ f.mu.Unlock()
+}
+
+// send sends packet to dst. It is best effort.
+func (f *forwarder) send(packet []byte, dst net.Addr) {
+ connIdx := rand.Intn(connCount)
+ conn := f.conns[connIdx]
+ conn.send(packet, dst)
+}
+
+func (f *forwarder) recv(conn *fwdConn) {
+ defer f.wg.Done()
+
+ for {
+ select {
+ case <-f.closed:
+ return
+ default:
+ }
+ out := make([]byte, maxResponseBytes)
+ n := conn.read(out)
+ if n == 0 {
+ continue
+ }
+ if n < headerBytes {
+ f.logf("recv: packet too small (%d bytes)", n)
+ }
+
+ out = out[:n]
+ txid := getTxID(out)
+
+ f.mu.Lock()
+
+ record, found := f.txMap[txid]
+ // At most one nameserver will return a response:
+ // the first one to do so will delete txid from the map.
+ if !found {
+ f.mu.Unlock()
+ continue
+ }
+ delete(f.txMap, txid)
+
+ f.mu.Unlock()
+
+ packet := Packet{
+ Payload: out,
+ Addr: record.src,
+ }
+ select {
+ case <-f.closed:
+ return
+ case f.responses <- packet:
+ // continue
+ }
+ }
+}
+
+// cleanMap periodically deletes timed-out forwarding records from f.txMap to bound growth.
+func (f *forwarder) cleanMap() {
+ defer f.wg.Done()
+
+ t := time.NewTicker(cleanupInterval)
+ defer t.Stop()
+
+ var now time.Time
+ for {
+ select {
+ case <-f.closed:
+ return
+ case now = <-t.C:
+ // continue
+ }
+
+ f.mu.Lock()
+ for k, v := range f.txMap {
+ if now.Sub(v.createdAt) > responseTimeout {
+ delete(f.txMap, k)
+ }
+ }
+ f.mu.Unlock()
+ }
+}
+
+// forward forwards the query to all upstream nameservers and returns the first response.
+func (f *forwarder) forward(query Packet) error {
+ txid := getTxID(query.Payload)
+
+ f.mu.Lock()
+
+ upstreams := f.upstreams
+ if len(upstreams) == 0 {
+ f.mu.Unlock()
+ return errNoUpstreams
+ }
+ f.txMap[txid] = forwardingRecord{
+ src: query.Addr,
+ createdAt: time.Now(),
+ }
+
+ f.mu.Unlock()
+
+ for _, upstream := range upstreams {
+ f.send(query.Payload, upstream)
+ }
+
+ return nil
+}
+
+// A fwdConn manages a single connection used to forward DNS requests.
+// Net link changes can cause a *net.UDPConn to become permanently unusable, particularly on macOS.
+// fwdConn detects such situations and transparently creates new connections.
+type fwdConn struct {
+ // logf allows a fwdConn to log.
+ logf logger.Logf
+
+ // wg tracks the number of outstanding conn.Read and conn.Write calls.
+ wg sync.WaitGroup
+ // change allows calls to read to block until a the network connection has been replaced.
+ change *sync.Cond
+
+ // mu protects fields that follow it; it is also change's Locker.
+ mu sync.Mutex
+ // closed tracks whether fwdConn has been permanently closed.
+ closed bool
+ // conn is the current active connection.
+ conn net.PacketConn
+}
+
+func newFwdConn(logf logger.Logf, idx int) *fwdConn {
+ c := new(fwdConn)
+ c.logf = logger.WithPrefix(logf, fmt.Sprintf("fwdConn %d: ", idx))
+ c.change = sync.NewCond(&c.mu)
+ // c.conn is created lazily in send
+ return c
+}
+
+// send sends packet to dst using c's connection.
+// It is best effort. It is UDP, after all. Failures are logged.
+func (c *fwdConn) send(packet []byte, dst net.Addr) {
+ var b *backoff.Backoff // lazily initialized, since it is not needed in the common case
+ backOff := func(err error) {
+ if b == nil {
+ b = backoff.NewBackoff("dns-fwdConn-send", c.logf, 30*time.Second)
+ }
+ b.BackOff(context.Background(), err)
+ }
+
+ for {
+ // Gather the current connection.
+ // We can't hold the lock while we call WriteTo.
+ c.mu.Lock()
+ conn := c.conn
+ closed := c.closed
+ if closed {
+ c.mu.Unlock()
+ return
+ }
+ if conn == nil {
+ c.reconnectLocked()
+ c.mu.Unlock()
+ continue
+ }
+ c.mu.Unlock()
+
+ c.wg.Add(1)
+ _, err := conn.WriteTo(packet, dst)
+ c.wg.Done()
+ if err == nil {
+ // Success
+ return
+ }
+ if errors.Is(err, os.ErrDeadlineExceeded) {
+ // We intentionally closed this connection.
+ // It has been replaced by a new connection. Try again.
+ continue
+ }
+ // Something else went wrong.
+ // We have three choices here: try again, give up, or create a new connection.
+ var opErr *net.OpError
+ if !errors.As(err, &opErr) {
+ // Weird. All errors from the net package should be *net.OpError. Bail.
+ c.logf("send: non-*net.OpErr %v (%T)", err, err)
+ return
+ }
+ if opErr.Temporary() || opErr.Timeout() {
+ // I doubt that either of these can happen (this is UDP),
+ // but go ahead and try again.
+ backOff(err)
+ continue
+ }
+ if networkIsDown(err) {
+ // Fail.
+ c.logf("send: network is down")
+ return
+ }
+ if networkIsUnreachable(err) {
+ // This can be caused by a link change.
+ // Replace the existing connection with a new one.
+ c.mu.Lock()
+ // It's possible that multiple senders discovered simultaneously
+ // that the network is unreachable. Avoid reconnecting multiple times:
+ // Only reconnect if the current connection is the one that we
+ // discovered to be problematic.
+ if c.conn == conn {
+ backOff(err)
+ c.reconnectLocked()
+ }
+ c.mu.Unlock()
+ // Try again with our new network connection.
+ continue
+ }
+ // Unrecognized error. Fail.
+ c.logf("send: unrecognized error: %v", err)
+ return
+ }
+}
+
+// read waits for a response from c's connection.
+// It returns the number of bytes read, which may be 0
+// in case of an error or a closed connection.
+func (c *fwdConn) read(out []byte) int {
+ for {
+ // Gather the current connection.
+ // We can't hold the lock while we call ReadFrom.
+ c.mu.Lock()
+ conn := c.conn
+ closed := c.closed
+ if closed {
+ c.mu.Unlock()
+ return 0
+ }
+ if conn == nil {
+ // There is no current connection.
+ // Wait for the connection to change, then try again.
+ c.change.Wait()
+ c.mu.Unlock()
+ continue
+ }
+ c.mu.Unlock()
+
+ c.wg.Add(1)
+ n, _, err := conn.ReadFrom(out)
+ c.wg.Done()
+ if err == nil {
+ // Success.
+ return n
+ }
+ if errors.Is(err, os.ErrDeadlineExceeded) {
+ // We intentionally closed this connection.
+ // It has been replaced by a new connection. Try again.
+ continue
+ }
+
+ c.logf("read: unrecognized error: %v", err)
+ return 0
+ }
+}
+
+// reconnectLocked replaces the current connection with a new one.
+// c.mu must be locked.
+func (c *fwdConn) reconnectLocked() {
+ c.closeConnLocked()
+ // Make a new connection.
+ conn, err := netns.Listener().ListenPacket(context.Background(), "udp", "")
+ if err != nil {
+ c.logf("ListenPacket failed: %v", err)
+ } else {
+ c.conn = conn
+ }
+ // Broadcast that a new connection is available.
+ c.change.Broadcast()
+}
+
+// closeCurrentConn closes the current connection.
+// c.mu must be locked.
+func (c *fwdConn) closeConnLocked() {
+ if c.conn == nil {
+ return
+ }
+ // Unblock all readers/writers, wait for them, close the connection.
+ c.conn.SetDeadline(aLongTimeAgo)
+ c.wg.Wait()
+ c.conn.Close()
+ c.conn = nil
+}
+
+// close permanently closes c.
+func (c *fwdConn) close() {
+ c.mu.Lock()
+ defer c.mu.Unlock()
+ if c.closed {
+ return
+ }
+ c.closed = true
+ c.closeConnLocked()
+ // Unblock any remaining readers.
+ c.change.Broadcast()
+}
diff --git a/net/dns/manager.go b/net/dns/manager.go
new file mode 100644
index 000000000..8e2fc9d71
--- /dev/null
+++ b/net/dns/manager.go
@@ -0,0 +1,100 @@
+// Copyright (c) 2020 Tailscale Inc & AUTHORS All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package dns
+
+import (
+ "time"
+
+ "tailscale.com/types/logger"
+)
+
+// We use file-ignore below instead of ignore because on some platforms,
+// the lint exception is necessary and on others it is not,
+// and plain ignore complains if the exception is unnecessary.
+
+//lint:file-ignore U1000 reconfigTimeout is used on some platforms but not others
+
+// reconfigTimeout is the time interval within which Manager.{Up,Down} should complete.
+//
+// This is particularly useful because certain conditions can cause indefinite hangs
+// (such as improper dbus auth followed by contextless dbus.Object.Call).
+// Such operations should be wrapped in a timeout context.
+const reconfigTimeout = time.Second
+
+type managerImpl interface {
+ // Up updates system DNS settings to match the given configuration.
+ Up(Config) error
+ // Down undoes the effects of Up.
+ // It is idempotent and performs no action if Up has never been called.
+ Down() error
+}
+
+// Manager manages system DNS settings.
+type Manager struct {
+ logf logger.Logf
+
+ impl managerImpl
+
+ config Config
+ mconfig ManagerConfig
+}
+
+// NewManagers created a new manager from the given config.
+func NewManager(mconfig ManagerConfig) *Manager {
+ mconfig.Logf = logger.WithPrefix(mconfig.Logf, "dns: ")
+ m := &Manager{
+ logf: mconfig.Logf,
+ impl: newManager(mconfig),
+
+ config: Config{PerDomain: mconfig.PerDomain},
+ mconfig: mconfig,
+ }
+
+ m.logf("using %T", m.impl)
+ return m
+}
+
+func (m *Manager) Set(config Config) error {
+ if config.Equal(m.config) {
+ return nil
+ }
+
+ m.logf("Set: %+v", config)
+
+ if len(config.Nameservers) == 0 {
+ err := m.impl.Down()
+ // If we save the config, we will not retry next time. Only do this on success.
+ if err == nil {
+ m.config = config
+ }
+ return err
+ }
+
+ // Switching to and from per-domain mode may require a change of manager.
+ if config.PerDomain != m.config.PerDomain {
+ if err := m.impl.Down(); err != nil {
+ return err
+ }
+ m.mconfig.PerDomain = config.PerDomain
+ m.impl = newManager(m.mconfig)
+ m.logf("switched to %T", m.impl)
+ }
+
+ err := m.impl.Up(config)
+ // If we save the config, we will not retry next time. Only do this on success.
+ if err == nil {
+ m.config = config
+ }
+
+ return err
+}
+
+func (m *Manager) Up() error {
+ return m.impl.Up(m.config)
+}
+
+func (m *Manager) Down() error {
+ return m.impl.Down()
+}
diff --git a/net/dns/manager_default.go b/net/dns/manager_default.go
new file mode 100644
index 000000000..04c8bb811
--- /dev/null
+++ b/net/dns/manager_default.go
@@ -0,0 +1,14 @@
+// Copyright (c) 2020 Tailscale Inc & AUTHORS All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+// +build !linux,!freebsd,!openbsd,!windows
+
+package dns
+
+func newManager(mconfig ManagerConfig) managerImpl {
+ // TODO(dmytro): on darwin, we should use a macOS-specific method such as scutil.
+ // This is currently not implemented. Editing /etc/resolv.conf does not work,
+ // as most applications use the system resolver, which disregards it.
+ return newNoopManager(mconfig)
+}
diff --git a/net/dns/manager_freebsd.go b/net/dns/manager_freebsd.go
new file mode 100644
index 000000000..232635f7e
--- /dev/null
+++ b/net/dns/manager_freebsd.go
@@ -0,0 +1,14 @@
+// Copyright (c) 2020 Tailscale Inc & AUTHORS All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package dns
+
+func newManager(mconfig ManagerConfig) managerImpl {
+ switch {
+ case isResolvconfActive():
+ return newResolvconfManager(mconfig)
+ default:
+ return newDirectManager(mconfig)
+ }
+}
diff --git a/net/dns/manager_linux.go b/net/dns/manager_linux.go
new file mode 100644
index 000000000..f53aed7d3
--- /dev/null
+++ b/net/dns/manager_linux.go
@@ -0,0 +1,27 @@
+// Copyright (c) 2020 Tailscale Inc & AUTHORS All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package dns
+
+func newManager(mconfig ManagerConfig) managerImpl {
+ switch {
+ // systemd-resolved should only activate per-domain.
+ case isResolvedActive() && mconfig.PerDomain:
+ if mconfig.Cleanup {
+ return newNoopManager(mconfig)
+ } else {
+ return newResolvedManager(mconfig)
+ }
+ case isNMActive():
+ if mconfig.Cleanup {
+ return newNoopManager(mconfig)
+ } else {
+ return newNMManager(mconfig)
+ }
+ case isResolvconfActive():
+ return newResolvconfManager(mconfig)
+ default:
+ return newDirectManager(mconfig)
+ }
+}
diff --git a/net/dns/manager_openbsd.go b/net/dns/manager_openbsd.go
new file mode 100644
index 000000000..228e3cca5
--- /dev/null
+++ b/net/dns/manager_openbsd.go
@@ -0,0 +1,9 @@
+// Copyright (c) 2020 Tailscale Inc & AUTHORS All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package dns
+
+func newManager(mconfig ManagerConfig) managerImpl {
+ return newDirectManager(mconfig)
+}
diff --git a/net/dns/manager_windows.go b/net/dns/manager_windows.go
new file mode 100644
index 000000000..5940404e7
--- /dev/null
+++ b/net/dns/manager_windows.go
@@ -0,0 +1,118 @@
+// Copyright (c) 2020 Tailscale Inc & AUTHORS All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package dns
+
+import (
+ "fmt"
+ "os/exec"
+ "strings"
+ "syscall"
+ "time"
+
+ "golang.org/x/sys/windows/registry"
+ "tailscale.com/types/logger"
+)
+
+const (
+ ipv4RegBase = `SYSTEM\CurrentControlSet\Services\Tcpip\Parameters`
+ ipv6RegBase = `SYSTEM\CurrentControlSet\Services\Tcpip6\Parameters`
+)
+
+type windowsManager struct {
+ logf logger.Logf
+ guid string
+}
+
+func newManager(mconfig ManagerConfig) managerImpl {
+ return windowsManager{
+ logf: mconfig.Logf,
+ guid: mconfig.InterfaceName,
+ }
+}
+
+// keyOpenTimeout is how long we wait for a registry key to
+// appear. For some reason, registry keys tied to ephemeral interfaces
+// can take a long while to appear after interface creation, and we
+// can end up racing with that.
+const keyOpenTimeout = 20 * time.Second
+
+func setRegistryString(path, name, value string) error {
+ key, err := openKeyWait(registry.LOCAL_MACHINE, path, registry.SET_VALUE, keyOpenTimeout)
+ if err != nil {
+ return fmt.Errorf("opening %s: %w", path, err)
+ }
+ defer key.Close()
+
+ err = key.SetStringValue(name, value)
+ if err != nil {
+ return fmt.Errorf("setting %s[%s]: %w", path, name, err)
+ }
+ return nil
+}
+
+func (m windowsManager) setNameservers(basePath string, nameservers []string) error {
+ path := fmt.Sprintf(`%s\Interfaces\%s`, basePath, m.guid)
+ value := strings.Join(nameservers, ",")
+ return setRegistryString(path, "NameServer", value)
+}
+
+func (m windowsManager) setDomains(basePath string, domains []string) error {
+ path := fmt.Sprintf(`%s\Interfaces\%s`, basePath, m.guid)
+ value := strings.Join(domains, ",")
+ return setRegistryString(path, "SearchList", value)
+}
+
+func (m windowsManager) Up(config Config) error {
+ var ipsv4 []string
+ var ipsv6 []string
+
+ for _, ip := range config.Nameservers {
+ if ip.Is4() {
+ ipsv4 = append(ipsv4, ip.String())
+ } else {
+ ipsv6 = append(ipsv6, ip.String())
+ }
+ }
+
+ if err := m.setNameservers(ipv4RegBase, ipsv4); err != nil {
+ return err
+ }
+ if err := m.setDomains(ipv4RegBase, config.Domains); err != nil {
+ return err
+ }
+
+ if err := m.setNameservers(ipv6RegBase, ipsv6); err != nil {
+ return err
+ }
+ if err := m.setDomains(ipv6RegBase, config.Domains); err != nil {
+ return err
+ }
+
+ // Force DNS re-registration in Active Directory. What we actually
+ // care about is that this command invokes the undocumented hidden
+ // function that forces Windows to notice that adapter settings
+ // have changed, which makes the DNS settings actually take
+ // effect.
+ //
+ // This command can take a few seconds to run, so run it async, best effort.
+ go func() {
+ t0 := time.Now()
+ m.logf("running ipconfig /registerdns ...")
+ cmd := exec.Command("ipconfig", "/registerdns")
+ cmd.SysProcAttr = &syscall.SysProcAttr{HideWindow: true}
+ d := time.Since(t0).Round(time.Millisecond)
+ if err := cmd.Run(); err != nil {
+ m.logf("error running ipconfig /registerdns after %v: %v", d, err)
+ } else {
+ m.logf("ran ipconfig /registerdns in %v", d)
+ }
+ }()
+
+ return nil
+}
+
+func (m windowsManager) Down() error {
+ return m.Up(Config{Nameservers: nil, Domains: nil})
+}
diff --git a/net/dns/map.go b/net/dns/map.go
new file mode 100644
index 000000000..119b6cc0a
--- /dev/null
+++ b/net/dns/map.go
@@ -0,0 +1,160 @@
+// Copyright (c) 2020 Tailscale Inc & AUTHORS All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package dns
+
+import (
+ "sort"
+ "strings"
+
+ "inet.af/netaddr"
+)
+
+// Map is all the data Resolver needs to resolve DNS queries within the Tailscale network.
+type Map struct {
+ // nameToIP is a mapping of Tailscale domain names to their IP addresses.
+ // For example, monitoring.tailscale.us -> 100.64.0.1.
+ nameToIP map[string]netaddr.IP
+ // ipToName is the inverse of nameToIP.
+ ipToName map[netaddr.IP]string
+ // names are the keys of nameToIP in sorted order.
+ names []string
+ // rootDomains are the domains whose subdomains should always
+ // be resolved locally to prevent leakage of sensitive names.
+ rootDomains []string // e.g. "user.provider.beta.tailscale.net."
+}
+
+// NewMap returns a new Map with name to address mapping given by nameToIP.
+//
+// rootDomains are the domains whose subdomains should always be
+// resolved locally to prevent leakage of sensitive names. They should
+// end in a period ("user-foo.tailscale.net.").
+func NewMap(initNameToIP map[string]netaddr.IP, rootDomains []string) *Map {
+ // TODO(dmytro): we have to allocate names and ipToName, but nameToIP can be avoided.
+ // It is here because control sends us names not in canonical form. Change this.
+ names := make([]string, 0, len(initNameToIP))
+ nameToIP := make(map[string]netaddr.IP, len(initNameToIP))
+ ipToName := make(map[netaddr.IP]string, len(initNameToIP))
+
+ for name, ip := range initNameToIP {
+ if len(name) == 0 {
+ // Nothing useful can be done with empty names.
+ continue
+ }
+ if name[len(name)-1] != '.' {
+ name += "."
+ }
+ names = append(names, name)
+ nameToIP[name] = ip
+ ipToName[ip] = name
+ }
+ sort.Strings(names)
+
+ return &Map{
+ nameToIP: nameToIP,
+ ipToName: ipToName,
+ names: names,
+
+ rootDomains: rootDomains,
+ }
+}
+
+func printSingleNameIP(buf *strings.Builder, name string, ip netaddr.IP) {
+ buf.WriteString(name)
+ buf.WriteByte('\t')
+ buf.WriteString(ip.String())
+ buf.WriteByte('\n')
+}
+
+func (m *Map) Pretty() string {
+ buf := new(strings.Builder)
+ for _, name := range m.names {
+ printSingleNameIP(buf, name, m.nameToIP[name])
+ }
+ return buf.String()
+}
+
+func (m *Map) PrettyDiffFrom(old *Map) string {
+ var (
+ oldNameToIP map[string]netaddr.IP
+ newNameToIP map[string]netaddr.IP
+ oldNames []string
+ newNames []string
+ )
+ if old != nil {
+ oldNameToIP = old.nameToIP
+ oldNames = old.names
+ }
+ if m != nil {
+ newNameToIP = m.nameToIP
+ newNames = m.names
+ }
+
+ buf := new(strings.Builder)
+ space := func() bool {
+ return buf.Len() < (1 << 10)
+ }
+
+ for len(oldNames) > 0 && len(newNames) > 0 {
+ var name string
+
+ newName, oldName := newNames[0], oldNames[0]
+ switch {
+ case oldName < newName:
+ name = oldName
+ oldNames = oldNames[1:]
+ case oldName > newName:
+ name = newName
+ newNames = newNames[1:]
+ case oldNames[0] == newNames[0]:
+ name = oldNames[0]
+ oldNames = oldNames[1:]
+ newNames = newNames[1:]
+ }
+ if !space() {
+ continue
+ }
+
+ ipOld, inOld := oldNameToIP[name]
+ ipNew, inNew := newNameToIP[name]
+ switch {
+ case !inOld:
+ buf.WriteByte('+')
+ printSingleNameIP(buf, name, ipNew)
+ case !inNew:
+ buf.WriteByte('-')
+ printSingleNameIP(buf, name, ipOld)
+ case ipOld != ipNew:
+ buf.WriteByte('-')
+ printSingleNameIP(buf, name, ipOld)
+ buf.WriteByte('+')
+ printSingleNameIP(buf, name, ipNew)
+ }
+ }
+
+ for _, name := range oldNames {
+ if !space() {
+ break
+ }
+ if _, ok := newNameToIP[name]; !ok {
+ buf.WriteByte('-')
+ printSingleNameIP(buf, name, oldNameToIP[name])
+ }
+ }
+
+ for _, name := range newNames {
+ if !space() {
+ break
+ }
+ if _, ok := oldNameToIP[name]; !ok {
+ buf.WriteByte('+')
+ printSingleNameIP(buf, name, newNameToIP[name])
+ }
+ }
+ if !space() {
+ buf.WriteString("... [truncated]\n")
+ }
+
+ return buf.String()
+}
diff --git a/net/dns/map_test.go b/net/dns/map_test.go
new file mode 100644
index 000000000..c438f95a0
--- /dev/null
+++ b/net/dns/map_test.go
@@ -0,0 +1,156 @@
+// Copyright (c) 2020 Tailscale Inc & AUTHORS All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package dns
+
+import (
+ "fmt"
+ "strings"
+ "testing"
+
+ "inet.af/netaddr"
+)
+
+func TestPretty(t *testing.T) {
+ tests := []struct {
+ name string
+ dmap *Map
+ want string
+ }{
+ {"empty", NewMap(nil, nil), ""},
+ {
+ "single",
+ NewMap(map[string]netaddr.IP{
+ "hello.ipn.dev.": netaddr.IPv4(100, 101, 102, 103),
+ }, nil),
+ "hello.ipn.dev.\t100.101.102.103\n",
+ },
+ {
+ "multiple",
+ NewMap(map[string]netaddr.IP{
+ "test1.domain.": netaddr.IPv4(100, 101, 102, 103),
+ "test2.sub.domain.": netaddr.IPv4(100, 99, 9, 1),
+ }, nil),
+ "test1.domain.\t100.101.102.103\ntest2.sub.domain.\t100.99.9.1\n",
+ },
+ }
+
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ got := tt.dmap.Pretty()
+ if tt.want != got {
+ t.Errorf("want %v; got %v", tt.want, got)
+ }
+ })
+ }
+}
+
+func TestPrettyDiffFrom(t *testing.T) {
+ tests := []struct {
+ name string
+ map1 *Map
+ map2 *Map
+ want string
+ }{
+ {
+ "from_empty",
+ nil,
+ NewMap(map[string]netaddr.IP{
+ "test1.ipn.dev.": netaddr.IPv4(100, 101, 102, 103),
+ "test2.ipn.dev.": netaddr.IPv4(100, 103, 102, 101),
+ }, nil),
+ "+test1.ipn.dev.\t100.101.102.103\n+test2.ipn.dev.\t100.103.102.101\n",
+ },
+ {
+ "equal",
+ NewMap(map[string]netaddr.IP{
+ "test1.ipn.dev.": netaddr.IPv4(100, 101, 102, 103),
+ "test2.ipn.dev.": netaddr.IPv4(100, 103, 102, 101),
+ }, nil),
+ NewMap(map[string]netaddr.IP{
+ "test2.ipn.dev.": netaddr.IPv4(100, 103, 102, 101),
+ "test1.ipn.dev.": netaddr.IPv4(100, 101, 102, 103),
+ }, nil),
+ "",
+ },
+ {
+ "changed_ip",
+ NewMap(map[string]netaddr.IP{
+ "test1.ipn.dev.": netaddr.IPv4(100, 101, 102, 103),
+ "test2.ipn.dev.": netaddr.IPv4(100, 103, 102, 101),
+ }, nil),
+ NewMap(map[string]netaddr.IP{
+ "test2.ipn.dev.": netaddr.IPv4(100, 104, 102, 101),
+ "test1.ipn.dev.": netaddr.IPv4(100, 101, 102, 103),
+ }, nil),
+ "-test2.ipn.dev.\t100.103.102.101\n+test2.ipn.dev.\t100.104.102.101\n",
+ },
+ {
+ "new_domain",
+ NewMap(map[string]netaddr.IP{
+ "test1.ipn.dev.": netaddr.IPv4(100, 101, 102, 103),
+ "test2.ipn.dev.": netaddr.IPv4(100, 103, 102, 101),
+ }, nil),
+ NewMap(map[string]netaddr.IP{
+ "test3.ipn.dev.": netaddr.IPv4(100, 105, 106, 107),
+ "test2.ipn.dev.": netaddr.IPv4(100, 103, 102, 101),
+ "test1.ipn.dev.": netaddr.IPv4(100, 101, 102, 103),
+ }, nil),
+ "+test3.ipn.dev.\t100.105.106.107\n",
+ },
+ {
+ "gone_domain",
+ NewMap(map[string]netaddr.IP{
+ "test1.ipn.dev.": netaddr.IPv4(100, 101, 102, 103),
+ "test2.ipn.dev.": netaddr.IPv4(100, 103, 102, 101),
+ }, nil),
+ NewMap(map[string]netaddr.IP{
+ "test1.ipn.dev.": netaddr.IPv4(100, 101, 102, 103),
+ }, nil),
+ "-test2.ipn.dev.\t100.103.102.101\n",
+ },
+ {
+ "mixed",
+ NewMap(map[string]netaddr.IP{
+ "test1.ipn.dev.": netaddr.IPv4(100, 101, 102, 103),
+ "test4.ipn.dev.": netaddr.IPv4(100, 107, 106, 105),
+ "test5.ipn.dev.": netaddr.IPv4(100, 64, 1, 1),
+ "test2.ipn.dev.": netaddr.IPv4(100, 103, 102, 101),
+ }, nil),
+ NewMap(map[string]netaddr.IP{
+ "test2.ipn.dev.": netaddr.IPv4(100, 104, 102, 101),
+ "test1.ipn.dev.": netaddr.IPv4(100, 100, 101, 102),
+ "test3.ipn.dev.": netaddr.IPv4(100, 64, 1, 1),
+ }, nil),
+ "-test1.ipn.dev.\t100.101.102.103\n+test1.ipn.dev.\t100.100.101.102\n" +
+ "-test2.ipn.dev.\t100.103.102.101\n+test2.ipn.dev.\t100.104.102.101\n" +
+ "+test3.ipn.dev.\t100.64.1.1\n-test4.ipn.dev.\t100.107.106.105\n-test5.ipn.dev.\t100.64.1.1\n",
+ },
+ }
+
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ got := tt.map2.PrettyDiffFrom(tt.map1)
+ if tt.want != got {
+ t.Errorf("want %v; got %v", tt.want, got)
+ }
+ })
+ }
+
+ t.Run("truncated", func(t *testing.T) {
+ small := NewMap(nil, nil)
+ m := map[string]netaddr.IP{}
+ for i := 0; i < 5000; i++ {
+ m[fmt.Sprintf("host%d.ipn.dev.", i)] = netaddr.IPv4(100, 64, 1, 1)
+ }
+ veryBig := NewMap(m, nil)
+ diff := veryBig.PrettyDiffFrom(small)
+ if len(diff) > 3<<10 {
+ t.Errorf("pretty diff too large: %d bytes", len(diff))
+ }
+ if !strings.Contains(diff, "truncated") {
+ t.Errorf("big diff not truncated")
+ }
+ })
+}
diff --git a/net/dns/neterr_darwin.go b/net/dns/neterr_darwin.go
new file mode 100644
index 000000000..7fd621fc7
--- /dev/null
+++ b/net/dns/neterr_darwin.go
@@ -0,0 +1,25 @@
+// Copyright (c) 2020 Tailscale Inc & AUTHORS All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package dns
+
+import (
+ "errors"
+ "syscall"
+)
+
+// Avoid allocation when calling errors.Is below
+// by converting syscall.Errno to error here.
+var (
+ networkDown error = syscall.ENETDOWN
+ networkUnreachable error = syscall.ENETUNREACH
+)
+
+func networkIsDown(err error) bool {
+ return errors.Is(err, networkDown)
+}
+
+func networkIsUnreachable(err error) bool {
+ return errors.Is(err, networkUnreachable)
+}
diff --git a/net/dns/neterr_other.go b/net/dns/neterr_other.go
new file mode 100644
index 000000000..b652f6e8b
--- /dev/null
+++ b/net/dns/neterr_other.go
@@ -0,0 +1,10 @@
+// Copyright (c) 2020 Tailscale Inc & AUTHORS All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+// +build !darwin,!windows
+
+package dns
+
+func networkIsDown(err error) bool { return false }
+func networkIsUnreachable(err error) bool { return false }
diff --git a/net/dns/neterr_windows.go b/net/dns/neterr_windows.go
new file mode 100644
index 000000000..2b197ee2b
--- /dev/null
+++ b/net/dns/neterr_windows.go
@@ -0,0 +1,29 @@
+// Copyright (c) 2020 Tailscale Inc & AUTHORS All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package dns
+
+import (
+ "net"
+ "os"
+
+ "golang.org/x/sys/windows"
+)
+
+func networkIsDown(err error) bool {
+ if oe, ok := err.(*net.OpError); ok && oe.Op == "write" {
+ if se, ok := oe.Err.(*os.SyscallError); ok {
+ if se.Syscall == "wsasendto" && se.Err == windows.WSAENETUNREACH {
+ return true
+ }
+ }
+ }
+ return false
+}
+
+func networkIsUnreachable(err error) bool {
+ // TODO(bradfitz,josharian): something here? what is the
+ // difference between down and unreachable? Add comments.
+ return false
+}
diff --git a/net/dns/nm.go b/net/dns/nm.go
new file mode 100644
index 000000000..a597fa60d
--- /dev/null
+++ b/net/dns/nm.go
@@ -0,0 +1,205 @@
+// Copyright (c) 2020 Tailscale Inc & AUTHORS All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+// +build linux
+
+package dns
+
+import (
+ "bufio"
+ "bytes"
+ "context"
+ "fmt"
+ "os"
+ "os/exec"
+
+ "github.com/godbus/dbus/v5"
+ "tailscale.com/util/endian"
+)
+
+// isNMActive determines if NetworkManager is currently managing system DNS settings.
+func isNMActive() bool {
+ // This is somewhat tricky because NetworkManager supports a number
+ // of DNS configuration modes. In all cases, we expect it to be installed
+ // and /etc/resolv.conf to contain a mention of NetworkManager in the comments.
+ _, err := exec.LookPath("NetworkManager")
+ if err != nil {
+ return false
+ }
+
+ f, err := os.Open("/etc/resolv.conf")
+ if err != nil {
+ return false
+ }
+ defer f.Close()
+
+ scanner := bufio.NewScanner(f)
+ for scanner.Scan() {
+ line := scanner.Bytes()
+ // Look for the word "NetworkManager" until comments end.
+ if len(line) > 0 && line[0] != '#' {
+ return false
+ }
+ if bytes.Contains(line, []byte("NetworkManager")) {
+ return true
+ }
+ }
+ return false
+}
+
+// nmManager uses the NetworkManager DBus API.
+type nmManager struct {
+ interfaceName string
+}
+
+func newNMManager(mconfig ManagerConfig) managerImpl {
+ return nmManager{
+ interfaceName: mconfig.InterfaceName,
+ }
+}
+
+type nmConnectionSettings map[string]map[string]dbus.Variant
+
+// Up implements managerImpl.
+func (m nmManager) Up(config Config) error {
+ ctx, cancel := context.WithTimeout(context.Background(), reconfigTimeout)
+ defer cancel()
+
+ // conn is a shared connection whose lifecycle is managed by the dbus package.
+ // We should not interfere with that by closing it.
+ conn, err := dbus.SystemBus()
+ if err != nil {
+ return fmt.Errorf("connecting to system bus: %w", err)
+ }
+
+ // This is how we get at the DNS settings:
+ //
+ // org.freedesktop.NetworkManager
+ // |
+ // [GetDeviceByIpIface]
+ // |
+ // v
+ // org.freedesktop.NetworkManager.Device <--------\
+ // (describes a network interface) |
+ // | |
+ // [GetAppliedConnection] [Reapply]
+ // | |
+ // v |
+ // org.freedesktop.NetworkManager.Connection |
+ // (connection settings) ------/
+ // contains {dns, dns-priority, dns-search}
+ //
+ // Ref: https://developer.gnome.org/NetworkManager/stable/settings-ipv4.html.
+
+ nm := conn.Object(
+ "org.freedesktop.NetworkManager",
+ dbus.ObjectPath("/org/freedesktop/NetworkManager"),
+ )
+
+ var devicePath dbus.ObjectPath
+ err = nm.CallWithContext(
+ ctx, "org.freedesktop.NetworkManager.GetDeviceByIpIface", 0,
+ m.interfaceName,
+ ).Store(&devicePath)
+ if err != nil {
+ return fmt.Errorf("getDeviceByIpIface: %w", err)
+ }
+ device := conn.Object("org.freedesktop.NetworkManager", devicePath)
+
+ var (
+ settings nmConnectionSettings
+ version uint64
+ )
+ err = device.CallWithContext(
+ ctx, "org.freedesktop.NetworkManager.Device.GetAppliedConnection", 0,
+ uint32(0),
+ ).Store(&settings, &version)
+ if err != nil {
+ return fmt.Errorf("getAppliedConnection: %w", err)
+ }
+
+ // Frustratingly, NetworkManager represents IPv4 addresses as uint32s,
+ // although IPv6 addresses are represented as byte arrays.
+ // Perform the conversion here.
+ var (
+ dnsv4 []uint32
+ dnsv6 [][]byte
+ )
+ for _, ip := range config.Nameservers {
+ b := ip.As16()
+ if ip.Is4() {
+ dnsv4 = append(dnsv4, endian.Native.Uint32(b[12:]))
+ } else {
+ dnsv6 = append(dnsv6, b[:])
+ }
+ }
+
+ ipv4Map := settings["ipv4"]
+ ipv4Map["dns"] = dbus.MakeVariant(dnsv4)
+ ipv4Map["dns-search"] = dbus.MakeVariant(config.Domains)
+ // We should only request priority if we have nameservers to set.
+ if len(dnsv4) == 0 {
+ ipv4Map["dns-priority"] = dbus.MakeVariant(100)
+ } else {
+ // dns-priority = -1 ensures that we have priority
+ // over other interfaces, except those exploiting this same trick.
+ // Ref: https://bugs.launchpad.net/ubuntu/+source/network-manager/+bug/1211110/comments/92.
+ ipv4Map["dns-priority"] = dbus.MakeVariant(-1)
+ }
+ // In principle, we should not need set this to true,
+ // as our interface does not configure any automatic DNS settings (presumably via DHCP).
+ // All the same, better to be safe.
+ ipv4Map["ignore-auto-dns"] = dbus.MakeVariant(true)
+
+ ipv6Map := settings["ipv6"]
+ // This is a hack.
+ // Methods "disabled", "ignore", "link-local" (IPv6 default) prevent us from setting DNS.
+ // It seems that our only recourse is "manual" or "auto".
+ // "manual" requires addresses, so we use "auto", which will assign us a random IPv6 /64.
+ ipv6Map["method"] = dbus.MakeVariant("auto")
+ // Our IPv6 config is a fake, so it should never become the default route.
+ ipv6Map["never-default"] = dbus.MakeVariant(true)
+ // Moreover, we should ignore all autoconfigured routes (hopefully none), as they are bogus.
+ ipv6Map["ignore-auto-routes"] = dbus.MakeVariant(true)
+
+ // Finally, set the actual DNS config.
+ ipv6Map["dns"] = dbus.MakeVariant(dnsv6)
+ ipv6Map["dns-search"] = dbus.MakeVariant(config.Domains)
+ if len(dnsv6) == 0 {
+ ipv6Map["dns-priority"] = dbus.MakeVariant(100)
+ } else {
+ ipv6Map["dns-priority"] = dbus.MakeVariant(-1)
+ }
+ ipv6Map["ignore-auto-dns"] = dbus.MakeVariant(true)
+
+ // deprecatedProperties are the properties in interface settings
+ // that are deprecated by NetworkManager.
+ //
+ // In practice, this means that they are returned for reading,
+ // but submitting a settings object with them present fails
+ // with hard-to-diagnose errors. They must be removed.
+ deprecatedProperties := []string{
+ "addresses", "routes",
+ }
+
+ for _, property := range deprecatedProperties {
+ delete(ipv4Map, property)
+ delete(ipv6Map, property)
+ }
+
+ err = device.CallWithContext(
+ ctx, "org.freedesktop.NetworkManager.Device.Reapply", 0,
+ settings, version, uint32(0),
+ ).Store()
+ if err != nil {
+ return fmt.Errorf("reapply: %w", err)
+ }
+
+ return nil
+}
+
+// Down implements managerImpl.
+func (m nmManager) Down() error {
+ return m.Up(Config{Nameservers: nil, Domains: nil})
+}
diff --git a/net/dns/noop.go b/net/dns/noop.go
new file mode 100644
index 000000000..35c07a232
--- /dev/null
+++ b/net/dns/noop.go
@@ -0,0 +1,17 @@
+// Copyright (c) 2020 Tailscale Inc & AUTHORS All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package dns
+
+type noopManager struct{}
+
+// Up implements managerImpl.
+func (m noopManager) Up(Config) error { return nil }
+
+// Down implements managerImpl.
+func (m noopManager) Down() error { return nil }
+
+func newNoopManager(mconfig ManagerConfig) managerImpl {
+ return noopManager{}
+}
diff --git a/net/dns/registry_windows.go b/net/dns/registry_windows.go
new file mode 100644
index 000000000..f8e1f514a
--- /dev/null
+++ b/net/dns/registry_windows.go
@@ -0,0 +1,76 @@
+// Copyright (c) 2020 Tailscale Inc & AUTHORS All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+//
+// The code in this file originates from https://git.zx2c4.com/wireguard-go:
+// Copyright (C) 2017-2020 WireGuard LLC. All Rights Reserved.
+// Copying license: https://git.zx2c4.com/wireguard-go/tree/COPYING
+
+package dns
+
+import (
+ "fmt"
+ "runtime"
+ "strings"
+ "time"
+
+ "golang.org/x/sys/windows"
+ "golang.org/x/sys/windows/registry"
+)
+
+func openKeyWait(k registry.Key, path string, access uint32, timeout time.Duration) (registry.Key, error) {
+ runtime.LockOSThread()
+ defer runtime.UnlockOSThread()
+
+ deadline := time.Now().Add(timeout)
+ pathSpl := strings.Split(path, "\\")
+ for i := 0; ; i++ {
+ keyName := pathSpl[i]
+ isLast := i+1 == len(pathSpl)
+
+ event, err := windows.CreateEvent(nil, 0, 0, nil)
+ if err != nil {
+ return 0, fmt.Errorf("windows.CreateEvent: %v", err)
+ }
+ defer windows.CloseHandle(event)
+
+ var key registry.Key
+ for {
+ err = windows.RegNotifyChangeKeyValue(windows.Handle(k), false, windows.REG_NOTIFY_CHANGE_NAME, event, true)
+ if err != nil {
+ return 0, fmt.Errorf("windows.RegNotifyChangeKeyValue: %v", err)
+ }
+
+ var accessFlags uint32
+ if isLast {
+ accessFlags = access
+ } else {
+ accessFlags = registry.NOTIFY
+ }
+ key, err = registry.OpenKey(k, keyName, accessFlags)
+ if err == windows.ERROR_FILE_NOT_FOUND || err == windows.ERROR_PATH_NOT_FOUND {
+ timeout := time.Until(deadline) / time.Millisecond
+ if timeout < 0 {
+ timeout = 0
+ }
+ s, err := windows.WaitForSingleObject(event, uint32(timeout))
+ if err != nil {
+ return 0, fmt.Errorf("windows.WaitForSingleObject: %v", err)
+ }
+ if s == uint32(windows.WAIT_TIMEOUT) { // windows.WAIT_TIMEOUT status const is misclassified as error in golang.org/x/sys/windows
+ return 0, fmt.Errorf("timeout waiting for registry key")
+ }
+ } else if err != nil {
+ return 0, fmt.Errorf("registry.OpenKey(%v): %v", path, err)
+ } else {
+ if isLast {
+ return key, nil
+ }
+ defer key.Close()
+ break
+ }
+ }
+
+ k = key
+ }
+}
diff --git a/net/dns/resolvconf.go b/net/dns/resolvconf.go
new file mode 100644
index 000000000..8bf97ee88
--- /dev/null
+++ b/net/dns/resolvconf.go
@@ -0,0 +1,157 @@
+// Copyright (c) 2020 Tailscale Inc & AUTHORS All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+// +build linux freebsd
+
+package dns
+
+import (
+ "bufio"
+ "bytes"
+ "fmt"
+ "os"
+ "os/exec"
+)
+
+// isResolvconfActive indicates whether the system appears to be using resolvconf.
+// If this is true, then directManager should be avoided:
+// resolvconf has exclusive ownership of /etc/resolv.conf.
+func isResolvconfActive() bool {
+ // Sanity-check first: if there is no resolvconf binary, then this is fruitless.
+ //
+ // However, this binary may be a shim like the one systemd-resolved provides.
+ // Such a shim may not behave as expected: in particular, systemd-resolved
+ // does not seem to respect the exclusive mode -x, saying:
+ // -x Send DNS traffic preferably over this interface
+ // whereas e.g. openresolv sends DNS traffix _exclusively_ over that interface,
+ // or not at all (in case of another exclusive-mode request later in time).
+ //
+ // Moreover, resolvconf may be installed but unused, in which case we should
+ // not use it either, lest we clobber existing configuration.
+ //
+ // To handle all the above correctly, we scan the comments in /etc/resolv.conf
+ // to ensure that it was generated by a resolvconf implementation.
+ _, err := exec.LookPath("resolvconf")
+ if err != nil {
+ return false
+ }
+
+ f, err := os.Open("/etc/resolv.conf")
+ if err != nil {
+ return false
+ }
+ defer f.Close()
+
+ scanner := bufio.NewScanner(f)
+ for scanner.Scan() {
+ line := scanner.Bytes()
+ // Look for the word "resolvconf" until comments end.
+ if len(line) > 0 && line[0] != '#' {
+ return false
+ }
+ if bytes.Contains(line, []byte("resolvconf")) {
+ return true
+ }
+ }
+ return false
+}
+
+// resolvconfImpl enumerates supported implementations of the resolvconf CLI.
+type resolvconfImpl uint8
+
+const (
+ // resolvconfOpenresolv is the implementation packaged as "openresolv" on Ubuntu.
+ // It supports exclusive mode and interface metrics.
+ resolvconfOpenresolv resolvconfImpl = iota
+ // resolvconfLegacy is the implementation by Thomas Hood packaged as "resolvconf" on Ubuntu.
+ // It does not support exclusive mode or interface metrics.
+ resolvconfLegacy
+)
+
+func (impl resolvconfImpl) String() string {
+ switch impl {
+ case resolvconfOpenresolv:
+ return "openresolv"
+ case resolvconfLegacy:
+ return "legacy"
+ default:
+ return "unknown"
+ }
+}
+
+// getResolvconfImpl returns the implementation of resolvconf that appears to be in use.
+func getResolvconfImpl() resolvconfImpl {
+ err := exec.Command("resolvconf", "-v").Run()
+ if err != nil {
+ if exitErr, ok := err.(*exec.ExitError); ok {
+ // Thomas Hood's resolvconf has a minimal flag set
+ // and exits with code 99 when passed an unknown flag.
+ if exitErr.ExitCode() == 99 {
+ return resolvconfLegacy
+ }
+ }
+ }
+ return resolvconfOpenresolv
+}
+
+type resolvconfManager struct {
+ impl resolvconfImpl
+}
+
+func newResolvconfManager(mconfig ManagerConfig) managerImpl {
+ impl := getResolvconfImpl()
+ mconfig.Logf("resolvconf implementation is %s", impl)
+
+ return resolvconfManager{
+ impl: impl,
+ }
+}
+
+// resolvconfConfigName is the name of the config submitted to resolvconf.
+// It has this form to match the "tun*" rule in interface-order
+// when running resolvconfLegacy, hopefully placing our config first.
+const resolvconfConfigName = "tun-tailscale.inet"
+
+// Up implements managerImpl.
+func (m resolvconfManager) Up(config Config) error {
+ stdin := new(bytes.Buffer)
+ writeResolvConf(stdin, config.Nameservers, config.Domains) // dns_direct.go
+
+ var cmd *exec.Cmd
+ switch m.impl {
+ case resolvconfOpenresolv:
+ // Request maximal priority (metric 0) and exclusive mode.
+ cmd = exec.Command("resolvconf", "-m", "0", "-x", "-a", resolvconfConfigName)
+ case resolvconfLegacy:
+ // This does not quite give us the desired behavior (queries leak),
+ // but there is nothing else we can do without messing with other interfaces' settings.
+ cmd = exec.Command("resolvconf", "-a", resolvconfConfigName)
+ }
+ cmd.Stdin = stdin
+ out, err := cmd.CombinedOutput()
+ if err != nil {
+ return fmt.Errorf("running %s: %s", cmd, out)
+ }
+
+ return nil
+}
+
+// Down implements managerImpl.
+func (m resolvconfManager) Down() error {
+ var cmd *exec.Cmd
+ switch m.impl {
+ case resolvconfOpenresolv:
+ cmd = exec.Command("resolvconf", "-f", "-d", resolvconfConfigName)
+ case resolvconfLegacy:
+ // resolvconfLegacy lacks the -f flag.
+ // Instead, it succeeds even when the config does not exist.
+ cmd = exec.Command("resolvconf", "-d", resolvconfConfigName)
+ }
+ out, err := cmd.CombinedOutput()
+ if err != nil {
+ return fmt.Errorf("running %s: %s", cmd, out)
+ }
+
+ return nil
+}
diff --git a/net/dns/resolved.go b/net/dns/resolved.go
new file mode 100644
index 000000000..9d8c40d90
--- /dev/null
+++ b/net/dns/resolved.go
@@ -0,0 +1,188 @@
+// Copyright (c) 2020 Tailscale Inc & AUTHORS All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+// +build linux
+
+package dns
+
+import (
+ "context"
+ "errors"
+ "fmt"
+ "os/exec"
+
+ "github.com/godbus/dbus/v5"
+ "golang.org/x/sys/unix"
+ "inet.af/netaddr"
+ "tailscale.com/net/interfaces"
+)
+
+// resolvedListenAddr is the listen address of the resolved stub resolver.
+//
+// We only consider resolved to be the system resolver if the stub resolver is;
+// that is, if this address is the sole nameserver in /etc/resolved.conf.
+// In other cases, resolved may be managing the system DNS configuration directly.
+// Then the nameserver list will be a concatenation of those for all
+// the interfaces that register their interest in being a default resolver with
+// SetLinkDomains([]{{"~.", true}, ...})
+// which includes at least the interface with the default route, i.e. not us.
+// This does not work for us: there is a possibility of getting NXDOMAIN
+// from the other nameservers before we are asked or get a chance to respond.
+// We consider this case as lacking resolved support and fall through to dnsDirect.
+//
+// While it may seem that we need to read a config option to get at this,
+// this address is, in fact, hard-coded into resolved.
+var resolvedListenAddr = netaddr.IPv4(127, 0, 0, 53)
+
+var errNotReady = errors.New("interface not ready")
+
+type resolvedLinkNameserver struct {
+ Family int32
+ Address []byte
+}
+
+type resolvedLinkDomain struct {
+ Domain string
+ RoutingOnly bool
+}
+
+// isResolvedActive determines if resolved is currently managing system DNS settings.
+func isResolvedActive() bool {
+ // systemd-resolved is never installed without systemd.
+ _, err := exec.LookPath("systemctl")
+ if err != nil {
+ return false
+ }
+
+ // is-active exits with code 3 if the service is not active.
+ err = exec.Command("systemctl", "is-active", "systemd-resolved").Run()
+ if err != nil {
+ return false
+ }
+
+ config, err := readResolvConf()
+ if err != nil {
+ return false
+ }
+
+ // The sole nameserver must be the systemd-resolved stub.
+ if len(config.Nameservers) == 1 && config.Nameservers[0] == resolvedListenAddr {
+ return true
+ }
+
+ return false
+}
+
+// resolvedManager uses the systemd-resolved DBus API.
+type resolvedManager struct{}
+
+func newResolvedManager(mconfig ManagerConfig) managerImpl {
+ return resolvedManager{}
+}
+
+// Up implements managerImpl.
+func (m resolvedManager) Up(config Config) error {
+ ctx, cancel := context.WithTimeout(context.Background(), reconfigTimeout)
+ defer cancel()
+
+ // conn is a shared connection whose lifecycle is managed by the dbus package.
+ // We should not interfere with that by closing it.
+ conn, err := dbus.SystemBus()
+ if err != nil {
+ return fmt.Errorf("connecting to system bus: %w", err)
+ }
+
+ resolved := conn.Object(
+ "org.freedesktop.resolve1",
+ dbus.ObjectPath("/org/freedesktop/resolve1"),
+ )
+
+ // In principle, we could persist this in the manager struct
+ // if we knew that interface indices are persistent. This does not seem to be the case.
+ _, iface, err := interfaces.Tailscale()
+ if err != nil {
+ return fmt.Errorf("getting interface index: %w", err)
+ }
+ if iface == nil {
+ return errNotReady
+ }
+
+ var linkNameservers = make([]resolvedLinkNameserver, len(config.Nameservers))
+ for i, server := range config.Nameservers {
+ ip := server.As16()
+ if server.Is4() {
+ linkNameservers[i] = resolvedLinkNameserver{
+ Family: unix.AF_INET,
+ Address: ip[12:],
+ }
+ } else {
+ linkNameservers[i] = resolvedLinkNameserver{
+ Family: unix.AF_INET6,
+ Address: ip[:],
+ }
+ }
+ }
+
+ err = resolved.CallWithContext(
+ ctx, "org.freedesktop.resolve1.Manager.SetLinkDNS", 0,
+ iface.Index, linkNameservers,
+ ).Store()
+ if err != nil {
+ return fmt.Errorf("setLinkDNS: %w", err)
+ }
+
+ var linkDomains = make([]resolvedLinkDomain, len(config.Domains))
+ for i, domain := range config.Domains {
+ linkDomains[i] = resolvedLinkDomain{
+ Domain: domain,
+ RoutingOnly: false,
+ }
+ }
+
+ err = resolved.CallWithContext(
+ ctx, "org.freedesktop.resolve1.Manager.SetLinkDomains", 0,
+ iface.Index, linkDomains,
+ ).Store()
+ if err != nil {
+ return fmt.Errorf("setLinkDomains: %w", err)
+ }
+
+ return nil
+}
+
+// Down implements managerImpl.
+func (m resolvedManager) Down() error {
+ ctx, cancel := context.WithTimeout(context.Background(), reconfigTimeout)
+ defer cancel()
+
+ // conn is a shared connection whose lifecycle is managed by the dbus package.
+ // We should not interfere with that by closing it.
+ conn, err := dbus.SystemBus()
+ if err != nil {
+ return fmt.Errorf("connecting to system bus: %w", err)
+ }
+
+ resolved := conn.Object(
+ "org.freedesktop.resolve1",
+ dbus.ObjectPath("/org/freedesktop/resolve1"),
+ )
+
+ _, iface, err := interfaces.Tailscale()
+ if err != nil {
+ return fmt.Errorf("getting interface index: %w", err)
+ }
+ if iface == nil {
+ return errNotReady
+ }
+
+ err = resolved.CallWithContext(
+ ctx, "org.freedesktop.resolve1.Manager.RevertLink", 0,
+ iface.Index,
+ ).Store()
+ if err != nil {
+ return fmt.Errorf("RevertLink: %w", err)
+ }
+
+ return nil
+}
diff --git a/net/dns/tsdns.go b/net/dns/tsdns.go
new file mode 100644
index 000000000..2b530b81e
--- /dev/null
+++ b/net/dns/tsdns.go
@@ -0,0 +1,662 @@
+// Copyright (c) 2020 Tailscale Inc & AUTHORS All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+// Package dns provides a Resolver capable of resolving
+// domains on a Tailscale network.
+package dns
+
+import (
+ "encoding/hex"
+ "errors"
+ "net"
+ "strings"
+ "sync"
+ "time"
+
+ dns "golang.org/x/net/dns/dnsmessage"
+ "inet.af/netaddr"
+ "tailscale.com/net/interfaces"
+ "tailscale.com/types/logger"
+ "tailscale.com/util/dnsname"
+ "tailscale.com/wgengine/monitor"
+)
+
+// maxResponseBytes is the maximum size of a response from a Resolver.
+const maxResponseBytes = 512
+
+// queueSize is the maximal number of DNS requests that can await polling.
+// If EnqueueRequest is called when this many requests are already pending,
+// the request will be dropped to avoid blocking the caller.
+const queueSize = 64
+
+// defaultTTL is the TTL of all responses from Resolver.
+const defaultTTL = 600 * time.Second
+
+// ErrClosed indicates that the resolver has been closed and readers should exit.
+var ErrClosed = errors.New("closed")
+
+var (
+ errFullQueue = errors.New("request queue full")
+ errMapNotSet = errors.New("domain map not set")
+ errNotForwarding = errors.New("forwarding disabled")
+ errNotImplemented = errors.New("query type not implemented")
+ errNotQuery = errors.New("not a DNS query")
+ errNotOurName = errors.New("not a Tailscale DNS name")
+)
+
+// Packet represents a DNS payload together with the address of its origin.
+type Packet struct {
+ // Payload is the application layer DNS payload.
+ // Resolver assumes ownership of the request payload when it is enqueued
+ // and cedes ownership of the response payload when it is returned from NextResponse.
+ Payload []byte
+ // Addr is the source address for a request and the destination address for a response.
+ Addr netaddr.IPPort
+}
+
+// Resolver is a DNS resolver for nodes on the Tailscale network,
+// associating them with domain names of the form <mynode>.<mydomain>.<root>.
+// If it is asked to resolve a domain that is not of that form,
+// it delegates to upstream nameservers if any are set.
+type Resolver struct {
+ logf logger.Logf
+ linkMon *monitor.Mon // or nil
+ unregLinkMon func() // or nil
+ // forwarder forwards requests to upstream nameservers.
+ forwarder *forwarder
+
+ // queue is a buffered channel holding DNS requests queued for resolution.
+ queue chan Packet
+ // responses is an unbuffered channel to which responses are returned.
+ responses chan Packet
+ // errors is an unbuffered channel to which errors are returned.
+ errors chan error
+ // closed signals all goroutines to stop.
+ closed chan struct{}
+ // wg signals when all goroutines have stopped.
+ wg sync.WaitGroup
+
+ // mu guards the following fields from being updated while used.
+ mu sync.Mutex
+ // dnsMap is the map most recently received from the control server.
+ dnsMap *Map
+}
+
+// ResolverConfig is the set of configuration options for a Resolver.
+type ResolverConfig struct {
+ // Logf is the logger to use throughout the Resolver.
+ Logf logger.Logf
+ // Forward determines whether the resolver will forward packets to
+ // nameservers set with SetUpstreams if the domain name is not of a Tailscale node.
+ Forward bool
+ // LinkMonitor optionally provides a link monitor to use to rebind
+ // connections on link changes.
+ // If nil, rebinds are not performend.
+ LinkMonitor *monitor.Mon
+}
+
+// NewResolver constructs a resolver associated with the given root domain.
+// The root domain must be in canonical form (with a trailing period).
+func NewResolver(config ResolverConfig) *Resolver {
+ r := &Resolver{
+ logf: logger.WithPrefix(config.Logf, "dns: "),
+ linkMon: config.LinkMonitor,
+ queue: make(chan Packet, queueSize),
+ responses: make(chan Packet),
+ errors: make(chan error),
+ closed: make(chan struct{}),
+ }
+
+ if config.Forward {
+ r.forwarder = newForwarder(r.logf, r.responses)
+ }
+ if r.linkMon != nil {
+ r.unregLinkMon = r.linkMon.RegisterChangeCallback(r.onLinkMonitorChange)
+ }
+
+ return r
+}
+
+func (r *Resolver) Start() error {
+ if r.forwarder != nil {
+ if err := r.forwarder.Start(); err != nil {
+ return err
+ }
+ }
+
+ r.wg.Add(1)
+ go r.poll()
+
+ return nil
+}
+
+// Close shuts down the resolver and ensures poll goroutines have exited.
+// The Resolver cannot be used again after Close is called.
+func (r *Resolver) Close() {
+ select {
+ case <-r.closed:
+ return
+ default:
+ // continue
+ }
+ close(r.closed)
+
+ if r.unregLinkMon != nil {
+ r.unregLinkMon()
+ }
+
+ if r.forwarder != nil {
+ r.forwarder.Close()
+ }
+
+ r.wg.Wait()
+}
+
+func (r *Resolver) onLinkMonitorChange(changed bool, state *interfaces.State) {
+ if !changed {
+ return
+ }
+ if r.forwarder != nil {
+ r.forwarder.rebindFromNetworkChange()
+ }
+}
+
+// SetMap sets the resolver's DNS map, taking ownership of it.
+func (r *Resolver) SetMap(m *Map) {
+ r.mu.Lock()
+ oldMap := r.dnsMap
+ r.dnsMap = m
+ r.mu.Unlock()
+ r.logf("map diff:\n%s", m.PrettyDiffFrom(oldMap))
+}
+
+// SetUpstreams sets the addresses of the resolver's
+// upstream nameservers, taking ownership of the argument.
+func (r *Resolver) SetUpstreams(upstreams []net.Addr) {
+ if r.forwarder != nil {
+ r.forwarder.setUpstreams(upstreams)
+ }
+ r.logf("set upstreams: %v", upstreams)
+}
+
+// EnqueueRequest places the given DNS request in the resolver's queue.
+// It takes ownership of the payload and does not block.
+// If the queue is full, the request will be dropped and an error will be returned.
+func (r *Resolver) EnqueueRequest(request Packet) error {
+ select {
+ case <-r.closed:
+ return ErrClosed
+ case r.queue <- request:
+ return nil
+ default:
+ return errFullQueue
+ }
+}
+
+// NextResponse returns a DNS response to a previously enqueued request.
+// It blocks until a response is available and gives up ownership of the response payload.
+func (r *Resolver) NextResponse() (Packet, error) {
+ select {
+ case <-r.closed:
+ return Packet{}, ErrClosed
+ case resp := <-r.responses:
+ return resp, nil
+ case err := <-r.errors:
+ return Packet{}, err
+ }
+}
+
+// Resolve maps a given domain name to the IP address of the host that owns it,
+// if the IP address conforms to the DNS resource type given by tp (one of A, AAAA, ALL).
+// The domain name must be in canonical form (with a trailing period).
+func (r *Resolver) Resolve(domain string, tp dns.Type) (netaddr.IP, dns.RCode, error) {
+ r.mu.Lock()
+ dnsMap := r.dnsMap
+ r.mu.Unlock()
+
+ if dnsMap == nil {
+ return netaddr.IP{}, dns.RCodeServerFailure, errMapNotSet
+ }
+
+ // Reject .onion domains per RFC 7686.
+ if dnsname.HasSuffix(domain, ".onion") {
+ return netaddr.IP{}, dns.RCodeNameError, nil
+ }
+
+ anyHasSuffix := false
+ for _, suffix := range dnsMap.rootDomains {
+ if dnsname.HasSuffix(domain, suffix) {
+ anyHasSuffix = true
+ break
+ }
+ }
+ addr, found := dnsMap.nameToIP[domain]
+ if !found {
+ if !anyHasSuffix {
+ return netaddr.IP{}, dns.RCodeRefused, nil
+ }
+ return netaddr.IP{}, dns.RCodeNameError, nil
+ }
+
+ // Refactoring note: this must happen after we check suffixes,
+ // otherwise we will respond with NOTIMP to requests that should be forwarded.
+ switch tp {
+ case dns.TypeA:
+ if !addr.Is4() {
+ return netaddr.IP{}, dns.RCodeSuccess, nil
+ }
+ return addr, dns.RCodeSuccess, nil
+ case dns.TypeAAAA:
+ if !addr.Is6() {
+ return netaddr.IP{}, dns.RCodeSuccess, nil
+ }
+ return addr, dns.RCodeSuccess, nil
+ case dns.TypeALL:
+ // Answer with whatever we've got.
+ // It could be IPv4, IPv6, or a zero addr.
+ // TODO: Return all available resolutions (A and AAAA, if we have them).
+ return addr, dns.RCodeSuccess, nil
+
+ // Leave some some record types explicitly unimplemented.
+ // These types relate to recursive resolution or special
+ // DNS sematics and might be implemented in the future.
+ case dns.TypeNS, dns.TypeSOA, dns.TypeAXFR, dns.TypeHINFO:
+ return netaddr.IP{}, dns.RCodeNotImplemented, errNotImplemented
+
+ // For everything except for the few types above that are explictly not implemented, return no records.
+ // This is what other DNS systems do: always return NOERROR
+ // without any records whenever the requested record type is unknown.
+ // You can try this with:
+ // dig -t TYPE9824 example.com
+ // and note that NOERROR is returned, despite that record type being made up.
+ default:
+ // no records exist of this type
+ return netaddr.IP{}, dns.RCodeSuccess, nil
+ }
+}
+
+// ResolveReverse returns the unique domain name that maps to the given address.
+// The returned domain name is in canonical form (with a trailing period).
+func (r *Resolver) ResolveReverse(ip netaddr.IP) (string, dns.RCode, error) {
+ r.mu.Lock()
+ dnsMap := r.dnsMap
+ r.mu.Unlock()
+
+ if dnsMap == nil {
+ return "", dns.RCodeServerFailure, errMapNotSet
+ }
+ name, found := dnsMap.ipToName[ip]
+ if !found {
+ return "", dns.RCodeNameError, nil
+ }
+ return name, dns.RCodeSuccess, nil
+}
+
+func (r *Resolver) poll() {
+ defer r.wg.Done()
+
+ var packet Packet
+ for {
+ select {
+ case <-r.closed:
+ return
+ case packet = <-r.queue:
+ // continue
+ }
+
+ out, err := r.respond(packet.Payload)
+
+ if err == errNotOurName {
+ if r.forwarder != nil {
+ err = r.forwarder.forward(packet)
+ if err == nil {
+ // forward will send response into r.responses, nothing to do.
+ continue
+ }
+ } else {
+ err = errNotForwarding
+ }
+ }
+
+ if err != nil {
+ select {
+ case <-r.closed:
+ return
+ case r.errors <- err:
+ // continue
+ }
+ } else {
+ packet.Payload = out
+ select {
+ case <-r.closed:
+ return
+ case r.responses <- packet:
+ // continue
+ }
+ }
+ }
+}
+
+type response struct {
+ Header dns.Header
+ Question dns.Question
+ // Name is the response to a PTR query.
+ Name string
+ // IP is the response to an A, AAAA, or ALL query.
+ IP netaddr.IP
+}
+
+// parseQuery parses the query in given packet into a response struct.
+func parseQuery(query []byte, resp *response) error {
+ var parser dns.Parser
+ var err error
+
+ resp.Header, err = parser.Start(query)
+ if err != nil {
+ return err
+ }
+
+ if resp.Header.Response {
+ return errNotQuery
+ }
+
+ resp.Question, err = parser.Question()
+ if err != nil {
+ return err
+ }
+
+ return nil
+}
+
+// marshalARecord serializes an A record into an active builder.
+// The caller may continue using the builder following the call.
+func marshalARecord(name dns.Name, ip netaddr.IP, builder *dns.Builder) error {
+ var answer dns.AResource
+
+ answerHeader := dns.ResourceHeader{
+ Name: name,
+ Type: dns.TypeA,
+ Class: dns.ClassINET,
+ TTL: uint32(defaultTTL / time.Second),
+ }
+ ipbytes := ip.As4()
+ copy(answer.A[:], ipbytes[:])
+ return builder.AResource(answerHeader, answer)
+}
+
+// marshalAAAARecord serializes an AAAA record into an active builder.
+// The caller may continue using the builder following the call.
+func marshalAAAARecord(name dns.Name, ip netaddr.IP, builder *dns.Builder) error {
+ var answer dns.AAAAResource
+
+ answerHeader := dns.ResourceHeader{
+ Name: name,
+ Type: dns.TypeAAAA,
+ Class: dns.ClassINET,
+ TTL: uint32(defaultTTL / time.Second),
+ }
+ ipbytes := ip.As16()
+ copy(answer.AAAA[:], ipbytes[:])
+ return builder.AAAAResource(answerHeader, answer)
+}
+
+// marshalPTRRecord serializes a PTR record into an active builder.
+// The caller may continue using the builder following the call.
+func marshalPTRRecord(queryName dns.Name, name string, builder *dns.Builder) error {
+ var answer dns.PTRResource
+ var err error
+
+ answerHeader := dns.ResourceHeader{
+ Name: queryName,
+ Type: dns.TypePTR,
+ Class: dns.ClassINET,
+ TTL: uint32(defaultTTL / time.Second),
+ }
+ answer.PTR, err = dns.NewName(name)
+ if err != nil {
+ return err
+ }
+ return builder.PTRResource(answerHeader, answer)
+}
+
+// marshalResponse serializes the DNS response into a new buffer.
+func marshalResponse(resp *response) ([]byte, error) {
+ resp.Header.Response = true
+ resp.Header.Authoritative = true
+ if resp.Header.RecursionDesired {
+ resp.Header.RecursionAvailable = true
+ }
+
+ builder := dns.NewBuilder(nil, resp.Header)
+
+ isSuccess := resp.Header.RCode == dns.RCodeSuccess
+
+ if resp.Question.Type != 0 || isSuccess {
+ err := builder.StartQuestions()
+ if err != nil {
+ return nil, err
+ }
+
+ err = builder.Question(resp.Question)
+ if err != nil {
+ return nil, err
+ }
+ }
+
+ // Only successful responses contain answers.
+ if !isSuccess {
+ return builder.Finish()
+ }
+
+ err := builder.StartAnswers()
+ if err != nil {
+ return nil, err
+ }
+
+ switch resp.Question.Type {
+ case dns.TypeA, dns.TypeAAAA, dns.TypeALL:
+ if resp.IP.Is4() {
+ err = marshalARecord(resp.Question.Name, resp.IP, &builder)
+ } else if resp.IP.Is6() {
+ err = marshalAAAARecord(resp.Question.Name, resp.IP, &builder)
+ }
+ case dns.TypePTR:
+ err = marshalPTRRecord(resp.Question.Name, resp.Name, &builder)
+ }
+ if err != nil {
+ return nil, err
+ }
+
+ return builder.Finish()
+}
+
+const (
+ rdnsv4Suffix = ".in-addr.arpa."
+ rdnsv6Suffix = ".ip6.arpa."
+)
+
+// hasRDNSBonjourPrefix reports whether name has a Bonjour Service Prefix..
+//
+// https://tools.ietf.org/html/rfc6763 lists
+// "five special RR names" for Bonjour service discovery:
+//
+// b._dns-sd._udp.<domain>.
+// db._dns-sd._udp.<domain>.
+// r._dns-sd._udp.<domain>.
+// dr._dns-sd._udp.<domain>.
+// lb._dns-sd._udp.<domain>.
+func hasRDNSBonjourPrefix(s string) bool {
+ // Even the shortest name containing a Bonjour prefix is long,
+ // so check length (cheap) and bail early if possible.
+ if len(s) < len("*._dns-sd._udp.0.0.0.0.in-addr.arpa.") {
+ return false
+ }
+ dot := strings.IndexByte(s, '.')
+ if dot == -1 {
+ return false // shouldn't happen
+ }
+ switch s[:dot] {
+ case "b", "db", "r", "dr", "lb":
+ default:
+ return false
+ }
+
+ return strings.HasPrefix(s[dot:], "._dns-sd._udp.")
+}
+
+// rawNameToLower converts a raw DNS name to a string, lowercasing it.
+func rawNameToLower(name []byte) string {
+ var sb strings.Builder
+ sb.Grow(len(name))
+
+ for _, b := range name {
+ if 'A' <= b && b <= 'Z' {
+ b = b - 'A' + 'a'
+ }
+ sb.WriteByte(b)
+ }
+
+ return sb.String()
+}
+
+// ptrNameToIPv4 transforms a PTR name representing an IPv4 address to said address.
+// Such names are IPv4 labels in reverse order followed by .in-addr.arpa.
+// For example,
+// 4.3.2.1.in-addr.arpa
+// is transformed to
+// 1.2.3.4
+func rdnsNameToIPv4(name string) (ip netaddr.IP, ok bool) {
+ name = strings.TrimSuffix(name, rdnsv4Suffix)
+ ip, err := netaddr.ParseIP(string(name))
+ if err != nil {
+ return netaddr.IP{}, false
+ }
+ if !ip.Is4() {
+ return netaddr.IP{}, false
+ }
+ b := ip.As4()
+ return netaddr.IPv4(b[3], b[2], b[1], b[0]), true
+}
+
+// ptrNameToIPv6 transforms a PTR name representing an IPv6 address to said address.
+// Such names are dot-separated nibbles in reverse order followed by .ip6.arpa.
+// For example,
+// b.a.9.8.7.6.5.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.8.b.d.0.1.0.0.2.ip6.arpa.
+// is transformed to
+// 2001:db8::567:89ab
+func rdnsNameToIPv6(name string) (ip netaddr.IP, ok bool) {
+ var b [32]byte
+ var ipb [16]byte
+
+ name = strings.TrimSuffix(name, rdnsv6Suffix)
+ // 32 nibbles and 31 dots between them.
+ if len(name) != 63 {
+ return netaddr.IP{}, false
+ }
+
+ // Dots and hex digits alternate.
+ prevDot := true
+ // i ranges over name backward; j ranges over b forward.
+ for i, j := len(name)-1, 0; i >= 0; i-- {
+ thisDot := (name[i] == '.')
+ if prevDot == thisDot {
+ return netaddr.IP{}, false
+ }
+ prevDot = thisDot
+
+ if !thisDot {
+ // This is safe assuming alternation.
+ // We do not check that non-dots are hex digits: hex.Decode below will do that.
+ b[j] = name[i]
+ j++
+ }
+ }
+
+ _, err := hex.Decode(ipb[:], b[:])
+ if err != nil {
+ return netaddr.IP{}, false
+ }
+
+ return netaddr.IPFrom16(ipb), true
+}
+
+// respondReverse returns a DNS response to a PTR query.
+// It is assumed that resp.Question is populated by respond before this is called.
+func (r *Resolver) respondReverse(query []byte, name string, resp *response) ([]byte, error) {
+ if hasRDNSBonjourPrefix(name) {
+ return nil, errNotOurName
+ }
+
+ var ip netaddr.IP
+ var ok bool
+ switch {
+ case strings.HasSuffix(name, rdnsv4Suffix):
+ ip, ok = rdnsNameToIPv4(name)
+ case strings.HasSuffix(name, rdnsv6Suffix):
+ ip, ok = rdnsNameToIPv6(name)
+ default:
+ return nil, errNotOurName
+ }
+
+ // It is more likely that we failed in parsing the name than that it is actually malformed.
+ // To avoid frustrating users, just log and delegate.
+ if !ok {
+ r.logf("parsing rdns: malformed name: %s", name)
+ return nil, errNotOurName
+ }
+
+ var err error
+ resp.Name, resp.Header.RCode, err = r.ResolveReverse(ip)
+ if err != nil {
+ r.logf("resolving rdns: %v", ip, err)
+ }
+ if resp.Header.RCode == dns.RCodeNameError {
+ return nil, errNotOurName
+ }
+
+ return marshalResponse(resp)
+}
+
+// respond returns a DNS response to query if it can be resolved locally.
+// Otherwise, it returns errNotOurName.
+func (r *Resolver) respond(query []byte) ([]byte, error) {
+ resp := new(response)
+
+ // ParseQuery is sufficiently fast to run on every DNS packet.
+ // This is considerably simpler than extracting the name by hand
+ // to shave off microseconds in case of delegation.
+ err := parseQuery(query, resp)
+ // We will not return this error: it is the sender's fault.
+ if err != nil {
+ if errors.Is(err, dns.ErrSectionDone) {
+ r.logf("parseQuery(%02x): no DNS questions", query)
+ } else {
+ r.logf("parseQuery(%02x): %v", query, err)
+ }
+ resp.Header.RCode = dns.RCodeFormatError
+ return marshalResponse(resp)
+ }
+ rawName := resp.Question.Name.Data[:resp.Question.Name.Length]
+ name := rawNameToLower(rawName)
+
+ // Always try to handle reverse lookups; delegate inside when not found.
+ // This way, queries for existent nodes do not leak,
+ // but we behave gracefully if non-Tailscale nodes exist in CGNATRange.
+ if resp.Question.Type == dns.TypePTR {
+ return r.respondReverse(query, name, resp)
+ }
+
+ resp.IP, resp.Header.RCode, err = r.Resolve(name, resp.Question.Type)
+ // This return code is special: it requests forwarding.
+ if resp.Header.RCode == dns.RCodeRefused {
+ return nil, errNotOurName
+ }
+
+ // We will not return this error: it is the sender's fault.
+ if err != nil {
+ r.logf("resolving: %v", err)
+ }
+
+ return marshalResponse(resp)
+}
diff --git a/net/dns/tsdns_server_test.go b/net/dns/tsdns_server_test.go
new file mode 100644
index 000000000..95544ba18
--- /dev/null
+++ b/net/dns/tsdns_server_test.go
@@ -0,0 +1,95 @@
+// Copyright (c) 2020 Tailscale Inc & AUTHORS All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package dns
+
+import (
+ "log"
+ "testing"
+
+ "github.com/miekg/dns"
+ "inet.af/netaddr"
+)
+
+// This file exists to isolate the test infrastructure
+// that depends on github.com/miekg/dns
+// from the rest, which only depends on dnsmessage.
+
+var dnsHandleFunc = dns.HandleFunc
+
+// resolveToIP returns a handler function which responds
+// to queries of type A it receives with an A record containing ipv4,
+// to queries of type AAAA with an AAAA record containing ipv6,
+// to queries of type NS with an NS record containg name.
+func resolveToIP(ipv4, ipv6 netaddr.IP, ns string) dns.HandlerFunc {
+ return func(w dns.ResponseWriter, req *dns.Msg) {
+ m := new(dns.Msg)
+ m.SetReply(req)
+
+ if len(req.Question) != 1 {
+ panic("not a single-question request")
+ }
+ question := req.Question[0]
+
+ var ans dns.RR
+ switch question.Qtype {
+ case dns.TypeA:
+ ans = &dns.A{
+ Hdr: dns.RR_Header{
+ Name: question.Name,
+ Rrtype: dns.TypeA,
+ Class: dns.ClassINET,
+ },
+ A: ipv4.IPAddr().IP,
+ }
+ case dns.TypeAAAA:
+ ans = &dns.AAAA{
+ Hdr: dns.RR_Header{
+ Name: question.Name,
+ Rrtype: dns.TypeAAAA,
+ Class: dns.ClassINET,
+ },
+ AAAA: ipv6.IPAddr().IP,
+ }
+ case dns.TypeNS:
+ ans = &dns.NS{
+ Hdr: dns.RR_Header{
+ Name: question.Name,
+ Rrtype: dns.TypeNS,
+ Class: dns.ClassINET,
+ },
+ Ns: ns,
+ }
+ }
+
+ m.Answer = append(m.Answer, ans)
+ w.WriteMsg(m)
+ }
+}
+
+func resolveToNXDOMAIN(w dns.ResponseWriter, req *dns.Msg) {
+ m := new(dns.Msg)
+ m.SetRcode(req, dns.RcodeNameError)
+ w.WriteMsg(m)
+}
+
+func serveDNS(tb testing.TB, addr string) (*dns.Server, chan error) {
+ server := &dns.Server{Addr: addr, Net: "udp"}
+
+ waitch := make(chan struct{})
+ server.NotifyStartedFunc = func() { close(waitch) }
+
+ errch := make(chan error, 1)
+ go func() {
+ err := server.ListenAndServe()
+ if err != nil {
+ log.Printf("ListenAndServe(%q): %v", addr, err)
+ }
+ errch <- err
+ close(errch)
+ }()
+
+ <-waitch
+ return server, errch
+}
diff --git a/net/dns/tsdns_test.go b/net/dns/tsdns_test.go
new file mode 100644
index 000000000..59bcd8ec1
--- /dev/null
+++ b/net/dns/tsdns_test.go
@@ -0,0 +1,816 @@
+// Copyright (c) 2020 Tailscale Inc & AUTHORS All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package dns
+
+import (
+ "bytes"
+ "errors"
+ "net"
+ "sync"
+ "testing"
+
+ dns "golang.org/x/net/dns/dnsmessage"
+ "inet.af/netaddr"
+ "tailscale.com/tstest"
+)
+
+var testipv4 = netaddr.IPv4(1, 2, 3, 4)
+var testipv6 = netaddr.IPv6Raw([16]byte{
+ 0x00, 0x01, 0x02, 0x03,
+ 0x04, 0x05, 0x06, 0x07,
+ 0x08, 0x09, 0x0a, 0x0b,
+ 0x0c, 0x0d, 0x0e, 0x0f,
+})
+
+var dnsMap = NewMap(
+ map[string]netaddr.IP{
+ "test1.ipn.dev.": testipv4,
+ "test2.ipn.dev.": testipv6,
+ },
+ []string{"ipn.dev."},
+)
+
+func dnspacket(domain string, tp dns.Type) []byte {
+ var dnsHeader dns.Header
+ question := dns.Question{
+ Name: dns.MustNewName(domain),
+ Type: tp,
+ Class: dns.ClassINET,
+ }
+
+ builder := dns.NewBuilder(nil, dnsHeader)
+ builder.StartQuestions()
+ builder.Question(question)
+ payload, _ := builder.Finish()
+
+ return payload
+}
+
+type dnsResponse struct {
+ ip netaddr.IP
+ name string
+ rcode dns.RCode
+}
+
+func unpackResponse(payload []byte) (dnsResponse, error) {
+ var response dnsResponse
+ var parser dns.Parser
+
+ h, err := parser.Start(payload)
+ if err != nil {
+ return response, err
+ }
+
+ if !h.Response {
+ return response, errors.New("not a response")
+ }
+
+ response.rcode = h.RCode
+ if response.rcode != dns.RCodeSuccess {
+ return response, nil
+ }
+
+ err = parser.SkipAllQuestions()
+ if err != nil {
+ return response, err
+ }
+
+ ah, err := parser.AnswerHeader()
+ if err != nil {
+ return response, err
+ }
+
+ switch ah.Type {
+ case dns.TypeA:
+ res, err := parser.AResource()
+ if err != nil {
+ return response, err
+ }
+ response.ip = netaddr.IPv4(res.A[0], res.A[1], res.A[2], res.A[3])
+ case dns.TypeAAAA:
+ res, err := parser.AAAAResource()
+ if err != nil {
+ return response, err
+ }
+ response.ip = netaddr.IPv6Raw(res.AAAA)
+ case dns.TypeNS:
+ res, err := parser.NSResource()
+ if err != nil {
+ return response, err
+ }
+ response.name = res.NS.String()
+ default:
+ return response, errors.New("type not in {A, AAAA, NS}")
+ }
+
+ return response, nil
+}
+
+func syncRespond(r *Resolver, query []byte) ([]byte, error) {
+ request := Packet{Payload: query}
+ r.EnqueueRequest(request)
+ resp, err := r.NextResponse()
+ return resp.Payload, err
+}
+
+func mustIP(str string) netaddr.IP {
+ ip, err := netaddr.ParseIP(str)
+ if err != nil {
+ panic(err)
+ }
+ return ip
+}
+
+func TestRDNSNameToIPv4(t *testing.T) {
+ tests := []struct {
+ name string
+ input string
+ wantIP netaddr.IP
+ wantOK bool
+ }{
+ {"valid", "4.123.24.1.in-addr.arpa.", netaddr.IPv4(1, 24, 123, 4), true},
+ {"double_dot", "1..2.3.in-addr.arpa.", netaddr.IP{}, false},
+ {"overflow", "1.256.3.4.in-addr.arpa.", netaddr.IP{}, false},
+ {"not_ip", "sub.do.ma.in.in-addr.arpa.", netaddr.IP{}, false},
+ }
+
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ ip, ok := rdnsNameToIPv4(tt.input)
+ if ok != tt.wantOK {
+ t.Errorf("ok = %v; want %v", ok, tt.wantOK)
+ } else if ok && ip != tt.wantIP {
+ t.Errorf("ip = %v; want %v", ip, tt.wantIP)
+ }
+ })
+ }
+}
+
+func TestRDNSNameToIPv6(t *testing.T) {
+ tests := []struct {
+ name string
+ input string
+ wantIP netaddr.IP
+ wantOK bool
+ }{
+ {
+ "valid",
+ "b.a.9.8.7.6.5.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.8.b.d.0.1.0.0.2.ip6.arpa.",
+ mustIP("2001:db8::567:89ab"),
+ true,
+ },
+ {
+ "double_dot",
+ "b..9.8.7.6.5.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.8.b.d.0.1.0.0.2.ip6.arpa.",
+ netaddr.IP{},
+ false,
+ },
+ {
+ "double_hex",
+ "b.a.98.0.7.6.5.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.8.b.d.0.1.0.0.2.ip6.arpa.",
+ netaddr.IP{},
+ false,
+ },
+ {
+ "not_hex",
+ "b.a.g.0.7.6.5.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.8.b.d.0.1.0.0.2.ip6.arpa.",
+ netaddr.IP{},
+ false,
+ },
+ }
+
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ ip, ok := rdnsNameToIPv6(tt.input)
+ if ok != tt.wantOK {
+ t.Errorf("ok = %v; want %v", ok, tt.wantOK)
+ } else if ok && ip != tt.wantIP {
+ t.Errorf("ip = %v; want %v", ip, tt.wantIP)
+ }
+ })
+ }
+}
+
+func TestResolve(t *testing.T) {
+ r := NewResolver(ResolverConfig{Logf: t.Logf, Forward: false})
+ r.SetMap(dnsMap)
+
+ if err := r.Start(); err != nil {
+ t.Fatalf("start: %v", err)
+ }
+ defer r.Close()
+
+ tests := []struct {
+ name string
+ qname string
+ qtype dns.Type
+ ip netaddr.IP
+ code dns.RCode
+ }{
+ {"ipv4", "test1.ipn.dev.", dns.TypeA, testipv4, dns.RCodeSuccess},
+ {"ipv6", "test2.ipn.dev.", dns.TypeAAAA, testipv6, dns.RCodeSuccess},
+ {"no-ipv6", "test1.ipn.dev.", dns.TypeAAAA, netaddr.IP{}, dns.RCodeSuccess},
+ {"nxdomain", "test3.ipn.dev.", dns.TypeA, netaddr.IP{}, dns.RCodeNameError},
+ {"foreign domain", "google.com.", dns.TypeA, netaddr.IP{}, dns.RCodeRefused},
+ {"all", "test1.ipn.dev.", dns.TypeA, testipv4, dns.RCodeSuccess},
+ {"mx-ipv4", "test1.ipn.dev.", dns.TypeMX, netaddr.IP{}, dns.RCodeSuccess},
+ {"mx-ipv6", "test2.ipn.dev.", dns.TypeMX, netaddr.IP{}, dns.RCodeSuccess},
+ {"mx-nxdomain", "test3.ipn.dev.", dns.TypeMX, netaddr.IP{}, dns.RCodeNameError},
+ {"ns-nxdomain", "test3.ipn.dev.", dns.TypeNS, netaddr.IP{}, dns.RCodeNameError},
+ {"onion-domain", "footest.onion.", dns.TypeA, netaddr.IP{}, dns.RCodeNameError},
+ }
+
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ ip, code, err := r.Resolve(tt.qname, tt.qtype)
+ if err != nil {
+ t.Errorf("err = %v; want nil", err)
+ }
+ if code != tt.code {
+ t.Errorf("code = %v; want %v", code, tt.code)
+ }
+ // Only check ip for non-err
+ if ip != tt.ip {
+ t.Errorf("ip = %v; want %v", ip, tt.ip)
+ }
+ })
+ }
+}
+
+func TestResolveReverse(t *testing.T) {
+ r := NewResolver(ResolverConfig{Logf: t.Logf, Forward: false})
+ r.SetMap(dnsMap)
+
+ if err := r.Start(); err != nil {
+ t.Fatalf("start: %v", err)
+ }
+ defer r.Close()
+
+ tests := []struct {
+ name string
+ ip netaddr.IP
+ want string
+ code dns.RCode
+ }{
+ {"ipv4", testipv4, "test1.ipn.dev.", dns.RCodeSuccess},
+ {"ipv6", testipv6, "test2.ipn.dev.", dns.RCodeSuccess},
+ {"nxdomain", netaddr.IPv4(4, 3, 2, 1), "", dns.RCodeNameError},
+ }
+
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ name, code, err := r.ResolveReverse(tt.ip)
+ if err != nil {
+ t.Errorf("err = %v; want nil", err)
+ }
+ if code != tt.code {
+ t.Errorf("code = %v; want %v", code, tt.code)
+ }
+ if name != tt.want {
+ t.Errorf("ip = %v; want %v", name, tt.want)
+ }
+ })
+ }
+}
+
+func ipv6Works() bool {
+ c, err := net.Listen("tcp", "[::1]:0")
+ if err != nil {
+ return false
+ }
+ c.Close()
+ return true
+}
+
+func TestDelegate(t *testing.T) {
+ tstest.ResourceCheck(t)
+
+ if !ipv6Works() {
+ t.Skip("skipping test that requires localhost IPv6")
+ }
+
+ dnsHandleFunc("test.site.", resolveToIP(testipv4, testipv6, "dns.test.site."))
+ dnsHandleFunc("nxdomain.site.", resolveToNXDOMAIN)
+
+ v4server, v4errch := serveDNS(t, "127.0.0.1:0")
+ v6server, v6errch := serveDNS(t, "[::1]:0")
+
+ defer func() {
+ if err := <-v4errch; err != nil {
+ t.Errorf("v4 server error: %v", err)
+ }
+ if err := <-v6errch; err != nil {
+ t.Errorf("v6 server error: %v", err)
+ }
+ }()
+ if v4server != nil {
+ defer v4server.Shutdown()
+ }
+ if v6server != nil {
+ defer v6server.Shutdown()
+ }
+
+ if v4server == nil || v6server == nil {
+ // There is an error in at least one of the channels
+ // and we cannot proceed; return to see it.
+ return
+ }
+
+ r := NewResolver(ResolverConfig{Logf: t.Logf, Forward: true})
+ r.SetMap(dnsMap)
+ r.SetUpstreams([]net.Addr{
+ v4server.PacketConn.LocalAddr(),
+ v6server.PacketConn.LocalAddr(),
+ })
+
+ if err := r.Start(); err != nil {
+ t.Fatalf("start: %v", err)
+ }
+ defer r.Close()
+
+ tests := []struct {
+ title string
+ query []byte
+ response dnsResponse
+ }{
+ {
+ "ipv4",
+ dnspacket("test.site.", dns.TypeA),
+ dnsResponse{ip: testipv4, rcode: dns.RCodeSuccess},
+ },
+ {
+ "ipv6",
+ dnspacket("test.site.", dns.TypeAAAA),
+ dnsResponse{ip: testipv6, rcode: dns.RCodeSuccess},
+ },
+ {
+ "ns",
+ dnspacket("test.site.", dns.TypeNS),
+ dnsResponse{name: "dns.test.site.", rcode: dns.RCodeSuccess},
+ },
+ {
+ "nxdomain",
+ dnspacket("nxdomain.site.", dns.TypeA),
+ dnsResponse{rcode: dns.RCodeNameError},
+ },
+ }
+
+ for _, tt := range tests {
+ t.Run(tt.title, func(t *testing.T) {
+ payload, err := syncRespond(r, tt.query)
+ if err != nil {
+ t.Errorf("err = %v; want nil", err)
+ return
+ }
+ response, err := unpackResponse(payload)
+ if err != nil {
+ t.Errorf("extract: err = %v; want nil (in %x)", err, payload)
+ return
+ }
+ if response.rcode != tt.response.rcode {
+ t.Errorf("rcode = %v; want %v", response.rcode, tt.response.rcode)
+ }
+ if response.ip != tt.response.ip {
+ t.Errorf("ip = %v; want %v", response.ip, tt.response.ip)
+ }
+ if response.name != tt.response.name {
+ t.Errorf("name = %v; want %v", response.name, tt.response.name)
+ }
+ })
+ }
+}
+
+func TestDelegateCollision(t *testing.T) {
+ dnsHandleFunc("test.site.", resolveToIP(testipv4, testipv6, "dns.test.site."))
+
+ server, errch := serveDNS(t, "127.0.0.1:0")
+ defer func() {
+ if err := <-errch; err != nil {
+ t.Errorf("server error: %v", err)
+ }
+ }()
+
+ if server == nil {
+ return
+ }
+ defer server.Shutdown()
+
+ r := NewResolver(ResolverConfig{Logf: t.Logf, Forward: true})
+ r.SetMap(dnsMap)
+ r.SetUpstreams([]net.Addr{server.PacketConn.LocalAddr()})
+
+ if err := r.Start(); err != nil {
+ t.Fatalf("start: %v", err)
+ }
+ defer r.Close()
+
+ packets := []struct {
+ qname string
+ qtype dns.Type
+ addr netaddr.IPPort
+ }{
+ {"test.site.", dns.TypeA, netaddr.IPPort{IP: netaddr.IPv4(1, 1, 1, 1), Port: 1001}},
+ {"test.site.", dns.TypeAAAA, netaddr.IPPort{IP: netaddr.IPv4(1, 1, 1, 1), Port: 1002}},
+ }
+
+ // packets will have the same dns txid.
+ for _, p := range packets {
+ payload := dnspacket(p.qname, p.qtype)
+ req := Packet{Payload: payload, Addr: p.addr}
+ err := r.EnqueueRequest(req)
+ if err != nil {
+ t.Error(err)
+ }
+ }
+
+ // Despite the txid collision, the answer(s) should still match the query.
+ resp, err := r.NextResponse()
+ if err != nil {
+ t.Error(err)
+ }
+
+ var p dns.Parser
+ _, err = p.Start(resp.Payload)
+ if err != nil {
+ t.Error(err)
+ }
+ err = p.SkipAllQuestions()
+ if err != nil {
+ t.Error(err)
+ }
+ ans, err := p.AllAnswers()
+ if err != nil {
+ t.Error(err)
+ }
+
+ var wantType dns.Type
+ switch ans[0].Body.(type) {
+ case *dns.AResource:
+ wantType = dns.TypeA
+ case *dns.AAAAResource:
+ wantType = dns.TypeAAAA
+ default:
+ t.Errorf("unexpected answer type: %T", ans[0].Body)
+ }
+
+ for _, p := range packets {
+ if p.qtype == wantType && p.addr != resp.Addr {
+ t.Errorf("addr = %v; want %v", resp.Addr, p.addr)
+ }
+ }
+}
+
+func TestConcurrentSetMap(t *testing.T) {
+ r := NewResolver(ResolverConfig{Logf: t.Logf, Forward: false})
+
+ if err := r.Start(); err != nil {
+ t.Fatalf("start: %v", err)
+ }
+ defer r.Close()
+
+ // This is purely to ensure that Resolve does not race with SetMap.
+ var wg sync.WaitGroup
+ wg.Add(2)
+ go func() {
+ defer wg.Done()
+ r.SetMap(dnsMap)
+ }()
+ go func() {
+ defer wg.Done()
+ r.Resolve("test1.ipn.dev", dns.TypeA)
+ }()
+ wg.Wait()
+}
+
+func TestConcurrentSetUpstreams(t *testing.T) {
+ dnsHandleFunc("test.site.", resolveToIP(testipv4, testipv6, "dns.test.site."))
+
+ server, errch := serveDNS(t, "127.0.0.1:0")
+ defer func() {
+ if err := <-errch; err != nil {
+ t.Errorf("server error: %v", err)
+ }
+ }()
+
+ if server == nil {
+ return
+ }
+ defer server.Shutdown()
+
+ r := NewResolver(ResolverConfig{Logf: t.Logf, Forward: true})
+ r.SetMap(dnsMap)
+
+ if err := r.Start(); err != nil {
+ t.Fatalf("start: %v", err)
+ }
+ defer r.Close()
+
+ packet := dnspacket("test.site.", dns.TypeA)
+ // This is purely to ensure that delegation does not race with SetUpstreams.
+ var wg sync.WaitGroup
+ wg.Add(2)
+ go func() {
+ defer wg.Done()
+ r.SetUpstreams([]net.Addr{server.PacketConn.LocalAddr()})
+ }()
+ go func() {
+ defer wg.Done()
+ syncRespond(r, packet)
+ }()
+ wg.Wait()
+}
+
+var allResponse = []byte{
+ 0x00, 0x00, // transaction id: 0
+ 0x84, 0x00, // flags: response, authoritative, no error
+ 0x00, 0x01, // one question
+ 0x00, 0x01, // one answer
+ 0x00, 0x00, 0x00, 0x00, // no authority or additional RRs
+ // Question:
+ 0x05, 0x74, 0x65, 0x73, 0x74, 0x31, 0x03, 0x69, 0x70, 0x6e, 0x03, 0x64, 0x65, 0x76, 0x00, // name
+ 0x00, 0xff, 0x00, 0x01, // type ALL, class IN
+ // Answer:
+ 0x05, 0x74, 0x65, 0x73, 0x74, 0x31, 0x03, 0x69, 0x70, 0x6e, 0x03, 0x64, 0x65, 0x76, 0x00, // name
+ 0x00, 0x01, 0x00, 0x01, // type A, class IN
+ 0x00, 0x00, 0x02, 0x58, // TTL: 600
+ 0x00, 0x04, // length: 4 bytes
+ 0x01, 0x02, 0x03, 0x04, // A: 1.2.3.4
+}
+
+var ipv4Response = []byte{
+ 0x00, 0x00, // transaction id: 0
+ 0x84, 0x00, // flags: response, authoritative, no error
+ 0x00, 0x01, // one question
+ 0x00, 0x01, // one answer
+ 0x00, 0x00, 0x00, 0x00, // no authority or additional RRs
+ // Question:
+ 0x05, 0x74, 0x65, 0x73, 0x74, 0x31, 0x03, 0x69, 0x70, 0x6e, 0x03, 0x64, 0x65, 0x76, 0x00, // name
+ 0x00, 0x01, 0x00, 0x01, // type A, class IN
+ // Answer:
+ 0x05, 0x74, 0x65, 0x73, 0x74, 0x31, 0x03, 0x69, 0x70, 0x6e, 0x03, 0x64, 0x65, 0x76, 0x00, // name
+ 0x00, 0x01, 0x00, 0x01, // type A, class IN
+ 0x00, 0x00, 0x02, 0x58, // TTL: 600
+ 0x00, 0x04, // length: 4 bytes
+ 0x01, 0x02, 0x03, 0x04, // A: 1.2.3.4
+}
+
+var ipv6Response = []byte{
+ 0x00, 0x00, // transaction id: 0
+ 0x84, 0x00, // flags: response, authoritative, no error
+ 0x00, 0x01, // one question
+ 0x00, 0x01, // one answer
+ 0x00, 0x00, 0x00, 0x00, // no authority or additional RRs
+ // Question:
+ 0x05, 0x74, 0x65, 0x73, 0x74, 0x32, 0x03, 0x69, 0x70, 0x6e, 0x03, 0x64, 0x65, 0x76, 0x00, // name
+ 0x00, 0x1c, 0x00, 0x01, // type AAAA, class IN
+ // Answer:
+ 0x05, 0x74, 0x65, 0x73, 0x74, 0x32, 0x03, 0x69, 0x70, 0x6e, 0x03, 0x64, 0x65, 0x76, 0x00, // name
+ 0x00, 0x1c, 0x00, 0x01, // type AAAA, class IN
+ 0x00, 0x00, 0x02, 0x58, // TTL: 600
+ 0x00, 0x10, // length: 16 bytes
+ // AAAA: 0001:0203:0405:0607:0809:0A0B:0C0D:0E0F
+ 0x00, 0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07, 0x08, 0x09, 0x0a, 0xb, 0xc, 0xd, 0xe, 0xf,
+}
+
+var ipv4UppercaseResponse = []byte{
+ 0x00, 0x00, // transaction id: 0
+ 0x84, 0x00, // flags: response, authoritative, no error
+ 0x00, 0x01, // one question
+ 0x00, 0x01, // one answer
+ 0x00, 0x00, 0x00, 0x00, // no authority or additional RRs
+ // Question:
+ 0x05, 0x54, 0x45, 0x53, 0x54, 0x31, 0x03, 0x49, 0x50, 0x4e, 0x03, 0x44, 0x45, 0x56, 0x00, // name
+ 0x00, 0x01, 0x00, 0x01, // type A, class IN
+ // Answer:
+ 0x05, 0x54, 0x45, 0x53, 0x54, 0x31, 0x03, 0x49, 0x50, 0x4e, 0x03, 0x44, 0x45, 0x56, 0x00, // name
+ 0x00, 0x01, 0x00, 0x01, // type A, class IN
+ 0x00, 0x00, 0x02, 0x58, // TTL: 600
+ 0x00, 0x04, // length: 4 bytes
+ 0x01, 0x02, 0x03, 0x04, // A: 1.2.3.4
+}
+
+var ptrResponse = []byte{
+ 0x00, 0x00, // transaction id: 0
+ 0x84, 0x00, // flags: response, authoritative, no error
+ 0x00, 0x01, // one question
+ 0x00, 0x01, // one answer
+ 0x00, 0x00, 0x00, 0x00, // no authority or additional RRs
+ // Question: 4.3.2.1.in-addr.arpa
+ 0x01, 0x34, 0x01, 0x33, 0x01, 0x32, 0x01, 0x31, 0x07,
+ 0x69, 0x6e, 0x2d, 0x61, 0x64, 0x64, 0x72, 0x04, 0x61, 0x72, 0x70, 0x61, 0x00,
+ 0x00, 0x0c, 0x00, 0x01, // type PTR, class IN
+ // Answer: 4.3.2.1.in-addr.arpa
+ 0x01, 0x34, 0x01, 0x33, 0x01, 0x32, 0x01, 0x31, 0x07,
+ 0x69, 0x6e, 0x2d, 0x61, 0x64, 0x64, 0x72, 0x04, 0x61, 0x72, 0x70, 0x61, 0x00,
+ 0x00, 0x0c, 0x00, 0x01, // type PTR, class IN
+ 0x00, 0x00, 0x02, 0x58, // TTL: 600
+ 0x00, 0x0f, // length: 15 bytes
+ // PTR: test1.ipn.dev
+ 0x05, 0x74, 0x65, 0x73, 0x74, 0x31, 0x03, 0x69, 0x70, 0x6e, 0x03, 0x64, 0x65, 0x76, 0x00,
+}
+
+var ptrResponse6 = []byte{
+ 0x00, 0x00, // transaction id: 0
+ 0x84, 0x00, // flags: response, authoritative, no error
+ 0x00, 0x01, // one question
+ 0x00, 0x01, // one answer
+ 0x00, 0x00, 0x00, 0x00, // no authority or additional RRs
+ // Question: f.0.e.0.d.0.c.0.b.0.a.0.9.0.8.0.7.0.6.0.5.0.4.0.3.0.2.0.1.0.0.0.ip6.arpa
+ 0x01, 0x66, 0x01, 0x30, 0x01, 0x65, 0x01, 0x30,
+ 0x01, 0x64, 0x01, 0x30, 0x01, 0x63, 0x01, 0x30,
+ 0x01, 0x62, 0x01, 0x30, 0x01, 0x61, 0x01, 0x30,
+ 0x01, 0x39, 0x01, 0x30, 0x01, 0x38, 0x01, 0x30,
+ 0x01, 0x37, 0x01, 0x30, 0x01, 0x36, 0x01, 0x30,
+ 0x01, 0x35, 0x01, 0x30, 0x01, 0x34, 0x01, 0x30,
+ 0x01, 0x33, 0x01, 0x30, 0x01, 0x32, 0x01, 0x30,
+ 0x01, 0x31, 0x01, 0x30, 0x01, 0x30, 0x01, 0x30,
+ 0x03, 0x69, 0x70, 0x36,
+ 0x04, 0x61, 0x72, 0x70, 0x61, 0x00,
+ 0x00, 0x0c, 0x00, 0x01, // type PTR, class IN6
+ // Answer: f.0.e.0.d.0.c.0.b.0.a.0.9.0.8.0.7.0.6.0.5.0.4.0.3.0.2.0.1.0.0.0.ip6.arpa
+ 0x01, 0x66, 0x01, 0x30, 0x01, 0x65, 0x01, 0x30,
+ 0x01, 0x64, 0x01, 0x30, 0x01, 0x63, 0x01, 0x30,
+ 0x01, 0x62, 0x01, 0x30, 0x01, 0x61, 0x01, 0x30,
+ 0x01, 0x39, 0x01, 0x30, 0x01, 0x38, 0x01, 0x30,
+ 0x01, 0x37, 0x01, 0x30, 0x01, 0x36, 0x01, 0x30,
+ 0x01, 0x35, 0x01, 0x30, 0x01, 0x34, 0x01, 0x30,
+ 0x01, 0x33, 0x01, 0x30, 0x01, 0x32, 0x01, 0x30,
+ 0x01, 0x31, 0x01, 0x30, 0x01, 0x30, 0x01, 0x30,
+ 0x03, 0x69, 0x70, 0x36,
+ 0x04, 0x61, 0x72, 0x70, 0x61, 0x00,
+ 0x00, 0x0c, 0x00, 0x01, // type PTR, class IN
+ 0x00, 0x00, 0x02, 0x58, // TTL: 600
+ 0x00, 0x0f, // length: 15 bytes
+ // PTR: test2.ipn.dev
+ 0x05, 0x74, 0x65, 0x73, 0x74, 0x32, 0x03, 0x69, 0x70, 0x6e, 0x03, 0x64, 0x65, 0x76, 0x00,
+}
+
+var nxdomainResponse = []byte{
+ 0x00, 0x00, // transaction id: 0
+ 0x84, 0x03, // flags: response, authoritative, error: nxdomain
+ 0x00, 0x01, // one question
+ 0x00, 0x00, // no answers
+ 0x00, 0x00, 0x00, 0x00, // no authority or additional RRs
+ // Question:
+ 0x05, 0x74, 0x65, 0x73, 0x74, 0x33, 0x03, 0x69, 0x70, 0x6e, 0x03, 0x64, 0x65, 0x76, 0x00, // name
+ 0x00, 0x01, 0x00, 0x01, // type A, class IN
+}
+
+var emptyResponse = []byte{
+ 0x00, 0x00, // transaction id: 0
+ 0x84, 0x00, // flags: response, authoritative, no error
+ 0x00, 0x01, // one question
+ 0x00, 0x00, // no answers
+ 0x00, 0x00, 0x00, 0x00, // no authority or additional RRs
+ // Question:
+ 0x05, 0x74, 0x65, 0x73, 0x74, 0x31, 0x03, 0x69, 0x70, 0x6e, 0x03, 0x64, 0x65, 0x76, 0x00, // name
+ 0x00, 0x1c, 0x00, 0x01, // type AAAA, class IN
+}
+
+func TestFull(t *testing.T) {
+ r := NewResolver(ResolverConfig{Logf: t.Logf, Forward: false})
+ r.SetMap(dnsMap)
+
+ if err := r.Start(); err != nil {
+ t.Fatalf("start: %v", err)
+ }
+ defer r.Close()
+
+ // One full packet and one error packet
+ tests := []struct {
+ name string
+ request []byte
+ response []byte
+ }{
+ {"all", dnspacket("test1.ipn.dev.", dns.TypeALL), allResponse},
+ {"ipv4", dnspacket("test1.ipn.dev.", dns.TypeA), ipv4Response},
+ {"ipv6", dnspacket("test2.ipn.dev.", dns.TypeAAAA), ipv6Response},
+ {"no-ipv6", dnspacket("test1.ipn.dev.", dns.TypeAAAA), emptyResponse},
+ {"upper", dnspacket("TEST1.IPN.DEV.", dns.TypeA), ipv4UppercaseResponse},
+ {"ptr", dnspacket("4.3.2.1.in-addr.arpa.", dns.TypePTR), ptrResponse},
+ {"ptr", dnspacket("f.0.e.0.d.0.c.0.b.0.a.0.9.0.8.0.7.0.6.0.5.0.4.0.3.0.2.0.1.0.0.0.ip6.arpa.",
+ dns.TypePTR), ptrResponse6},
+ {"nxdomain", dnspacket("test3.ipn.dev.", dns.TypeA), nxdomainResponse},
+ }
+
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ response, err := syncRespond(r, tt.request)
+ if err != nil {
+ t.Errorf("err = %v; want nil", err)
+ }
+ if !bytes.Equal(response, tt.response) {
+ t.Errorf("response = %x; want %x", response, tt.response)
+ }
+ })
+ }
+}
+
+func TestAllocs(t *testing.T) {
+ r := NewResolver(ResolverConfig{Logf: t.Logf, Forward: false})
+ r.SetMap(dnsMap)
+
+ if err := r.Start(); err != nil {
+ t.Fatalf("start: %v", err)
+ }
+ defer r.Close()
+
+ // It is seemingly pointless to test allocs in the delegate path,
+ // as dialer.Dial -> Read -> Write alone comprise 12 allocs.
+ tests := []struct {
+ name string
+ query []byte
+ want int
+ }{
+ // Name lowercasing and response slice created by dns.NewBuilder.
+ {"forward", dnspacket("test1.ipn.dev.", dns.TypeA), 2},
+ // 3 extra allocs in rdnsNameToIPv4 and one in marshalPTRRecord (dns.NewName).
+ {"reverse", dnspacket("4.3.2.1.in-addr.arpa.", dns.TypePTR), 5},
+ }
+
+ for _, tt := range tests {
+ allocs := testing.AllocsPerRun(100, func() {
+ syncRespond(r, tt.query)
+ })
+ if int(allocs) > tt.want {
+ t.Errorf("%s: allocs = %v; want %v", tt.name, allocs, tt.want)
+ }
+ }
+}
+
+func TestTrimRDNSBonjourPrefix(t *testing.T) {
+ tests := []struct {
+ in string
+ want bool
+ }{
+ {"b._dns-sd._udp.0.10.20.172.in-addr.arpa.", true},
+ {"db._dns-sd._udp.0.10.20.172.in-addr.arpa.", true},
+ {"r._dns-sd._udp.0.10.20.172.in-addr.arpa.", true},
+ {"dr._dns-sd._udp.0.10.20.172.in-addr.arpa.", true},
+ {"lb._dns-sd._udp.0.10.20.172.in-addr.arpa.", true},
+ {"qq._dns-sd._udp.0.10.20.172.in-addr.arpa.", false},
+ {"0.10.20.172.in-addr.arpa.", false},
+ {"i-have-no-dot", false},
+ }
+
+ for _, test := range tests {
+ got := hasRDNSBonjourPrefix(test.in)
+ if got != test.want {
+ t.Errorf("trimRDNSBonjourPrefix(%q) = %v, want %v", test.in, got, test.want)
+ }
+ }
+}
+
+func BenchmarkFull(b *testing.B) {
+ dnsHandleFunc("test.site.", resolveToIP(testipv4, testipv6, "dns.test.site."))
+
+ server, errch := serveDNS(b, "127.0.0.1:0")
+ defer func() {
+ if err := <-errch; err != nil {
+ b.Errorf("server error: %v", err)
+ }
+ }()
+
+ if server == nil {
+ return
+ }
+ defer server.Shutdown()
+
+ r := NewResolver(ResolverConfig{Logf: b.Logf, Forward: true})
+ r.SetMap(dnsMap)
+ r.SetUpstreams([]net.Addr{server.PacketConn.LocalAddr()})
+
+ if err := r.Start(); err != nil {
+ b.Fatalf("start: %v", err)
+ }
+ defer r.Close()
+
+ tests := []struct {
+ name string
+ request []byte
+ }{
+ {"forward", dnspacket("test1.ipn.dev.", dns.TypeA)},
+ {"reverse", dnspacket("4.3.2.1.in-addr.arpa.", dns.TypePTR)},
+ {"delegated", dnspacket("test.site.", dns.TypeA)},
+ }
+
+ for _, tt := range tests {
+ b.Run(tt.name, func(b *testing.B) {
+ for i := 0; i < b.N; i++ {
+ syncRespond(r, tt.request)
+ }
+ })
+ }
+}
+
+func TestMarshalResponseFormatError(t *testing.T) {
+ resp := new(response)
+ resp.Header.RCode = dns.RCodeFormatError
+ v, err := marshalResponse(resp)
+ if err != nil {
+ t.Errorf("marshal error: %v", err)
+ }
+ t.Logf("response: %q", v)
+}
diff --git a/net/flowtrack/flowtrack.go b/net/flowtrack/flowtrack.go
index 5fcd4ab20..8387145d3 100644
--- a/net/flowtrack/flowtrack.go
+++ b/net/flowtrack/flowtrack.go
@@ -15,16 +15,18 @@ import (
"fmt"
"inet.af/netaddr"
+ "tailscale.com/types/ipproto"
)
-// Tuple is a 4-tuple of source and destination IP and port.
+// Tuple is a 5-tuple of proto, source and destination IP and port.
type Tuple struct {
- Src netaddr.IPPort
- Dst netaddr.IPPort
+ Proto ipproto.Proto
+ Src netaddr.IPPort
+ Dst netaddr.IPPort
}
func (t Tuple) String() string {
- return fmt.Sprintf("(%v => %v)", t.Src, t.Dst)
+ return fmt.Sprintf("(%v %v => %v)", t.Proto, t.Src, t.Dst)
}
// Cache is an LRU cache keyed by Tuple.
diff --git a/net/interfaces/interfaces.go b/net/interfaces/interfaces.go
index 3a0ffeb0b..f2988af34 100644
--- a/net/interfaces/interfaces.go
+++ b/net/interfaces/interfaces.go
@@ -6,10 +6,10 @@
package interfaces
import (
+ "bytes"
"fmt"
"net"
"net/http"
- "reflect"
"runtime"
"sort"
"strings"
@@ -190,6 +190,9 @@ func ForeachInterface(fn func(Interface, []netaddr.IPPrefix)) error {
}
}
}
+ sort.Slice(pfxs, func(i, j int) bool {
+ return pfxs[i].IP.Less(pfxs[j].IP)
+ })
fn(Interface{iface}, pfxs)
}
return nil
@@ -204,7 +207,7 @@ type State struct {
// IPPrefix, where the IP is the interface IP address and Bits is
// the subnet mask.
InterfaceIPs map[string][]netaddr.IPPrefix
- InterfaceUp map[string]bool
+ Interface map[string]Interface
// HaveV6Global is whether this machine has an IPv6 global address
// on some non-Tailscale interface that's up.
@@ -235,14 +238,14 @@ type State struct {
func (s *State) String() string {
var sb strings.Builder
fmt.Fprintf(&sb, "interfaces.State{defaultRoute=%v ifs={", s.DefaultRouteInterface)
- ifs := make([]string, 0, len(s.InterfaceUp))
- for k := range s.InterfaceUp {
+ ifs := make([]string, 0, len(s.Interface))
+ for k := range s.Interface {
if anyInterestingIP(s.InterfaceIPs[k]) {
ifs = append(ifs, k)
}
}
sort.Slice(ifs, func(i, j int) bool {
- upi, upj := s.InterfaceUp[ifs[i]], s.InterfaceUp[ifs[j]]
+ upi, upj := s.Interface[ifs[i]].IsUp(), s.Interface[ifs[j]].IsUp()
if upi != upj {
// Up sorts before down.
return upi
@@ -253,7 +256,7 @@ func (s *State) String() string {
if i > 0 {
sb.WriteString(" ")
}
- if s.InterfaceUp[ifName] {
+ if s.Interface[ifName].IsUp() {
fmt.Fprintf(&sb, "%s:[", ifName)
needSpace := false
for _, pfx := range s.InterfaceIPs[ifName] {
@@ -286,50 +289,76 @@ func (s *State) String() string {
return sb.String()
}
-func (s *State) Equal(s2 *State) bool {
- return reflect.DeepEqual(s, s2)
-}
-
-func (s *State) HasPAC() bool { return s != nil && s.PAC != "" }
-
-// AnyInterfaceUp reports whether any interface seems like it has Internet access.
-func (s *State) AnyInterfaceUp() bool {
- return s != nil && (s.HaveV4 || s.HaveV6Global)
-}
-
-// RemoveUninterestingInterfacesAndAddresses removes uninteresting IPs
-// from InterfaceIPs, also removing from both the InterfaceIPs and
-// InterfaceUp map any interfaces that don't have any interesting IPs.
-func (s *State) RemoveUninterestingInterfacesAndAddresses() {
- for ifName := range s.InterfaceUp {
- ips := s.InterfaceIPs[ifName]
- keep := ips[:0]
- for _, pfx := range ips {
- if isInterestingIP(pfx.IP) {
- keep = append(keep, pfx)
- }
- }
- if len(keep) == 0 {
- delete(s.InterfaceUp, ifName)
- delete(s.InterfaceIPs, ifName)
+// EqualFiltered reports whether s and s2 are equal,
+// considering only interfaces in s for which filter returns true.
+func (s *State) EqualFiltered(s2 *State, filter func(i Interface, ips []netaddr.IPPrefix) bool) bool {
+ if s == nil && s2 == nil {
+ return true
+ }
+ if s == nil || s2 == nil {
+ return false
+ }
+ if s.HaveV6Global != s2.HaveV6Global ||
+ s.HaveV4 != s2.HaveV4 ||
+ s.IsExpensive != s2.IsExpensive ||
+ s.DefaultRouteInterface != s2.DefaultRouteInterface ||
+ s.HTTPProxy != s2.HTTPProxy ||
+ s.PAC != s2.PAC {
+ return false
+ }
+ for iname, i := range s.Interface {
+ ips := s.InterfaceIPs[iname]
+ if !filter(i, ips) {
continue
}
- if len(keep) < len(ips) {
- s.InterfaceIPs[ifName] = keep
+ i2, ok := s2.Interface[iname]
+ if !ok {
+ return false
+ }
+ ips2, ok := s2.InterfaceIPs[iname]
+ if !ok {
+ return false
+ }
+ if !interfacesEqual(i, i2) || !prefixesEqual(ips, ips2) {
+ return false
}
}
+ return true
+}
+
+func interfacesEqual(a, b Interface) bool {
+ return a.Index == b.Index &&
+ a.MTU == b.MTU &&
+ a.Name == b.Name &&
+ a.Flags == b.Flags &&
+ bytes.Equal([]byte(a.HardwareAddr), []byte(b.HardwareAddr))
}
-// RemoveTailscaleInterfaces modifes s to remove any interfaces that
-// are owned by this process. (TODO: make this true; currently it
-// uses some heuristics)
-func (s *State) RemoveTailscaleInterfaces() {
- for name, pfxs := range s.InterfaceIPs {
- if isTailscaleInterface(name, pfxs) {
- delete(s.InterfaceIPs, name)
- delete(s.InterfaceUp, name)
+func prefixesEqual(a, b []netaddr.IPPrefix) bool {
+ if len(a) != len(b) {
+ return false
+ }
+ for i, v := range a {
+ if b[i] != v {
+ return false
}
}
+ return true
+}
+
+// FilterInteresting reports whether i is an interesting non-Tailscale interface.
+func FilterInteresting(i Interface, ips []netaddr.IPPrefix) bool {
+ return !isTailscaleInterface(i.Name, ips) && anyInterestingIP(ips)
+}
+
+// FilterAll always returns true, to use EqualFiltered against all interfaces.
+func FilterAll(i Interface, ips []netaddr.IPPrefix) bool { return true }
+
+func (s *State) HasPAC() bool { return s != nil && s.PAC != "" }
+
+// AnyInterfaceUp reports whether any interface seems like it has Internet access.
+func (s *State) AnyInterfaceUp() bool {
+ return s != nil && (s.HaveV4 || s.HaveV6Global)
}
func hasTailscaleIP(pfxs []netaddr.IPPrefix) bool {
@@ -364,11 +393,11 @@ var getPAC func() string
func GetState() (*State, error) {
s := &State{
InterfaceIPs: make(map[string][]netaddr.IPPrefix),
- InterfaceUp: make(map[string]bool),
+ Interface: make(map[string]Interface),
}
if err := ForeachInterface(func(ni Interface, pfxs []netaddr.IPPrefix) {
ifUp := ni.IsUp()
- s.InterfaceUp[ni.Name] = ifUp
+ s.Interface[ni.Name] = ni
s.InterfaceIPs[ni.Name] = append(s.InterfaceIPs[ni.Name], pfxs...)
if !ifUp || isTailscaleInterface(ni.Name, pfxs) {
return
diff --git a/net/interfaces/interfaces_test.go b/net/interfaces/interfaces_test.go
index 88948b579..ab0ee3734 100644
--- a/net/interfaces/interfaces_test.go
+++ b/net/interfaces/interfaces_test.go
@@ -5,6 +5,7 @@
package interfaces
import (
+ "encoding/json"
"testing"
)
@@ -13,7 +14,11 @@ func TestGetState(t *testing.T) {
if err != nil {
t.Fatal(err)
}
- t.Logf("Got: %#v", st)
+ j, err := json.MarshalIndent(st, "", "\t")
+ if err != nil {
+ t.Errorf("JSON: %v", err)
+ }
+ t.Logf("Got: %s", j)
t.Logf("As string: %s", st)
st2, err := GetState()
@@ -21,14 +26,13 @@ func TestGetState(t *testing.T) {
t.Fatal(err)
}
- if !st.Equal(st2) {
+ if !st.EqualFiltered(st2, FilterAll) {
// let's assume nobody was changing the system network interfaces between
// the two GetState calls.
t.Fatal("two States back-to-back were not equal")
}
- st.RemoveTailscaleInterfaces()
- t.Logf("As string without Tailscale:\n\t%s", st)
+ t.Logf("As string:\n\t%s", st)
}
func TestLikelyHomeRouterIP(t *testing.T) {
diff --git a/net/interfaces/interfaces_windows.go b/net/interfaces/interfaces_windows.go
index 19e9b48b4..91f679c97 100644
--- a/net/interfaces/interfaces_windows.go
+++ b/net/interfaces/interfaces_windows.go
@@ -7,17 +7,19 @@ package interfaces
import (
"fmt"
"log"
+ "net"
"net/url"
- "os/exec"
"syscall"
"unsafe"
- "go4.org/mem"
"golang.org/x/sys/windows"
"golang.zx2c4.com/wireguard/windows/tunnel/winipcfg"
"inet.af/netaddr"
"tailscale.com/tsconst"
- "tailscale.com/util/lineread"
+)
+
+const (
+ fallbackInterfaceMetric = uint32(0) // Used if we cannot get the actual interface metric
)
func init() {
@@ -25,58 +27,76 @@ func init() {
getPAC = getPACWindows
}
-/*
-Parse out 10.0.0.1 from:
-
-Z:\>route print -4
-===========================================================================
-Interface List
- 15...aa 15 48 ff 1c 72 ......Red Hat VirtIO Ethernet Adapter
- 5...........................Tailscale Tunnel
- 1...........................Software Loopback Interface 1
-===========================================================================
-
-IPv4 Route Table
-===========================================================================
-Active Routes:
-Network Destination Netmask Gateway Interface Metric
- 0.0.0.0 0.0.0.0 10.0.0.1 10.0.28.63 5
- 10.0.0.0 255.255.0.0 On-link 10.0.28.63 261
- 10.0.28.63 255.255.255.255 On-link 10.0.28.63 261
- 10.0.42.0 255.255.255.0 100.103.42.106 100.103.42.106 5
- 10.0.255.255 255.255.255.255 On-link 10.0.28.63 261
- 34.193.248.174 255.255.255.255 100.103.42.106 100.103.42.106 5
-
-*/
func likelyHomeRouterIPWindows() (ret netaddr.IP, ok bool) {
- cmd := exec.Command("route", "print", "-4")
- cmd.SysProcAttr = &syscall.SysProcAttr{HideWindow: true}
- stdout, err := cmd.StdoutPipe()
+ rs, err := winipcfg.GetIPForwardTable2(windows.AF_INET)
if err != nil {
+ log.Printf("routerIP/GetIPForwardTable2 error: %v", err)
return
}
- if err := cmd.Start(); err != nil {
+
+ var ifaceMetricCache map[winipcfg.LUID]uint32
+
+ getIfaceMetric := func(luid winipcfg.LUID) (metric uint32) {
+ if ifaceMetricCache == nil {
+ ifaceMetricCache = make(map[winipcfg.LUID]uint32)
+ } else if m, ok := ifaceMetricCache[luid]; ok {
+ return m
+ }
+
+ if iface, err := luid.IPInterface(windows.AF_INET); err == nil {
+ metric = iface.Metric
+ } else {
+ log.Printf("routerIP/luid.IPInterface error: %v", err)
+ metric = fallbackInterfaceMetric
+ }
+
+ ifaceMetricCache[luid] = metric
return
}
- defer cmd.Wait()
- var f []mem.RO
- lineread.Reader(stdout, func(lineb []byte) error {
- line := mem.B(lineb)
- if !mem.Contains(line, mem.S("0.0.0.0")) {
- return nil
+ unspec := net.IPv4(0, 0, 0, 0)
+ var best *winipcfg.MibIPforwardRow2 // best (lowest metric) found so far, or nil
+
+ for i := range rs {
+ r := &rs[i]
+ if r.Loopback || r.DestinationPrefix.PrefixLength != 0 || !r.DestinationPrefix.Prefix.IP().Equal(unspec) {
+ // Not a default route, so skip
+ continue
}
- f = mem.AppendFields(f[:0], line)
- if len(f) < 3 || !f[0].EqualString("0.0.0.0") || !f[1].EqualString("0.0.0.0") {
- return nil
+
+ ip, ok := netaddr.FromStdIP(r.NextHop.IP())
+ if !ok {
+ // Not a valid gateway, so skip (won't happen though)
+ continue
}
- ipm := f[2]
- ip, err := netaddr.ParseIP(string(mem.Append(nil, ipm)))
- if err == nil && isPrivateIP(ip) {
+
+ if best == nil {
+ best = r
ret = ip
+ continue
}
- return nil
- })
+
+ // We can get here only if there are multiple default gateways defined (rare case),
+ // in which case we need to calculate the effective metric.
+ // Effective metric is sum of interface metric and route metric offset
+ if ifaceMetricCache == nil {
+ // If we're here it means that previous route still isn't updated, so update it
+ best.Metric += getIfaceMetric(best.InterfaceLUID)
+ }
+ r.Metric += getIfaceMetric(r.InterfaceLUID)
+
+ if best.Metric > r.Metric || best.Metric == r.Metric && ret.Compare(ip) > 0 {
+ // Pick the route with lower metric, or lower IP if metrics are equal
+ best = r
+ ret = ip
+ }
+ }
+
+ if !ret.IsZero() && !isPrivateIP(ret) {
+ // Default route has a non-private gateway
+ return netaddr.IP{}, false
+ }
+
return ret, !ret.IsZero()
}
diff --git a/net/packet/header.go b/net/packet/header.go
index 86680a5a7..5cf4ef650 100644
--- a/net/packet/header.go
+++ b/net/packet/header.go
@@ -10,6 +10,7 @@ import (
)
const tcpHeaderLength = 20
+const sctpHeaderLength = 12
// maxPacketLength is the largest length that all headers support.
// IPv4 headers using uint16 for this forces an upper bound of 64KB.
diff --git a/net/packet/icmp4.go b/net/packet/icmp4.go
index 8a1568114..da4774887 100644
--- a/net/packet/icmp4.go
+++ b/net/packet/icmp4.go
@@ -4,7 +4,11 @@
package packet
-import "encoding/binary"
+import (
+ "encoding/binary"
+
+ "tailscale.com/types/ipproto"
+)
// icmp4HeaderLength is the size of the ICMPv4 packet header, not
// including the outer IP layer or the variable "response data"
@@ -66,7 +70,7 @@ func (h ICMP4Header) Marshal(buf []byte) error {
return errLargePacket
}
// The caller does not need to set this.
- h.IPProto = ICMPv4
+ h.IPProto = ipproto.ICMPv4
buf[20] = uint8(h.Type)
buf[21] = uint8(h.Code)
diff --git a/net/packet/ip.go b/net/packet/ip.go
deleted file mode 100644
index 34194f344..000000000
--- a/net/packet/ip.go
+++ /dev/null
@@ -1,66 +0,0 @@
-// Copyright (c) 2020 Tailscale Inc & AUTHORS All rights reserved.
-// Use of this source code is governed by a BSD-style
-// license that can be found in the LICENSE file.
-
-package packet
-
-// IPProto is an IP subprotocol as defined by the IANA protocol
-// numbers list
-// (https://www.iana.org/assignments/protocol-numbers/protocol-numbers.xhtml),
-// or the special values Unknown or Fragment.
-type IPProto uint8
-
-const (
- // Unknown represents an unknown or unsupported protocol; it's
- // deliberately the zero value. Strictly speaking the zero
- // value is IPv6 hop-by-hop extensions, but we don't support
- // those, so this is still technically correct.
- Unknown IPProto = 0x00
-
- // Values from the IANA registry.
- ICMPv4 IPProto = 0x01
- IGMP IPProto = 0x02
- ICMPv6 IPProto = 0x3a
- TCP IPProto = 0x06
- UDP IPProto = 0x11
-
- // TSMP is the Tailscale Message Protocol (our ICMP-ish
- // thing), an IP protocol used only between Tailscale nodes
- // (still encrypted by WireGuard) that communicates why things
- // failed, etc.
- //
- // Proto number 99 is reserved for "any private encryption
- // scheme". We never accept these from the host OS stack nor
- // send them to the host network stack. It's only used between
- // nodes.
- TSMP IPProto = 99
-
- // Fragment represents any non-first IP fragment, for which we
- // don't have the sub-protocol header (and therefore can't
- // figure out what the sub-protocol is).
- //
- // 0xFF is reserved in the IANA registry, so we steal it for
- // internal use.
- Fragment IPProto = 0xFF
-)
-
-func (p IPProto) String() string {
- switch p {
- case Fragment:
- return "Frag"
- case ICMPv4:
- return "ICMPv4"
- case IGMP:
- return "IGMP"
- case ICMPv6:
- return "ICMPv6"
- case UDP:
- return "UDP"
- case TCP:
- return "TCP"
- case TSMP:
- return "TSMP"
- default:
- return "Unknown"
- }
-}
diff --git a/net/packet/ip4.go b/net/packet/ip4.go
index 0240abaa1..2c090d9f1 100644
--- a/net/packet/ip4.go
+++ b/net/packet/ip4.go
@@ -9,6 +9,7 @@ import (
"errors"
"inet.af/netaddr"
+ "tailscale.com/types/ipproto"
)
// ip4HeaderLength is the length of an IPv4 header with no IP options.
@@ -16,7 +17,7 @@ const ip4HeaderLength = 20
// IP4Header represents an IPv4 packet header.
type IP4Header struct {
- IPProto IPProto
+ IPProto ipproto.Proto
IPID uint16
Src netaddr.IP
Dst netaddr.IP
diff --git a/net/packet/ip6.go b/net/packet/ip6.go
index 59f605b32..e181f1dde 100644
--- a/net/packet/ip6.go
+++ b/net/packet/ip6.go
@@ -8,6 +8,7 @@ import (
"encoding/binary"
"inet.af/netaddr"
+ "tailscale.com/types/ipproto"
)
// ip6HeaderLength is the length of an IPv6 header with no IP options.
@@ -15,7 +16,7 @@ const ip6HeaderLength = 40
// IP6Header represents an IPv6 packet header.
type IP6Header struct {
- IPProto IPProto
+ IPProto ipproto.Proto
IPID uint32 // only lower 20 bits used
Src netaddr.IP
Dst netaddr.IP
diff --git a/net/packet/packet.go b/net/packet/packet.go
index a88a1af7a..05c4a382f 100644
--- a/net/packet/packet.go
+++ b/net/packet/packet.go
@@ -11,9 +11,12 @@ import (
"strings"
"inet.af/netaddr"
+ "tailscale.com/types/ipproto"
"tailscale.com/types/strbuilder"
)
+const unknown = ipproto.Unknown
+
// RFC1858: prevent overlapping fragment attacks.
const minFrag = 60 + 20 // max IPv4 header + basic TCP header
@@ -44,7 +47,7 @@ type Parsed struct {
// 6), or 0 if the packet doesn't look like IPv4 or IPv6.
IPVersion uint8
// IPProto is the IP subprotocol (UDP, TCP, etc.). Valid iff IPVersion != 0.
- IPProto IPProto
+ IPProto ipproto.Proto
// SrcIP4 is the source address. Family matches IPVersion. Port is
// valid iff IPProto == TCP || IPProto == UDP.
Src netaddr.IPPort
@@ -100,7 +103,7 @@ func (q *Parsed) Decode(b []byte) {
if len(b) < 1 {
q.IPVersion = 0
- q.IPProto = Unknown
+ q.IPProto = unknown
return
}
@@ -112,7 +115,7 @@ func (q *Parsed) Decode(b []byte) {
q.decode6(b)
default:
q.IPVersion = 0
- q.IPProto = Unknown
+ q.IPProto = unknown
}
}
@@ -125,16 +128,16 @@ func (q *Parsed) StuffForTesting(len int) {
func (q *Parsed) decode4(b []byte) {
if len(b) < ip4HeaderLength {
q.IPVersion = 0
- q.IPProto = Unknown
+ q.IPProto = unknown
return
}
// Check that it's IPv4.
- q.IPProto = IPProto(b[9])
+ q.IPProto = ipproto.Proto(b[9])
q.length = int(binary.BigEndian.Uint16(b[2:4]))
if len(b) < q.length {
// Packet was cut off before full IPv4 length.
- q.IPProto = Unknown
+ q.IPProto = unknown
return
}
@@ -145,7 +148,7 @@ func (q *Parsed) decode4(b []byte) {
q.subofs = int((b[0] & 0x0F) << 2)
if q.subofs > q.length {
// next-proto starts beyond end of packet.
- q.IPProto = Unknown
+ q.IPProto = unknown
return
}
sub := b[q.subofs:]
@@ -170,29 +173,29 @@ func (q *Parsed) decode4(b []byte) {
// This is the first fragment
if moreFrags && len(sub) < minFrag {
// Suspiciously short first fragment, dump it.
- q.IPProto = Unknown
+ q.IPProto = unknown
return
}
// otherwise, this is either non-fragmented (the usual case)
// or a big enough initial fragment that we can read the
// whole subprotocol header.
switch q.IPProto {
- case ICMPv4:
+ case ipproto.ICMPv4:
if len(sub) < icmp4HeaderLength {
- q.IPProto = Unknown
+ q.IPProto = unknown
return
}
q.Src.Port = 0
q.Dst.Port = 0
q.dataofs = q.subofs + icmp4HeaderLength
return
- case IGMP:
+ case ipproto.IGMP:
// Keep IPProto, but don't parse anything else
// out.
return
- case TCP:
+ case ipproto.TCP:
if len(sub) < tcpHeaderLength {
- q.IPProto = Unknown
+ q.IPProto = unknown
return
}
q.Src.Port = binary.BigEndian.Uint16(sub[0:2])
@@ -201,21 +204,29 @@ func (q *Parsed) decode4(b []byte) {
headerLength := (sub[12] & 0xF0) >> 2
q.dataofs = q.subofs + int(headerLength)
return
- case UDP:
+ case ipproto.UDP:
if len(sub) < udpHeaderLength {
- q.IPProto = Unknown
+ q.IPProto = unknown
return
}
q.Src.Port = binary.BigEndian.Uint16(sub[0:2])
q.Dst.Port = binary.BigEndian.Uint16(sub[2:4])
q.dataofs = q.subofs + udpHeaderLength
return
- case TSMP:
+ case ipproto.SCTP:
+ if len(sub) < sctpHeaderLength {
+ q.IPProto = unknown
+ return
+ }
+ q.Src.Port = binary.BigEndian.Uint16(sub[0:2])
+ q.Dst.Port = binary.BigEndian.Uint16(sub[2:4])
+ return
+ case ipproto.TSMP:
// Inter-tailscale messages.
q.dataofs = q.subofs
return
default:
- q.IPProto = Unknown
+ q.IPProto = unknown
return
}
} else {
@@ -223,7 +234,7 @@ func (q *Parsed) decode4(b []byte) {
if fragOfs < minFrag {
// First frag was suspiciously short, so we can't
// trust the followup either.
- q.IPProto = Unknown
+ q.IPProto = unknown
return
}
// otherwise, we have to permit the fragment to slide through.
@@ -232,7 +243,7 @@ func (q *Parsed) decode4(b []byte) {
// but that would require statefulness. Anyway, receivers'
// kernels know to drop fragments where the initial fragment
// doesn't arrive.
- q.IPProto = Fragment
+ q.IPProto = ipproto.Fragment
return
}
}
@@ -240,15 +251,15 @@ func (q *Parsed) decode4(b []byte) {
func (q *Parsed) decode6(b []byte) {
if len(b) < ip6HeaderLength {
q.IPVersion = 0
- q.IPProto = Unknown
+ q.IPProto = unknown
return
}
- q.IPProto = IPProto(b[6])
+ q.IPProto = ipproto.Proto(b[6])
q.length = int(binary.BigEndian.Uint16(b[4:6])) + ip6HeaderLength
if len(b) < q.length {
// Packet was cut off before the full IPv6 length.
- q.IPProto = Unknown
+ q.IPProto = unknown
return
}
@@ -274,17 +285,17 @@ func (q *Parsed) decode6(b []byte) {
sub = sub[:len(sub):len(sub)] // help the compiler do bounds check elimination
switch q.IPProto {
- case ICMPv6:
+ case ipproto.ICMPv6:
if len(sub) < icmp6HeaderLength {
- q.IPProto = Unknown
+ q.IPProto = unknown
return
}
q.Src.Port = 0
q.Dst.Port = 0
q.dataofs = q.subofs + icmp6HeaderLength
- case TCP:
+ case ipproto.TCP:
if len(sub) < tcpHeaderLength {
- q.IPProto = Unknown
+ q.IPProto = unknown
return
}
q.Src.Port = binary.BigEndian.Uint16(sub[0:2])
@@ -293,20 +304,28 @@ func (q *Parsed) decode6(b []byte) {
headerLength := (sub[12] & 0xF0) >> 2
q.dataofs = q.subofs + int(headerLength)
return
- case UDP:
+ case ipproto.UDP:
if len(sub) < udpHeaderLength {
- q.IPProto = Unknown
+ q.IPProto = unknown
return
}
q.Src.Port = binary.BigEndian.Uint16(sub[0:2])
q.Dst.Port = binary.BigEndian.Uint16(sub[2:4])
q.dataofs = q.subofs + udpHeaderLength
- case TSMP:
+ case ipproto.SCTP:
+ if len(sub) < sctpHeaderLength {
+ q.IPProto = unknown
+ return
+ }
+ q.Src.Port = binary.BigEndian.Uint16(sub[0:2])
+ q.Dst.Port = binary.BigEndian.Uint16(sub[2:4])
+ return
+ case ipproto.TSMP:
// Inter-tailscale messages.
q.dataofs = q.subofs
return
default:
- q.IPProto = Unknown
+ q.IPProto = unknown
return
}
}
@@ -324,6 +343,19 @@ func (q *Parsed) IP4Header() IP4Header {
}
}
+func (q *Parsed) IP6Header() IP6Header {
+ if q.IPVersion != 6 {
+ panic("IP6Header called on non-IPv6 Parsed")
+ }
+ ipid := (binary.BigEndian.Uint32(q.b[:4]) << 12) >> 12
+ return IP6Header{
+ IPID: ipid,
+ IPProto: q.IPProto,
+ Src: q.Src.IP,
+ Dst: q.Dst.IP,
+ }
+}
+
func (q *Parsed) ICMP4Header() ICMP4Header {
if q.IPVersion != 4 {
panic("IP4Header called on non-IPv4 Parsed")
@@ -367,13 +399,13 @@ func (q *Parsed) IsTCPSyn() bool {
// IsError reports whether q is an ICMP "Error" packet.
func (q *Parsed) IsError() bool {
switch q.IPProto {
- case ICMPv4:
+ case ipproto.ICMPv4:
if len(q.b) < q.subofs+8 {
return false
}
t := ICMP4Type(q.b[q.subofs])
return t == ICMP4Unreachable || t == ICMP4TimeExceeded
- case ICMPv6:
+ case ipproto.ICMPv6:
if len(q.b) < q.subofs+8 {
return false
}
@@ -387,9 +419,9 @@ func (q *Parsed) IsError() bool {
// IsEchoRequest reports whether q is an ICMP Echo Request.
func (q *Parsed) IsEchoRequest() bool {
switch q.IPProto {
- case ICMPv4:
+ case ipproto.ICMPv4:
return len(q.b) >= q.subofs+8 && ICMP4Type(q.b[q.subofs]) == ICMP4EchoRequest && ICMP4Code(q.b[q.subofs+1]) == ICMP4NoCode
- case ICMPv6:
+ case ipproto.ICMPv6:
return len(q.b) >= q.subofs+8 && ICMP6Type(q.b[q.subofs]) == ICMP6EchoRequest && ICMP6Code(q.b[q.subofs+1]) == ICMP6NoCode
default:
return false
@@ -399,9 +431,9 @@ func (q *Parsed) IsEchoRequest() bool {
// IsEchoRequest reports whether q is an IPv4 ICMP Echo Response.
func (q *Parsed) IsEchoResponse() bool {
switch q.IPProto {
- case ICMPv4:
+ case ipproto.ICMPv4:
return len(q.b) >= q.subofs+8 && ICMP4Type(q.b[q.subofs]) == ICMP4EchoReply && ICMP4Code(q.b[q.subofs+1]) == ICMP4NoCode
- case ICMPv6:
+ case ipproto.ICMPv6:
return len(q.b) >= q.subofs+8 && ICMP6Type(q.b[q.subofs]) == ICMP6EchoReply && ICMP6Code(q.b[q.subofs+1]) == ICMP6NoCode
default:
return false
diff --git a/net/packet/packet_test.go b/net/packet/packet_test.go
index 8bac5db4a..ac4fa33f3 100644
--- a/net/packet/packet_test.go
+++ b/net/packet/packet_test.go
@@ -10,6 +10,19 @@ import (
"testing"
"inet.af/netaddr"
+ "tailscale.com/types/ipproto"
+)
+
+const (
+ Unknown = ipproto.Unknown
+ TCP = ipproto.TCP
+ UDP = ipproto.UDP
+ SCTP = ipproto.SCTP
+ IGMP = ipproto.IGMP
+ ICMPv4 = ipproto.ICMPv4
+ ICMPv6 = ipproto.ICMPv6
+ TSMP = ipproto.TSMP
+ Fragment = ipproto.Fragment
)
func mustIPPort(s string) netaddr.IPPort {
@@ -305,6 +318,39 @@ var ipv4TSMPDecode = Parsed{
Dst: mustIPPort("100.74.70.3:0"),
}
+// IPv4 SCTP
+var sctpBuffer = []byte{
+ // IPv4 header:
+ 0x45, 0x00,
+ 0x00, 0x20, // 20 + 12 bytes total
+ 0x00, 0x00, // ID
+ 0x00, 0x00, // Fragment
+ 0x40, // TTL
+ byte(SCTP),
+ // Checksum, unchecked:
+ 1, 2,
+ // source IP:
+ 0x64, 0x5e, 0x0c, 0x0e,
+ // dest IP:
+ 0x64, 0x4a, 0x46, 0x03,
+ // Src Port, Dest Port:
+ 0x00, 0x7b, 0x01, 0xc8,
+ // Verification tag:
+ 1, 2, 3, 4,
+ // Checksum: (unchecked)
+ 5, 6, 7, 8,
+}
+
+var sctpDecode = Parsed{
+ b: sctpBuffer,
+ subofs: 20,
+ length: 20 + 12,
+ IPVersion: 4,
+ IPProto: SCTP,
+ Src: mustIPPort("100.94.12.14:123"),
+ Dst: mustIPPort("100.74.70.3:456"),
+}
+
func TestParsedString(t *testing.T) {
tests := []struct {
name string
@@ -320,6 +366,7 @@ func TestParsedString(t *testing.T) {
{"igmp", igmpPacketDecode, "IGMP{192.168.1.82:0 > 224.0.0.251:0}"},
{"unknown", unknownPacketDecode, "Unknown{???}"},
{"ipv4_tsmp", ipv4TSMPDecode, "TSMP{100.94.12.14:0 > 100.74.70.3:0}"},
+ {"sctp", sctpDecode, "SCTP{100.94.12.14:123 > 100.74.70.3:456}"},
}
for _, tt := range tests {
@@ -357,6 +404,7 @@ func TestDecode(t *testing.T) {
{"unknown", unknownPacketBuffer, unknownPacketDecode},
{"invalid4", invalid4RequestBuffer, invalid4RequestDecode},
{"ipv4_tsmp", ipv4TSMPBuffer, ipv4TSMPDecode},
+ {"ipv4_sctp", sctpBuffer, sctpDecode},
}
for _, tt := range tests {
diff --git a/net/packet/tsmp.go b/net/packet/tsmp.go
index 2346c9419..fb257556c 100644
--- a/net/packet/tsmp.go
+++ b/net/packet/tsmp.go
@@ -17,6 +17,7 @@ import (
"inet.af/netaddr"
"tailscale.com/net/flowtrack"
+ "tailscale.com/types/ipproto"
)
// TailscaleRejectedHeader is a TSMP message that says that one
@@ -39,7 +40,7 @@ type TailscaleRejectedHeader struct {
IPDst netaddr.IP // IPv4 or IPv6 header's dst IP
Src netaddr.IPPort // rejected flow's src
Dst netaddr.IPPort // rejected flow's dst
- Proto IPProto // proto that was rejected (TCP or UDP)
+ Proto ipproto.Proto // proto that was rejected (TCP or UDP)
Reason TailscaleRejectReason // why the connection was rejected
// MaybeBroken is whether the rejection is non-terminal (the
@@ -57,7 +58,7 @@ type TailscaleRejectedHeader struct {
const rejectFlagBitMaybeBroken = 0x1
func (rh TailscaleRejectedHeader) Flow() flowtrack.Tuple {
- return flowtrack.Tuple{Src: rh.Src, Dst: rh.Dst}
+ return flowtrack.Tuple{Proto: rh.Proto, Src: rh.Src, Dst: rh.Dst}
}
func (rh TailscaleRejectedHeader) String() string {
@@ -69,6 +70,12 @@ type TSMPType uint8
const (
// TSMPTypeRejectedConn is the type byte for a TailscaleRejectedHeader.
TSMPTypeRejectedConn TSMPType = '!'
+
+ // TSMPTypePing is the type byte for a TailscalePingRequest.
+ TSMPTypePing TSMPType = 'p'
+
+ // TSMPTypePong is the type byte for a TailscalePongResponse.
+ TSMPTypePong TSMPType = 'o'
)
type TailscaleRejectReason byte
@@ -138,7 +145,7 @@ func (h TailscaleRejectedHeader) Marshal(buf []byte) error {
}
if h.Src.IP.Is4() {
iph := IP4Header{
- IPProto: TSMP,
+ IPProto: ipproto.TSMP,
Src: h.IPSrc,
Dst: h.IPDst,
}
@@ -146,7 +153,7 @@ func (h TailscaleRejectedHeader) Marshal(buf []byte) error {
buf = buf[ip4HeaderLength:]
} else if h.Src.IP.Is6() {
iph := IP6Header{
- IPProto: TSMP,
+ IPProto: ipproto.TSMP,
Src: h.IPSrc,
Dst: h.IPDst,
}
@@ -181,7 +188,7 @@ func (pp *Parsed) AsTailscaleRejectedHeader() (h TailscaleRejectedHeader, ok boo
return
}
h = TailscaleRejectedHeader{
- Proto: IPProto(p[1]),
+ Proto: ipproto.Proto(p[1]),
Reason: TailscaleRejectReason(p[2]),
IPSrc: pp.Src.IP,
IPDst: pp.Dst.IP,
@@ -194,3 +201,58 @@ func (pp *Parsed) AsTailscaleRejectedHeader() (h TailscaleRejectedHeader, ok boo
}
return h, true
}
+
+// TSMPPingRequest is a TSMP message that's like an ICMP ping request.
+//
+// On the wire, after the IP header, it's currently 9 bytes:
+// * 'p' (TSMPTypePing)
+// * 8 opaque ping bytes to copy back in the response
+type TSMPPingRequest struct {
+ Data [8]byte
+}
+
+func (pp *Parsed) AsTSMPPing() (h TSMPPingRequest, ok bool) {
+ if pp.IPProto != ipproto.TSMP {
+ return
+ }
+ p := pp.Payload()
+ if len(p) < 9 || p[0] != byte(TSMPTypePing) {
+ return
+ }
+ copy(h.Data[:], p[1:])
+ return h, true
+}
+
+type TSMPPongReply struct {
+ IPHeader Header
+ Data [8]byte
+}
+
+func (pp *Parsed) AsTSMPPong() (data [8]byte, ok bool) {
+ if pp.IPProto != ipproto.TSMP {
+ return
+ }
+ p := pp.Payload()
+ if len(p) < 9 || p[0] != byte(TSMPTypePong) {
+ return
+ }
+ copy(data[:], p[1:])
+ return data, true
+}
+
+func (h TSMPPongReply) Len() int {
+ return h.IPHeader.Len() + 9
+}
+
+func (h TSMPPongReply) Marshal(buf []byte) error {
+ if len(buf) < h.Len() {
+ return errSmallBuffer
+ }
+ if err := h.IPHeader.Marshal(buf); err != nil {
+ return err
+ }
+ buf = buf[h.IPHeader.Len():]
+ buf[0] = byte(TSMPTypePong)
+ copy(buf[1:], h.Data[:])
+ return nil
+}
diff --git a/net/packet/udp4.go b/net/packet/udp4.go
index 82aa30179..ce179f89d 100644
--- a/net/packet/udp4.go
+++ b/net/packet/udp4.go
@@ -4,7 +4,11 @@
package packet
-import "encoding/binary"
+import (
+ "encoding/binary"
+
+ "tailscale.com/types/ipproto"
+)
// udpHeaderLength is the size of the UDP packet header, not including
// the outer IP header.
@@ -31,7 +35,7 @@ func (h UDP4Header) Marshal(buf []byte) error {
return errLargePacket
}
// The caller does not need to set this.
- h.IPProto = UDP
+ h.IPProto = ipproto.UDP
length := len(buf) - h.IP4Header.Len()
binary.BigEndian.PutUint16(buf[20:22], h.SrcPort)
diff --git a/net/packet/udp6.go b/net/packet/udp6.go
index 0450eae9e..18213c1fb 100644
--- a/net/packet/udp6.go
+++ b/net/packet/udp6.go
@@ -4,7 +4,11 @@
package packet
-import "encoding/binary"
+import (
+ "encoding/binary"
+
+ "tailscale.com/types/ipproto"
+)
// UDP6Header is an IPv6+UDP header.
type UDP6Header struct {
@@ -27,7 +31,7 @@ func (h UDP6Header) Marshal(buf []byte) error {
return errLargePacket
}
// The caller does not need to set this.
- h.IPProto = UDP
+ h.IPProto = ipproto.UDP
length := len(buf) - h.IP6Header.Len()
binary.BigEndian.PutUint16(buf[40:42], h.SrcPort)
diff --git a/net/tsaddr/tsaddr.go b/net/tsaddr/tsaddr.go
index 44cf5cf23..9bf81326e 100644
--- a/net/tsaddr/tsaddr.go
+++ b/net/tsaddr/tsaddr.go
@@ -37,7 +37,7 @@ var (
)
// TailscaleServiceIP returns the listen address of services
-// provided by Tailscale itself such as the Magic DNS proxy.
+// provided by Tailscale itself such as the MagicDNS proxy.
func TailscaleServiceIP() netaddr.IP {
serviceIP.Do(func() { mustIP(&serviceIP.v, "100.100.100.100") })
return serviceIP.v
diff --git a/net/tstun/fake.go b/net/tstun/fake.go
new file mode 100644
index 000000000..f9c3a9d6f
--- /dev/null
+++ b/net/tstun/fake.go
@@ -0,0 +1,54 @@
+// Copyright (c) 2020 Tailscale Inc & AUTHORS All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package tstun
+
+import (
+ "io"
+ "os"
+
+ "github.com/tailscale/wireguard-go/tun"
+)
+
+type fakeTUN struct {
+ evchan chan tun.Event
+ closechan chan struct{}
+}
+
+// NewFake returns a tun.Device that does nothing.
+func NewFake() tun.Device {
+ return &fakeTUN{
+ evchan: make(chan tun.Event),
+ closechan: make(chan struct{}),
+ }
+}
+
+func (t *fakeTUN) File() *os.File {
+ panic("fakeTUN.File() called, which makes no sense")
+}
+
+func (t *fakeTUN) Close() error {
+ close(t.closechan)
+ close(t.evchan)
+ return nil
+}
+
+func (t *fakeTUN) Read(out []byte, offset int) (int, error) {
+ <-t.closechan
+ return 0, io.EOF
+}
+
+func (t *fakeTUN) Write(b []byte, n int) (int, error) {
+ select {
+ case <-t.closechan:
+ return 0, ErrClosed
+ default:
+ }
+ return len(b), nil
+}
+
+func (t *fakeTUN) Flush() error { return nil }
+func (t *fakeTUN) MTU() (int, error) { return 1500, nil }
+func (t *fakeTUN) Name() (string, error) { return "FakeTUN", nil }
+func (t *fakeTUN) Events() chan tun.Event { return t.evchan }
diff --git a/net/tstun/ifstatus_noop.go b/net/tstun/ifstatus_noop.go
new file mode 100644
index 000000000..223be7949
--- /dev/null
+++ b/net/tstun/ifstatus_noop.go
@@ -0,0 +1,19 @@
+// Copyright (c) 2021 Tailscale Inc & AUTHORS All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+// +build !windows
+
+package tstun
+
+import (
+ "time"
+
+ "github.com/tailscale/wireguard-go/tun"
+ "tailscale.com/types/logger"
+)
+
+// Dummy implementation that does nothing.
+func waitInterfaceUp(iface tun.Device, timeout time.Duration, logf logger.Logf) error {
+ return nil
+}
diff --git a/net/tstun/ifstatus_windows.go b/net/tstun/ifstatus_windows.go
new file mode 100644
index 000000000..840e50f4d
--- /dev/null
+++ b/net/tstun/ifstatus_windows.go
@@ -0,0 +1,111 @@
+// Copyright (c) 2021 Tailscale Inc & AUTHORS All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package tstun
+
+import (
+ "fmt"
+ "sync"
+ "time"
+
+ "github.com/tailscale/wireguard-go/tun"
+ "golang.zx2c4.com/wireguard/windows/tunnel/winipcfg"
+ "tailscale.com/types/logger"
+)
+
+// ifaceWatcher waits for an interface to be up.
+type ifaceWatcher struct {
+ logf logger.Logf
+ luid winipcfg.LUID
+
+ mu sync.Mutex // guards following
+ done bool
+ sig chan bool
+}
+
+// callback is the callback we register with Windows to call when IP interface changes.
+func (iw *ifaceWatcher) callback(notificationType winipcfg.MibNotificationType, iface *winipcfg.MibIPInterfaceRow) {
+ // Probably should check only when MibParameterNotification, but just in case included MibAddInstance also.
+ if notificationType == winipcfg.MibParameterNotification || notificationType == winipcfg.MibAddInstance {
+ // Out of paranoia, start a goroutine to finish our work, to return to Windows out of this callback.
+ go iw.isUp()
+ }
+}
+
+func (iw *ifaceWatcher) isUp() bool {
+ iw.mu.Lock()
+ defer iw.mu.Unlock()
+
+ if iw.done {
+ // We already know that it's up
+ return true
+ }
+
+ if iw.getOperStatus() != winipcfg.IfOperStatusUp {
+ return false
+ }
+
+ iw.done = true
+ iw.sig <- true
+ return true
+}
+
+func (iw *ifaceWatcher) getOperStatus() winipcfg.IfOperStatus {
+ ifc, err := iw.luid.Interface()
+ if err != nil {
+ iw.logf("iw.luid.Interface error: %v", err)
+ return 0
+ }
+ return ifc.OperStatus
+}
+
+func waitInterfaceUp(iface tun.Device, timeout time.Duration, logf logger.Logf) error {
+ iw := &ifaceWatcher{
+ luid: winipcfg.LUID(iface.(*tun.NativeTun).LUID()),
+ logf: logger.WithPrefix(logf, "waitInterfaceUp: "),
+ }
+
+ // Just in case check the status first
+ if iw.getOperStatus() == winipcfg.IfOperStatusUp {
+ iw.logf("TUN interface already up; no need to wait")
+ return nil
+ }
+
+ iw.sig = make(chan bool, 1)
+ cb, err := winipcfg.RegisterInterfaceChangeCallback(iw.callback)
+ if err != nil {
+ iw.logf("RegisterInterfaceChangeCallback error: %v", err)
+ return err
+ }
+ defer cb.Unregister()
+
+ t0 := time.Now()
+ expires := t0.Add(timeout)
+ ticker := time.NewTicker(10 * time.Second)
+ defer ticker.Stop()
+
+ for {
+ iw.logf("waiting for TUN interface to come up...")
+
+ select {
+ case <-iw.sig:
+ iw.logf("TUN interface is up after %v", time.Since(t0))
+ return nil
+ case <-ticker.C:
+ break
+ }
+
+ if iw.isUp() {
+ // Very unlikely to happen - either NotifyIpInterfaceChange doesn't work
+ // or it came up in the same moment as tick. Indicate this in the log message.
+ iw.logf("TUN interface is up after %v (on poll, without notification)", time.Since(t0))
+ return nil
+ }
+
+ if expires.Before(time.Now()) {
+ iw.logf("timeout waiting %v for TUN interface to come up", timeout)
+ return fmt.Errorf("timeout waiting for TUN interface to come up")
+ }
+ }
+}
diff --git a/net/tstun/tun.go b/net/tstun/tun.go
new file mode 100644
index 000000000..d480d3244
--- /dev/null
+++ b/net/tstun/tun.go
@@ -0,0 +1,129 @@
+// Copyright (c) 2021 Tailscale Inc & AUTHORS All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+// Package tun creates a tuntap device, working around OS-specific
+// quirks if necessary.
+package tstun
+
+import (
+ "bytes"
+ "os"
+ "os/exec"
+ "runtime"
+ "time"
+
+ "github.com/tailscale/wireguard-go/tun"
+ "tailscale.com/types/logger"
+ "tailscale.com/version/distro"
+)
+
+// minimalMTU is the MTU we set on tailscale's TUN
+// interface. wireguard-go defaults to 1420 bytes, which only works if
+// the "outer" MTU is 1500 bytes. This breaks on DSL connections
+// (typically 1492 MTU) and on GCE (1460 MTU?!).
+//
+// 1280 is the smallest MTU allowed for IPv6, which is a sensible
+// "probably works everywhere" setting until we develop proper PMTU
+// discovery.
+const minimalMTU = 1280
+
+// New returns a tun.Device for the requested device name.
+func New(logf logger.Logf, tunName string) (tun.Device, error) {
+ dev, err := tun.CreateTUN(tunName, minimalMTU)
+ if err != nil {
+ return nil, err
+ }
+ if err := waitInterfaceUp(dev, 90*time.Second, logf); err != nil {
+ return nil, err
+ }
+ return dev, nil
+}
+
+// Diagnose tries to explain a tuntap device creation failure.
+// It pokes around the system and logs some diagnostic info that might
+// help debug why tun creation failed. Because device creation has
+// already failed and the program's about to end, log a lot.
+func Diagnose(logf logger.Logf, tunName string) {
+ switch runtime.GOOS {
+ case "linux":
+ diagnoseLinuxTUNFailure(tunName, logf)
+ case "darwin":
+ diagnoseDarwinTUNFailure(tunName, logf)
+ default:
+ logf("no TUN failure diagnostics for OS %q", runtime.GOOS)
+ }
+}
+
+func diagnoseDarwinTUNFailure(tunName string, logf logger.Logf) {
+ if os.Getuid() != 0 {
+ logf("failed to create TUN device as non-root user; use 'sudo tailscaled', or run under launchd with 'sudo tailscaled install-system-daemon'")
+ }
+ if tunName != "utun" {
+ logf("failed to create TUN device %q; try using tun device \"utun\" instead for automatic selection", tunName)
+ }
+}
+
+func diagnoseLinuxTUNFailure(tunName string, logf logger.Logf) {
+ kernel, err := exec.Command("uname", "-r").Output()
+ kernel = bytes.TrimSpace(kernel)
+ if err != nil {
+ logf("no TUN, and failed to look up kernel version: %v", err)
+ return
+ }
+ logf("Linux kernel version: %s", kernel)
+
+ modprobeOut, err := exec.Command("/sbin/modprobe", "tun").CombinedOutput()
+ if err == nil {
+ logf("'modprobe tun' successful")
+ // Either tun is currently loaded, or it's statically
+ // compiled into the kernel (which modprobe checks
+ // with /lib/modules/$(uname -r)/modules.builtin)
+ //
+ // So if there's a problem at this point, it's
+ // probably because /dev/net/tun doesn't exist.
+ const dev = "/dev/net/tun"
+ if fi, err := os.Stat(dev); err != nil {
+ logf("tun module loaded in kernel, but %s does not exist", dev)
+ } else {
+ logf("%s: %v", dev, fi.Mode())
+ }
+
+ // We failed to find why it failed. Just let our
+ // caller report the error it got from wireguard-go.
+ return
+ }
+ logf("is CONFIG_TUN enabled in your kernel? `modprobe tun` failed with: %s", modprobeOut)
+
+ switch distro.Get() {
+ case distro.Debian:
+ dpkgOut, err := exec.Command("dpkg", "-S", "kernel/drivers/net/tun.ko").CombinedOutput()
+ if len(bytes.TrimSpace(dpkgOut)) == 0 || err != nil {
+ logf("tun module not loaded nor found on disk")
+ return
+ }
+ if !bytes.Contains(dpkgOut, kernel) {
+ logf("kernel/drivers/net/tun.ko found on disk, but not for current kernel; are you in middle of a system update and haven't rebooted? found: %s", dpkgOut)
+ }
+ case distro.Arch:
+ findOut, err := exec.Command("find", "/lib/modules/", "-path", "*/net/tun.ko*").CombinedOutput()
+ if len(bytes.TrimSpace(findOut)) == 0 || err != nil {
+ logf("tun module not loaded nor found on disk")
+ return
+ }
+ if !bytes.Contains(findOut, kernel) {
+ logf("kernel/drivers/net/tun.ko found on disk, but not for current kernel; are you in middle of a system update and haven't rebooted? found: %s", findOut)
+ }
+ case distro.OpenWrt:
+ out, err := exec.Command("opkg", "list-installed").CombinedOutput()
+ if err != nil {
+ logf("error querying OpenWrt installed packages: %s", out)
+ return
+ }
+ for _, pkg := range []string{"kmod-tun", "ca-bundle"} {
+ if !bytes.Contains(out, []byte(pkg+" - ")) {
+ logf("Missing required package %s; run: opkg install %s", pkg, pkg)
+ }
+ }
+ }
+}
diff --git a/net/tstun/tun_windows.go b/net/tstun/tun_windows.go
new file mode 100644
index 000000000..dc5fc2d79
--- /dev/null
+++ b/net/tstun/tun_windows.go
@@ -0,0 +1,24 @@
+// Copyright (c) 2020 Tailscale Inc & AUTHORS All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package tstun
+
+import (
+ "github.com/tailscale/wireguard-go/tun"
+ "github.com/tailscale/wireguard-go/tun/wintun"
+ "golang.org/x/sys/windows"
+)
+
+func init() {
+ var err error
+ tun.WintunPool, err = wintun.MakePool("Tailscale")
+ if err != nil {
+ panic(err)
+ }
+ guid, err := windows.GUIDFromString("{37217669-42da-4657-a55b-0d995d328250}")
+ if err != nil {
+ panic(err)
+ }
+ tun.WintunStaticRequestedGUID = &guid
+}
diff --git a/net/tstun/wrap.go b/net/tstun/wrap.go
new file mode 100644
index 000000000..70225c52e
--- /dev/null
+++ b/net/tstun/wrap.go
@@ -0,0 +1,498 @@
+// Copyright (c) 2020 Tailscale Inc & AUTHORS All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+// Package tstun provides a TUN struct implementing the tun.Device interface
+// with additional features as required by wgengine.
+package tstun
+
+import (
+ "errors"
+ "io"
+ "os"
+ "sync"
+ "sync/atomic"
+ "time"
+
+ "github.com/tailscale/wireguard-go/device"
+ "github.com/tailscale/wireguard-go/tun"
+ "inet.af/netaddr"
+ "tailscale.com/net/packet"
+ "tailscale.com/types/ipproto"
+ "tailscale.com/types/logger"
+ "tailscale.com/wgengine/filter"
+)
+
+const maxBufferSize = device.MaxMessageSize
+
+// PacketStartOffset is the minimal amount of leading space that must exist
+// before &packet[offset] in a packet passed to Read, Write, or InjectInboundDirect.
+// This is necessary to avoid reallocation in wireguard-go internals.
+const PacketStartOffset = device.MessageTransportHeaderSize
+
+// MaxPacketSize is the maximum size (in bytes)
+// of a packet that can be injected into a tstun.Wrapper.
+const MaxPacketSize = device.MaxContentSize
+
+var (
+ // ErrClosed is returned when attempting an operation on a closed Wrapper.
+ ErrClosed = errors.New("device closed")
+ // ErrFiltered is returned when the acted-on packet is rejected by a filter.
+ ErrFiltered = errors.New("packet dropped by filter")
+)
+
+var (
+ errPacketTooBig = errors.New("packet too big")
+ errOffsetTooBig = errors.New("offset larger than buffer length")
+ errOffsetTooSmall = errors.New("offset smaller than PacketStartOffset")
+)
+
+// parsedPacketPool holds a pool of Parsed structs for use in filtering.
+// This is needed because escape analysis cannot see that parsed packets
+// do not escape through {Pre,Post}Filter{In,Out}.
+var parsedPacketPool = sync.Pool{New: func() interface{} { return new(packet.Parsed) }}
+
+// FilterFunc is a packet-filtering function with access to the Wrapper device.
+// It must not hold onto the packet struct, as its backing storage will be reused.
+type FilterFunc func(*packet.Parsed, *Wrapper) filter.Response
+
+// Wrapper augments a tun.Device with packet filtering and injection.
+type Wrapper struct {
+ logf logger.Logf
+ // tdev is the underlying Wrapper device.
+ tdev tun.Device
+
+ closeOnce sync.Once
+
+ lastActivityAtomic int64 // unix seconds of last send or receive
+
+ destIPActivity atomic.Value // of map[netaddr.IP]func()
+
+ // buffer stores the oldest unconsumed packet from tdev.
+ // It is made a static buffer in order to avoid allocations.
+ buffer [maxBufferSize]byte
+ // bufferConsumed synchronizes access to buffer (shared by Read and poll).
+ bufferConsumed chan struct{}
+
+ // closed signals poll (by closing) when the device is closed.
+ closed chan struct{}
+ // errors is the error queue populated by poll.
+ errors chan error
+ // outbound is the queue by which packets leave the TUN device.
+ //
+ // The directions are relative to the network, not the device:
+ // inbound packets arrive via UDP and are written into the TUN device;
+ // outbound packets are read from the TUN device and sent out via UDP.
+ // This queue is needed because although inbound writes are synchronous,
+ // the other direction must wait on a Wireguard goroutine to poll it.
+ //
+ // Empty reads are skipped by Wireguard, so it is always legal
+ // to discard an empty packet instead of sending it through t.outbound.
+ outbound chan []byte
+
+ // fitler stores the currently active package filter
+ filter atomic.Value // of *filter.Filter
+ // filterFlags control the verbosity of logging packet drops/accepts.
+ filterFlags filter.RunFlags
+
+ // PreFilterIn is the inbound filter function that runs before the main filter
+ // and therefore sees the packets that may be later dropped by it.
+ PreFilterIn FilterFunc
+ // PostFilterIn is the inbound filter function that runs after the main filter.
+ PostFilterIn FilterFunc
+ // PreFilterOut is the outbound filter function that runs before the main filter
+ // and therefore sees the packets that may be later dropped by it.
+ PreFilterOut FilterFunc
+ // PostFilterOut is the outbound filter function that runs after the main filter.
+ PostFilterOut FilterFunc
+
+ // OnTSMPPongReceived, if non-nil, is called whenever a TSMP pong arrives.
+ OnTSMPPongReceived func(data [8]byte)
+
+ // disableFilter disables all filtering when set. This should only be used in tests.
+ disableFilter bool
+}
+
+func Wrap(logf logger.Logf, tdev tun.Device) *Wrapper {
+ tun := &Wrapper{
+ logf: logger.WithPrefix(logf, "tstun: "),
+ tdev: tdev,
+ // bufferConsumed is conceptually a condition variable:
+ // a goroutine should not block when setting it, even with no listeners.
+ bufferConsumed: make(chan struct{}, 1),
+ closed: make(chan struct{}),
+ errors: make(chan error),
+ outbound: make(chan []byte),
+ // TODO(dmytro): (highly rate-limited) hexdumps should happen on unknown packets.
+ filterFlags: filter.LogAccepts | filter.LogDrops,
+ }
+
+ go tun.poll()
+ // The buffer starts out consumed.
+ tun.bufferConsumed <- struct{}{}
+
+ return tun
+}
+
+// SetDestIPActivityFuncs sets a map of funcs to run per packet
+// destination (the map keys).
+//
+// The map ownership passes to the Wrapper. It must be non-nil.
+func (t *Wrapper) SetDestIPActivityFuncs(m map[netaddr.IP]func()) {
+ t.destIPActivity.Store(m)
+}
+
+func (t *Wrapper) Close() error {
+ var err error
+ t.closeOnce.Do(func() {
+ // Other channels need not be closed: poll will exit gracefully after this.
+ close(t.closed)
+
+ err = t.tdev.Close()
+ })
+ return err
+}
+
+func (t *Wrapper) Events() chan tun.Event {
+ return t.tdev.Events()
+}
+
+func (t *Wrapper) File() *os.File {
+ return t.tdev.File()
+}
+
+func (t *Wrapper) Flush() error {
+ return t.tdev.Flush()
+}
+
+func (t *Wrapper) MTU() (int, error) {
+ return t.tdev.MTU()
+}
+
+func (t *Wrapper) Name() (string, error) {
+ return t.tdev.Name()
+}
+
+// poll polls t.tdev.Read, placing the oldest unconsumed packet into t.buffer.
+// This is needed because t.tdev.Read in general may block (it does on Windows),
+// so packets may be stuck in t.outbound if t.Read called t.tdev.Read directly.
+func (t *Wrapper) poll() {
+ for {
+ select {
+ case <-t.closed:
+ return
+ case <-t.bufferConsumed:
+ // continue
+ }
+
+ // Read may use memory in t.buffer before PacketStartOffset for mandatory headers.
+ // This is the rationale behind the tun.Wrapper.{Read,Write} interfaces
+ // and the reason t.buffer has size MaxMessageSize and not MaxContentSize.
+ n, err := t.tdev.Read(t.buffer[:], PacketStartOffset)
+ if err != nil {
+ select {
+ case <-t.closed:
+ return
+ case t.errors <- err:
+ // In principle, read errors are not fatal (but wireguard-go disagrees).
+ t.bufferConsumed <- struct{}{}
+ }
+ continue
+ }
+
+ // Wireguard will skip an empty read,
+ // so we might as well do it here to avoid the send through t.outbound.
+ if n == 0 {
+ t.bufferConsumed <- struct{}{}
+ continue
+ }
+
+ select {
+ case <-t.closed:
+ return
+ case t.outbound <- t.buffer[PacketStartOffset : PacketStartOffset+n]:
+ // continue
+ }
+ }
+}
+
+var magicDNSIPPort = netaddr.MustParseIPPort("100.100.100.100:0")
+
+func (t *Wrapper) filterOut(p *packet.Parsed) filter.Response {
+ // Fake ICMP echo responses to MagicDNS (100.100.100.100).
+ if p.IsEchoRequest() && p.Dst == magicDNSIPPort {
+ header := p.ICMP4Header()
+ header.ToResponse()
+ outp := packet.Generate(&header, p.Payload())
+ t.InjectInboundCopy(outp)
+ return filter.DropSilently // don't pass on to OS; already handled
+ }
+
+ if t.PreFilterOut != nil {
+ if res := t.PreFilterOut(p, t); res.IsDrop() {
+ return res
+ }
+ }
+
+ filt, _ := t.filter.Load().(*filter.Filter)
+
+ if filt == nil {
+ return filter.Drop
+ }
+
+ if filt.RunOut(p, t.filterFlags) != filter.Accept {
+ return filter.Drop
+ }
+
+ if t.PostFilterOut != nil {
+ if res := t.PostFilterOut(p, t); res.IsDrop() {
+ return res
+ }
+ }
+
+ return filter.Accept
+}
+
+// noteActivity records that there was a read or write at the current time.
+func (t *Wrapper) noteActivity() {
+ atomic.StoreInt64(&t.lastActivityAtomic, time.Now().Unix())
+}
+
+// IdleDuration reports how long it's been since the last read or write to this device.
+//
+// Its value is only accurate to roughly second granularity.
+// If there's never been activity, the duration is since 1970.
+func (t *Wrapper) IdleDuration() time.Duration {
+ sec := atomic.LoadInt64(&t.lastActivityAtomic)
+ return time.Since(time.Unix(sec, 0))
+}
+
+func (t *Wrapper) Read(buf []byte, offset int) (int, error) {
+ var n int
+
+ wasInjectedPacket := false
+
+ select {
+ case <-t.closed:
+ return 0, io.EOF
+ case err := <-t.errors:
+ return 0, err
+ case pkt := <-t.outbound:
+ n = copy(buf[offset:], pkt)
+ // t.buffer has a fixed location in memory,
+ // so this is the easiest way to tell when it has been consumed.
+ // &pkt[0] can be used because empty packets do not reach t.outbound.
+ if &pkt[0] == &t.buffer[PacketStartOffset] {
+ t.bufferConsumed <- struct{}{}
+ } else {
+ // If the packet is not from t.buffer, then it is an injected packet.
+ wasInjectedPacket = true
+ }
+ }
+
+ p := parsedPacketPool.Get().(*packet.Parsed)
+ defer parsedPacketPool.Put(p)
+ p.Decode(buf[offset : offset+n])
+
+ if m, ok := t.destIPActivity.Load().(map[netaddr.IP]func()); ok {
+ if fn := m[p.Dst.IP]; fn != nil {
+ fn()
+ }
+ }
+
+ // For injected packets, we return early to bypass filtering.
+ if wasInjectedPacket {
+ t.noteActivity()
+ return n, nil
+ }
+
+ if !t.disableFilter {
+ response := t.filterOut(p)
+ if response != filter.Accept {
+ // Wireguard considers read errors fatal; pretend nothing was read
+ return 0, nil
+ }
+ }
+
+ t.noteActivity()
+ return n, nil
+}
+
+func (t *Wrapper) filterIn(buf []byte) filter.Response {
+ p := parsedPacketPool.Get().(*packet.Parsed)
+ defer parsedPacketPool.Put(p)
+ p.Decode(buf)
+
+ if p.IPProto == ipproto.TSMP {
+ if pingReq, ok := p.AsTSMPPing(); ok {
+ t.noteActivity()
+ t.injectOutboundPong(p, pingReq)
+ return filter.DropSilently
+ } else if data, ok := p.AsTSMPPong(); ok {
+ if f := t.OnTSMPPongReceived; f != nil {
+ f(data)
+ }
+ }
+ }
+
+ if t.PreFilterIn != nil {
+ if res := t.PreFilterIn(p, t); res.IsDrop() {
+ return res
+ }
+ }
+
+ filt, _ := t.filter.Load().(*filter.Filter)
+
+ if filt == nil {
+ return filter.Drop
+ }
+
+ if filt.RunIn(p, t.filterFlags) != filter.Accept {
+
+ // Tell them, via TSMP, we're dropping them due to the ACL.
+ // Their host networking stack can translate this into ICMP
+ // or whatnot as required. But notably, their GUI or tailscale CLI
+ // can show them a rejection history with reasons.
+ if p.IPVersion == 4 && p.IPProto == ipproto.TCP && p.TCPFlags&packet.TCPSyn != 0 {
+ rj := packet.TailscaleRejectedHeader{
+ IPSrc: p.Dst.IP,
+ IPDst: p.Src.IP,
+ Src: p.Src,
+ Dst: p.Dst,
+ Proto: p.IPProto,
+ Reason: packet.RejectedDueToACLs,
+ }
+ if filt.ShieldsUp() {
+ rj.Reason = packet.RejectedDueToShieldsUp
+ }
+ pkt := packet.Generate(rj, nil)
+ t.InjectOutbound(pkt)
+
+ // TODO(bradfitz): also send a TCP RST, after the TSMP message.
+ }
+
+ return filter.Drop
+ }
+
+ if t.PostFilterIn != nil {
+ if res := t.PostFilterIn(p, t); res.IsDrop() {
+ return res
+ }
+ }
+
+ return filter.Accept
+}
+
+// Write accepts an incoming packet. The packet begins at buf[offset:],
+// like wireguard-go/tun.Device.Write.
+func (t *Wrapper) Write(buf []byte, offset int) (int, error) {
+ if !t.disableFilter {
+ res := t.filterIn(buf[offset:])
+ if res == filter.DropSilently {
+ return len(buf), nil
+ }
+ if res != filter.Accept {
+ return 0, ErrFiltered
+ }
+ }
+
+ t.noteActivity()
+ return t.tdev.Write(buf, offset)
+}
+
+func (t *Wrapper) GetFilter() *filter.Filter {
+ filt, _ := t.filter.Load().(*filter.Filter)
+ return filt
+}
+
+func (t *Wrapper) SetFilter(filt *filter.Filter) {
+ t.filter.Store(filt)
+}
+
+// InjectInboundDirect makes the Wrapper device behave as if a packet
+// with the given contents was received from the network.
+// It blocks and does not take ownership of the packet.
+// The injected packet will not pass through inbound filters.
+//
+// The packet contents are to start at &buf[offset].
+// offset must be greater or equal to PacketStartOffset.
+// The space before &buf[offset] will be used by Wireguard.
+func (t *Wrapper) InjectInboundDirect(buf []byte, offset int) error {
+ if len(buf) > MaxPacketSize {
+ return errPacketTooBig
+ }
+ if len(buf) < offset {
+ return errOffsetTooBig
+ }
+ if offset < PacketStartOffset {
+ return errOffsetTooSmall
+ }
+
+ // Write to the underlying device to skip filters.
+ _, err := t.tdev.Write(buf, offset)
+ return err
+}
+
+// InjectInboundCopy takes a packet without leading space,
+// reallocates it to conform to the InjectInboundDirect interface
+// and calls InjectInboundDirect on it. Injecting a nil packet is a no-op.
+func (t *Wrapper) InjectInboundCopy(packet []byte) error {
+ // We duplicate this check from InjectInboundDirect here
+ // to avoid wasting an allocation on an oversized packet.
+ if len(packet) > MaxPacketSize {
+ return errPacketTooBig
+ }
+ if len(packet) == 0 {
+ return nil
+ }
+
+ buf := make([]byte, PacketStartOffset+len(packet))
+ copy(buf[PacketStartOffset:], packet)
+
+ return t.InjectInboundDirect(buf, PacketStartOffset)
+}
+
+func (t *Wrapper) injectOutboundPong(pp *packet.Parsed, req packet.TSMPPingRequest) {
+ pong := packet.TSMPPongReply{
+ Data: req.Data,
+ }
+ switch pp.IPVersion {
+ case 4:
+ h4 := pp.IP4Header()
+ h4.ToResponse()
+ pong.IPHeader = h4
+ case 6:
+ h6 := pp.IP6Header()
+ h6.ToResponse()
+ pong.IPHeader = h6
+ default:
+ return
+ }
+
+ t.InjectOutbound(packet.Generate(pong, nil))
+}
+
+// InjectOutbound makes the Wrapper device behave as if a packet
+// with the given contents was sent to the network.
+// It does not block, but takes ownership of the packet.
+// The injected packet will not pass through outbound filters.
+// Injecting an empty packet is a no-op.
+func (t *Wrapper) InjectOutbound(packet []byte) error {
+ if len(packet) > MaxPacketSize {
+ return errPacketTooBig
+ }
+ if len(packet) == 0 {
+ return nil
+ }
+ select {
+ case <-t.closed:
+ return ErrClosed
+ case t.outbound <- packet:
+ return nil
+ }
+}
+
+// Unwrap returns the underlying tun.Device.
+func (t *Wrapper) Unwrap() tun.Device {
+ return t.tdev
+}
diff --git a/net/tstun/wrap_test.go b/net/tstun/wrap_test.go
new file mode 100644
index 000000000..4032b168b
--- /dev/null
+++ b/net/tstun/wrap_test.go
@@ -0,0 +1,387 @@
+// Copyright (c) 2020 Tailscale Inc & AUTHORS All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package tstun
+
+import (
+ "bytes"
+ "fmt"
+ "strconv"
+ "strings"
+ "sync/atomic"
+ "testing"
+ "unsafe"
+
+ "github.com/tailscale/wireguard-go/tun/tuntest"
+ "inet.af/netaddr"
+ "tailscale.com/net/packet"
+ "tailscale.com/types/ipproto"
+ "tailscale.com/types/logger"
+ "tailscale.com/wgengine/filter"
+)
+
+func udp4(src, dst string, sport, dport uint16) []byte {
+ sip, err := netaddr.ParseIP(src)
+ if err != nil {
+ panic(err)
+ }
+ dip, err := netaddr.ParseIP(dst)
+ if err != nil {
+ panic(err)
+ }
+ header := &packet.UDP4Header{
+ IP4Header: packet.IP4Header{
+ Src: sip,
+ Dst: dip,
+ IPID: 0,
+ },
+ SrcPort: sport,
+ DstPort: dport,
+ }
+ return packet.Generate(header, []byte("udp_payload"))
+}
+
+func nets(nets ...string) (ret []netaddr.IPPrefix) {
+ for _, s := range nets {
+ if i := strings.IndexByte(s, '/'); i == -1 {
+ ip, err := netaddr.ParseIP(s)
+ if err != nil {
+ panic(err)
+ }
+ bits := uint8(32)
+ if ip.Is6() {
+ bits = 128
+ }
+ ret = append(ret, netaddr.IPPrefix{IP: ip, Bits: bits})
+ } else {
+ pfx, err := netaddr.ParseIPPrefix(s)
+ if err != nil {
+ panic(err)
+ }
+ ret = append(ret, pfx)
+ }
+ }
+ return ret
+}
+
+func ports(s string) filter.PortRange {
+ if s == "*" {
+ return filter.PortRange{First: 0, Last: 65535}
+ }
+
+ var fs, ls string
+ i := strings.IndexByte(s, '-')
+ if i == -1 {
+ fs = s
+ ls = fs
+ } else {
+ fs = s[:i]
+ ls = s[i+1:]
+ }
+ first, err := strconv.ParseInt(fs, 10, 16)
+ if err != nil {
+ panic(fmt.Sprintf("invalid NetPortRange %q", s))
+ }
+ last, err := strconv.ParseInt(ls, 10, 16)
+ if err != nil {
+ panic(fmt.Sprintf("invalid NetPortRange %q", s))
+ }
+ return filter.PortRange{First: uint16(first), Last: uint16(last)}
+}
+
+func netports(netPorts ...string) (ret []filter.NetPortRange) {
+ for _, s := range netPorts {
+ i := strings.LastIndexByte(s, ':')
+ if i == -1 {
+ panic(fmt.Sprintf("invalid NetPortRange %q", s))
+ }
+
+ npr := filter.NetPortRange{
+ Net: nets(s[:i])[0],
+ Ports: ports(s[i+1:]),
+ }
+ ret = append(ret, npr)
+ }
+ return ret
+}
+
+func setfilter(logf logger.Logf, tun *Wrapper) {
+ protos := []ipproto.Proto{
+ ipproto.TCP,
+ ipproto.UDP,
+ }
+ matches := []filter.Match{
+ {IPProto: protos, Srcs: nets("5.6.7.8"), Dsts: netports("1.2.3.4:89-90")},
+ {IPProto: protos, Srcs: nets("1.2.3.4"), Dsts: netports("5.6.7.8:98")},
+ }
+ var sb netaddr.IPSetBuilder
+ sb.AddPrefix(netaddr.MustParseIPPrefix("1.2.0.0/16"))
+ tun.SetFilter(filter.New(matches, sb.IPSet(), sb.IPSet(), nil, logf))
+}
+
+func newChannelTUN(logf logger.Logf, secure bool) (*tuntest.ChannelTUN, *Wrapper) {
+ chtun := tuntest.NewChannelTUN()
+ tun := Wrap(logf, chtun.TUN())
+ if secure {
+ setfilter(logf, tun)
+ } else {
+ tun.disableFilter = true
+ }
+ return chtun, tun
+}
+
+func newFakeTUN(logf logger.Logf, secure bool) (*fakeTUN, *Wrapper) {
+ ftun := NewFake()
+ tun := Wrap(logf, ftun)
+ if secure {
+ setfilter(logf, tun)
+ } else {
+ tun.disableFilter = true
+ }
+ return ftun.(*fakeTUN), tun
+}
+
+func TestReadAndInject(t *testing.T) {
+ chtun, tun := newChannelTUN(t.Logf, false)
+ defer tun.Close()
+
+ const size = 2 // all payloads have this size
+ written := []string{"w0", "w1"}
+ injected := []string{"i0", "i1"}
+
+ go func() {
+ for _, packet := range written {
+ payload := []byte(packet)
+ chtun.Outbound <- payload
+ }
+ }()
+
+ for _, packet := range injected {
+ go func(packet string) {
+ payload := []byte(packet)
+ err := tun.InjectOutbound(payload)
+ if err != nil {
+ t.Errorf("%s: error: %v", packet, err)
+ }
+ }(packet)
+ }
+
+ var buf [MaxPacketSize]byte
+ var seen = make(map[string]bool)
+ // We expect the same packets back, in no particular order.
+ for i := 0; i < len(written)+len(injected); i++ {
+ n, err := tun.Read(buf[:], 0)
+ if err != nil {
+ t.Errorf("read %d: error: %v", i, err)
+ }
+ if n != size {
+ t.Errorf("read %d: got size %d; want %d", i, n, size)
+ }
+ got := string(buf[:n])
+ t.Logf("read %d: got %s", i, got)
+ seen[got] = true
+ }
+
+ for _, packet := range written {
+ if !seen[packet] {
+ t.Errorf("%s not received", packet)
+ }
+ }
+ for _, packet := range injected {
+ if !seen[packet] {
+ t.Errorf("%s not received", packet)
+ }
+ }
+}
+
+func TestWriteAndInject(t *testing.T) {
+ chtun, tun := newChannelTUN(t.Logf, false)
+ defer tun.Close()
+
+ const size = 2 // all payloads have this size
+ written := []string{"w0", "w1"}
+ injected := []string{"i0", "i1"}
+
+ go func() {
+ for _, packet := range written {
+ payload := []byte(packet)
+ n, err := tun.Write(payload, 0)
+ if err != nil {
+ t.Errorf("%s: error: %v", packet, err)
+ }
+ if n != size {
+ t.Errorf("%s: got size %d; want %d", packet, n, size)
+ }
+ }
+ }()
+
+ for _, packet := range injected {
+ go func(packet string) {
+ payload := []byte(packet)
+ err := tun.InjectInboundCopy(payload)
+ if err != nil {
+ t.Errorf("%s: error: %v", packet, err)
+ }
+ }(packet)
+ }
+
+ seen := make(map[string]bool)
+ // We expect the same packets back, in no particular order.
+ for i := 0; i < len(written)+len(injected); i++ {
+ packet := <-chtun.Inbound
+ got := string(packet)
+ t.Logf("read %d: got %s", i, got)
+ seen[got] = true
+ }
+
+ for _, packet := range written {
+ if !seen[packet] {
+ t.Errorf("%s not received", packet)
+ }
+ }
+ for _, packet := range injected {
+ if !seen[packet] {
+ t.Errorf("%s not received", packet)
+ }
+ }
+}
+
+func TestFilter(t *testing.T) {
+ chtun, tun := newChannelTUN(t.Logf, true)
+ defer tun.Close()
+
+ type direction int
+
+ const (
+ in direction = iota
+ out
+ )
+
+ tests := []struct {
+ name string
+ dir direction
+ drop bool
+ data []byte
+ }{
+ {"junk_in", in, true, []byte("\x45not a valid IPv4 packet")},
+ {"junk_out", out, true, []byte("\x45not a valid IPv4 packet")},
+ {"bad_port_in", in, true, udp4("5.6.7.8", "1.2.3.4", 22, 22)},
+ {"bad_port_out", out, false, udp4("1.2.3.4", "5.6.7.8", 22, 22)},
+ {"bad_ip_in", in, true, udp4("8.1.1.1", "1.2.3.4", 89, 89)},
+ {"bad_ip_out", out, false, udp4("1.2.3.4", "8.1.1.1", 98, 98)},
+ {"good_packet_in", in, false, udp4("5.6.7.8", "1.2.3.4", 89, 89)},
+ {"good_packet_out", out, false, udp4("1.2.3.4", "5.6.7.8", 98, 98)},
+ }
+
+ // A reader on the other end of the tun.
+ go func() {
+ var recvbuf []byte
+ for {
+ select {
+ case <-tun.closed:
+ return
+ case recvbuf = <-chtun.Inbound:
+ // continue
+ }
+ for _, tt := range tests {
+ if tt.drop && bytes.Equal(recvbuf, tt.data) {
+ t.Errorf("did not drop %s", tt.name)
+ }
+ }
+ }
+ }()
+
+ var buf [MaxPacketSize]byte
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ var n int
+ var err error
+ var filtered bool
+
+ if tt.dir == in {
+ _, err = tun.Write(tt.data, 0)
+ if err == ErrFiltered {
+ filtered = true
+ err = nil
+ }
+ } else {
+ chtun.Outbound <- tt.data
+ n, err = tun.Read(buf[:], 0)
+ // In the read direction, errors are fatal, so we return n = 0 instead.
+ filtered = (n == 0)
+ }
+
+ if err != nil {
+ t.Errorf("got err %v; want nil", err)
+ }
+
+ if filtered {
+ if !tt.drop {
+ t.Errorf("got drop; want accept")
+ }
+ } else {
+ if tt.drop {
+ t.Errorf("got accept; want drop")
+ }
+ }
+ })
+ }
+}
+
+func TestAllocs(t *testing.T) {
+ ftun, tun := newFakeTUN(t.Logf, false)
+ defer tun.Close()
+
+ buf := []byte{0x00}
+ allocs := testing.AllocsPerRun(100, func() {
+ _, err := ftun.Write(buf, 0)
+ if err != nil {
+ t.Errorf("write: error: %v", err)
+ return
+ }
+ })
+
+ if allocs > 0 {
+ t.Errorf("read allocs = %v; want 0", allocs)
+ }
+}
+
+func TestClose(t *testing.T) {
+ ftun, tun := newFakeTUN(t.Logf, false)
+
+ data := udp4("1.2.3.4", "5.6.7.8", 98, 98)
+ _, err := ftun.Write(data, 0)
+ if err != nil {
+ t.Error(err)
+ }
+
+ tun.Close()
+ _, err = ftun.Write(data, 0)
+ if err == nil {
+ t.Error("Expected error from ftun.Write() after Close()")
+ }
+}
+
+func BenchmarkWrite(b *testing.B) {
+ ftun, tun := newFakeTUN(b.Logf, true)
+ defer tun.Close()
+
+ packet := udp4("5.6.7.8", "1.2.3.4", 89, 89)
+ for i := 0; i < b.N; i++ {
+ _, err := ftun.Write(packet, 0)
+ if err != nil {
+ b.Errorf("err = %v; want nil", err)
+ }
+ }
+}
+
+func TestAtomic64Alignment(t *testing.T) {
+ off := unsafe.Offsetof(Wrapper{}.lastActivityAtomic)
+ if off%8 != 0 {
+ t.Errorf("offset %v not 8-byte aligned", off)
+ }
+
+ c := new(Wrapper)
+ atomic.StoreInt64(&c.lastActivityAtomic, 123)
+}