summaryrefslogtreecommitdiffhomepage
path: root/net/dns
diff options
context:
space:
mode:
authorNaman Sood <mail@nsood.in>2021-03-29 14:28:08 -0400
committerNaman Sood <mail@nsood.in>2021-03-29 14:28:08 -0400
commitc0a88a0129ebf0f9886b93b1f4e4f04a7c3bb86f (patch)
tree57d5aef2985e3424e5bb6f4c810628aa3ccbf5d0 /net/dns
parent47bd3c4cf5543fd7ecb049302c37c1001fa9f2d6 (diff)
parenta4c679e64691a3f0ba41ad9078312ca67e5e67fd (diff)
downloadtailscale-naman/netstack-subnet-routing.tar.xz
tailscale-naman/netstack-subnet-routing.zip
Signed-off-by: Naman Sood <mail@nsood.in>
Diffstat (limited to 'net/dns')
-rw-r--r--net/dns/config.go77
-rw-r--r--net/dns/direct.go188
-rw-r--r--net/dns/flush_windows.go19
-rw-r--r--net/dns/forwarder.go474
-rw-r--r--net/dns/manager.go100
-rw-r--r--net/dns/manager_default.go14
-rw-r--r--net/dns/manager_freebsd.go14
-rw-r--r--net/dns/manager_linux.go27
-rw-r--r--net/dns/manager_openbsd.go9
-rw-r--r--net/dns/manager_windows.go118
-rw-r--r--net/dns/map.go160
-rw-r--r--net/dns/map_test.go156
-rw-r--r--net/dns/neterr_darwin.go25
-rw-r--r--net/dns/neterr_other.go10
-rw-r--r--net/dns/neterr_windows.go29
-rw-r--r--net/dns/nm.go205
-rw-r--r--net/dns/noop.go17
-rw-r--r--net/dns/registry_windows.go76
-rw-r--r--net/dns/resolvconf.go157
-rw-r--r--net/dns/resolved.go188
-rw-r--r--net/dns/tsdns.go662
-rw-r--r--net/dns/tsdns_server_test.go95
-rw-r--r--net/dns/tsdns_test.go816
23 files changed, 3636 insertions, 0 deletions
diff --git a/net/dns/config.go b/net/dns/config.go
new file mode 100644
index 000000000..db08df7aa
--- /dev/null
+++ b/net/dns/config.go
@@ -0,0 +1,77 @@
+// Copyright (c) 2020 Tailscale Inc & AUTHORS All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package dns
+
+import (
+ "inet.af/netaddr"
+
+ "tailscale.com/types/logger"
+)
+
+// Config is the set of parameters that uniquely determine
+// the state to which a manager should bring system DNS settings.
+type Config struct {
+ // Nameservers are the IP addresses of the nameservers to use.
+ Nameservers []netaddr.IP
+ // Domains are the search domains to use.
+ Domains []string
+ // PerDomain indicates whether it is preferred to use Nameservers
+ // only for DNS queries for subdomains of Domains.
+ // Note that Nameservers may still be applied to all queries
+ // if the manager does not support per-domain settings.
+ PerDomain bool
+ // Proxied indicates whether DNS requests are proxied through a dns.Resolver.
+ // This enables MagicDNS.
+ Proxied bool
+}
+
+// Equal determines whether its argument and receiver
+// represent equivalent DNS configurations (then DNS reconfig is a no-op).
+func (lhs Config) Equal(rhs Config) bool {
+ if lhs.Proxied != rhs.Proxied || lhs.PerDomain != rhs.PerDomain {
+ return false
+ }
+
+ if len(lhs.Nameservers) != len(rhs.Nameservers) {
+ return false
+ }
+
+ if len(lhs.Domains) != len(rhs.Domains) {
+ return false
+ }
+
+ // With how we perform resolution order shouldn't matter,
+ // but it is unlikely that we will encounter different orders.
+ for i, server := range lhs.Nameservers {
+ if rhs.Nameservers[i] != server {
+ return false
+ }
+ }
+
+ // The order of domains, on the other hand, is significant.
+ for i, domain := range lhs.Domains {
+ if rhs.Domains[i] != domain {
+ return false
+ }
+ }
+
+ return true
+}
+
+// ManagerConfig is the set of parameters from which
+// a manager implementation is chosen and initialized.
+type ManagerConfig struct {
+ // Logf is the logger for the manager to use.
+ // It is wrapped with a "dns: " prefix.
+ Logf logger.Logf
+ // InterfaceName is the name of the interface with which DNS settings should be associated.
+ InterfaceName string
+ // Cleanup indicates that the manager is created for cleanup only.
+ // A no-op manager will be instantiated if the system needs no cleanup.
+ Cleanup bool
+ // PerDomain indicates that a manager capable of per-domain configuration is preferred.
+ // Certain managers are per-domain only; they will not be considered if this is false.
+ PerDomain bool
+}
diff --git a/net/dns/direct.go b/net/dns/direct.go
new file mode 100644
index 000000000..bd1c03b9d
--- /dev/null
+++ b/net/dns/direct.go
@@ -0,0 +1,188 @@
+// Copyright (c) 2020 Tailscale Inc & AUTHORS All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+// +build linux freebsd openbsd
+
+package dns
+
+import (
+ "bufio"
+ "bytes"
+ "errors"
+ "fmt"
+ "io"
+ "io/ioutil"
+ "os"
+ "os/exec"
+ "runtime"
+ "strings"
+
+ "inet.af/netaddr"
+ "tailscale.com/atomicfile"
+)
+
+const (
+ tsConf = "/etc/resolv.tailscale.conf"
+ backupConf = "/etc/resolv.pre-tailscale-backup.conf"
+ resolvConf = "/etc/resolv.conf"
+)
+
+// writeResolvConf writes DNS configuration in resolv.conf format to the given writer.
+func writeResolvConf(w io.Writer, servers []netaddr.IP, domains []string) {
+ io.WriteString(w, "# resolv.conf(5) file generated by tailscale\n")
+ io.WriteString(w, "# DO NOT EDIT THIS FILE BY HAND -- CHANGES WILL BE OVERWRITTEN\n\n")
+ for _, ns := range servers {
+ io.WriteString(w, "nameserver ")
+ io.WriteString(w, ns.String())
+ io.WriteString(w, "\n")
+ }
+ if len(domains) > 0 {
+ io.WriteString(w, "search")
+ for _, domain := range domains {
+ io.WriteString(w, " ")
+ io.WriteString(w, domain)
+ }
+ io.WriteString(w, "\n")
+ }
+}
+
+// readResolvConf reads DNS configuration from /etc/resolv.conf.
+func readResolvConf() (Config, error) {
+ var config Config
+
+ f, err := os.Open("/etc/resolv.conf")
+ if err != nil {
+ return config, err
+ }
+
+ scanner := bufio.NewScanner(f)
+ for scanner.Scan() {
+ line := strings.TrimSpace(scanner.Text())
+
+ if strings.HasPrefix(line, "nameserver") {
+ nameserver := strings.TrimPrefix(line, "nameserver")
+ nameserver = strings.TrimSpace(nameserver)
+ ip, err := netaddr.ParseIP(nameserver)
+ if err != nil {
+ return config, err
+ }
+ config.Nameservers = append(config.Nameservers, ip)
+ continue
+ }
+
+ if strings.HasPrefix(line, "search") {
+ domain := strings.TrimPrefix(line, "search")
+ domain = strings.TrimSpace(domain)
+ config.Domains = append(config.Domains, domain)
+ continue
+ }
+ }
+
+ return config, nil
+}
+
+// isResolvedRunning reports whether systemd-resolved is running on the system,
+// even if it is not managing the system DNS settings.
+func isResolvedRunning() bool {
+ if runtime.GOOS != "linux" {
+ return false
+ }
+
+ // systemd-resolved is never installed without systemd.
+ _, err := exec.LookPath("systemctl")
+ if err != nil {
+ return false
+ }
+
+ // is-active exits with code 3 if the service is not active.
+ err = exec.Command("systemctl", "is-active", "systemd-resolved.service").Run()
+
+ return err == nil
+}
+
+// directManager is a managerImpl which replaces /etc/resolv.conf with a file
+// generated from the given configuration, creating a backup of its old state.
+//
+// This way of configuring DNS is precarious, since it does not react
+// to the disappearance of the Tailscale interface.
+// The caller must call Down before program shutdown
+// or as cleanup if the program terminates unexpectedly.
+type directManager struct{}
+
+func newDirectManager(mconfig ManagerConfig) managerImpl {
+ return directManager{}
+}
+
+// Up implements managerImpl.
+func (m directManager) Up(config Config) error {
+ // Write the tsConf file.
+ buf := new(bytes.Buffer)
+ writeResolvConf(buf, config.Nameservers, config.Domains)
+ if err := atomicfile.WriteFile(tsConf, buf.Bytes(), 0644); err != nil {
+ return err
+ }
+
+ if linkPath, err := os.Readlink(resolvConf); err != nil {
+ // Remove any old backup that may exist.
+ os.Remove(backupConf)
+
+ // Backup the existing /etc/resolv.conf file.
+ contents, err := ioutil.ReadFile(resolvConf)
+ // If the original did not exist, still back up an empty file.
+ // The presence of a backup file is the way we know that Up ran.
+ if err != nil && !errors.Is(err, os.ErrNotExist) {
+ return err
+ }
+ if err := atomicfile.WriteFile(backupConf, contents, 0644); err != nil {
+ return err
+ }
+ } else if linkPath != tsConf {
+ // Backup the existing symlink.
+ os.Remove(backupConf)
+ if err := os.Symlink(linkPath, backupConf); err != nil {
+ return err
+ }
+ } else {
+ // Nothing to do, resolvConf already points to tsConf.
+ return nil
+ }
+
+ os.Remove(resolvConf)
+ if err := os.Symlink(tsConf, resolvConf); err != nil {
+ return err
+ }
+
+ if isResolvedRunning() {
+ exec.Command("systemctl", "restart", "systemd-resolved.service").Run() // Best-effort.
+ }
+
+ return nil
+}
+
+// Down implements managerImpl.
+func (m directManager) Down() error {
+ if _, err := os.Stat(backupConf); err != nil {
+ // If the backup file does not exist, then Up never ran successfully.
+ if os.IsNotExist(err) {
+ return nil
+ }
+ return err
+ }
+
+ if ln, err := os.Readlink(resolvConf); err != nil {
+ return err
+ } else if ln != tsConf {
+ return fmt.Errorf("resolv.conf is not a symlink to %s", tsConf)
+ }
+ if err := os.Rename(backupConf, resolvConf); err != nil {
+ return err
+ }
+ os.Remove(tsConf)
+
+ if isResolvedRunning() {
+ exec.Command("systemctl", "restart", "systemd-resolved.service").Run() // Best-effort.
+ }
+
+ return nil
+}
diff --git a/net/dns/flush_windows.go b/net/dns/flush_windows.go
new file mode 100644
index 000000000..3c7e7d645
--- /dev/null
+++ b/net/dns/flush_windows.go
@@ -0,0 +1,19 @@
+// Copyright (c) 2021 Tailscale Inc & AUTHORS All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package dns
+
+import (
+ "fmt"
+ "os/exec"
+)
+
+// Flush clears the local resolver cache.
+func Flush() error {
+ out, err := exec.Command("ipconfig", "/flushdns").CombinedOutput()
+ if err != nil {
+ return fmt.Errorf("%v (output: %s)", err, out)
+ }
+ return nil
+}
diff --git a/net/dns/forwarder.go b/net/dns/forwarder.go
new file mode 100644
index 000000000..519c00027
--- /dev/null
+++ b/net/dns/forwarder.go
@@ -0,0 +1,474 @@
+// Copyright (c) 2020 Tailscale Inc & AUTHORS All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package dns
+
+import (
+ "bytes"
+ "context"
+ "encoding/binary"
+ "errors"
+ "fmt"
+ "hash/crc32"
+ "math/rand"
+ "net"
+ "os"
+ "sync"
+ "time"
+
+ "inet.af/netaddr"
+ "tailscale.com/logtail/backoff"
+ "tailscale.com/net/netns"
+ "tailscale.com/types/logger"
+)
+
+// headerBytes is the number of bytes in a DNS message header.
+const headerBytes = 12
+
+// connCount is the number of UDP connections to use for forwarding.
+const connCount = 32
+
+const (
+ // cleanupInterval is the interval between purged of timed-out entries from txMap.
+ cleanupInterval = 30 * time.Second
+ // responseTimeout is the maximal amount of time to wait for a DNS response.
+ responseTimeout = 5 * time.Second
+)
+
+var errNoUpstreams = errors.New("upstream nameservers not set")
+
+var aLongTimeAgo = time.Unix(0, 1)
+
+type forwardingRecord struct {
+ src netaddr.IPPort
+ createdAt time.Time
+}
+
+// txid identifies a DNS transaction.
+//
+// As the standard DNS Request ID is only 16 bits, we extend it:
+// the lower 32 bits are the zero-extended bits of the DNS Request ID;
+// the upper 32 bits are the CRC32 checksum of the first question in the request.
+// This makes probability of txid collision negligible.
+type txid uint64
+
+// getTxID computes the txid of the given DNS packet.
+func getTxID(packet []byte) txid {
+ if len(packet) < headerBytes {
+ return 0
+ }
+
+ dnsid := binary.BigEndian.Uint16(packet[0:2])
+ qcount := binary.BigEndian.Uint16(packet[4:6])
+ if qcount == 0 {
+ return txid(dnsid)
+ }
+
+ offset := headerBytes
+ for i := uint16(0); i < qcount; i++ {
+ // Note: this relies on the fact that names are not compressed in questions,
+ // so they are guaranteed to end with a NUL byte.
+ //
+ // Justification:
+ // RFC 1035 doesn't seem to explicitly prohibit compressing names in questions,
+ // but this is exceedingly unlikely to be done in practice. A DNS request
+ // with multiple questions is ill-defined (which questions do the header flags apply to?)
+ // and a single question would have to contain a pointer to an *answer*,
+ // which would be excessively smart, pointless (an answer can just as well refer to the question)
+ // and perhaps even prohibited: a draft RFC (draft-ietf-dnsind-local-compression-05) states:
+ //
+ // > It is important that these pointers always point backwards.
+ //
+ // This is said in summarizing RFC 1035, although that phrase does not appear in the original RFC.
+ // Additionally, (https://cr.yp.to/djbdns/notes.html) states:
+ //
+ // > The precise rule is that a name can be compressed if it is a response owner name,
+ // > the name in NS data, the name in CNAME data, the name in PTR data, the name in MX data,
+ // > or one of the names in SOA data.
+ namebytes := bytes.IndexByte(packet[offset:], 0)
+ // ... | name | NUL | type | class
+ // ?? 1 2 2
+ offset = offset + namebytes + 5
+ if len(packet) < offset {
+ // Corrupt packet; don't crash.
+ return txid(dnsid)
+ }
+ }
+
+ hash := crc32.ChecksumIEEE(packet[headerBytes:offset])
+ return (txid(hash) << 32) | txid(dnsid)
+}
+
+// forwarder forwards DNS packets to a number of upstream nameservers.
+type forwarder struct {
+ logf logger.Logf
+
+ // responses is a channel by which responses are returned.
+ responses chan Packet
+ // closed signals all goroutines to stop.
+ closed chan struct{}
+ // wg signals when all goroutines have stopped.
+ wg sync.WaitGroup
+
+ // conns are the UDP connections used for forwarding.
+ // A random one is selected for each request, regardless of the target upstream.
+ conns []*fwdConn
+
+ mu sync.Mutex
+ // upstreams are the nameserver addresses that should be used for forwarding.
+ upstreams []net.Addr
+ // txMap maps DNS txids to active forwarding records.
+ txMap map[txid]forwardingRecord
+}
+
+func init() {
+ rand.Seed(time.Now().UnixNano())
+}
+
+func newForwarder(logf logger.Logf, responses chan Packet) *forwarder {
+ return &forwarder{
+ logf: logger.WithPrefix(logf, "forward: "),
+ responses: responses,
+ closed: make(chan struct{}),
+ conns: make([]*fwdConn, connCount),
+ txMap: make(map[txid]forwardingRecord),
+ }
+}
+
+func (f *forwarder) Start() error {
+ f.wg.Add(connCount + 1)
+ for idx := range f.conns {
+ f.conns[idx] = newFwdConn(f.logf, idx)
+ go f.recv(f.conns[idx])
+ }
+ go f.cleanMap()
+
+ return nil
+}
+
+func (f *forwarder) Close() {
+ select {
+ case <-f.closed:
+ return
+ default:
+ // continue
+ }
+ close(f.closed)
+
+ for _, conn := range f.conns {
+ conn.close()
+ }
+
+ f.wg.Wait()
+}
+
+func (f *forwarder) rebindFromNetworkChange() {
+ for _, c := range f.conns {
+ c.mu.Lock()
+ c.reconnectLocked()
+ c.mu.Unlock()
+ }
+}
+
+func (f *forwarder) setUpstreams(upstreams []net.Addr) {
+ f.mu.Lock()
+ f.upstreams = upstreams
+ f.mu.Unlock()
+}
+
+// send sends packet to dst. It is best effort.
+func (f *forwarder) send(packet []byte, dst net.Addr) {
+ connIdx := rand.Intn(connCount)
+ conn := f.conns[connIdx]
+ conn.send(packet, dst)
+}
+
+func (f *forwarder) recv(conn *fwdConn) {
+ defer f.wg.Done()
+
+ for {
+ select {
+ case <-f.closed:
+ return
+ default:
+ }
+ out := make([]byte, maxResponseBytes)
+ n := conn.read(out)
+ if n == 0 {
+ continue
+ }
+ if n < headerBytes {
+ f.logf("recv: packet too small (%d bytes)", n)
+ }
+
+ out = out[:n]
+ txid := getTxID(out)
+
+ f.mu.Lock()
+
+ record, found := f.txMap[txid]
+ // At most one nameserver will return a response:
+ // the first one to do so will delete txid from the map.
+ if !found {
+ f.mu.Unlock()
+ continue
+ }
+ delete(f.txMap, txid)
+
+ f.mu.Unlock()
+
+ packet := Packet{
+ Payload: out,
+ Addr: record.src,
+ }
+ select {
+ case <-f.closed:
+ return
+ case f.responses <- packet:
+ // continue
+ }
+ }
+}
+
+// cleanMap periodically deletes timed-out forwarding records from f.txMap to bound growth.
+func (f *forwarder) cleanMap() {
+ defer f.wg.Done()
+
+ t := time.NewTicker(cleanupInterval)
+ defer t.Stop()
+
+ var now time.Time
+ for {
+ select {
+ case <-f.closed:
+ return
+ case now = <-t.C:
+ // continue
+ }
+
+ f.mu.Lock()
+ for k, v := range f.txMap {
+ if now.Sub(v.createdAt) > responseTimeout {
+ delete(f.txMap, k)
+ }
+ }
+ f.mu.Unlock()
+ }
+}
+
+// forward forwards the query to all upstream nameservers and returns the first response.
+func (f *forwarder) forward(query Packet) error {
+ txid := getTxID(query.Payload)
+
+ f.mu.Lock()
+
+ upstreams := f.upstreams
+ if len(upstreams) == 0 {
+ f.mu.Unlock()
+ return errNoUpstreams
+ }
+ f.txMap[txid] = forwardingRecord{
+ src: query.Addr,
+ createdAt: time.Now(),
+ }
+
+ f.mu.Unlock()
+
+ for _, upstream := range upstreams {
+ f.send(query.Payload, upstream)
+ }
+
+ return nil
+}
+
+// A fwdConn manages a single connection used to forward DNS requests.
+// Net link changes can cause a *net.UDPConn to become permanently unusable, particularly on macOS.
+// fwdConn detects such situations and transparently creates new connections.
+type fwdConn struct {
+ // logf allows a fwdConn to log.
+ logf logger.Logf
+
+ // wg tracks the number of outstanding conn.Read and conn.Write calls.
+ wg sync.WaitGroup
+ // change allows calls to read to block until a the network connection has been replaced.
+ change *sync.Cond
+
+ // mu protects fields that follow it; it is also change's Locker.
+ mu sync.Mutex
+ // closed tracks whether fwdConn has been permanently closed.
+ closed bool
+ // conn is the current active connection.
+ conn net.PacketConn
+}
+
+func newFwdConn(logf logger.Logf, idx int) *fwdConn {
+ c := new(fwdConn)
+ c.logf = logger.WithPrefix(logf, fmt.Sprintf("fwdConn %d: ", idx))
+ c.change = sync.NewCond(&c.mu)
+ // c.conn is created lazily in send
+ return c
+}
+
+// send sends packet to dst using c's connection.
+// It is best effort. It is UDP, after all. Failures are logged.
+func (c *fwdConn) send(packet []byte, dst net.Addr) {
+ var b *backoff.Backoff // lazily initialized, since it is not needed in the common case
+ backOff := func(err error) {
+ if b == nil {
+ b = backoff.NewBackoff("dns-fwdConn-send", c.logf, 30*time.Second)
+ }
+ b.BackOff(context.Background(), err)
+ }
+
+ for {
+ // Gather the current connection.
+ // We can't hold the lock while we call WriteTo.
+ c.mu.Lock()
+ conn := c.conn
+ closed := c.closed
+ if closed {
+ c.mu.Unlock()
+ return
+ }
+ if conn == nil {
+ c.reconnectLocked()
+ c.mu.Unlock()
+ continue
+ }
+ c.mu.Unlock()
+
+ c.wg.Add(1)
+ _, err := conn.WriteTo(packet, dst)
+ c.wg.Done()
+ if err == nil {
+ // Success
+ return
+ }
+ if errors.Is(err, os.ErrDeadlineExceeded) {
+ // We intentionally closed this connection.
+ // It has been replaced by a new connection. Try again.
+ continue
+ }
+ // Something else went wrong.
+ // We have three choices here: try again, give up, or create a new connection.
+ var opErr *net.OpError
+ if !errors.As(err, &opErr) {
+ // Weird. All errors from the net package should be *net.OpError. Bail.
+ c.logf("send: non-*net.OpErr %v (%T)", err, err)
+ return
+ }
+ if opErr.Temporary() || opErr.Timeout() {
+ // I doubt that either of these can happen (this is UDP),
+ // but go ahead and try again.
+ backOff(err)
+ continue
+ }
+ if networkIsDown(err) {
+ // Fail.
+ c.logf("send: network is down")
+ return
+ }
+ if networkIsUnreachable(err) {
+ // This can be caused by a link change.
+ // Replace the existing connection with a new one.
+ c.mu.Lock()
+ // It's possible that multiple senders discovered simultaneously
+ // that the network is unreachable. Avoid reconnecting multiple times:
+ // Only reconnect if the current connection is the one that we
+ // discovered to be problematic.
+ if c.conn == conn {
+ backOff(err)
+ c.reconnectLocked()
+ }
+ c.mu.Unlock()
+ // Try again with our new network connection.
+ continue
+ }
+ // Unrecognized error. Fail.
+ c.logf("send: unrecognized error: %v", err)
+ return
+ }
+}
+
+// read waits for a response from c's connection.
+// It returns the number of bytes read, which may be 0
+// in case of an error or a closed connection.
+func (c *fwdConn) read(out []byte) int {
+ for {
+ // Gather the current connection.
+ // We can't hold the lock while we call ReadFrom.
+ c.mu.Lock()
+ conn := c.conn
+ closed := c.closed
+ if closed {
+ c.mu.Unlock()
+ return 0
+ }
+ if conn == nil {
+ // There is no current connection.
+ // Wait for the connection to change, then try again.
+ c.change.Wait()
+ c.mu.Unlock()
+ continue
+ }
+ c.mu.Unlock()
+
+ c.wg.Add(1)
+ n, _, err := conn.ReadFrom(out)
+ c.wg.Done()
+ if err == nil {
+ // Success.
+ return n
+ }
+ if errors.Is(err, os.ErrDeadlineExceeded) {
+ // We intentionally closed this connection.
+ // It has been replaced by a new connection. Try again.
+ continue
+ }
+
+ c.logf("read: unrecognized error: %v", err)
+ return 0
+ }
+}
+
+// reconnectLocked replaces the current connection with a new one.
+// c.mu must be locked.
+func (c *fwdConn) reconnectLocked() {
+ c.closeConnLocked()
+ // Make a new connection.
+ conn, err := netns.Listener().ListenPacket(context.Background(), "udp", "")
+ if err != nil {
+ c.logf("ListenPacket failed: %v", err)
+ } else {
+ c.conn = conn
+ }
+ // Broadcast that a new connection is available.
+ c.change.Broadcast()
+}
+
+// closeCurrentConn closes the current connection.
+// c.mu must be locked.
+func (c *fwdConn) closeConnLocked() {
+ if c.conn == nil {
+ return
+ }
+ // Unblock all readers/writers, wait for them, close the connection.
+ c.conn.SetDeadline(aLongTimeAgo)
+ c.wg.Wait()
+ c.conn.Close()
+ c.conn = nil
+}
+
+// close permanently closes c.
+func (c *fwdConn) close() {
+ c.mu.Lock()
+ defer c.mu.Unlock()
+ if c.closed {
+ return
+ }
+ c.closed = true
+ c.closeConnLocked()
+ // Unblock any remaining readers.
+ c.change.Broadcast()
+}
diff --git a/net/dns/manager.go b/net/dns/manager.go
new file mode 100644
index 000000000..8e2fc9d71
--- /dev/null
+++ b/net/dns/manager.go
@@ -0,0 +1,100 @@
+// Copyright (c) 2020 Tailscale Inc & AUTHORS All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package dns
+
+import (
+ "time"
+
+ "tailscale.com/types/logger"
+)
+
+// We use file-ignore below instead of ignore because on some platforms,
+// the lint exception is necessary and on others it is not,
+// and plain ignore complains if the exception is unnecessary.
+
+//lint:file-ignore U1000 reconfigTimeout is used on some platforms but not others
+
+// reconfigTimeout is the time interval within which Manager.{Up,Down} should complete.
+//
+// This is particularly useful because certain conditions can cause indefinite hangs
+// (such as improper dbus auth followed by contextless dbus.Object.Call).
+// Such operations should be wrapped in a timeout context.
+const reconfigTimeout = time.Second
+
+type managerImpl interface {
+ // Up updates system DNS settings to match the given configuration.
+ Up(Config) error
+ // Down undoes the effects of Up.
+ // It is idempotent and performs no action if Up has never been called.
+ Down() error
+}
+
+// Manager manages system DNS settings.
+type Manager struct {
+ logf logger.Logf
+
+ impl managerImpl
+
+ config Config
+ mconfig ManagerConfig
+}
+
+// NewManagers created a new manager from the given config.
+func NewManager(mconfig ManagerConfig) *Manager {
+ mconfig.Logf = logger.WithPrefix(mconfig.Logf, "dns: ")
+ m := &Manager{
+ logf: mconfig.Logf,
+ impl: newManager(mconfig),
+
+ config: Config{PerDomain: mconfig.PerDomain},
+ mconfig: mconfig,
+ }
+
+ m.logf("using %T", m.impl)
+ return m
+}
+
+func (m *Manager) Set(config Config) error {
+ if config.Equal(m.config) {
+ return nil
+ }
+
+ m.logf("Set: %+v", config)
+
+ if len(config.Nameservers) == 0 {
+ err := m.impl.Down()
+ // If we save the config, we will not retry next time. Only do this on success.
+ if err == nil {
+ m.config = config
+ }
+ return err
+ }
+
+ // Switching to and from per-domain mode may require a change of manager.
+ if config.PerDomain != m.config.PerDomain {
+ if err := m.impl.Down(); err != nil {
+ return err
+ }
+ m.mconfig.PerDomain = config.PerDomain
+ m.impl = newManager(m.mconfig)
+ m.logf("switched to %T", m.impl)
+ }
+
+ err := m.impl.Up(config)
+ // If we save the config, we will not retry next time. Only do this on success.
+ if err == nil {
+ m.config = config
+ }
+
+ return err
+}
+
+func (m *Manager) Up() error {
+ return m.impl.Up(m.config)
+}
+
+func (m *Manager) Down() error {
+ return m.impl.Down()
+}
diff --git a/net/dns/manager_default.go b/net/dns/manager_default.go
new file mode 100644
index 000000000..04c8bb811
--- /dev/null
+++ b/net/dns/manager_default.go
@@ -0,0 +1,14 @@
+// Copyright (c) 2020 Tailscale Inc & AUTHORS All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+// +build !linux,!freebsd,!openbsd,!windows
+
+package dns
+
+func newManager(mconfig ManagerConfig) managerImpl {
+ // TODO(dmytro): on darwin, we should use a macOS-specific method such as scutil.
+ // This is currently not implemented. Editing /etc/resolv.conf does not work,
+ // as most applications use the system resolver, which disregards it.
+ return newNoopManager(mconfig)
+}
diff --git a/net/dns/manager_freebsd.go b/net/dns/manager_freebsd.go
new file mode 100644
index 000000000..232635f7e
--- /dev/null
+++ b/net/dns/manager_freebsd.go
@@ -0,0 +1,14 @@
+// Copyright (c) 2020 Tailscale Inc & AUTHORS All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package dns
+
+func newManager(mconfig ManagerConfig) managerImpl {
+ switch {
+ case isResolvconfActive():
+ return newResolvconfManager(mconfig)
+ default:
+ return newDirectManager(mconfig)
+ }
+}
diff --git a/net/dns/manager_linux.go b/net/dns/manager_linux.go
new file mode 100644
index 000000000..f53aed7d3
--- /dev/null
+++ b/net/dns/manager_linux.go
@@ -0,0 +1,27 @@
+// Copyright (c) 2020 Tailscale Inc & AUTHORS All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package dns
+
+func newManager(mconfig ManagerConfig) managerImpl {
+ switch {
+ // systemd-resolved should only activate per-domain.
+ case isResolvedActive() && mconfig.PerDomain:
+ if mconfig.Cleanup {
+ return newNoopManager(mconfig)
+ } else {
+ return newResolvedManager(mconfig)
+ }
+ case isNMActive():
+ if mconfig.Cleanup {
+ return newNoopManager(mconfig)
+ } else {
+ return newNMManager(mconfig)
+ }
+ case isResolvconfActive():
+ return newResolvconfManager(mconfig)
+ default:
+ return newDirectManager(mconfig)
+ }
+}
diff --git a/net/dns/manager_openbsd.go b/net/dns/manager_openbsd.go
new file mode 100644
index 000000000..228e3cca5
--- /dev/null
+++ b/net/dns/manager_openbsd.go
@@ -0,0 +1,9 @@
+// Copyright (c) 2020 Tailscale Inc & AUTHORS All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package dns
+
+func newManager(mconfig ManagerConfig) managerImpl {
+ return newDirectManager(mconfig)
+}
diff --git a/net/dns/manager_windows.go b/net/dns/manager_windows.go
new file mode 100644
index 000000000..5940404e7
--- /dev/null
+++ b/net/dns/manager_windows.go
@@ -0,0 +1,118 @@
+// Copyright (c) 2020 Tailscale Inc & AUTHORS All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package dns
+
+import (
+ "fmt"
+ "os/exec"
+ "strings"
+ "syscall"
+ "time"
+
+ "golang.org/x/sys/windows/registry"
+ "tailscale.com/types/logger"
+)
+
+const (
+ ipv4RegBase = `SYSTEM\CurrentControlSet\Services\Tcpip\Parameters`
+ ipv6RegBase = `SYSTEM\CurrentControlSet\Services\Tcpip6\Parameters`
+)
+
+type windowsManager struct {
+ logf logger.Logf
+ guid string
+}
+
+func newManager(mconfig ManagerConfig) managerImpl {
+ return windowsManager{
+ logf: mconfig.Logf,
+ guid: mconfig.InterfaceName,
+ }
+}
+
+// keyOpenTimeout is how long we wait for a registry key to
+// appear. For some reason, registry keys tied to ephemeral interfaces
+// can take a long while to appear after interface creation, and we
+// can end up racing with that.
+const keyOpenTimeout = 20 * time.Second
+
+func setRegistryString(path, name, value string) error {
+ key, err := openKeyWait(registry.LOCAL_MACHINE, path, registry.SET_VALUE, keyOpenTimeout)
+ if err != nil {
+ return fmt.Errorf("opening %s: %w", path, err)
+ }
+ defer key.Close()
+
+ err = key.SetStringValue(name, value)
+ if err != nil {
+ return fmt.Errorf("setting %s[%s]: %w", path, name, err)
+ }
+ return nil
+}
+
+func (m windowsManager) setNameservers(basePath string, nameservers []string) error {
+ path := fmt.Sprintf(`%s\Interfaces\%s`, basePath, m.guid)
+ value := strings.Join(nameservers, ",")
+ return setRegistryString(path, "NameServer", value)
+}
+
+func (m windowsManager) setDomains(basePath string, domains []string) error {
+ path := fmt.Sprintf(`%s\Interfaces\%s`, basePath, m.guid)
+ value := strings.Join(domains, ",")
+ return setRegistryString(path, "SearchList", value)
+}
+
+func (m windowsManager) Up(config Config) error {
+ var ipsv4 []string
+ var ipsv6 []string
+
+ for _, ip := range config.Nameservers {
+ if ip.Is4() {
+ ipsv4 = append(ipsv4, ip.String())
+ } else {
+ ipsv6 = append(ipsv6, ip.String())
+ }
+ }
+
+ if err := m.setNameservers(ipv4RegBase, ipsv4); err != nil {
+ return err
+ }
+ if err := m.setDomains(ipv4RegBase, config.Domains); err != nil {
+ return err
+ }
+
+ if err := m.setNameservers(ipv6RegBase, ipsv6); err != nil {
+ return err
+ }
+ if err := m.setDomains(ipv6RegBase, config.Domains); err != nil {
+ return err
+ }
+
+ // Force DNS re-registration in Active Directory. What we actually
+ // care about is that this command invokes the undocumented hidden
+ // function that forces Windows to notice that adapter settings
+ // have changed, which makes the DNS settings actually take
+ // effect.
+ //
+ // This command can take a few seconds to run, so run it async, best effort.
+ go func() {
+ t0 := time.Now()
+ m.logf("running ipconfig /registerdns ...")
+ cmd := exec.Command("ipconfig", "/registerdns")
+ cmd.SysProcAttr = &syscall.SysProcAttr{HideWindow: true}
+ d := time.Since(t0).Round(time.Millisecond)
+ if err := cmd.Run(); err != nil {
+ m.logf("error running ipconfig /registerdns after %v: %v", d, err)
+ } else {
+ m.logf("ran ipconfig /registerdns in %v", d)
+ }
+ }()
+
+ return nil
+}
+
+func (m windowsManager) Down() error {
+ return m.Up(Config{Nameservers: nil, Domains: nil})
+}
diff --git a/net/dns/map.go b/net/dns/map.go
new file mode 100644
index 000000000..119b6cc0a
--- /dev/null
+++ b/net/dns/map.go
@@ -0,0 +1,160 @@
+// Copyright (c) 2020 Tailscale Inc & AUTHORS All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package dns
+
+import (
+ "sort"
+ "strings"
+
+ "inet.af/netaddr"
+)
+
+// Map is all the data Resolver needs to resolve DNS queries within the Tailscale network.
+type Map struct {
+ // nameToIP is a mapping of Tailscale domain names to their IP addresses.
+ // For example, monitoring.tailscale.us -> 100.64.0.1.
+ nameToIP map[string]netaddr.IP
+ // ipToName is the inverse of nameToIP.
+ ipToName map[netaddr.IP]string
+ // names are the keys of nameToIP in sorted order.
+ names []string
+ // rootDomains are the domains whose subdomains should always
+ // be resolved locally to prevent leakage of sensitive names.
+ rootDomains []string // e.g. "user.provider.beta.tailscale.net."
+}
+
+// NewMap returns a new Map with name to address mapping given by nameToIP.
+//
+// rootDomains are the domains whose subdomains should always be
+// resolved locally to prevent leakage of sensitive names. They should
+// end in a period ("user-foo.tailscale.net.").
+func NewMap(initNameToIP map[string]netaddr.IP, rootDomains []string) *Map {
+ // TODO(dmytro): we have to allocate names and ipToName, but nameToIP can be avoided.
+ // It is here because control sends us names not in canonical form. Change this.
+ names := make([]string, 0, len(initNameToIP))
+ nameToIP := make(map[string]netaddr.IP, len(initNameToIP))
+ ipToName := make(map[netaddr.IP]string, len(initNameToIP))
+
+ for name, ip := range initNameToIP {
+ if len(name) == 0 {
+ // Nothing useful can be done with empty names.
+ continue
+ }
+ if name[len(name)-1] != '.' {
+ name += "."
+ }
+ names = append(names, name)
+ nameToIP[name] = ip
+ ipToName[ip] = name
+ }
+ sort.Strings(names)
+
+ return &Map{
+ nameToIP: nameToIP,
+ ipToName: ipToName,
+ names: names,
+
+ rootDomains: rootDomains,
+ }
+}
+
+func printSingleNameIP(buf *strings.Builder, name string, ip netaddr.IP) {
+ buf.WriteString(name)
+ buf.WriteByte('\t')
+ buf.WriteString(ip.String())
+ buf.WriteByte('\n')
+}
+
+func (m *Map) Pretty() string {
+ buf := new(strings.Builder)
+ for _, name := range m.names {
+ printSingleNameIP(buf, name, m.nameToIP[name])
+ }
+ return buf.String()
+}
+
+func (m *Map) PrettyDiffFrom(old *Map) string {
+ var (
+ oldNameToIP map[string]netaddr.IP
+ newNameToIP map[string]netaddr.IP
+ oldNames []string
+ newNames []string
+ )
+ if old != nil {
+ oldNameToIP = old.nameToIP
+ oldNames = old.names
+ }
+ if m != nil {
+ newNameToIP = m.nameToIP
+ newNames = m.names
+ }
+
+ buf := new(strings.Builder)
+ space := func() bool {
+ return buf.Len() < (1 << 10)
+ }
+
+ for len(oldNames) > 0 && len(newNames) > 0 {
+ var name string
+
+ newName, oldName := newNames[0], oldNames[0]
+ switch {
+ case oldName < newName:
+ name = oldName
+ oldNames = oldNames[1:]
+ case oldName > newName:
+ name = newName
+ newNames = newNames[1:]
+ case oldNames[0] == newNames[0]:
+ name = oldNames[0]
+ oldNames = oldNames[1:]
+ newNames = newNames[1:]
+ }
+ if !space() {
+ continue
+ }
+
+ ipOld, inOld := oldNameToIP[name]
+ ipNew, inNew := newNameToIP[name]
+ switch {
+ case !inOld:
+ buf.WriteByte('+')
+ printSingleNameIP(buf, name, ipNew)
+ case !inNew:
+ buf.WriteByte('-')
+ printSingleNameIP(buf, name, ipOld)
+ case ipOld != ipNew:
+ buf.WriteByte('-')
+ printSingleNameIP(buf, name, ipOld)
+ buf.WriteByte('+')
+ printSingleNameIP(buf, name, ipNew)
+ }
+ }
+
+ for _, name := range oldNames {
+ if !space() {
+ break
+ }
+ if _, ok := newNameToIP[name]; !ok {
+ buf.WriteByte('-')
+ printSingleNameIP(buf, name, oldNameToIP[name])
+ }
+ }
+
+ for _, name := range newNames {
+ if !space() {
+ break
+ }
+ if _, ok := oldNameToIP[name]; !ok {
+ buf.WriteByte('+')
+ printSingleNameIP(buf, name, newNameToIP[name])
+ }
+ }
+ if !space() {
+ buf.WriteString("... [truncated]\n")
+ }
+
+ return buf.String()
+}
diff --git a/net/dns/map_test.go b/net/dns/map_test.go
new file mode 100644
index 000000000..c438f95a0
--- /dev/null
+++ b/net/dns/map_test.go
@@ -0,0 +1,156 @@
+// Copyright (c) 2020 Tailscale Inc & AUTHORS All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package dns
+
+import (
+ "fmt"
+ "strings"
+ "testing"
+
+ "inet.af/netaddr"
+)
+
+func TestPretty(t *testing.T) {
+ tests := []struct {
+ name string
+ dmap *Map
+ want string
+ }{
+ {"empty", NewMap(nil, nil), ""},
+ {
+ "single",
+ NewMap(map[string]netaddr.IP{
+ "hello.ipn.dev.": netaddr.IPv4(100, 101, 102, 103),
+ }, nil),
+ "hello.ipn.dev.\t100.101.102.103\n",
+ },
+ {
+ "multiple",
+ NewMap(map[string]netaddr.IP{
+ "test1.domain.": netaddr.IPv4(100, 101, 102, 103),
+ "test2.sub.domain.": netaddr.IPv4(100, 99, 9, 1),
+ }, nil),
+ "test1.domain.\t100.101.102.103\ntest2.sub.domain.\t100.99.9.1\n",
+ },
+ }
+
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ got := tt.dmap.Pretty()
+ if tt.want != got {
+ t.Errorf("want %v; got %v", tt.want, got)
+ }
+ })
+ }
+}
+
+func TestPrettyDiffFrom(t *testing.T) {
+ tests := []struct {
+ name string
+ map1 *Map
+ map2 *Map
+ want string
+ }{
+ {
+ "from_empty",
+ nil,
+ NewMap(map[string]netaddr.IP{
+ "test1.ipn.dev.": netaddr.IPv4(100, 101, 102, 103),
+ "test2.ipn.dev.": netaddr.IPv4(100, 103, 102, 101),
+ }, nil),
+ "+test1.ipn.dev.\t100.101.102.103\n+test2.ipn.dev.\t100.103.102.101\n",
+ },
+ {
+ "equal",
+ NewMap(map[string]netaddr.IP{
+ "test1.ipn.dev.": netaddr.IPv4(100, 101, 102, 103),
+ "test2.ipn.dev.": netaddr.IPv4(100, 103, 102, 101),
+ }, nil),
+ NewMap(map[string]netaddr.IP{
+ "test2.ipn.dev.": netaddr.IPv4(100, 103, 102, 101),
+ "test1.ipn.dev.": netaddr.IPv4(100, 101, 102, 103),
+ }, nil),
+ "",
+ },
+ {
+ "changed_ip",
+ NewMap(map[string]netaddr.IP{
+ "test1.ipn.dev.": netaddr.IPv4(100, 101, 102, 103),
+ "test2.ipn.dev.": netaddr.IPv4(100, 103, 102, 101),
+ }, nil),
+ NewMap(map[string]netaddr.IP{
+ "test2.ipn.dev.": netaddr.IPv4(100, 104, 102, 101),
+ "test1.ipn.dev.": netaddr.IPv4(100, 101, 102, 103),
+ }, nil),
+ "-test2.ipn.dev.\t100.103.102.101\n+test2.ipn.dev.\t100.104.102.101\n",
+ },
+ {
+ "new_domain",
+ NewMap(map[string]netaddr.IP{
+ "test1.ipn.dev.": netaddr.IPv4(100, 101, 102, 103),
+ "test2.ipn.dev.": netaddr.IPv4(100, 103, 102, 101),
+ }, nil),
+ NewMap(map[string]netaddr.IP{
+ "test3.ipn.dev.": netaddr.IPv4(100, 105, 106, 107),
+ "test2.ipn.dev.": netaddr.IPv4(100, 103, 102, 101),
+ "test1.ipn.dev.": netaddr.IPv4(100, 101, 102, 103),
+ }, nil),
+ "+test3.ipn.dev.\t100.105.106.107\n",
+ },
+ {
+ "gone_domain",
+ NewMap(map[string]netaddr.IP{
+ "test1.ipn.dev.": netaddr.IPv4(100, 101, 102, 103),
+ "test2.ipn.dev.": netaddr.IPv4(100, 103, 102, 101),
+ }, nil),
+ NewMap(map[string]netaddr.IP{
+ "test1.ipn.dev.": netaddr.IPv4(100, 101, 102, 103),
+ }, nil),
+ "-test2.ipn.dev.\t100.103.102.101\n",
+ },
+ {
+ "mixed",
+ NewMap(map[string]netaddr.IP{
+ "test1.ipn.dev.": netaddr.IPv4(100, 101, 102, 103),
+ "test4.ipn.dev.": netaddr.IPv4(100, 107, 106, 105),
+ "test5.ipn.dev.": netaddr.IPv4(100, 64, 1, 1),
+ "test2.ipn.dev.": netaddr.IPv4(100, 103, 102, 101),
+ }, nil),
+ NewMap(map[string]netaddr.IP{
+ "test2.ipn.dev.": netaddr.IPv4(100, 104, 102, 101),
+ "test1.ipn.dev.": netaddr.IPv4(100, 100, 101, 102),
+ "test3.ipn.dev.": netaddr.IPv4(100, 64, 1, 1),
+ }, nil),
+ "-test1.ipn.dev.\t100.101.102.103\n+test1.ipn.dev.\t100.100.101.102\n" +
+ "-test2.ipn.dev.\t100.103.102.101\n+test2.ipn.dev.\t100.104.102.101\n" +
+ "+test3.ipn.dev.\t100.64.1.1\n-test4.ipn.dev.\t100.107.106.105\n-test5.ipn.dev.\t100.64.1.1\n",
+ },
+ }
+
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ got := tt.map2.PrettyDiffFrom(tt.map1)
+ if tt.want != got {
+ t.Errorf("want %v; got %v", tt.want, got)
+ }
+ })
+ }
+
+ t.Run("truncated", func(t *testing.T) {
+ small := NewMap(nil, nil)
+ m := map[string]netaddr.IP{}
+ for i := 0; i < 5000; i++ {
+ m[fmt.Sprintf("host%d.ipn.dev.", i)] = netaddr.IPv4(100, 64, 1, 1)
+ }
+ veryBig := NewMap(m, nil)
+ diff := veryBig.PrettyDiffFrom(small)
+ if len(diff) > 3<<10 {
+ t.Errorf("pretty diff too large: %d bytes", len(diff))
+ }
+ if !strings.Contains(diff, "truncated") {
+ t.Errorf("big diff not truncated")
+ }
+ })
+}
diff --git a/net/dns/neterr_darwin.go b/net/dns/neterr_darwin.go
new file mode 100644
index 000000000..7fd621fc7
--- /dev/null
+++ b/net/dns/neterr_darwin.go
@@ -0,0 +1,25 @@
+// Copyright (c) 2020 Tailscale Inc & AUTHORS All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package dns
+
+import (
+ "errors"
+ "syscall"
+)
+
+// Avoid allocation when calling errors.Is below
+// by converting syscall.Errno to error here.
+var (
+ networkDown error = syscall.ENETDOWN
+ networkUnreachable error = syscall.ENETUNREACH
+)
+
+func networkIsDown(err error) bool {
+ return errors.Is(err, networkDown)
+}
+
+func networkIsUnreachable(err error) bool {
+ return errors.Is(err, networkUnreachable)
+}
diff --git a/net/dns/neterr_other.go b/net/dns/neterr_other.go
new file mode 100644
index 000000000..b652f6e8b
--- /dev/null
+++ b/net/dns/neterr_other.go
@@ -0,0 +1,10 @@
+// Copyright (c) 2020 Tailscale Inc & AUTHORS All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+// +build !darwin,!windows
+
+package dns
+
+func networkIsDown(err error) bool { return false }
+func networkIsUnreachable(err error) bool { return false }
diff --git a/net/dns/neterr_windows.go b/net/dns/neterr_windows.go
new file mode 100644
index 000000000..2b197ee2b
--- /dev/null
+++ b/net/dns/neterr_windows.go
@@ -0,0 +1,29 @@
+// Copyright (c) 2020 Tailscale Inc & AUTHORS All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package dns
+
+import (
+ "net"
+ "os"
+
+ "golang.org/x/sys/windows"
+)
+
+func networkIsDown(err error) bool {
+ if oe, ok := err.(*net.OpError); ok && oe.Op == "write" {
+ if se, ok := oe.Err.(*os.SyscallError); ok {
+ if se.Syscall == "wsasendto" && se.Err == windows.WSAENETUNREACH {
+ return true
+ }
+ }
+ }
+ return false
+}
+
+func networkIsUnreachable(err error) bool {
+ // TODO(bradfitz,josharian): something here? what is the
+ // difference between down and unreachable? Add comments.
+ return false
+}
diff --git a/net/dns/nm.go b/net/dns/nm.go
new file mode 100644
index 000000000..a597fa60d
--- /dev/null
+++ b/net/dns/nm.go
@@ -0,0 +1,205 @@
+// Copyright (c) 2020 Tailscale Inc & AUTHORS All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+// +build linux
+
+package dns
+
+import (
+ "bufio"
+ "bytes"
+ "context"
+ "fmt"
+ "os"
+ "os/exec"
+
+ "github.com/godbus/dbus/v5"
+ "tailscale.com/util/endian"
+)
+
+// isNMActive determines if NetworkManager is currently managing system DNS settings.
+func isNMActive() bool {
+ // This is somewhat tricky because NetworkManager supports a number
+ // of DNS configuration modes. In all cases, we expect it to be installed
+ // and /etc/resolv.conf to contain a mention of NetworkManager in the comments.
+ _, err := exec.LookPath("NetworkManager")
+ if err != nil {
+ return false
+ }
+
+ f, err := os.Open("/etc/resolv.conf")
+ if err != nil {
+ return false
+ }
+ defer f.Close()
+
+ scanner := bufio.NewScanner(f)
+ for scanner.Scan() {
+ line := scanner.Bytes()
+ // Look for the word "NetworkManager" until comments end.
+ if len(line) > 0 && line[0] != '#' {
+ return false
+ }
+ if bytes.Contains(line, []byte("NetworkManager")) {
+ return true
+ }
+ }
+ return false
+}
+
+// nmManager uses the NetworkManager DBus API.
+type nmManager struct {
+ interfaceName string
+}
+
+func newNMManager(mconfig ManagerConfig) managerImpl {
+ return nmManager{
+ interfaceName: mconfig.InterfaceName,
+ }
+}
+
+type nmConnectionSettings map[string]map[string]dbus.Variant
+
+// Up implements managerImpl.
+func (m nmManager) Up(config Config) error {
+ ctx, cancel := context.WithTimeout(context.Background(), reconfigTimeout)
+ defer cancel()
+
+ // conn is a shared connection whose lifecycle is managed by the dbus package.
+ // We should not interfere with that by closing it.
+ conn, err := dbus.SystemBus()
+ if err != nil {
+ return fmt.Errorf("connecting to system bus: %w", err)
+ }
+
+ // This is how we get at the DNS settings:
+ //
+ // org.freedesktop.NetworkManager
+ // |
+ // [GetDeviceByIpIface]
+ // |
+ // v
+ // org.freedesktop.NetworkManager.Device <--------\
+ // (describes a network interface) |
+ // | |
+ // [GetAppliedConnection] [Reapply]
+ // | |
+ // v |
+ // org.freedesktop.NetworkManager.Connection |
+ // (connection settings) ------/
+ // contains {dns, dns-priority, dns-search}
+ //
+ // Ref: https://developer.gnome.org/NetworkManager/stable/settings-ipv4.html.
+
+ nm := conn.Object(
+ "org.freedesktop.NetworkManager",
+ dbus.ObjectPath("/org/freedesktop/NetworkManager"),
+ )
+
+ var devicePath dbus.ObjectPath
+ err = nm.CallWithContext(
+ ctx, "org.freedesktop.NetworkManager.GetDeviceByIpIface", 0,
+ m.interfaceName,
+ ).Store(&devicePath)
+ if err != nil {
+ return fmt.Errorf("getDeviceByIpIface: %w", err)
+ }
+ device := conn.Object("org.freedesktop.NetworkManager", devicePath)
+
+ var (
+ settings nmConnectionSettings
+ version uint64
+ )
+ err = device.CallWithContext(
+ ctx, "org.freedesktop.NetworkManager.Device.GetAppliedConnection", 0,
+ uint32(0),
+ ).Store(&settings, &version)
+ if err != nil {
+ return fmt.Errorf("getAppliedConnection: %w", err)
+ }
+
+ // Frustratingly, NetworkManager represents IPv4 addresses as uint32s,
+ // although IPv6 addresses are represented as byte arrays.
+ // Perform the conversion here.
+ var (
+ dnsv4 []uint32
+ dnsv6 [][]byte
+ )
+ for _, ip := range config.Nameservers {
+ b := ip.As16()
+ if ip.Is4() {
+ dnsv4 = append(dnsv4, endian.Native.Uint32(b[12:]))
+ } else {
+ dnsv6 = append(dnsv6, b[:])
+ }
+ }
+
+ ipv4Map := settings["ipv4"]
+ ipv4Map["dns"] = dbus.MakeVariant(dnsv4)
+ ipv4Map["dns-search"] = dbus.MakeVariant(config.Domains)
+ // We should only request priority if we have nameservers to set.
+ if len(dnsv4) == 0 {
+ ipv4Map["dns-priority"] = dbus.MakeVariant(100)
+ } else {
+ // dns-priority = -1 ensures that we have priority
+ // over other interfaces, except those exploiting this same trick.
+ // Ref: https://bugs.launchpad.net/ubuntu/+source/network-manager/+bug/1211110/comments/92.
+ ipv4Map["dns-priority"] = dbus.MakeVariant(-1)
+ }
+ // In principle, we should not need set this to true,
+ // as our interface does not configure any automatic DNS settings (presumably via DHCP).
+ // All the same, better to be safe.
+ ipv4Map["ignore-auto-dns"] = dbus.MakeVariant(true)
+
+ ipv6Map := settings["ipv6"]
+ // This is a hack.
+ // Methods "disabled", "ignore", "link-local" (IPv6 default) prevent us from setting DNS.
+ // It seems that our only recourse is "manual" or "auto".
+ // "manual" requires addresses, so we use "auto", which will assign us a random IPv6 /64.
+ ipv6Map["method"] = dbus.MakeVariant("auto")
+ // Our IPv6 config is a fake, so it should never become the default route.
+ ipv6Map["never-default"] = dbus.MakeVariant(true)
+ // Moreover, we should ignore all autoconfigured routes (hopefully none), as they are bogus.
+ ipv6Map["ignore-auto-routes"] = dbus.MakeVariant(true)
+
+ // Finally, set the actual DNS config.
+ ipv6Map["dns"] = dbus.MakeVariant(dnsv6)
+ ipv6Map["dns-search"] = dbus.MakeVariant(config.Domains)
+ if len(dnsv6) == 0 {
+ ipv6Map["dns-priority"] = dbus.MakeVariant(100)
+ } else {
+ ipv6Map["dns-priority"] = dbus.MakeVariant(-1)
+ }
+ ipv6Map["ignore-auto-dns"] = dbus.MakeVariant(true)
+
+ // deprecatedProperties are the properties in interface settings
+ // that are deprecated by NetworkManager.
+ //
+ // In practice, this means that they are returned for reading,
+ // but submitting a settings object with them present fails
+ // with hard-to-diagnose errors. They must be removed.
+ deprecatedProperties := []string{
+ "addresses", "routes",
+ }
+
+ for _, property := range deprecatedProperties {
+ delete(ipv4Map, property)
+ delete(ipv6Map, property)
+ }
+
+ err = device.CallWithContext(
+ ctx, "org.freedesktop.NetworkManager.Device.Reapply", 0,
+ settings, version, uint32(0),
+ ).Store()
+ if err != nil {
+ return fmt.Errorf("reapply: %w", err)
+ }
+
+ return nil
+}
+
+// Down implements managerImpl.
+func (m nmManager) Down() error {
+ return m.Up(Config{Nameservers: nil, Domains: nil})
+}
diff --git a/net/dns/noop.go b/net/dns/noop.go
new file mode 100644
index 000000000..35c07a232
--- /dev/null
+++ b/net/dns/noop.go
@@ -0,0 +1,17 @@
+// Copyright (c) 2020 Tailscale Inc & AUTHORS All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package dns
+
+type noopManager struct{}
+
+// Up implements managerImpl.
+func (m noopManager) Up(Config) error { return nil }
+
+// Down implements managerImpl.
+func (m noopManager) Down() error { return nil }
+
+func newNoopManager(mconfig ManagerConfig) managerImpl {
+ return noopManager{}
+}
diff --git a/net/dns/registry_windows.go b/net/dns/registry_windows.go
new file mode 100644
index 000000000..f8e1f514a
--- /dev/null
+++ b/net/dns/registry_windows.go
@@ -0,0 +1,76 @@
+// Copyright (c) 2020 Tailscale Inc & AUTHORS All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+//
+// The code in this file originates from https://git.zx2c4.com/wireguard-go:
+// Copyright (C) 2017-2020 WireGuard LLC. All Rights Reserved.
+// Copying license: https://git.zx2c4.com/wireguard-go/tree/COPYING
+
+package dns
+
+import (
+ "fmt"
+ "runtime"
+ "strings"
+ "time"
+
+ "golang.org/x/sys/windows"
+ "golang.org/x/sys/windows/registry"
+)
+
+func openKeyWait(k registry.Key, path string, access uint32, timeout time.Duration) (registry.Key, error) {
+ runtime.LockOSThread()
+ defer runtime.UnlockOSThread()
+
+ deadline := time.Now().Add(timeout)
+ pathSpl := strings.Split(path, "\\")
+ for i := 0; ; i++ {
+ keyName := pathSpl[i]
+ isLast := i+1 == len(pathSpl)
+
+ event, err := windows.CreateEvent(nil, 0, 0, nil)
+ if err != nil {
+ return 0, fmt.Errorf("windows.CreateEvent: %v", err)
+ }
+ defer windows.CloseHandle(event)
+
+ var key registry.Key
+ for {
+ err = windows.RegNotifyChangeKeyValue(windows.Handle(k), false, windows.REG_NOTIFY_CHANGE_NAME, event, true)
+ if err != nil {
+ return 0, fmt.Errorf("windows.RegNotifyChangeKeyValue: %v", err)
+ }
+
+ var accessFlags uint32
+ if isLast {
+ accessFlags = access
+ } else {
+ accessFlags = registry.NOTIFY
+ }
+ key, err = registry.OpenKey(k, keyName, accessFlags)
+ if err == windows.ERROR_FILE_NOT_FOUND || err == windows.ERROR_PATH_NOT_FOUND {
+ timeout := time.Until(deadline) / time.Millisecond
+ if timeout < 0 {
+ timeout = 0
+ }
+ s, err := windows.WaitForSingleObject(event, uint32(timeout))
+ if err != nil {
+ return 0, fmt.Errorf("windows.WaitForSingleObject: %v", err)
+ }
+ if s == uint32(windows.WAIT_TIMEOUT) { // windows.WAIT_TIMEOUT status const is misclassified as error in golang.org/x/sys/windows
+ return 0, fmt.Errorf("timeout waiting for registry key")
+ }
+ } else if err != nil {
+ return 0, fmt.Errorf("registry.OpenKey(%v): %v", path, err)
+ } else {
+ if isLast {
+ return key, nil
+ }
+ defer key.Close()
+ break
+ }
+ }
+
+ k = key
+ }
+}
diff --git a/net/dns/resolvconf.go b/net/dns/resolvconf.go
new file mode 100644
index 000000000..8bf97ee88
--- /dev/null
+++ b/net/dns/resolvconf.go
@@ -0,0 +1,157 @@
+// Copyright (c) 2020 Tailscale Inc & AUTHORS All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+// +build linux freebsd
+
+package dns
+
+import (
+ "bufio"
+ "bytes"
+ "fmt"
+ "os"
+ "os/exec"
+)
+
+// isResolvconfActive indicates whether the system appears to be using resolvconf.
+// If this is true, then directManager should be avoided:
+// resolvconf has exclusive ownership of /etc/resolv.conf.
+func isResolvconfActive() bool {
+ // Sanity-check first: if there is no resolvconf binary, then this is fruitless.
+ //
+ // However, this binary may be a shim like the one systemd-resolved provides.
+ // Such a shim may not behave as expected: in particular, systemd-resolved
+ // does not seem to respect the exclusive mode -x, saying:
+ // -x Send DNS traffic preferably over this interface
+ // whereas e.g. openresolv sends DNS traffix _exclusively_ over that interface,
+ // or not at all (in case of another exclusive-mode request later in time).
+ //
+ // Moreover, resolvconf may be installed but unused, in which case we should
+ // not use it either, lest we clobber existing configuration.
+ //
+ // To handle all the above correctly, we scan the comments in /etc/resolv.conf
+ // to ensure that it was generated by a resolvconf implementation.
+ _, err := exec.LookPath("resolvconf")
+ if err != nil {
+ return false
+ }
+
+ f, err := os.Open("/etc/resolv.conf")
+ if err != nil {
+ return false
+ }
+ defer f.Close()
+
+ scanner := bufio.NewScanner(f)
+ for scanner.Scan() {
+ line := scanner.Bytes()
+ // Look for the word "resolvconf" until comments end.
+ if len(line) > 0 && line[0] != '#' {
+ return false
+ }
+ if bytes.Contains(line, []byte("resolvconf")) {
+ return true
+ }
+ }
+ return false
+}
+
+// resolvconfImpl enumerates supported implementations of the resolvconf CLI.
+type resolvconfImpl uint8
+
+const (
+ // resolvconfOpenresolv is the implementation packaged as "openresolv" on Ubuntu.
+ // It supports exclusive mode and interface metrics.
+ resolvconfOpenresolv resolvconfImpl = iota
+ // resolvconfLegacy is the implementation by Thomas Hood packaged as "resolvconf" on Ubuntu.
+ // It does not support exclusive mode or interface metrics.
+ resolvconfLegacy
+)
+
+func (impl resolvconfImpl) String() string {
+ switch impl {
+ case resolvconfOpenresolv:
+ return "openresolv"
+ case resolvconfLegacy:
+ return "legacy"
+ default:
+ return "unknown"
+ }
+}
+
+// getResolvconfImpl returns the implementation of resolvconf that appears to be in use.
+func getResolvconfImpl() resolvconfImpl {
+ err := exec.Command("resolvconf", "-v").Run()
+ if err != nil {
+ if exitErr, ok := err.(*exec.ExitError); ok {
+ // Thomas Hood's resolvconf has a minimal flag set
+ // and exits with code 99 when passed an unknown flag.
+ if exitErr.ExitCode() == 99 {
+ return resolvconfLegacy
+ }
+ }
+ }
+ return resolvconfOpenresolv
+}
+
+type resolvconfManager struct {
+ impl resolvconfImpl
+}
+
+func newResolvconfManager(mconfig ManagerConfig) managerImpl {
+ impl := getResolvconfImpl()
+ mconfig.Logf("resolvconf implementation is %s", impl)
+
+ return resolvconfManager{
+ impl: impl,
+ }
+}
+
+// resolvconfConfigName is the name of the config submitted to resolvconf.
+// It has this form to match the "tun*" rule in interface-order
+// when running resolvconfLegacy, hopefully placing our config first.
+const resolvconfConfigName = "tun-tailscale.inet"
+
+// Up implements managerImpl.
+func (m resolvconfManager) Up(config Config) error {
+ stdin := new(bytes.Buffer)
+ writeResolvConf(stdin, config.Nameservers, config.Domains) // dns_direct.go
+
+ var cmd *exec.Cmd
+ switch m.impl {
+ case resolvconfOpenresolv:
+ // Request maximal priority (metric 0) and exclusive mode.
+ cmd = exec.Command("resolvconf", "-m", "0", "-x", "-a", resolvconfConfigName)
+ case resolvconfLegacy:
+ // This does not quite give us the desired behavior (queries leak),
+ // but there is nothing else we can do without messing with other interfaces' settings.
+ cmd = exec.Command("resolvconf", "-a", resolvconfConfigName)
+ }
+ cmd.Stdin = stdin
+ out, err := cmd.CombinedOutput()
+ if err != nil {
+ return fmt.Errorf("running %s: %s", cmd, out)
+ }
+
+ return nil
+}
+
+// Down implements managerImpl.
+func (m resolvconfManager) Down() error {
+ var cmd *exec.Cmd
+ switch m.impl {
+ case resolvconfOpenresolv:
+ cmd = exec.Command("resolvconf", "-f", "-d", resolvconfConfigName)
+ case resolvconfLegacy:
+ // resolvconfLegacy lacks the -f flag.
+ // Instead, it succeeds even when the config does not exist.
+ cmd = exec.Command("resolvconf", "-d", resolvconfConfigName)
+ }
+ out, err := cmd.CombinedOutput()
+ if err != nil {
+ return fmt.Errorf("running %s: %s", cmd, out)
+ }
+
+ return nil
+}
diff --git a/net/dns/resolved.go b/net/dns/resolved.go
new file mode 100644
index 000000000..9d8c40d90
--- /dev/null
+++ b/net/dns/resolved.go
@@ -0,0 +1,188 @@
+// Copyright (c) 2020 Tailscale Inc & AUTHORS All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+// +build linux
+
+package dns
+
+import (
+ "context"
+ "errors"
+ "fmt"
+ "os/exec"
+
+ "github.com/godbus/dbus/v5"
+ "golang.org/x/sys/unix"
+ "inet.af/netaddr"
+ "tailscale.com/net/interfaces"
+)
+
+// resolvedListenAddr is the listen address of the resolved stub resolver.
+//
+// We only consider resolved to be the system resolver if the stub resolver is;
+// that is, if this address is the sole nameserver in /etc/resolved.conf.
+// In other cases, resolved may be managing the system DNS configuration directly.
+// Then the nameserver list will be a concatenation of those for all
+// the interfaces that register their interest in being a default resolver with
+// SetLinkDomains([]{{"~.", true}, ...})
+// which includes at least the interface with the default route, i.e. not us.
+// This does not work for us: there is a possibility of getting NXDOMAIN
+// from the other nameservers before we are asked or get a chance to respond.
+// We consider this case as lacking resolved support and fall through to dnsDirect.
+//
+// While it may seem that we need to read a config option to get at this,
+// this address is, in fact, hard-coded into resolved.
+var resolvedListenAddr = netaddr.IPv4(127, 0, 0, 53)
+
+var errNotReady = errors.New("interface not ready")
+
+type resolvedLinkNameserver struct {
+ Family int32
+ Address []byte
+}
+
+type resolvedLinkDomain struct {
+ Domain string
+ RoutingOnly bool
+}
+
+// isResolvedActive determines if resolved is currently managing system DNS settings.
+func isResolvedActive() bool {
+ // systemd-resolved is never installed without systemd.
+ _, err := exec.LookPath("systemctl")
+ if err != nil {
+ return false
+ }
+
+ // is-active exits with code 3 if the service is not active.
+ err = exec.Command("systemctl", "is-active", "systemd-resolved").Run()
+ if err != nil {
+ return false
+ }
+
+ config, err := readResolvConf()
+ if err != nil {
+ return false
+ }
+
+ // The sole nameserver must be the systemd-resolved stub.
+ if len(config.Nameservers) == 1 && config.Nameservers[0] == resolvedListenAddr {
+ return true
+ }
+
+ return false
+}
+
+// resolvedManager uses the systemd-resolved DBus API.
+type resolvedManager struct{}
+
+func newResolvedManager(mconfig ManagerConfig) managerImpl {
+ return resolvedManager{}
+}
+
+// Up implements managerImpl.
+func (m resolvedManager) Up(config Config) error {
+ ctx, cancel := context.WithTimeout(context.Background(), reconfigTimeout)
+ defer cancel()
+
+ // conn is a shared connection whose lifecycle is managed by the dbus package.
+ // We should not interfere with that by closing it.
+ conn, err := dbus.SystemBus()
+ if err != nil {
+ return fmt.Errorf("connecting to system bus: %w", err)
+ }
+
+ resolved := conn.Object(
+ "org.freedesktop.resolve1",
+ dbus.ObjectPath("/org/freedesktop/resolve1"),
+ )
+
+ // In principle, we could persist this in the manager struct
+ // if we knew that interface indices are persistent. This does not seem to be the case.
+ _, iface, err := interfaces.Tailscale()
+ if err != nil {
+ return fmt.Errorf("getting interface index: %w", err)
+ }
+ if iface == nil {
+ return errNotReady
+ }
+
+ var linkNameservers = make([]resolvedLinkNameserver, len(config.Nameservers))
+ for i, server := range config.Nameservers {
+ ip := server.As16()
+ if server.Is4() {
+ linkNameservers[i] = resolvedLinkNameserver{
+ Family: unix.AF_INET,
+ Address: ip[12:],
+ }
+ } else {
+ linkNameservers[i] = resolvedLinkNameserver{
+ Family: unix.AF_INET6,
+ Address: ip[:],
+ }
+ }
+ }
+
+ err = resolved.CallWithContext(
+ ctx, "org.freedesktop.resolve1.Manager.SetLinkDNS", 0,
+ iface.Index, linkNameservers,
+ ).Store()
+ if err != nil {
+ return fmt.Errorf("setLinkDNS: %w", err)
+ }
+
+ var linkDomains = make([]resolvedLinkDomain, len(config.Domains))
+ for i, domain := range config.Domains {
+ linkDomains[i] = resolvedLinkDomain{
+ Domain: domain,
+ RoutingOnly: false,
+ }
+ }
+
+ err = resolved.CallWithContext(
+ ctx, "org.freedesktop.resolve1.Manager.SetLinkDomains", 0,
+ iface.Index, linkDomains,
+ ).Store()
+ if err != nil {
+ return fmt.Errorf("setLinkDomains: %w", err)
+ }
+
+ return nil
+}
+
+// Down implements managerImpl.
+func (m resolvedManager) Down() error {
+ ctx, cancel := context.WithTimeout(context.Background(), reconfigTimeout)
+ defer cancel()
+
+ // conn is a shared connection whose lifecycle is managed by the dbus package.
+ // We should not interfere with that by closing it.
+ conn, err := dbus.SystemBus()
+ if err != nil {
+ return fmt.Errorf("connecting to system bus: %w", err)
+ }
+
+ resolved := conn.Object(
+ "org.freedesktop.resolve1",
+ dbus.ObjectPath("/org/freedesktop/resolve1"),
+ )
+
+ _, iface, err := interfaces.Tailscale()
+ if err != nil {
+ return fmt.Errorf("getting interface index: %w", err)
+ }
+ if iface == nil {
+ return errNotReady
+ }
+
+ err = resolved.CallWithContext(
+ ctx, "org.freedesktop.resolve1.Manager.RevertLink", 0,
+ iface.Index,
+ ).Store()
+ if err != nil {
+ return fmt.Errorf("RevertLink: %w", err)
+ }
+
+ return nil
+}
diff --git a/net/dns/tsdns.go b/net/dns/tsdns.go
new file mode 100644
index 000000000..2b530b81e
--- /dev/null
+++ b/net/dns/tsdns.go
@@ -0,0 +1,662 @@
+// Copyright (c) 2020 Tailscale Inc & AUTHORS All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+// Package dns provides a Resolver capable of resolving
+// domains on a Tailscale network.
+package dns
+
+import (
+ "encoding/hex"
+ "errors"
+ "net"
+ "strings"
+ "sync"
+ "time"
+
+ dns "golang.org/x/net/dns/dnsmessage"
+ "inet.af/netaddr"
+ "tailscale.com/net/interfaces"
+ "tailscale.com/types/logger"
+ "tailscale.com/util/dnsname"
+ "tailscale.com/wgengine/monitor"
+)
+
+// maxResponseBytes is the maximum size of a response from a Resolver.
+const maxResponseBytes = 512
+
+// queueSize is the maximal number of DNS requests that can await polling.
+// If EnqueueRequest is called when this many requests are already pending,
+// the request will be dropped to avoid blocking the caller.
+const queueSize = 64
+
+// defaultTTL is the TTL of all responses from Resolver.
+const defaultTTL = 600 * time.Second
+
+// ErrClosed indicates that the resolver has been closed and readers should exit.
+var ErrClosed = errors.New("closed")
+
+var (
+ errFullQueue = errors.New("request queue full")
+ errMapNotSet = errors.New("domain map not set")
+ errNotForwarding = errors.New("forwarding disabled")
+ errNotImplemented = errors.New("query type not implemented")
+ errNotQuery = errors.New("not a DNS query")
+ errNotOurName = errors.New("not a Tailscale DNS name")
+)
+
+// Packet represents a DNS payload together with the address of its origin.
+type Packet struct {
+ // Payload is the application layer DNS payload.
+ // Resolver assumes ownership of the request payload when it is enqueued
+ // and cedes ownership of the response payload when it is returned from NextResponse.
+ Payload []byte
+ // Addr is the source address for a request and the destination address for a response.
+ Addr netaddr.IPPort
+}
+
+// Resolver is a DNS resolver for nodes on the Tailscale network,
+// associating them with domain names of the form <mynode>.<mydomain>.<root>.
+// If it is asked to resolve a domain that is not of that form,
+// it delegates to upstream nameservers if any are set.
+type Resolver struct {
+ logf logger.Logf
+ linkMon *monitor.Mon // or nil
+ unregLinkMon func() // or nil
+ // forwarder forwards requests to upstream nameservers.
+ forwarder *forwarder
+
+ // queue is a buffered channel holding DNS requests queued for resolution.
+ queue chan Packet
+ // responses is an unbuffered channel to which responses are returned.
+ responses chan Packet
+ // errors is an unbuffered channel to which errors are returned.
+ errors chan error
+ // closed signals all goroutines to stop.
+ closed chan struct{}
+ // wg signals when all goroutines have stopped.
+ wg sync.WaitGroup
+
+ // mu guards the following fields from being updated while used.
+ mu sync.Mutex
+ // dnsMap is the map most recently received from the control server.
+ dnsMap *Map
+}
+
+// ResolverConfig is the set of configuration options for a Resolver.
+type ResolverConfig struct {
+ // Logf is the logger to use throughout the Resolver.
+ Logf logger.Logf
+ // Forward determines whether the resolver will forward packets to
+ // nameservers set with SetUpstreams if the domain name is not of a Tailscale node.
+ Forward bool
+ // LinkMonitor optionally provides a link monitor to use to rebind
+ // connections on link changes.
+ // If nil, rebinds are not performend.
+ LinkMonitor *monitor.Mon
+}
+
+// NewResolver constructs a resolver associated with the given root domain.
+// The root domain must be in canonical form (with a trailing period).
+func NewResolver(config ResolverConfig) *Resolver {
+ r := &Resolver{
+ logf: logger.WithPrefix(config.Logf, "dns: "),
+ linkMon: config.LinkMonitor,
+ queue: make(chan Packet, queueSize),
+ responses: make(chan Packet),
+ errors: make(chan error),
+ closed: make(chan struct{}),
+ }
+
+ if config.Forward {
+ r.forwarder = newForwarder(r.logf, r.responses)
+ }
+ if r.linkMon != nil {
+ r.unregLinkMon = r.linkMon.RegisterChangeCallback(r.onLinkMonitorChange)
+ }
+
+ return r
+}
+
+func (r *Resolver) Start() error {
+ if r.forwarder != nil {
+ if err := r.forwarder.Start(); err != nil {
+ return err
+ }
+ }
+
+ r.wg.Add(1)
+ go r.poll()
+
+ return nil
+}
+
+// Close shuts down the resolver and ensures poll goroutines have exited.
+// The Resolver cannot be used again after Close is called.
+func (r *Resolver) Close() {
+ select {
+ case <-r.closed:
+ return
+ default:
+ // continue
+ }
+ close(r.closed)
+
+ if r.unregLinkMon != nil {
+ r.unregLinkMon()
+ }
+
+ if r.forwarder != nil {
+ r.forwarder.Close()
+ }
+
+ r.wg.Wait()
+}
+
+func (r *Resolver) onLinkMonitorChange(changed bool, state *interfaces.State) {
+ if !changed {
+ return
+ }
+ if r.forwarder != nil {
+ r.forwarder.rebindFromNetworkChange()
+ }
+}
+
+// SetMap sets the resolver's DNS map, taking ownership of it.
+func (r *Resolver) SetMap(m *Map) {
+ r.mu.Lock()
+ oldMap := r.dnsMap
+ r.dnsMap = m
+ r.mu.Unlock()
+ r.logf("map diff:\n%s", m.PrettyDiffFrom(oldMap))
+}
+
+// SetUpstreams sets the addresses of the resolver's
+// upstream nameservers, taking ownership of the argument.
+func (r *Resolver) SetUpstreams(upstreams []net.Addr) {
+ if r.forwarder != nil {
+ r.forwarder.setUpstreams(upstreams)
+ }
+ r.logf("set upstreams: %v", upstreams)
+}
+
+// EnqueueRequest places the given DNS request in the resolver's queue.
+// It takes ownership of the payload and does not block.
+// If the queue is full, the request will be dropped and an error will be returned.
+func (r *Resolver) EnqueueRequest(request Packet) error {
+ select {
+ case <-r.closed:
+ return ErrClosed
+ case r.queue <- request:
+ return nil
+ default:
+ return errFullQueue
+ }
+}
+
+// NextResponse returns a DNS response to a previously enqueued request.
+// It blocks until a response is available and gives up ownership of the response payload.
+func (r *Resolver) NextResponse() (Packet, error) {
+ select {
+ case <-r.closed:
+ return Packet{}, ErrClosed
+ case resp := <-r.responses:
+ return resp, nil
+ case err := <-r.errors:
+ return Packet{}, err
+ }
+}
+
+// Resolve maps a given domain name to the IP address of the host that owns it,
+// if the IP address conforms to the DNS resource type given by tp (one of A, AAAA, ALL).
+// The domain name must be in canonical form (with a trailing period).
+func (r *Resolver) Resolve(domain string, tp dns.Type) (netaddr.IP, dns.RCode, error) {
+ r.mu.Lock()
+ dnsMap := r.dnsMap
+ r.mu.Unlock()
+
+ if dnsMap == nil {
+ return netaddr.IP{}, dns.RCodeServerFailure, errMapNotSet
+ }
+
+ // Reject .onion domains per RFC 7686.
+ if dnsname.HasSuffix(domain, ".onion") {
+ return netaddr.IP{}, dns.RCodeNameError, nil
+ }
+
+ anyHasSuffix := false
+ for _, suffix := range dnsMap.rootDomains {
+ if dnsname.HasSuffix(domain, suffix) {
+ anyHasSuffix = true
+ break
+ }
+ }
+ addr, found := dnsMap.nameToIP[domain]
+ if !found {
+ if !anyHasSuffix {
+ return netaddr.IP{}, dns.RCodeRefused, nil
+ }
+ return netaddr.IP{}, dns.RCodeNameError, nil
+ }
+
+ // Refactoring note: this must happen after we check suffixes,
+ // otherwise we will respond with NOTIMP to requests that should be forwarded.
+ switch tp {
+ case dns.TypeA:
+ if !addr.Is4() {
+ return netaddr.IP{}, dns.RCodeSuccess, nil
+ }
+ return addr, dns.RCodeSuccess, nil
+ case dns.TypeAAAA:
+ if !addr.Is6() {
+ return netaddr.IP{}, dns.RCodeSuccess, nil
+ }
+ return addr, dns.RCodeSuccess, nil
+ case dns.TypeALL:
+ // Answer with whatever we've got.
+ // It could be IPv4, IPv6, or a zero addr.
+ // TODO: Return all available resolutions (A and AAAA, if we have them).
+ return addr, dns.RCodeSuccess, nil
+
+ // Leave some some record types explicitly unimplemented.
+ // These types relate to recursive resolution or special
+ // DNS sematics and might be implemented in the future.
+ case dns.TypeNS, dns.TypeSOA, dns.TypeAXFR, dns.TypeHINFO:
+ return netaddr.IP{}, dns.RCodeNotImplemented, errNotImplemented
+
+ // For everything except for the few types above that are explictly not implemented, return no records.
+ // This is what other DNS systems do: always return NOERROR
+ // without any records whenever the requested record type is unknown.
+ // You can try this with:
+ // dig -t TYPE9824 example.com
+ // and note that NOERROR is returned, despite that record type being made up.
+ default:
+ // no records exist of this type
+ return netaddr.IP{}, dns.RCodeSuccess, nil
+ }
+}
+
+// ResolveReverse returns the unique domain name that maps to the given address.
+// The returned domain name is in canonical form (with a trailing period).
+func (r *Resolver) ResolveReverse(ip netaddr.IP) (string, dns.RCode, error) {
+ r.mu.Lock()
+ dnsMap := r.dnsMap
+ r.mu.Unlock()
+
+ if dnsMap == nil {
+ return "", dns.RCodeServerFailure, errMapNotSet
+ }
+ name, found := dnsMap.ipToName[ip]
+ if !found {
+ return "", dns.RCodeNameError, nil
+ }
+ return name, dns.RCodeSuccess, nil
+}
+
+func (r *Resolver) poll() {
+ defer r.wg.Done()
+
+ var packet Packet
+ for {
+ select {
+ case <-r.closed:
+ return
+ case packet = <-r.queue:
+ // continue
+ }
+
+ out, err := r.respond(packet.Payload)
+
+ if err == errNotOurName {
+ if r.forwarder != nil {
+ err = r.forwarder.forward(packet)
+ if err == nil {
+ // forward will send response into r.responses, nothing to do.
+ continue
+ }
+ } else {
+ err = errNotForwarding
+ }
+ }
+
+ if err != nil {
+ select {
+ case <-r.closed:
+ return
+ case r.errors <- err:
+ // continue
+ }
+ } else {
+ packet.Payload = out
+ select {
+ case <-r.closed:
+ return
+ case r.responses <- packet:
+ // continue
+ }
+ }
+ }
+}
+
+type response struct {
+ Header dns.Header
+ Question dns.Question
+ // Name is the response to a PTR query.
+ Name string
+ // IP is the response to an A, AAAA, or ALL query.
+ IP netaddr.IP
+}
+
+// parseQuery parses the query in given packet into a response struct.
+func parseQuery(query []byte, resp *response) error {
+ var parser dns.Parser
+ var err error
+
+ resp.Header, err = parser.Start(query)
+ if err != nil {
+ return err
+ }
+
+ if resp.Header.Response {
+ return errNotQuery
+ }
+
+ resp.Question, err = parser.Question()
+ if err != nil {
+ return err
+ }
+
+ return nil
+}
+
+// marshalARecord serializes an A record into an active builder.
+// The caller may continue using the builder following the call.
+func marshalARecord(name dns.Name, ip netaddr.IP, builder *dns.Builder) error {
+ var answer dns.AResource
+
+ answerHeader := dns.ResourceHeader{
+ Name: name,
+ Type: dns.TypeA,
+ Class: dns.ClassINET,
+ TTL: uint32(defaultTTL / time.Second),
+ }
+ ipbytes := ip.As4()
+ copy(answer.A[:], ipbytes[:])
+ return builder.AResource(answerHeader, answer)
+}
+
+// marshalAAAARecord serializes an AAAA record into an active builder.
+// The caller may continue using the builder following the call.
+func marshalAAAARecord(name dns.Name, ip netaddr.IP, builder *dns.Builder) error {
+ var answer dns.AAAAResource
+
+ answerHeader := dns.ResourceHeader{
+ Name: name,
+ Type: dns.TypeAAAA,
+ Class: dns.ClassINET,
+ TTL: uint32(defaultTTL / time.Second),
+ }
+ ipbytes := ip.As16()
+ copy(answer.AAAA[:], ipbytes[:])
+ return builder.AAAAResource(answerHeader, answer)
+}
+
+// marshalPTRRecord serializes a PTR record into an active builder.
+// The caller may continue using the builder following the call.
+func marshalPTRRecord(queryName dns.Name, name string, builder *dns.Builder) error {
+ var answer dns.PTRResource
+ var err error
+
+ answerHeader := dns.ResourceHeader{
+ Name: queryName,
+ Type: dns.TypePTR,
+ Class: dns.ClassINET,
+ TTL: uint32(defaultTTL / time.Second),
+ }
+ answer.PTR, err = dns.NewName(name)
+ if err != nil {
+ return err
+ }
+ return builder.PTRResource(answerHeader, answer)
+}
+
+// marshalResponse serializes the DNS response into a new buffer.
+func marshalResponse(resp *response) ([]byte, error) {
+ resp.Header.Response = true
+ resp.Header.Authoritative = true
+ if resp.Header.RecursionDesired {
+ resp.Header.RecursionAvailable = true
+ }
+
+ builder := dns.NewBuilder(nil, resp.Header)
+
+ isSuccess := resp.Header.RCode == dns.RCodeSuccess
+
+ if resp.Question.Type != 0 || isSuccess {
+ err := builder.StartQuestions()
+ if err != nil {
+ return nil, err
+ }
+
+ err = builder.Question(resp.Question)
+ if err != nil {
+ return nil, err
+ }
+ }
+
+ // Only successful responses contain answers.
+ if !isSuccess {
+ return builder.Finish()
+ }
+
+ err := builder.StartAnswers()
+ if err != nil {
+ return nil, err
+ }
+
+ switch resp.Question.Type {
+ case dns.TypeA, dns.TypeAAAA, dns.TypeALL:
+ if resp.IP.Is4() {
+ err = marshalARecord(resp.Question.Name, resp.IP, &builder)
+ } else if resp.IP.Is6() {
+ err = marshalAAAARecord(resp.Question.Name, resp.IP, &builder)
+ }
+ case dns.TypePTR:
+ err = marshalPTRRecord(resp.Question.Name, resp.Name, &builder)
+ }
+ if err != nil {
+ return nil, err
+ }
+
+ return builder.Finish()
+}
+
+const (
+ rdnsv4Suffix = ".in-addr.arpa."
+ rdnsv6Suffix = ".ip6.arpa."
+)
+
+// hasRDNSBonjourPrefix reports whether name has a Bonjour Service Prefix..
+//
+// https://tools.ietf.org/html/rfc6763 lists
+// "five special RR names" for Bonjour service discovery:
+//
+// b._dns-sd._udp.<domain>.
+// db._dns-sd._udp.<domain>.
+// r._dns-sd._udp.<domain>.
+// dr._dns-sd._udp.<domain>.
+// lb._dns-sd._udp.<domain>.
+func hasRDNSBonjourPrefix(s string) bool {
+ // Even the shortest name containing a Bonjour prefix is long,
+ // so check length (cheap) and bail early if possible.
+ if len(s) < len("*._dns-sd._udp.0.0.0.0.in-addr.arpa.") {
+ return false
+ }
+ dot := strings.IndexByte(s, '.')
+ if dot == -1 {
+ return false // shouldn't happen
+ }
+ switch s[:dot] {
+ case "b", "db", "r", "dr", "lb":
+ default:
+ return false
+ }
+
+ return strings.HasPrefix(s[dot:], "._dns-sd._udp.")
+}
+
+// rawNameToLower converts a raw DNS name to a string, lowercasing it.
+func rawNameToLower(name []byte) string {
+ var sb strings.Builder
+ sb.Grow(len(name))
+
+ for _, b := range name {
+ if 'A' <= b && b <= 'Z' {
+ b = b - 'A' + 'a'
+ }
+ sb.WriteByte(b)
+ }
+
+ return sb.String()
+}
+
+// ptrNameToIPv4 transforms a PTR name representing an IPv4 address to said address.
+// Such names are IPv4 labels in reverse order followed by .in-addr.arpa.
+// For example,
+// 4.3.2.1.in-addr.arpa
+// is transformed to
+// 1.2.3.4
+func rdnsNameToIPv4(name string) (ip netaddr.IP, ok bool) {
+ name = strings.TrimSuffix(name, rdnsv4Suffix)
+ ip, err := netaddr.ParseIP(string(name))
+ if err != nil {
+ return netaddr.IP{}, false
+ }
+ if !ip.Is4() {
+ return netaddr.IP{}, false
+ }
+ b := ip.As4()
+ return netaddr.IPv4(b[3], b[2], b[1], b[0]), true
+}
+
+// ptrNameToIPv6 transforms a PTR name representing an IPv6 address to said address.
+// Such names are dot-separated nibbles in reverse order followed by .ip6.arpa.
+// For example,
+// b.a.9.8.7.6.5.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.8.b.d.0.1.0.0.2.ip6.arpa.
+// is transformed to
+// 2001:db8::567:89ab
+func rdnsNameToIPv6(name string) (ip netaddr.IP, ok bool) {
+ var b [32]byte
+ var ipb [16]byte
+
+ name = strings.TrimSuffix(name, rdnsv6Suffix)
+ // 32 nibbles and 31 dots between them.
+ if len(name) != 63 {
+ return netaddr.IP{}, false
+ }
+
+ // Dots and hex digits alternate.
+ prevDot := true
+ // i ranges over name backward; j ranges over b forward.
+ for i, j := len(name)-1, 0; i >= 0; i-- {
+ thisDot := (name[i] == '.')
+ if prevDot == thisDot {
+ return netaddr.IP{}, false
+ }
+ prevDot = thisDot
+
+ if !thisDot {
+ // This is safe assuming alternation.
+ // We do not check that non-dots are hex digits: hex.Decode below will do that.
+ b[j] = name[i]
+ j++
+ }
+ }
+
+ _, err := hex.Decode(ipb[:], b[:])
+ if err != nil {
+ return netaddr.IP{}, false
+ }
+
+ return netaddr.IPFrom16(ipb), true
+}
+
+// respondReverse returns a DNS response to a PTR query.
+// It is assumed that resp.Question is populated by respond before this is called.
+func (r *Resolver) respondReverse(query []byte, name string, resp *response) ([]byte, error) {
+ if hasRDNSBonjourPrefix(name) {
+ return nil, errNotOurName
+ }
+
+ var ip netaddr.IP
+ var ok bool
+ switch {
+ case strings.HasSuffix(name, rdnsv4Suffix):
+ ip, ok = rdnsNameToIPv4(name)
+ case strings.HasSuffix(name, rdnsv6Suffix):
+ ip, ok = rdnsNameToIPv6(name)
+ default:
+ return nil, errNotOurName
+ }
+
+ // It is more likely that we failed in parsing the name than that it is actually malformed.
+ // To avoid frustrating users, just log and delegate.
+ if !ok {
+ r.logf("parsing rdns: malformed name: %s", name)
+ return nil, errNotOurName
+ }
+
+ var err error
+ resp.Name, resp.Header.RCode, err = r.ResolveReverse(ip)
+ if err != nil {
+ r.logf("resolving rdns: %v", ip, err)
+ }
+ if resp.Header.RCode == dns.RCodeNameError {
+ return nil, errNotOurName
+ }
+
+ return marshalResponse(resp)
+}
+
+// respond returns a DNS response to query if it can be resolved locally.
+// Otherwise, it returns errNotOurName.
+func (r *Resolver) respond(query []byte) ([]byte, error) {
+ resp := new(response)
+
+ // ParseQuery is sufficiently fast to run on every DNS packet.
+ // This is considerably simpler than extracting the name by hand
+ // to shave off microseconds in case of delegation.
+ err := parseQuery(query, resp)
+ // We will not return this error: it is the sender's fault.
+ if err != nil {
+ if errors.Is(err, dns.ErrSectionDone) {
+ r.logf("parseQuery(%02x): no DNS questions", query)
+ } else {
+ r.logf("parseQuery(%02x): %v", query, err)
+ }
+ resp.Header.RCode = dns.RCodeFormatError
+ return marshalResponse(resp)
+ }
+ rawName := resp.Question.Name.Data[:resp.Question.Name.Length]
+ name := rawNameToLower(rawName)
+
+ // Always try to handle reverse lookups; delegate inside when not found.
+ // This way, queries for existent nodes do not leak,
+ // but we behave gracefully if non-Tailscale nodes exist in CGNATRange.
+ if resp.Question.Type == dns.TypePTR {
+ return r.respondReverse(query, name, resp)
+ }
+
+ resp.IP, resp.Header.RCode, err = r.Resolve(name, resp.Question.Type)
+ // This return code is special: it requests forwarding.
+ if resp.Header.RCode == dns.RCodeRefused {
+ return nil, errNotOurName
+ }
+
+ // We will not return this error: it is the sender's fault.
+ if err != nil {
+ r.logf("resolving: %v", err)
+ }
+
+ return marshalResponse(resp)
+}
diff --git a/net/dns/tsdns_server_test.go b/net/dns/tsdns_server_test.go
new file mode 100644
index 000000000..95544ba18
--- /dev/null
+++ b/net/dns/tsdns_server_test.go
@@ -0,0 +1,95 @@
+// Copyright (c) 2020 Tailscale Inc & AUTHORS All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package dns
+
+import (
+ "log"
+ "testing"
+
+ "github.com/miekg/dns"
+ "inet.af/netaddr"
+)
+
+// This file exists to isolate the test infrastructure
+// that depends on github.com/miekg/dns
+// from the rest, which only depends on dnsmessage.
+
+var dnsHandleFunc = dns.HandleFunc
+
+// resolveToIP returns a handler function which responds
+// to queries of type A it receives with an A record containing ipv4,
+// to queries of type AAAA with an AAAA record containing ipv6,
+// to queries of type NS with an NS record containg name.
+func resolveToIP(ipv4, ipv6 netaddr.IP, ns string) dns.HandlerFunc {
+ return func(w dns.ResponseWriter, req *dns.Msg) {
+ m := new(dns.Msg)
+ m.SetReply(req)
+
+ if len(req.Question) != 1 {
+ panic("not a single-question request")
+ }
+ question := req.Question[0]
+
+ var ans dns.RR
+ switch question.Qtype {
+ case dns.TypeA:
+ ans = &dns.A{
+ Hdr: dns.RR_Header{
+ Name: question.Name,
+ Rrtype: dns.TypeA,
+ Class: dns.ClassINET,
+ },
+ A: ipv4.IPAddr().IP,
+ }
+ case dns.TypeAAAA:
+ ans = &dns.AAAA{
+ Hdr: dns.RR_Header{
+ Name: question.Name,
+ Rrtype: dns.TypeAAAA,
+ Class: dns.ClassINET,
+ },
+ AAAA: ipv6.IPAddr().IP,
+ }
+ case dns.TypeNS:
+ ans = &dns.NS{
+ Hdr: dns.RR_Header{
+ Name: question.Name,
+ Rrtype: dns.TypeNS,
+ Class: dns.ClassINET,
+ },
+ Ns: ns,
+ }
+ }
+
+ m.Answer = append(m.Answer, ans)
+ w.WriteMsg(m)
+ }
+}
+
+func resolveToNXDOMAIN(w dns.ResponseWriter, req *dns.Msg) {
+ m := new(dns.Msg)
+ m.SetRcode(req, dns.RcodeNameError)
+ w.WriteMsg(m)
+}
+
+func serveDNS(tb testing.TB, addr string) (*dns.Server, chan error) {
+ server := &dns.Server{Addr: addr, Net: "udp"}
+
+ waitch := make(chan struct{})
+ server.NotifyStartedFunc = func() { close(waitch) }
+
+ errch := make(chan error, 1)
+ go func() {
+ err := server.ListenAndServe()
+ if err != nil {
+ log.Printf("ListenAndServe(%q): %v", addr, err)
+ }
+ errch <- err
+ close(errch)
+ }()
+
+ <-waitch
+ return server, errch
+}
diff --git a/net/dns/tsdns_test.go b/net/dns/tsdns_test.go
new file mode 100644
index 000000000..59bcd8ec1
--- /dev/null
+++ b/net/dns/tsdns_test.go
@@ -0,0 +1,816 @@
+// Copyright (c) 2020 Tailscale Inc & AUTHORS All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package dns
+
+import (
+ "bytes"
+ "errors"
+ "net"
+ "sync"
+ "testing"
+
+ dns "golang.org/x/net/dns/dnsmessage"
+ "inet.af/netaddr"
+ "tailscale.com/tstest"
+)
+
+var testipv4 = netaddr.IPv4(1, 2, 3, 4)
+var testipv6 = netaddr.IPv6Raw([16]byte{
+ 0x00, 0x01, 0x02, 0x03,
+ 0x04, 0x05, 0x06, 0x07,
+ 0x08, 0x09, 0x0a, 0x0b,
+ 0x0c, 0x0d, 0x0e, 0x0f,
+})
+
+var dnsMap = NewMap(
+ map[string]netaddr.IP{
+ "test1.ipn.dev.": testipv4,
+ "test2.ipn.dev.": testipv6,
+ },
+ []string{"ipn.dev."},
+)
+
+func dnspacket(domain string, tp dns.Type) []byte {
+ var dnsHeader dns.Header
+ question := dns.Question{
+ Name: dns.MustNewName(domain),
+ Type: tp,
+ Class: dns.ClassINET,
+ }
+
+ builder := dns.NewBuilder(nil, dnsHeader)
+ builder.StartQuestions()
+ builder.Question(question)
+ payload, _ := builder.Finish()
+
+ return payload
+}
+
+type dnsResponse struct {
+ ip netaddr.IP
+ name string
+ rcode dns.RCode
+}
+
+func unpackResponse(payload []byte) (dnsResponse, error) {
+ var response dnsResponse
+ var parser dns.Parser
+
+ h, err := parser.Start(payload)
+ if err != nil {
+ return response, err
+ }
+
+ if !h.Response {
+ return response, errors.New("not a response")
+ }
+
+ response.rcode = h.RCode
+ if response.rcode != dns.RCodeSuccess {
+ return response, nil
+ }
+
+ err = parser.SkipAllQuestions()
+ if err != nil {
+ return response, err
+ }
+
+ ah, err := parser.AnswerHeader()
+ if err != nil {
+ return response, err
+ }
+
+ switch ah.Type {
+ case dns.TypeA:
+ res, err := parser.AResource()
+ if err != nil {
+ return response, err
+ }
+ response.ip = netaddr.IPv4(res.A[0], res.A[1], res.A[2], res.A[3])
+ case dns.TypeAAAA:
+ res, err := parser.AAAAResource()
+ if err != nil {
+ return response, err
+ }
+ response.ip = netaddr.IPv6Raw(res.AAAA)
+ case dns.TypeNS:
+ res, err := parser.NSResource()
+ if err != nil {
+ return response, err
+ }
+ response.name = res.NS.String()
+ default:
+ return response, errors.New("type not in {A, AAAA, NS}")
+ }
+
+ return response, nil
+}
+
+func syncRespond(r *Resolver, query []byte) ([]byte, error) {
+ request := Packet{Payload: query}
+ r.EnqueueRequest(request)
+ resp, err := r.NextResponse()
+ return resp.Payload, err
+}
+
+func mustIP(str string) netaddr.IP {
+ ip, err := netaddr.ParseIP(str)
+ if err != nil {
+ panic(err)
+ }
+ return ip
+}
+
+func TestRDNSNameToIPv4(t *testing.T) {
+ tests := []struct {
+ name string
+ input string
+ wantIP netaddr.IP
+ wantOK bool
+ }{
+ {"valid", "4.123.24.1.in-addr.arpa.", netaddr.IPv4(1, 24, 123, 4), true},
+ {"double_dot", "1..2.3.in-addr.arpa.", netaddr.IP{}, false},
+ {"overflow", "1.256.3.4.in-addr.arpa.", netaddr.IP{}, false},
+ {"not_ip", "sub.do.ma.in.in-addr.arpa.", netaddr.IP{}, false},
+ }
+
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ ip, ok := rdnsNameToIPv4(tt.input)
+ if ok != tt.wantOK {
+ t.Errorf("ok = %v; want %v", ok, tt.wantOK)
+ } else if ok && ip != tt.wantIP {
+ t.Errorf("ip = %v; want %v", ip, tt.wantIP)
+ }
+ })
+ }
+}
+
+func TestRDNSNameToIPv6(t *testing.T) {
+ tests := []struct {
+ name string
+ input string
+ wantIP netaddr.IP
+ wantOK bool
+ }{
+ {
+ "valid",
+ "b.a.9.8.7.6.5.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.8.b.d.0.1.0.0.2.ip6.arpa.",
+ mustIP("2001:db8::567:89ab"),
+ true,
+ },
+ {
+ "double_dot",
+ "b..9.8.7.6.5.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.8.b.d.0.1.0.0.2.ip6.arpa.",
+ netaddr.IP{},
+ false,
+ },
+ {
+ "double_hex",
+ "b.a.98.0.7.6.5.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.8.b.d.0.1.0.0.2.ip6.arpa.",
+ netaddr.IP{},
+ false,
+ },
+ {
+ "not_hex",
+ "b.a.g.0.7.6.5.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.8.b.d.0.1.0.0.2.ip6.arpa.",
+ netaddr.IP{},
+ false,
+ },
+ }
+
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ ip, ok := rdnsNameToIPv6(tt.input)
+ if ok != tt.wantOK {
+ t.Errorf("ok = %v; want %v", ok, tt.wantOK)
+ } else if ok && ip != tt.wantIP {
+ t.Errorf("ip = %v; want %v", ip, tt.wantIP)
+ }
+ })
+ }
+}
+
+func TestResolve(t *testing.T) {
+ r := NewResolver(ResolverConfig{Logf: t.Logf, Forward: false})
+ r.SetMap(dnsMap)
+
+ if err := r.Start(); err != nil {
+ t.Fatalf("start: %v", err)
+ }
+ defer r.Close()
+
+ tests := []struct {
+ name string
+ qname string
+ qtype dns.Type
+ ip netaddr.IP
+ code dns.RCode
+ }{
+ {"ipv4", "test1.ipn.dev.", dns.TypeA, testipv4, dns.RCodeSuccess},
+ {"ipv6", "test2.ipn.dev.", dns.TypeAAAA, testipv6, dns.RCodeSuccess},
+ {"no-ipv6", "test1.ipn.dev.", dns.TypeAAAA, netaddr.IP{}, dns.RCodeSuccess},
+ {"nxdomain", "test3.ipn.dev.", dns.TypeA, netaddr.IP{}, dns.RCodeNameError},
+ {"foreign domain", "google.com.", dns.TypeA, netaddr.IP{}, dns.RCodeRefused},
+ {"all", "test1.ipn.dev.", dns.TypeA, testipv4, dns.RCodeSuccess},
+ {"mx-ipv4", "test1.ipn.dev.", dns.TypeMX, netaddr.IP{}, dns.RCodeSuccess},
+ {"mx-ipv6", "test2.ipn.dev.", dns.TypeMX, netaddr.IP{}, dns.RCodeSuccess},
+ {"mx-nxdomain", "test3.ipn.dev.", dns.TypeMX, netaddr.IP{}, dns.RCodeNameError},
+ {"ns-nxdomain", "test3.ipn.dev.", dns.TypeNS, netaddr.IP{}, dns.RCodeNameError},
+ {"onion-domain", "footest.onion.", dns.TypeA, netaddr.IP{}, dns.RCodeNameError},
+ }
+
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ ip, code, err := r.Resolve(tt.qname, tt.qtype)
+ if err != nil {
+ t.Errorf("err = %v; want nil", err)
+ }
+ if code != tt.code {
+ t.Errorf("code = %v; want %v", code, tt.code)
+ }
+ // Only check ip for non-err
+ if ip != tt.ip {
+ t.Errorf("ip = %v; want %v", ip, tt.ip)
+ }
+ })
+ }
+}
+
+func TestResolveReverse(t *testing.T) {
+ r := NewResolver(ResolverConfig{Logf: t.Logf, Forward: false})
+ r.SetMap(dnsMap)
+
+ if err := r.Start(); err != nil {
+ t.Fatalf("start: %v", err)
+ }
+ defer r.Close()
+
+ tests := []struct {
+ name string
+ ip netaddr.IP
+ want string
+ code dns.RCode
+ }{
+ {"ipv4", testipv4, "test1.ipn.dev.", dns.RCodeSuccess},
+ {"ipv6", testipv6, "test2.ipn.dev.", dns.RCodeSuccess},
+ {"nxdomain", netaddr.IPv4(4, 3, 2, 1), "", dns.RCodeNameError},
+ }
+
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ name, code, err := r.ResolveReverse(tt.ip)
+ if err != nil {
+ t.Errorf("err = %v; want nil", err)
+ }
+ if code != tt.code {
+ t.Errorf("code = %v; want %v", code, tt.code)
+ }
+ if name != tt.want {
+ t.Errorf("ip = %v; want %v", name, tt.want)
+ }
+ })
+ }
+}
+
+func ipv6Works() bool {
+ c, err := net.Listen("tcp", "[::1]:0")
+ if err != nil {
+ return false
+ }
+ c.Close()
+ return true
+}
+
+func TestDelegate(t *testing.T) {
+ tstest.ResourceCheck(t)
+
+ if !ipv6Works() {
+ t.Skip("skipping test that requires localhost IPv6")
+ }
+
+ dnsHandleFunc("test.site.", resolveToIP(testipv4, testipv6, "dns.test.site."))
+ dnsHandleFunc("nxdomain.site.", resolveToNXDOMAIN)
+
+ v4server, v4errch := serveDNS(t, "127.0.0.1:0")
+ v6server, v6errch := serveDNS(t, "[::1]:0")
+
+ defer func() {
+ if err := <-v4errch; err != nil {
+ t.Errorf("v4 server error: %v", err)
+ }
+ if err := <-v6errch; err != nil {
+ t.Errorf("v6 server error: %v", err)
+ }
+ }()
+ if v4server != nil {
+ defer v4server.Shutdown()
+ }
+ if v6server != nil {
+ defer v6server.Shutdown()
+ }
+
+ if v4server == nil || v6server == nil {
+ // There is an error in at least one of the channels
+ // and we cannot proceed; return to see it.
+ return
+ }
+
+ r := NewResolver(ResolverConfig{Logf: t.Logf, Forward: true})
+ r.SetMap(dnsMap)
+ r.SetUpstreams([]net.Addr{
+ v4server.PacketConn.LocalAddr(),
+ v6server.PacketConn.LocalAddr(),
+ })
+
+ if err := r.Start(); err != nil {
+ t.Fatalf("start: %v", err)
+ }
+ defer r.Close()
+
+ tests := []struct {
+ title string
+ query []byte
+ response dnsResponse
+ }{
+ {
+ "ipv4",
+ dnspacket("test.site.", dns.TypeA),
+ dnsResponse{ip: testipv4, rcode: dns.RCodeSuccess},
+ },
+ {
+ "ipv6",
+ dnspacket("test.site.", dns.TypeAAAA),
+ dnsResponse{ip: testipv6, rcode: dns.RCodeSuccess},
+ },
+ {
+ "ns",
+ dnspacket("test.site.", dns.TypeNS),
+ dnsResponse{name: "dns.test.site.", rcode: dns.RCodeSuccess},
+ },
+ {
+ "nxdomain",
+ dnspacket("nxdomain.site.", dns.TypeA),
+ dnsResponse{rcode: dns.RCodeNameError},
+ },
+ }
+
+ for _, tt := range tests {
+ t.Run(tt.title, func(t *testing.T) {
+ payload, err := syncRespond(r, tt.query)
+ if err != nil {
+ t.Errorf("err = %v; want nil", err)
+ return
+ }
+ response, err := unpackResponse(payload)
+ if err != nil {
+ t.Errorf("extract: err = %v; want nil (in %x)", err, payload)
+ return
+ }
+ if response.rcode != tt.response.rcode {
+ t.Errorf("rcode = %v; want %v", response.rcode, tt.response.rcode)
+ }
+ if response.ip != tt.response.ip {
+ t.Errorf("ip = %v; want %v", response.ip, tt.response.ip)
+ }
+ if response.name != tt.response.name {
+ t.Errorf("name = %v; want %v", response.name, tt.response.name)
+ }
+ })
+ }
+}
+
+func TestDelegateCollision(t *testing.T) {
+ dnsHandleFunc("test.site.", resolveToIP(testipv4, testipv6, "dns.test.site."))
+
+ server, errch := serveDNS(t, "127.0.0.1:0")
+ defer func() {
+ if err := <-errch; err != nil {
+ t.Errorf("server error: %v", err)
+ }
+ }()
+
+ if server == nil {
+ return
+ }
+ defer server.Shutdown()
+
+ r := NewResolver(ResolverConfig{Logf: t.Logf, Forward: true})
+ r.SetMap(dnsMap)
+ r.SetUpstreams([]net.Addr{server.PacketConn.LocalAddr()})
+
+ if err := r.Start(); err != nil {
+ t.Fatalf("start: %v", err)
+ }
+ defer r.Close()
+
+ packets := []struct {
+ qname string
+ qtype dns.Type
+ addr netaddr.IPPort
+ }{
+ {"test.site.", dns.TypeA, netaddr.IPPort{IP: netaddr.IPv4(1, 1, 1, 1), Port: 1001}},
+ {"test.site.", dns.TypeAAAA, netaddr.IPPort{IP: netaddr.IPv4(1, 1, 1, 1), Port: 1002}},
+ }
+
+ // packets will have the same dns txid.
+ for _, p := range packets {
+ payload := dnspacket(p.qname, p.qtype)
+ req := Packet{Payload: payload, Addr: p.addr}
+ err := r.EnqueueRequest(req)
+ if err != nil {
+ t.Error(err)
+ }
+ }
+
+ // Despite the txid collision, the answer(s) should still match the query.
+ resp, err := r.NextResponse()
+ if err != nil {
+ t.Error(err)
+ }
+
+ var p dns.Parser
+ _, err = p.Start(resp.Payload)
+ if err != nil {
+ t.Error(err)
+ }
+ err = p.SkipAllQuestions()
+ if err != nil {
+ t.Error(err)
+ }
+ ans, err := p.AllAnswers()
+ if err != nil {
+ t.Error(err)
+ }
+
+ var wantType dns.Type
+ switch ans[0].Body.(type) {
+ case *dns.AResource:
+ wantType = dns.TypeA
+ case *dns.AAAAResource:
+ wantType = dns.TypeAAAA
+ default:
+ t.Errorf("unexpected answer type: %T", ans[0].Body)
+ }
+
+ for _, p := range packets {
+ if p.qtype == wantType && p.addr != resp.Addr {
+ t.Errorf("addr = %v; want %v", resp.Addr, p.addr)
+ }
+ }
+}
+
+func TestConcurrentSetMap(t *testing.T) {
+ r := NewResolver(ResolverConfig{Logf: t.Logf, Forward: false})
+
+ if err := r.Start(); err != nil {
+ t.Fatalf("start: %v", err)
+ }
+ defer r.Close()
+
+ // This is purely to ensure that Resolve does not race with SetMap.
+ var wg sync.WaitGroup
+ wg.Add(2)
+ go func() {
+ defer wg.Done()
+ r.SetMap(dnsMap)
+ }()
+ go func() {
+ defer wg.Done()
+ r.Resolve("test1.ipn.dev", dns.TypeA)
+ }()
+ wg.Wait()
+}
+
+func TestConcurrentSetUpstreams(t *testing.T) {
+ dnsHandleFunc("test.site.", resolveToIP(testipv4, testipv6, "dns.test.site."))
+
+ server, errch := serveDNS(t, "127.0.0.1:0")
+ defer func() {
+ if err := <-errch; err != nil {
+ t.Errorf("server error: %v", err)
+ }
+ }()
+
+ if server == nil {
+ return
+ }
+ defer server.Shutdown()
+
+ r := NewResolver(ResolverConfig{Logf: t.Logf, Forward: true})
+ r.SetMap(dnsMap)
+
+ if err := r.Start(); err != nil {
+ t.Fatalf("start: %v", err)
+ }
+ defer r.Close()
+
+ packet := dnspacket("test.site.", dns.TypeA)
+ // This is purely to ensure that delegation does not race with SetUpstreams.
+ var wg sync.WaitGroup
+ wg.Add(2)
+ go func() {
+ defer wg.Done()
+ r.SetUpstreams([]net.Addr{server.PacketConn.LocalAddr()})
+ }()
+ go func() {
+ defer wg.Done()
+ syncRespond(r, packet)
+ }()
+ wg.Wait()
+}
+
+var allResponse = []byte{
+ 0x00, 0x00, // transaction id: 0
+ 0x84, 0x00, // flags: response, authoritative, no error
+ 0x00, 0x01, // one question
+ 0x00, 0x01, // one answer
+ 0x00, 0x00, 0x00, 0x00, // no authority or additional RRs
+ // Question:
+ 0x05, 0x74, 0x65, 0x73, 0x74, 0x31, 0x03, 0x69, 0x70, 0x6e, 0x03, 0x64, 0x65, 0x76, 0x00, // name
+ 0x00, 0xff, 0x00, 0x01, // type ALL, class IN
+ // Answer:
+ 0x05, 0x74, 0x65, 0x73, 0x74, 0x31, 0x03, 0x69, 0x70, 0x6e, 0x03, 0x64, 0x65, 0x76, 0x00, // name
+ 0x00, 0x01, 0x00, 0x01, // type A, class IN
+ 0x00, 0x00, 0x02, 0x58, // TTL: 600
+ 0x00, 0x04, // length: 4 bytes
+ 0x01, 0x02, 0x03, 0x04, // A: 1.2.3.4
+}
+
+var ipv4Response = []byte{
+ 0x00, 0x00, // transaction id: 0
+ 0x84, 0x00, // flags: response, authoritative, no error
+ 0x00, 0x01, // one question
+ 0x00, 0x01, // one answer
+ 0x00, 0x00, 0x00, 0x00, // no authority or additional RRs
+ // Question:
+ 0x05, 0x74, 0x65, 0x73, 0x74, 0x31, 0x03, 0x69, 0x70, 0x6e, 0x03, 0x64, 0x65, 0x76, 0x00, // name
+ 0x00, 0x01, 0x00, 0x01, // type A, class IN
+ // Answer:
+ 0x05, 0x74, 0x65, 0x73, 0x74, 0x31, 0x03, 0x69, 0x70, 0x6e, 0x03, 0x64, 0x65, 0x76, 0x00, // name
+ 0x00, 0x01, 0x00, 0x01, // type A, class IN
+ 0x00, 0x00, 0x02, 0x58, // TTL: 600
+ 0x00, 0x04, // length: 4 bytes
+ 0x01, 0x02, 0x03, 0x04, // A: 1.2.3.4
+}
+
+var ipv6Response = []byte{
+ 0x00, 0x00, // transaction id: 0
+ 0x84, 0x00, // flags: response, authoritative, no error
+ 0x00, 0x01, // one question
+ 0x00, 0x01, // one answer
+ 0x00, 0x00, 0x00, 0x00, // no authority or additional RRs
+ // Question:
+ 0x05, 0x74, 0x65, 0x73, 0x74, 0x32, 0x03, 0x69, 0x70, 0x6e, 0x03, 0x64, 0x65, 0x76, 0x00, // name
+ 0x00, 0x1c, 0x00, 0x01, // type AAAA, class IN
+ // Answer:
+ 0x05, 0x74, 0x65, 0x73, 0x74, 0x32, 0x03, 0x69, 0x70, 0x6e, 0x03, 0x64, 0x65, 0x76, 0x00, // name
+ 0x00, 0x1c, 0x00, 0x01, // type AAAA, class IN
+ 0x00, 0x00, 0x02, 0x58, // TTL: 600
+ 0x00, 0x10, // length: 16 bytes
+ // AAAA: 0001:0203:0405:0607:0809:0A0B:0C0D:0E0F
+ 0x00, 0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07, 0x08, 0x09, 0x0a, 0xb, 0xc, 0xd, 0xe, 0xf,
+}
+
+var ipv4UppercaseResponse = []byte{
+ 0x00, 0x00, // transaction id: 0
+ 0x84, 0x00, // flags: response, authoritative, no error
+ 0x00, 0x01, // one question
+ 0x00, 0x01, // one answer
+ 0x00, 0x00, 0x00, 0x00, // no authority or additional RRs
+ // Question:
+ 0x05, 0x54, 0x45, 0x53, 0x54, 0x31, 0x03, 0x49, 0x50, 0x4e, 0x03, 0x44, 0x45, 0x56, 0x00, // name
+ 0x00, 0x01, 0x00, 0x01, // type A, class IN
+ // Answer:
+ 0x05, 0x54, 0x45, 0x53, 0x54, 0x31, 0x03, 0x49, 0x50, 0x4e, 0x03, 0x44, 0x45, 0x56, 0x00, // name
+ 0x00, 0x01, 0x00, 0x01, // type A, class IN
+ 0x00, 0x00, 0x02, 0x58, // TTL: 600
+ 0x00, 0x04, // length: 4 bytes
+ 0x01, 0x02, 0x03, 0x04, // A: 1.2.3.4
+}
+
+var ptrResponse = []byte{
+ 0x00, 0x00, // transaction id: 0
+ 0x84, 0x00, // flags: response, authoritative, no error
+ 0x00, 0x01, // one question
+ 0x00, 0x01, // one answer
+ 0x00, 0x00, 0x00, 0x00, // no authority or additional RRs
+ // Question: 4.3.2.1.in-addr.arpa
+ 0x01, 0x34, 0x01, 0x33, 0x01, 0x32, 0x01, 0x31, 0x07,
+ 0x69, 0x6e, 0x2d, 0x61, 0x64, 0x64, 0x72, 0x04, 0x61, 0x72, 0x70, 0x61, 0x00,
+ 0x00, 0x0c, 0x00, 0x01, // type PTR, class IN
+ // Answer: 4.3.2.1.in-addr.arpa
+ 0x01, 0x34, 0x01, 0x33, 0x01, 0x32, 0x01, 0x31, 0x07,
+ 0x69, 0x6e, 0x2d, 0x61, 0x64, 0x64, 0x72, 0x04, 0x61, 0x72, 0x70, 0x61, 0x00,
+ 0x00, 0x0c, 0x00, 0x01, // type PTR, class IN
+ 0x00, 0x00, 0x02, 0x58, // TTL: 600
+ 0x00, 0x0f, // length: 15 bytes
+ // PTR: test1.ipn.dev
+ 0x05, 0x74, 0x65, 0x73, 0x74, 0x31, 0x03, 0x69, 0x70, 0x6e, 0x03, 0x64, 0x65, 0x76, 0x00,
+}
+
+var ptrResponse6 = []byte{
+ 0x00, 0x00, // transaction id: 0
+ 0x84, 0x00, // flags: response, authoritative, no error
+ 0x00, 0x01, // one question
+ 0x00, 0x01, // one answer
+ 0x00, 0x00, 0x00, 0x00, // no authority or additional RRs
+ // Question: f.0.e.0.d.0.c.0.b.0.a.0.9.0.8.0.7.0.6.0.5.0.4.0.3.0.2.0.1.0.0.0.ip6.arpa
+ 0x01, 0x66, 0x01, 0x30, 0x01, 0x65, 0x01, 0x30,
+ 0x01, 0x64, 0x01, 0x30, 0x01, 0x63, 0x01, 0x30,
+ 0x01, 0x62, 0x01, 0x30, 0x01, 0x61, 0x01, 0x30,
+ 0x01, 0x39, 0x01, 0x30, 0x01, 0x38, 0x01, 0x30,
+ 0x01, 0x37, 0x01, 0x30, 0x01, 0x36, 0x01, 0x30,
+ 0x01, 0x35, 0x01, 0x30, 0x01, 0x34, 0x01, 0x30,
+ 0x01, 0x33, 0x01, 0x30, 0x01, 0x32, 0x01, 0x30,
+ 0x01, 0x31, 0x01, 0x30, 0x01, 0x30, 0x01, 0x30,
+ 0x03, 0x69, 0x70, 0x36,
+ 0x04, 0x61, 0x72, 0x70, 0x61, 0x00,
+ 0x00, 0x0c, 0x00, 0x01, // type PTR, class IN6
+ // Answer: f.0.e.0.d.0.c.0.b.0.a.0.9.0.8.0.7.0.6.0.5.0.4.0.3.0.2.0.1.0.0.0.ip6.arpa
+ 0x01, 0x66, 0x01, 0x30, 0x01, 0x65, 0x01, 0x30,
+ 0x01, 0x64, 0x01, 0x30, 0x01, 0x63, 0x01, 0x30,
+ 0x01, 0x62, 0x01, 0x30, 0x01, 0x61, 0x01, 0x30,
+ 0x01, 0x39, 0x01, 0x30, 0x01, 0x38, 0x01, 0x30,
+ 0x01, 0x37, 0x01, 0x30, 0x01, 0x36, 0x01, 0x30,
+ 0x01, 0x35, 0x01, 0x30, 0x01, 0x34, 0x01, 0x30,
+ 0x01, 0x33, 0x01, 0x30, 0x01, 0x32, 0x01, 0x30,
+ 0x01, 0x31, 0x01, 0x30, 0x01, 0x30, 0x01, 0x30,
+ 0x03, 0x69, 0x70, 0x36,
+ 0x04, 0x61, 0x72, 0x70, 0x61, 0x00,
+ 0x00, 0x0c, 0x00, 0x01, // type PTR, class IN
+ 0x00, 0x00, 0x02, 0x58, // TTL: 600
+ 0x00, 0x0f, // length: 15 bytes
+ // PTR: test2.ipn.dev
+ 0x05, 0x74, 0x65, 0x73, 0x74, 0x32, 0x03, 0x69, 0x70, 0x6e, 0x03, 0x64, 0x65, 0x76, 0x00,
+}
+
+var nxdomainResponse = []byte{
+ 0x00, 0x00, // transaction id: 0
+ 0x84, 0x03, // flags: response, authoritative, error: nxdomain
+ 0x00, 0x01, // one question
+ 0x00, 0x00, // no answers
+ 0x00, 0x00, 0x00, 0x00, // no authority or additional RRs
+ // Question:
+ 0x05, 0x74, 0x65, 0x73, 0x74, 0x33, 0x03, 0x69, 0x70, 0x6e, 0x03, 0x64, 0x65, 0x76, 0x00, // name
+ 0x00, 0x01, 0x00, 0x01, // type A, class IN
+}
+
+var emptyResponse = []byte{
+ 0x00, 0x00, // transaction id: 0
+ 0x84, 0x00, // flags: response, authoritative, no error
+ 0x00, 0x01, // one question
+ 0x00, 0x00, // no answers
+ 0x00, 0x00, 0x00, 0x00, // no authority or additional RRs
+ // Question:
+ 0x05, 0x74, 0x65, 0x73, 0x74, 0x31, 0x03, 0x69, 0x70, 0x6e, 0x03, 0x64, 0x65, 0x76, 0x00, // name
+ 0x00, 0x1c, 0x00, 0x01, // type AAAA, class IN
+}
+
+func TestFull(t *testing.T) {
+ r := NewResolver(ResolverConfig{Logf: t.Logf, Forward: false})
+ r.SetMap(dnsMap)
+
+ if err := r.Start(); err != nil {
+ t.Fatalf("start: %v", err)
+ }
+ defer r.Close()
+
+ // One full packet and one error packet
+ tests := []struct {
+ name string
+ request []byte
+ response []byte
+ }{
+ {"all", dnspacket("test1.ipn.dev.", dns.TypeALL), allResponse},
+ {"ipv4", dnspacket("test1.ipn.dev.", dns.TypeA), ipv4Response},
+ {"ipv6", dnspacket("test2.ipn.dev.", dns.TypeAAAA), ipv6Response},
+ {"no-ipv6", dnspacket("test1.ipn.dev.", dns.TypeAAAA), emptyResponse},
+ {"upper", dnspacket("TEST1.IPN.DEV.", dns.TypeA), ipv4UppercaseResponse},
+ {"ptr", dnspacket("4.3.2.1.in-addr.arpa.", dns.TypePTR), ptrResponse},
+ {"ptr", dnspacket("f.0.e.0.d.0.c.0.b.0.a.0.9.0.8.0.7.0.6.0.5.0.4.0.3.0.2.0.1.0.0.0.ip6.arpa.",
+ dns.TypePTR), ptrResponse6},
+ {"nxdomain", dnspacket("test3.ipn.dev.", dns.TypeA), nxdomainResponse},
+ }
+
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ response, err := syncRespond(r, tt.request)
+ if err != nil {
+ t.Errorf("err = %v; want nil", err)
+ }
+ if !bytes.Equal(response, tt.response) {
+ t.Errorf("response = %x; want %x", response, tt.response)
+ }
+ })
+ }
+}
+
+func TestAllocs(t *testing.T) {
+ r := NewResolver(ResolverConfig{Logf: t.Logf, Forward: false})
+ r.SetMap(dnsMap)
+
+ if err := r.Start(); err != nil {
+ t.Fatalf("start: %v", err)
+ }
+ defer r.Close()
+
+ // It is seemingly pointless to test allocs in the delegate path,
+ // as dialer.Dial -> Read -> Write alone comprise 12 allocs.
+ tests := []struct {
+ name string
+ query []byte
+ want int
+ }{
+ // Name lowercasing and response slice created by dns.NewBuilder.
+ {"forward", dnspacket("test1.ipn.dev.", dns.TypeA), 2},
+ // 3 extra allocs in rdnsNameToIPv4 and one in marshalPTRRecord (dns.NewName).
+ {"reverse", dnspacket("4.3.2.1.in-addr.arpa.", dns.TypePTR), 5},
+ }
+
+ for _, tt := range tests {
+ allocs := testing.AllocsPerRun(100, func() {
+ syncRespond(r, tt.query)
+ })
+ if int(allocs) > tt.want {
+ t.Errorf("%s: allocs = %v; want %v", tt.name, allocs, tt.want)
+ }
+ }
+}
+
+func TestTrimRDNSBonjourPrefix(t *testing.T) {
+ tests := []struct {
+ in string
+ want bool
+ }{
+ {"b._dns-sd._udp.0.10.20.172.in-addr.arpa.", true},
+ {"db._dns-sd._udp.0.10.20.172.in-addr.arpa.", true},
+ {"r._dns-sd._udp.0.10.20.172.in-addr.arpa.", true},
+ {"dr._dns-sd._udp.0.10.20.172.in-addr.arpa.", true},
+ {"lb._dns-sd._udp.0.10.20.172.in-addr.arpa.", true},
+ {"qq._dns-sd._udp.0.10.20.172.in-addr.arpa.", false},
+ {"0.10.20.172.in-addr.arpa.", false},
+ {"i-have-no-dot", false},
+ }
+
+ for _, test := range tests {
+ got := hasRDNSBonjourPrefix(test.in)
+ if got != test.want {
+ t.Errorf("trimRDNSBonjourPrefix(%q) = %v, want %v", test.in, got, test.want)
+ }
+ }
+}
+
+func BenchmarkFull(b *testing.B) {
+ dnsHandleFunc("test.site.", resolveToIP(testipv4, testipv6, "dns.test.site."))
+
+ server, errch := serveDNS(b, "127.0.0.1:0")
+ defer func() {
+ if err := <-errch; err != nil {
+ b.Errorf("server error: %v", err)
+ }
+ }()
+
+ if server == nil {
+ return
+ }
+ defer server.Shutdown()
+
+ r := NewResolver(ResolverConfig{Logf: b.Logf, Forward: true})
+ r.SetMap(dnsMap)
+ r.SetUpstreams([]net.Addr{server.PacketConn.LocalAddr()})
+
+ if err := r.Start(); err != nil {
+ b.Fatalf("start: %v", err)
+ }
+ defer r.Close()
+
+ tests := []struct {
+ name string
+ request []byte
+ }{
+ {"forward", dnspacket("test1.ipn.dev.", dns.TypeA)},
+ {"reverse", dnspacket("4.3.2.1.in-addr.arpa.", dns.TypePTR)},
+ {"delegated", dnspacket("test.site.", dns.TypeA)},
+ }
+
+ for _, tt := range tests {
+ b.Run(tt.name, func(b *testing.B) {
+ for i := 0; i < b.N; i++ {
+ syncRespond(r, tt.request)
+ }
+ })
+ }
+}
+
+func TestMarshalResponseFormatError(t *testing.T) {
+ resp := new(response)
+ resp.Header.RCode = dns.RCodeFormatError
+ v, err := marshalResponse(resp)
+ if err != nil {
+ t.Errorf("marshal error: %v", err)
+ }
+ t.Logf("response: %q", v)
+}