diff options
| author | Naman Sood <mail@nsood.in> | 2021-03-29 14:28:08 -0400 |
|---|---|---|
| committer | Naman Sood <mail@nsood.in> | 2021-03-29 14:28:08 -0400 |
| commit | c0a88a0129ebf0f9886b93b1f4e4f04a7c3bb86f (patch) | |
| tree | 57d5aef2985e3424e5bb6f4c810628aa3ccbf5d0 /net | |
| parent | 47bd3c4cf5543fd7ecb049302c37c1001fa9f2d6 (diff) | |
| parent | a4c679e64691a3f0ba41ad9078312ca67e5e67fd (diff) | |
| download | tailscale-naman/netstack-subnet-routing.tar.xz tailscale-naman/netstack-subnet-routing.zip | |
merge with mainnaman/netstack-subnet-routing
Signed-off-by: Naman Sood <mail@nsood.in>
Diffstat (limited to 'net')
45 files changed, 5216 insertions, 212 deletions
diff --git a/net/dns/config.go b/net/dns/config.go new file mode 100644 index 000000000..db08df7aa --- /dev/null +++ b/net/dns/config.go @@ -0,0 +1,77 @@ +// Copyright (c) 2020 Tailscale Inc & AUTHORS All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package dns + +import ( + "inet.af/netaddr" + + "tailscale.com/types/logger" +) + +// Config is the set of parameters that uniquely determine +// the state to which a manager should bring system DNS settings. +type Config struct { + // Nameservers are the IP addresses of the nameservers to use. + Nameservers []netaddr.IP + // Domains are the search domains to use. + Domains []string + // PerDomain indicates whether it is preferred to use Nameservers + // only for DNS queries for subdomains of Domains. + // Note that Nameservers may still be applied to all queries + // if the manager does not support per-domain settings. + PerDomain bool + // Proxied indicates whether DNS requests are proxied through a dns.Resolver. + // This enables MagicDNS. + Proxied bool +} + +// Equal determines whether its argument and receiver +// represent equivalent DNS configurations (then DNS reconfig is a no-op). +func (lhs Config) Equal(rhs Config) bool { + if lhs.Proxied != rhs.Proxied || lhs.PerDomain != rhs.PerDomain { + return false + } + + if len(lhs.Nameservers) != len(rhs.Nameservers) { + return false + } + + if len(lhs.Domains) != len(rhs.Domains) { + return false + } + + // With how we perform resolution order shouldn't matter, + // but it is unlikely that we will encounter different orders. + for i, server := range lhs.Nameservers { + if rhs.Nameservers[i] != server { + return false + } + } + + // The order of domains, on the other hand, is significant. + for i, domain := range lhs.Domains { + if rhs.Domains[i] != domain { + return false + } + } + + return true +} + +// ManagerConfig is the set of parameters from which +// a manager implementation is chosen and initialized. +type ManagerConfig struct { + // Logf is the logger for the manager to use. + // It is wrapped with a "dns: " prefix. + Logf logger.Logf + // InterfaceName is the name of the interface with which DNS settings should be associated. + InterfaceName string + // Cleanup indicates that the manager is created for cleanup only. + // A no-op manager will be instantiated if the system needs no cleanup. + Cleanup bool + // PerDomain indicates that a manager capable of per-domain configuration is preferred. + // Certain managers are per-domain only; they will not be considered if this is false. + PerDomain bool +} diff --git a/net/dns/direct.go b/net/dns/direct.go new file mode 100644 index 000000000..bd1c03b9d --- /dev/null +++ b/net/dns/direct.go @@ -0,0 +1,188 @@ +// Copyright (c) 2020 Tailscale Inc & AUTHORS All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// +build linux freebsd openbsd + +package dns + +import ( + "bufio" + "bytes" + "errors" + "fmt" + "io" + "io/ioutil" + "os" + "os/exec" + "runtime" + "strings" + + "inet.af/netaddr" + "tailscale.com/atomicfile" +) + +const ( + tsConf = "/etc/resolv.tailscale.conf" + backupConf = "/etc/resolv.pre-tailscale-backup.conf" + resolvConf = "/etc/resolv.conf" +) + +// writeResolvConf writes DNS configuration in resolv.conf format to the given writer. +func writeResolvConf(w io.Writer, servers []netaddr.IP, domains []string) { + io.WriteString(w, "# resolv.conf(5) file generated by tailscale\n") + io.WriteString(w, "# DO NOT EDIT THIS FILE BY HAND -- CHANGES WILL BE OVERWRITTEN\n\n") + for _, ns := range servers { + io.WriteString(w, "nameserver ") + io.WriteString(w, ns.String()) + io.WriteString(w, "\n") + } + if len(domains) > 0 { + io.WriteString(w, "search") + for _, domain := range domains { + io.WriteString(w, " ") + io.WriteString(w, domain) + } + io.WriteString(w, "\n") + } +} + +// readResolvConf reads DNS configuration from /etc/resolv.conf. +func readResolvConf() (Config, error) { + var config Config + + f, err := os.Open("/etc/resolv.conf") + if err != nil { + return config, err + } + + scanner := bufio.NewScanner(f) + for scanner.Scan() { + line := strings.TrimSpace(scanner.Text()) + + if strings.HasPrefix(line, "nameserver") { + nameserver := strings.TrimPrefix(line, "nameserver") + nameserver = strings.TrimSpace(nameserver) + ip, err := netaddr.ParseIP(nameserver) + if err != nil { + return config, err + } + config.Nameservers = append(config.Nameservers, ip) + continue + } + + if strings.HasPrefix(line, "search") { + domain := strings.TrimPrefix(line, "search") + domain = strings.TrimSpace(domain) + config.Domains = append(config.Domains, domain) + continue + } + } + + return config, nil +} + +// isResolvedRunning reports whether systemd-resolved is running on the system, +// even if it is not managing the system DNS settings. +func isResolvedRunning() bool { + if runtime.GOOS != "linux" { + return false + } + + // systemd-resolved is never installed without systemd. + _, err := exec.LookPath("systemctl") + if err != nil { + return false + } + + // is-active exits with code 3 if the service is not active. + err = exec.Command("systemctl", "is-active", "systemd-resolved.service").Run() + + return err == nil +} + +// directManager is a managerImpl which replaces /etc/resolv.conf with a file +// generated from the given configuration, creating a backup of its old state. +// +// This way of configuring DNS is precarious, since it does not react +// to the disappearance of the Tailscale interface. +// The caller must call Down before program shutdown +// or as cleanup if the program terminates unexpectedly. +type directManager struct{} + +func newDirectManager(mconfig ManagerConfig) managerImpl { + return directManager{} +} + +// Up implements managerImpl. +func (m directManager) Up(config Config) error { + // Write the tsConf file. + buf := new(bytes.Buffer) + writeResolvConf(buf, config.Nameservers, config.Domains) + if err := atomicfile.WriteFile(tsConf, buf.Bytes(), 0644); err != nil { + return err + } + + if linkPath, err := os.Readlink(resolvConf); err != nil { + // Remove any old backup that may exist. + os.Remove(backupConf) + + // Backup the existing /etc/resolv.conf file. + contents, err := ioutil.ReadFile(resolvConf) + // If the original did not exist, still back up an empty file. + // The presence of a backup file is the way we know that Up ran. + if err != nil && !errors.Is(err, os.ErrNotExist) { + return err + } + if err := atomicfile.WriteFile(backupConf, contents, 0644); err != nil { + return err + } + } else if linkPath != tsConf { + // Backup the existing symlink. + os.Remove(backupConf) + if err := os.Symlink(linkPath, backupConf); err != nil { + return err + } + } else { + // Nothing to do, resolvConf already points to tsConf. + return nil + } + + os.Remove(resolvConf) + if err := os.Symlink(tsConf, resolvConf); err != nil { + return err + } + + if isResolvedRunning() { + exec.Command("systemctl", "restart", "systemd-resolved.service").Run() // Best-effort. + } + + return nil +} + +// Down implements managerImpl. +func (m directManager) Down() error { + if _, err := os.Stat(backupConf); err != nil { + // If the backup file does not exist, then Up never ran successfully. + if os.IsNotExist(err) { + return nil + } + return err + } + + if ln, err := os.Readlink(resolvConf); err != nil { + return err + } else if ln != tsConf { + return fmt.Errorf("resolv.conf is not a symlink to %s", tsConf) + } + if err := os.Rename(backupConf, resolvConf); err != nil { + return err + } + os.Remove(tsConf) + + if isResolvedRunning() { + exec.Command("systemctl", "restart", "systemd-resolved.service").Run() // Best-effort. + } + + return nil +} diff --git a/net/dns/flush_windows.go b/net/dns/flush_windows.go new file mode 100644 index 000000000..3c7e7d645 --- /dev/null +++ b/net/dns/flush_windows.go @@ -0,0 +1,19 @@ +// Copyright (c) 2021 Tailscale Inc & AUTHORS All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package dns + +import ( + "fmt" + "os/exec" +) + +// Flush clears the local resolver cache. +func Flush() error { + out, err := exec.Command("ipconfig", "/flushdns").CombinedOutput() + if err != nil { + return fmt.Errorf("%v (output: %s)", err, out) + } + return nil +} diff --git a/net/dns/forwarder.go b/net/dns/forwarder.go new file mode 100644 index 000000000..519c00027 --- /dev/null +++ b/net/dns/forwarder.go @@ -0,0 +1,474 @@ +// Copyright (c) 2020 Tailscale Inc & AUTHORS All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package dns + +import ( + "bytes" + "context" + "encoding/binary" + "errors" + "fmt" + "hash/crc32" + "math/rand" + "net" + "os" + "sync" + "time" + + "inet.af/netaddr" + "tailscale.com/logtail/backoff" + "tailscale.com/net/netns" + "tailscale.com/types/logger" +) + +// headerBytes is the number of bytes in a DNS message header. +const headerBytes = 12 + +// connCount is the number of UDP connections to use for forwarding. +const connCount = 32 + +const ( + // cleanupInterval is the interval between purged of timed-out entries from txMap. + cleanupInterval = 30 * time.Second + // responseTimeout is the maximal amount of time to wait for a DNS response. + responseTimeout = 5 * time.Second +) + +var errNoUpstreams = errors.New("upstream nameservers not set") + +var aLongTimeAgo = time.Unix(0, 1) + +type forwardingRecord struct { + src netaddr.IPPort + createdAt time.Time +} + +// txid identifies a DNS transaction. +// +// As the standard DNS Request ID is only 16 bits, we extend it: +// the lower 32 bits are the zero-extended bits of the DNS Request ID; +// the upper 32 bits are the CRC32 checksum of the first question in the request. +// This makes probability of txid collision negligible. +type txid uint64 + +// getTxID computes the txid of the given DNS packet. +func getTxID(packet []byte) txid { + if len(packet) < headerBytes { + return 0 + } + + dnsid := binary.BigEndian.Uint16(packet[0:2]) + qcount := binary.BigEndian.Uint16(packet[4:6]) + if qcount == 0 { + return txid(dnsid) + } + + offset := headerBytes + for i := uint16(0); i < qcount; i++ { + // Note: this relies on the fact that names are not compressed in questions, + // so they are guaranteed to end with a NUL byte. + // + // Justification: + // RFC 1035 doesn't seem to explicitly prohibit compressing names in questions, + // but this is exceedingly unlikely to be done in practice. A DNS request + // with multiple questions is ill-defined (which questions do the header flags apply to?) + // and a single question would have to contain a pointer to an *answer*, + // which would be excessively smart, pointless (an answer can just as well refer to the question) + // and perhaps even prohibited: a draft RFC (draft-ietf-dnsind-local-compression-05) states: + // + // > It is important that these pointers always point backwards. + // + // This is said in summarizing RFC 1035, although that phrase does not appear in the original RFC. + // Additionally, (https://cr.yp.to/djbdns/notes.html) states: + // + // > The precise rule is that a name can be compressed if it is a response owner name, + // > the name in NS data, the name in CNAME data, the name in PTR data, the name in MX data, + // > or one of the names in SOA data. + namebytes := bytes.IndexByte(packet[offset:], 0) + // ... | name | NUL | type | class + // ?? 1 2 2 + offset = offset + namebytes + 5 + if len(packet) < offset { + // Corrupt packet; don't crash. + return txid(dnsid) + } + } + + hash := crc32.ChecksumIEEE(packet[headerBytes:offset]) + return (txid(hash) << 32) | txid(dnsid) +} + +// forwarder forwards DNS packets to a number of upstream nameservers. +type forwarder struct { + logf logger.Logf + + // responses is a channel by which responses are returned. + responses chan Packet + // closed signals all goroutines to stop. + closed chan struct{} + // wg signals when all goroutines have stopped. + wg sync.WaitGroup + + // conns are the UDP connections used for forwarding. + // A random one is selected for each request, regardless of the target upstream. + conns []*fwdConn + + mu sync.Mutex + // upstreams are the nameserver addresses that should be used for forwarding. + upstreams []net.Addr + // txMap maps DNS txids to active forwarding records. + txMap map[txid]forwardingRecord +} + +func init() { + rand.Seed(time.Now().UnixNano()) +} + +func newForwarder(logf logger.Logf, responses chan Packet) *forwarder { + return &forwarder{ + logf: logger.WithPrefix(logf, "forward: "), + responses: responses, + closed: make(chan struct{}), + conns: make([]*fwdConn, connCount), + txMap: make(map[txid]forwardingRecord), + } +} + +func (f *forwarder) Start() error { + f.wg.Add(connCount + 1) + for idx := range f.conns { + f.conns[idx] = newFwdConn(f.logf, idx) + go f.recv(f.conns[idx]) + } + go f.cleanMap() + + return nil +} + +func (f *forwarder) Close() { + select { + case <-f.closed: + return + default: + // continue + } + close(f.closed) + + for _, conn := range f.conns { + conn.close() + } + + f.wg.Wait() +} + +func (f *forwarder) rebindFromNetworkChange() { + for _, c := range f.conns { + c.mu.Lock() + c.reconnectLocked() + c.mu.Unlock() + } +} + +func (f *forwarder) setUpstreams(upstreams []net.Addr) { + f.mu.Lock() + f.upstreams = upstreams + f.mu.Unlock() +} + +// send sends packet to dst. It is best effort. +func (f *forwarder) send(packet []byte, dst net.Addr) { + connIdx := rand.Intn(connCount) + conn := f.conns[connIdx] + conn.send(packet, dst) +} + +func (f *forwarder) recv(conn *fwdConn) { + defer f.wg.Done() + + for { + select { + case <-f.closed: + return + default: + } + out := make([]byte, maxResponseBytes) + n := conn.read(out) + if n == 0 { + continue + } + if n < headerBytes { + f.logf("recv: packet too small (%d bytes)", n) + } + + out = out[:n] + txid := getTxID(out) + + f.mu.Lock() + + record, found := f.txMap[txid] + // At most one nameserver will return a response: + // the first one to do so will delete txid from the map. + if !found { + f.mu.Unlock() + continue + } + delete(f.txMap, txid) + + f.mu.Unlock() + + packet := Packet{ + Payload: out, + Addr: record.src, + } + select { + case <-f.closed: + return + case f.responses <- packet: + // continue + } + } +} + +// cleanMap periodically deletes timed-out forwarding records from f.txMap to bound growth. +func (f *forwarder) cleanMap() { + defer f.wg.Done() + + t := time.NewTicker(cleanupInterval) + defer t.Stop() + + var now time.Time + for { + select { + case <-f.closed: + return + case now = <-t.C: + // continue + } + + f.mu.Lock() + for k, v := range f.txMap { + if now.Sub(v.createdAt) > responseTimeout { + delete(f.txMap, k) + } + } + f.mu.Unlock() + } +} + +// forward forwards the query to all upstream nameservers and returns the first response. +func (f *forwarder) forward(query Packet) error { + txid := getTxID(query.Payload) + + f.mu.Lock() + + upstreams := f.upstreams + if len(upstreams) == 0 { + f.mu.Unlock() + return errNoUpstreams + } + f.txMap[txid] = forwardingRecord{ + src: query.Addr, + createdAt: time.Now(), + } + + f.mu.Unlock() + + for _, upstream := range upstreams { + f.send(query.Payload, upstream) + } + + return nil +} + +// A fwdConn manages a single connection used to forward DNS requests. +// Net link changes can cause a *net.UDPConn to become permanently unusable, particularly on macOS. +// fwdConn detects such situations and transparently creates new connections. +type fwdConn struct { + // logf allows a fwdConn to log. + logf logger.Logf + + // wg tracks the number of outstanding conn.Read and conn.Write calls. + wg sync.WaitGroup + // change allows calls to read to block until a the network connection has been replaced. + change *sync.Cond + + // mu protects fields that follow it; it is also change's Locker. + mu sync.Mutex + // closed tracks whether fwdConn has been permanently closed. + closed bool + // conn is the current active connection. + conn net.PacketConn +} + +func newFwdConn(logf logger.Logf, idx int) *fwdConn { + c := new(fwdConn) + c.logf = logger.WithPrefix(logf, fmt.Sprintf("fwdConn %d: ", idx)) + c.change = sync.NewCond(&c.mu) + // c.conn is created lazily in send + return c +} + +// send sends packet to dst using c's connection. +// It is best effort. It is UDP, after all. Failures are logged. +func (c *fwdConn) send(packet []byte, dst net.Addr) { + var b *backoff.Backoff // lazily initialized, since it is not needed in the common case + backOff := func(err error) { + if b == nil { + b = backoff.NewBackoff("dns-fwdConn-send", c.logf, 30*time.Second) + } + b.BackOff(context.Background(), err) + } + + for { + // Gather the current connection. + // We can't hold the lock while we call WriteTo. + c.mu.Lock() + conn := c.conn + closed := c.closed + if closed { + c.mu.Unlock() + return + } + if conn == nil { + c.reconnectLocked() + c.mu.Unlock() + continue + } + c.mu.Unlock() + + c.wg.Add(1) + _, err := conn.WriteTo(packet, dst) + c.wg.Done() + if err == nil { + // Success + return + } + if errors.Is(err, os.ErrDeadlineExceeded) { + // We intentionally closed this connection. + // It has been replaced by a new connection. Try again. + continue + } + // Something else went wrong. + // We have three choices here: try again, give up, or create a new connection. + var opErr *net.OpError + if !errors.As(err, &opErr) { + // Weird. All errors from the net package should be *net.OpError. Bail. + c.logf("send: non-*net.OpErr %v (%T)", err, err) + return + } + if opErr.Temporary() || opErr.Timeout() { + // I doubt that either of these can happen (this is UDP), + // but go ahead and try again. + backOff(err) + continue + } + if networkIsDown(err) { + // Fail. + c.logf("send: network is down") + return + } + if networkIsUnreachable(err) { + // This can be caused by a link change. + // Replace the existing connection with a new one. + c.mu.Lock() + // It's possible that multiple senders discovered simultaneously + // that the network is unreachable. Avoid reconnecting multiple times: + // Only reconnect if the current connection is the one that we + // discovered to be problematic. + if c.conn == conn { + backOff(err) + c.reconnectLocked() + } + c.mu.Unlock() + // Try again with our new network connection. + continue + } + // Unrecognized error. Fail. + c.logf("send: unrecognized error: %v", err) + return + } +} + +// read waits for a response from c's connection. +// It returns the number of bytes read, which may be 0 +// in case of an error or a closed connection. +func (c *fwdConn) read(out []byte) int { + for { + // Gather the current connection. + // We can't hold the lock while we call ReadFrom. + c.mu.Lock() + conn := c.conn + closed := c.closed + if closed { + c.mu.Unlock() + return 0 + } + if conn == nil { + // There is no current connection. + // Wait for the connection to change, then try again. + c.change.Wait() + c.mu.Unlock() + continue + } + c.mu.Unlock() + + c.wg.Add(1) + n, _, err := conn.ReadFrom(out) + c.wg.Done() + if err == nil { + // Success. + return n + } + if errors.Is(err, os.ErrDeadlineExceeded) { + // We intentionally closed this connection. + // It has been replaced by a new connection. Try again. + continue + } + + c.logf("read: unrecognized error: %v", err) + return 0 + } +} + +// reconnectLocked replaces the current connection with a new one. +// c.mu must be locked. +func (c *fwdConn) reconnectLocked() { + c.closeConnLocked() + // Make a new connection. + conn, err := netns.Listener().ListenPacket(context.Background(), "udp", "") + if err != nil { + c.logf("ListenPacket failed: %v", err) + } else { + c.conn = conn + } + // Broadcast that a new connection is available. + c.change.Broadcast() +} + +// closeCurrentConn closes the current connection. +// c.mu must be locked. +func (c *fwdConn) closeConnLocked() { + if c.conn == nil { + return + } + // Unblock all readers/writers, wait for them, close the connection. + c.conn.SetDeadline(aLongTimeAgo) + c.wg.Wait() + c.conn.Close() + c.conn = nil +} + +// close permanently closes c. +func (c *fwdConn) close() { + c.mu.Lock() + defer c.mu.Unlock() + if c.closed { + return + } + c.closed = true + c.closeConnLocked() + // Unblock any remaining readers. + c.change.Broadcast() +} diff --git a/net/dns/manager.go b/net/dns/manager.go new file mode 100644 index 000000000..8e2fc9d71 --- /dev/null +++ b/net/dns/manager.go @@ -0,0 +1,100 @@ +// Copyright (c) 2020 Tailscale Inc & AUTHORS All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package dns + +import ( + "time" + + "tailscale.com/types/logger" +) + +// We use file-ignore below instead of ignore because on some platforms, +// the lint exception is necessary and on others it is not, +// and plain ignore complains if the exception is unnecessary. + +//lint:file-ignore U1000 reconfigTimeout is used on some platforms but not others + +// reconfigTimeout is the time interval within which Manager.{Up,Down} should complete. +// +// This is particularly useful because certain conditions can cause indefinite hangs +// (such as improper dbus auth followed by contextless dbus.Object.Call). +// Such operations should be wrapped in a timeout context. +const reconfigTimeout = time.Second + +type managerImpl interface { + // Up updates system DNS settings to match the given configuration. + Up(Config) error + // Down undoes the effects of Up. + // It is idempotent and performs no action if Up has never been called. + Down() error +} + +// Manager manages system DNS settings. +type Manager struct { + logf logger.Logf + + impl managerImpl + + config Config + mconfig ManagerConfig +} + +// NewManagers created a new manager from the given config. +func NewManager(mconfig ManagerConfig) *Manager { + mconfig.Logf = logger.WithPrefix(mconfig.Logf, "dns: ") + m := &Manager{ + logf: mconfig.Logf, + impl: newManager(mconfig), + + config: Config{PerDomain: mconfig.PerDomain}, + mconfig: mconfig, + } + + m.logf("using %T", m.impl) + return m +} + +func (m *Manager) Set(config Config) error { + if config.Equal(m.config) { + return nil + } + + m.logf("Set: %+v", config) + + if len(config.Nameservers) == 0 { + err := m.impl.Down() + // If we save the config, we will not retry next time. Only do this on success. + if err == nil { + m.config = config + } + return err + } + + // Switching to and from per-domain mode may require a change of manager. + if config.PerDomain != m.config.PerDomain { + if err := m.impl.Down(); err != nil { + return err + } + m.mconfig.PerDomain = config.PerDomain + m.impl = newManager(m.mconfig) + m.logf("switched to %T", m.impl) + } + + err := m.impl.Up(config) + // If we save the config, we will not retry next time. Only do this on success. + if err == nil { + m.config = config + } + + return err +} + +func (m *Manager) Up() error { + return m.impl.Up(m.config) +} + +func (m *Manager) Down() error { + return m.impl.Down() +} diff --git a/net/dns/manager_default.go b/net/dns/manager_default.go new file mode 100644 index 000000000..04c8bb811 --- /dev/null +++ b/net/dns/manager_default.go @@ -0,0 +1,14 @@ +// Copyright (c) 2020 Tailscale Inc & AUTHORS All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// +build !linux,!freebsd,!openbsd,!windows + +package dns + +func newManager(mconfig ManagerConfig) managerImpl { + // TODO(dmytro): on darwin, we should use a macOS-specific method such as scutil. + // This is currently not implemented. Editing /etc/resolv.conf does not work, + // as most applications use the system resolver, which disregards it. + return newNoopManager(mconfig) +} diff --git a/net/dns/manager_freebsd.go b/net/dns/manager_freebsd.go new file mode 100644 index 000000000..232635f7e --- /dev/null +++ b/net/dns/manager_freebsd.go @@ -0,0 +1,14 @@ +// Copyright (c) 2020 Tailscale Inc & AUTHORS All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package dns + +func newManager(mconfig ManagerConfig) managerImpl { + switch { + case isResolvconfActive(): + return newResolvconfManager(mconfig) + default: + return newDirectManager(mconfig) + } +} diff --git a/net/dns/manager_linux.go b/net/dns/manager_linux.go new file mode 100644 index 000000000..f53aed7d3 --- /dev/null +++ b/net/dns/manager_linux.go @@ -0,0 +1,27 @@ +// Copyright (c) 2020 Tailscale Inc & AUTHORS All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package dns + +func newManager(mconfig ManagerConfig) managerImpl { + switch { + // systemd-resolved should only activate per-domain. + case isResolvedActive() && mconfig.PerDomain: + if mconfig.Cleanup { + return newNoopManager(mconfig) + } else { + return newResolvedManager(mconfig) + } + case isNMActive(): + if mconfig.Cleanup { + return newNoopManager(mconfig) + } else { + return newNMManager(mconfig) + } + case isResolvconfActive(): + return newResolvconfManager(mconfig) + default: + return newDirectManager(mconfig) + } +} diff --git a/net/dns/manager_openbsd.go b/net/dns/manager_openbsd.go new file mode 100644 index 000000000..228e3cca5 --- /dev/null +++ b/net/dns/manager_openbsd.go @@ -0,0 +1,9 @@ +// Copyright (c) 2020 Tailscale Inc & AUTHORS All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package dns + +func newManager(mconfig ManagerConfig) managerImpl { + return newDirectManager(mconfig) +} diff --git a/net/dns/manager_windows.go b/net/dns/manager_windows.go new file mode 100644 index 000000000..5940404e7 --- /dev/null +++ b/net/dns/manager_windows.go @@ -0,0 +1,118 @@ +// Copyright (c) 2020 Tailscale Inc & AUTHORS All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package dns + +import ( + "fmt" + "os/exec" + "strings" + "syscall" + "time" + + "golang.org/x/sys/windows/registry" + "tailscale.com/types/logger" +) + +const ( + ipv4RegBase = `SYSTEM\CurrentControlSet\Services\Tcpip\Parameters` + ipv6RegBase = `SYSTEM\CurrentControlSet\Services\Tcpip6\Parameters` +) + +type windowsManager struct { + logf logger.Logf + guid string +} + +func newManager(mconfig ManagerConfig) managerImpl { + return windowsManager{ + logf: mconfig.Logf, + guid: mconfig.InterfaceName, + } +} + +// keyOpenTimeout is how long we wait for a registry key to +// appear. For some reason, registry keys tied to ephemeral interfaces +// can take a long while to appear after interface creation, and we +// can end up racing with that. +const keyOpenTimeout = 20 * time.Second + +func setRegistryString(path, name, value string) error { + key, err := openKeyWait(registry.LOCAL_MACHINE, path, registry.SET_VALUE, keyOpenTimeout) + if err != nil { + return fmt.Errorf("opening %s: %w", path, err) + } + defer key.Close() + + err = key.SetStringValue(name, value) + if err != nil { + return fmt.Errorf("setting %s[%s]: %w", path, name, err) + } + return nil +} + +func (m windowsManager) setNameservers(basePath string, nameservers []string) error { + path := fmt.Sprintf(`%s\Interfaces\%s`, basePath, m.guid) + value := strings.Join(nameservers, ",") + return setRegistryString(path, "NameServer", value) +} + +func (m windowsManager) setDomains(basePath string, domains []string) error { + path := fmt.Sprintf(`%s\Interfaces\%s`, basePath, m.guid) + value := strings.Join(domains, ",") + return setRegistryString(path, "SearchList", value) +} + +func (m windowsManager) Up(config Config) error { + var ipsv4 []string + var ipsv6 []string + + for _, ip := range config.Nameservers { + if ip.Is4() { + ipsv4 = append(ipsv4, ip.String()) + } else { + ipsv6 = append(ipsv6, ip.String()) + } + } + + if err := m.setNameservers(ipv4RegBase, ipsv4); err != nil { + return err + } + if err := m.setDomains(ipv4RegBase, config.Domains); err != nil { + return err + } + + if err := m.setNameservers(ipv6RegBase, ipsv6); err != nil { + return err + } + if err := m.setDomains(ipv6RegBase, config.Domains); err != nil { + return err + } + + // Force DNS re-registration in Active Directory. What we actually + // care about is that this command invokes the undocumented hidden + // function that forces Windows to notice that adapter settings + // have changed, which makes the DNS settings actually take + // effect. + // + // This command can take a few seconds to run, so run it async, best effort. + go func() { + t0 := time.Now() + m.logf("running ipconfig /registerdns ...") + cmd := exec.Command("ipconfig", "/registerdns") + cmd.SysProcAttr = &syscall.SysProcAttr{HideWindow: true} + d := time.Since(t0).Round(time.Millisecond) + if err := cmd.Run(); err != nil { + m.logf("error running ipconfig /registerdns after %v: %v", d, err) + } else { + m.logf("ran ipconfig /registerdns in %v", d) + } + }() + + return nil +} + +func (m windowsManager) Down() error { + return m.Up(Config{Nameservers: nil, Domains: nil}) +} diff --git a/net/dns/map.go b/net/dns/map.go new file mode 100644 index 000000000..119b6cc0a --- /dev/null +++ b/net/dns/map.go @@ -0,0 +1,160 @@ +// Copyright (c) 2020 Tailscale Inc & AUTHORS All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package dns + +import ( + "sort" + "strings" + + "inet.af/netaddr" +) + +// Map is all the data Resolver needs to resolve DNS queries within the Tailscale network. +type Map struct { + // nameToIP is a mapping of Tailscale domain names to their IP addresses. + // For example, monitoring.tailscale.us -> 100.64.0.1. + nameToIP map[string]netaddr.IP + // ipToName is the inverse of nameToIP. + ipToName map[netaddr.IP]string + // names are the keys of nameToIP in sorted order. + names []string + // rootDomains are the domains whose subdomains should always + // be resolved locally to prevent leakage of sensitive names. + rootDomains []string // e.g. "user.provider.beta.tailscale.net." +} + +// NewMap returns a new Map with name to address mapping given by nameToIP. +// +// rootDomains are the domains whose subdomains should always be +// resolved locally to prevent leakage of sensitive names. They should +// end in a period ("user-foo.tailscale.net."). +func NewMap(initNameToIP map[string]netaddr.IP, rootDomains []string) *Map { + // TODO(dmytro): we have to allocate names and ipToName, but nameToIP can be avoided. + // It is here because control sends us names not in canonical form. Change this. + names := make([]string, 0, len(initNameToIP)) + nameToIP := make(map[string]netaddr.IP, len(initNameToIP)) + ipToName := make(map[netaddr.IP]string, len(initNameToIP)) + + for name, ip := range initNameToIP { + if len(name) == 0 { + // Nothing useful can be done with empty names. + continue + } + if name[len(name)-1] != '.' { + name += "." + } + names = append(names, name) + nameToIP[name] = ip + ipToName[ip] = name + } + sort.Strings(names) + + return &Map{ + nameToIP: nameToIP, + ipToName: ipToName, + names: names, + + rootDomains: rootDomains, + } +} + +func printSingleNameIP(buf *strings.Builder, name string, ip netaddr.IP) { + buf.WriteString(name) + buf.WriteByte('\t') + buf.WriteString(ip.String()) + buf.WriteByte('\n') +} + +func (m *Map) Pretty() string { + buf := new(strings.Builder) + for _, name := range m.names { + printSingleNameIP(buf, name, m.nameToIP[name]) + } + return buf.String() +} + +func (m *Map) PrettyDiffFrom(old *Map) string { + var ( + oldNameToIP map[string]netaddr.IP + newNameToIP map[string]netaddr.IP + oldNames []string + newNames []string + ) + if old != nil { + oldNameToIP = old.nameToIP + oldNames = old.names + } + if m != nil { + newNameToIP = m.nameToIP + newNames = m.names + } + + buf := new(strings.Builder) + space := func() bool { + return buf.Len() < (1 << 10) + } + + for len(oldNames) > 0 && len(newNames) > 0 { + var name string + + newName, oldName := newNames[0], oldNames[0] + switch { + case oldName < newName: + name = oldName + oldNames = oldNames[1:] + case oldName > newName: + name = newName + newNames = newNames[1:] + case oldNames[0] == newNames[0]: + name = oldNames[0] + oldNames = oldNames[1:] + newNames = newNames[1:] + } + if !space() { + continue + } + + ipOld, inOld := oldNameToIP[name] + ipNew, inNew := newNameToIP[name] + switch { + case !inOld: + buf.WriteByte('+') + printSingleNameIP(buf, name, ipNew) + case !inNew: + buf.WriteByte('-') + printSingleNameIP(buf, name, ipOld) + case ipOld != ipNew: + buf.WriteByte('-') + printSingleNameIP(buf, name, ipOld) + buf.WriteByte('+') + printSingleNameIP(buf, name, ipNew) + } + } + + for _, name := range oldNames { + if !space() { + break + } + if _, ok := newNameToIP[name]; !ok { + buf.WriteByte('-') + printSingleNameIP(buf, name, oldNameToIP[name]) + } + } + + for _, name := range newNames { + if !space() { + break + } + if _, ok := oldNameToIP[name]; !ok { + buf.WriteByte('+') + printSingleNameIP(buf, name, newNameToIP[name]) + } + } + if !space() { + buf.WriteString("... [truncated]\n") + } + + return buf.String() +} diff --git a/net/dns/map_test.go b/net/dns/map_test.go new file mode 100644 index 000000000..c438f95a0 --- /dev/null +++ b/net/dns/map_test.go @@ -0,0 +1,156 @@ +// Copyright (c) 2020 Tailscale Inc & AUTHORS All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package dns + +import ( + "fmt" + "strings" + "testing" + + "inet.af/netaddr" +) + +func TestPretty(t *testing.T) { + tests := []struct { + name string + dmap *Map + want string + }{ + {"empty", NewMap(nil, nil), ""}, + { + "single", + NewMap(map[string]netaddr.IP{ + "hello.ipn.dev.": netaddr.IPv4(100, 101, 102, 103), + }, nil), + "hello.ipn.dev.\t100.101.102.103\n", + }, + { + "multiple", + NewMap(map[string]netaddr.IP{ + "test1.domain.": netaddr.IPv4(100, 101, 102, 103), + "test2.sub.domain.": netaddr.IPv4(100, 99, 9, 1), + }, nil), + "test1.domain.\t100.101.102.103\ntest2.sub.domain.\t100.99.9.1\n", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := tt.dmap.Pretty() + if tt.want != got { + t.Errorf("want %v; got %v", tt.want, got) + } + }) + } +} + +func TestPrettyDiffFrom(t *testing.T) { + tests := []struct { + name string + map1 *Map + map2 *Map + want string + }{ + { + "from_empty", + nil, + NewMap(map[string]netaddr.IP{ + "test1.ipn.dev.": netaddr.IPv4(100, 101, 102, 103), + "test2.ipn.dev.": netaddr.IPv4(100, 103, 102, 101), + }, nil), + "+test1.ipn.dev.\t100.101.102.103\n+test2.ipn.dev.\t100.103.102.101\n", + }, + { + "equal", + NewMap(map[string]netaddr.IP{ + "test1.ipn.dev.": netaddr.IPv4(100, 101, 102, 103), + "test2.ipn.dev.": netaddr.IPv4(100, 103, 102, 101), + }, nil), + NewMap(map[string]netaddr.IP{ + "test2.ipn.dev.": netaddr.IPv4(100, 103, 102, 101), + "test1.ipn.dev.": netaddr.IPv4(100, 101, 102, 103), + }, nil), + "", + }, + { + "changed_ip", + NewMap(map[string]netaddr.IP{ + "test1.ipn.dev.": netaddr.IPv4(100, 101, 102, 103), + "test2.ipn.dev.": netaddr.IPv4(100, 103, 102, 101), + }, nil), + NewMap(map[string]netaddr.IP{ + "test2.ipn.dev.": netaddr.IPv4(100, 104, 102, 101), + "test1.ipn.dev.": netaddr.IPv4(100, 101, 102, 103), + }, nil), + "-test2.ipn.dev.\t100.103.102.101\n+test2.ipn.dev.\t100.104.102.101\n", + }, + { + "new_domain", + NewMap(map[string]netaddr.IP{ + "test1.ipn.dev.": netaddr.IPv4(100, 101, 102, 103), + "test2.ipn.dev.": netaddr.IPv4(100, 103, 102, 101), + }, nil), + NewMap(map[string]netaddr.IP{ + "test3.ipn.dev.": netaddr.IPv4(100, 105, 106, 107), + "test2.ipn.dev.": netaddr.IPv4(100, 103, 102, 101), + "test1.ipn.dev.": netaddr.IPv4(100, 101, 102, 103), + }, nil), + "+test3.ipn.dev.\t100.105.106.107\n", + }, + { + "gone_domain", + NewMap(map[string]netaddr.IP{ + "test1.ipn.dev.": netaddr.IPv4(100, 101, 102, 103), + "test2.ipn.dev.": netaddr.IPv4(100, 103, 102, 101), + }, nil), + NewMap(map[string]netaddr.IP{ + "test1.ipn.dev.": netaddr.IPv4(100, 101, 102, 103), + }, nil), + "-test2.ipn.dev.\t100.103.102.101\n", + }, + { + "mixed", + NewMap(map[string]netaddr.IP{ + "test1.ipn.dev.": netaddr.IPv4(100, 101, 102, 103), + "test4.ipn.dev.": netaddr.IPv4(100, 107, 106, 105), + "test5.ipn.dev.": netaddr.IPv4(100, 64, 1, 1), + "test2.ipn.dev.": netaddr.IPv4(100, 103, 102, 101), + }, nil), + NewMap(map[string]netaddr.IP{ + "test2.ipn.dev.": netaddr.IPv4(100, 104, 102, 101), + "test1.ipn.dev.": netaddr.IPv4(100, 100, 101, 102), + "test3.ipn.dev.": netaddr.IPv4(100, 64, 1, 1), + }, nil), + "-test1.ipn.dev.\t100.101.102.103\n+test1.ipn.dev.\t100.100.101.102\n" + + "-test2.ipn.dev.\t100.103.102.101\n+test2.ipn.dev.\t100.104.102.101\n" + + "+test3.ipn.dev.\t100.64.1.1\n-test4.ipn.dev.\t100.107.106.105\n-test5.ipn.dev.\t100.64.1.1\n", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := tt.map2.PrettyDiffFrom(tt.map1) + if tt.want != got { + t.Errorf("want %v; got %v", tt.want, got) + } + }) + } + + t.Run("truncated", func(t *testing.T) { + small := NewMap(nil, nil) + m := map[string]netaddr.IP{} + for i := 0; i < 5000; i++ { + m[fmt.Sprintf("host%d.ipn.dev.", i)] = netaddr.IPv4(100, 64, 1, 1) + } + veryBig := NewMap(m, nil) + diff := veryBig.PrettyDiffFrom(small) + if len(diff) > 3<<10 { + t.Errorf("pretty diff too large: %d bytes", len(diff)) + } + if !strings.Contains(diff, "truncated") { + t.Errorf("big diff not truncated") + } + }) +} diff --git a/net/dns/neterr_darwin.go b/net/dns/neterr_darwin.go new file mode 100644 index 000000000..7fd621fc7 --- /dev/null +++ b/net/dns/neterr_darwin.go @@ -0,0 +1,25 @@ +// Copyright (c) 2020 Tailscale Inc & AUTHORS All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package dns + +import ( + "errors" + "syscall" +) + +// Avoid allocation when calling errors.Is below +// by converting syscall.Errno to error here. +var ( + networkDown error = syscall.ENETDOWN + networkUnreachable error = syscall.ENETUNREACH +) + +func networkIsDown(err error) bool { + return errors.Is(err, networkDown) +} + +func networkIsUnreachable(err error) bool { + return errors.Is(err, networkUnreachable) +} diff --git a/net/dns/neterr_other.go b/net/dns/neterr_other.go new file mode 100644 index 000000000..b652f6e8b --- /dev/null +++ b/net/dns/neterr_other.go @@ -0,0 +1,10 @@ +// Copyright (c) 2020 Tailscale Inc & AUTHORS All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// +build !darwin,!windows + +package dns + +func networkIsDown(err error) bool { return false } +func networkIsUnreachable(err error) bool { return false } diff --git a/net/dns/neterr_windows.go b/net/dns/neterr_windows.go new file mode 100644 index 000000000..2b197ee2b --- /dev/null +++ b/net/dns/neterr_windows.go @@ -0,0 +1,29 @@ +// Copyright (c) 2020 Tailscale Inc & AUTHORS All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package dns + +import ( + "net" + "os" + + "golang.org/x/sys/windows" +) + +func networkIsDown(err error) bool { + if oe, ok := err.(*net.OpError); ok && oe.Op == "write" { + if se, ok := oe.Err.(*os.SyscallError); ok { + if se.Syscall == "wsasendto" && se.Err == windows.WSAENETUNREACH { + return true + } + } + } + return false +} + +func networkIsUnreachable(err error) bool { + // TODO(bradfitz,josharian): something here? what is the + // difference between down and unreachable? Add comments. + return false +} diff --git a/net/dns/nm.go b/net/dns/nm.go new file mode 100644 index 000000000..a597fa60d --- /dev/null +++ b/net/dns/nm.go @@ -0,0 +1,205 @@ +// Copyright (c) 2020 Tailscale Inc & AUTHORS All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// +build linux + +package dns + +import ( + "bufio" + "bytes" + "context" + "fmt" + "os" + "os/exec" + + "github.com/godbus/dbus/v5" + "tailscale.com/util/endian" +) + +// isNMActive determines if NetworkManager is currently managing system DNS settings. +func isNMActive() bool { + // This is somewhat tricky because NetworkManager supports a number + // of DNS configuration modes. In all cases, we expect it to be installed + // and /etc/resolv.conf to contain a mention of NetworkManager in the comments. + _, err := exec.LookPath("NetworkManager") + if err != nil { + return false + } + + f, err := os.Open("/etc/resolv.conf") + if err != nil { + return false + } + defer f.Close() + + scanner := bufio.NewScanner(f) + for scanner.Scan() { + line := scanner.Bytes() + // Look for the word "NetworkManager" until comments end. + if len(line) > 0 && line[0] != '#' { + return false + } + if bytes.Contains(line, []byte("NetworkManager")) { + return true + } + } + return false +} + +// nmManager uses the NetworkManager DBus API. +type nmManager struct { + interfaceName string +} + +func newNMManager(mconfig ManagerConfig) managerImpl { + return nmManager{ + interfaceName: mconfig.InterfaceName, + } +} + +type nmConnectionSettings map[string]map[string]dbus.Variant + +// Up implements managerImpl. +func (m nmManager) Up(config Config) error { + ctx, cancel := context.WithTimeout(context.Background(), reconfigTimeout) + defer cancel() + + // conn is a shared connection whose lifecycle is managed by the dbus package. + // We should not interfere with that by closing it. + conn, err := dbus.SystemBus() + if err != nil { + return fmt.Errorf("connecting to system bus: %w", err) + } + + // This is how we get at the DNS settings: + // + // org.freedesktop.NetworkManager + // | + // [GetDeviceByIpIface] + // | + // v + // org.freedesktop.NetworkManager.Device <--------\ + // (describes a network interface) | + // | | + // [GetAppliedConnection] [Reapply] + // | | + // v | + // org.freedesktop.NetworkManager.Connection | + // (connection settings) ------/ + // contains {dns, dns-priority, dns-search} + // + // Ref: https://developer.gnome.org/NetworkManager/stable/settings-ipv4.html. + + nm := conn.Object( + "org.freedesktop.NetworkManager", + dbus.ObjectPath("/org/freedesktop/NetworkManager"), + ) + + var devicePath dbus.ObjectPath + err = nm.CallWithContext( + ctx, "org.freedesktop.NetworkManager.GetDeviceByIpIface", 0, + m.interfaceName, + ).Store(&devicePath) + if err != nil { + return fmt.Errorf("getDeviceByIpIface: %w", err) + } + device := conn.Object("org.freedesktop.NetworkManager", devicePath) + + var ( + settings nmConnectionSettings + version uint64 + ) + err = device.CallWithContext( + ctx, "org.freedesktop.NetworkManager.Device.GetAppliedConnection", 0, + uint32(0), + ).Store(&settings, &version) + if err != nil { + return fmt.Errorf("getAppliedConnection: %w", err) + } + + // Frustratingly, NetworkManager represents IPv4 addresses as uint32s, + // although IPv6 addresses are represented as byte arrays. + // Perform the conversion here. + var ( + dnsv4 []uint32 + dnsv6 [][]byte + ) + for _, ip := range config.Nameservers { + b := ip.As16() + if ip.Is4() { + dnsv4 = append(dnsv4, endian.Native.Uint32(b[12:])) + } else { + dnsv6 = append(dnsv6, b[:]) + } + } + + ipv4Map := settings["ipv4"] + ipv4Map["dns"] = dbus.MakeVariant(dnsv4) + ipv4Map["dns-search"] = dbus.MakeVariant(config.Domains) + // We should only request priority if we have nameservers to set. + if len(dnsv4) == 0 { + ipv4Map["dns-priority"] = dbus.MakeVariant(100) + } else { + // dns-priority = -1 ensures that we have priority + // over other interfaces, except those exploiting this same trick. + // Ref: https://bugs.launchpad.net/ubuntu/+source/network-manager/+bug/1211110/comments/92. + ipv4Map["dns-priority"] = dbus.MakeVariant(-1) + } + // In principle, we should not need set this to true, + // as our interface does not configure any automatic DNS settings (presumably via DHCP). + // All the same, better to be safe. + ipv4Map["ignore-auto-dns"] = dbus.MakeVariant(true) + + ipv6Map := settings["ipv6"] + // This is a hack. + // Methods "disabled", "ignore", "link-local" (IPv6 default) prevent us from setting DNS. + // It seems that our only recourse is "manual" or "auto". + // "manual" requires addresses, so we use "auto", which will assign us a random IPv6 /64. + ipv6Map["method"] = dbus.MakeVariant("auto") + // Our IPv6 config is a fake, so it should never become the default route. + ipv6Map["never-default"] = dbus.MakeVariant(true) + // Moreover, we should ignore all autoconfigured routes (hopefully none), as they are bogus. + ipv6Map["ignore-auto-routes"] = dbus.MakeVariant(true) + + // Finally, set the actual DNS config. + ipv6Map["dns"] = dbus.MakeVariant(dnsv6) + ipv6Map["dns-search"] = dbus.MakeVariant(config.Domains) + if len(dnsv6) == 0 { + ipv6Map["dns-priority"] = dbus.MakeVariant(100) + } else { + ipv6Map["dns-priority"] = dbus.MakeVariant(-1) + } + ipv6Map["ignore-auto-dns"] = dbus.MakeVariant(true) + + // deprecatedProperties are the properties in interface settings + // that are deprecated by NetworkManager. + // + // In practice, this means that they are returned for reading, + // but submitting a settings object with them present fails + // with hard-to-diagnose errors. They must be removed. + deprecatedProperties := []string{ + "addresses", "routes", + } + + for _, property := range deprecatedProperties { + delete(ipv4Map, property) + delete(ipv6Map, property) + } + + err = device.CallWithContext( + ctx, "org.freedesktop.NetworkManager.Device.Reapply", 0, + settings, version, uint32(0), + ).Store() + if err != nil { + return fmt.Errorf("reapply: %w", err) + } + + return nil +} + +// Down implements managerImpl. +func (m nmManager) Down() error { + return m.Up(Config{Nameservers: nil, Domains: nil}) +} diff --git a/net/dns/noop.go b/net/dns/noop.go new file mode 100644 index 000000000..35c07a232 --- /dev/null +++ b/net/dns/noop.go @@ -0,0 +1,17 @@ +// Copyright (c) 2020 Tailscale Inc & AUTHORS All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package dns + +type noopManager struct{} + +// Up implements managerImpl. +func (m noopManager) Up(Config) error { return nil } + +// Down implements managerImpl. +func (m noopManager) Down() error { return nil } + +func newNoopManager(mconfig ManagerConfig) managerImpl { + return noopManager{} +} diff --git a/net/dns/registry_windows.go b/net/dns/registry_windows.go new file mode 100644 index 000000000..f8e1f514a --- /dev/null +++ b/net/dns/registry_windows.go @@ -0,0 +1,76 @@ +// Copyright (c) 2020 Tailscale Inc & AUTHORS All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. +// +// The code in this file originates from https://git.zx2c4.com/wireguard-go: +// Copyright (C) 2017-2020 WireGuard LLC. All Rights Reserved. +// Copying license: https://git.zx2c4.com/wireguard-go/tree/COPYING + +package dns + +import ( + "fmt" + "runtime" + "strings" + "time" + + "golang.org/x/sys/windows" + "golang.org/x/sys/windows/registry" +) + +func openKeyWait(k registry.Key, path string, access uint32, timeout time.Duration) (registry.Key, error) { + runtime.LockOSThread() + defer runtime.UnlockOSThread() + + deadline := time.Now().Add(timeout) + pathSpl := strings.Split(path, "\\") + for i := 0; ; i++ { + keyName := pathSpl[i] + isLast := i+1 == len(pathSpl) + + event, err := windows.CreateEvent(nil, 0, 0, nil) + if err != nil { + return 0, fmt.Errorf("windows.CreateEvent: %v", err) + } + defer windows.CloseHandle(event) + + var key registry.Key + for { + err = windows.RegNotifyChangeKeyValue(windows.Handle(k), false, windows.REG_NOTIFY_CHANGE_NAME, event, true) + if err != nil { + return 0, fmt.Errorf("windows.RegNotifyChangeKeyValue: %v", err) + } + + var accessFlags uint32 + if isLast { + accessFlags = access + } else { + accessFlags = registry.NOTIFY + } + key, err = registry.OpenKey(k, keyName, accessFlags) + if err == windows.ERROR_FILE_NOT_FOUND || err == windows.ERROR_PATH_NOT_FOUND { + timeout := time.Until(deadline) / time.Millisecond + if timeout < 0 { + timeout = 0 + } + s, err := windows.WaitForSingleObject(event, uint32(timeout)) + if err != nil { + return 0, fmt.Errorf("windows.WaitForSingleObject: %v", err) + } + if s == uint32(windows.WAIT_TIMEOUT) { // windows.WAIT_TIMEOUT status const is misclassified as error in golang.org/x/sys/windows + return 0, fmt.Errorf("timeout waiting for registry key") + } + } else if err != nil { + return 0, fmt.Errorf("registry.OpenKey(%v): %v", path, err) + } else { + if isLast { + return key, nil + } + defer key.Close() + break + } + } + + k = key + } +} diff --git a/net/dns/resolvconf.go b/net/dns/resolvconf.go new file mode 100644 index 000000000..8bf97ee88 --- /dev/null +++ b/net/dns/resolvconf.go @@ -0,0 +1,157 @@ +// Copyright (c) 2020 Tailscale Inc & AUTHORS All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// +build linux freebsd + +package dns + +import ( + "bufio" + "bytes" + "fmt" + "os" + "os/exec" +) + +// isResolvconfActive indicates whether the system appears to be using resolvconf. +// If this is true, then directManager should be avoided: +// resolvconf has exclusive ownership of /etc/resolv.conf. +func isResolvconfActive() bool { + // Sanity-check first: if there is no resolvconf binary, then this is fruitless. + // + // However, this binary may be a shim like the one systemd-resolved provides. + // Such a shim may not behave as expected: in particular, systemd-resolved + // does not seem to respect the exclusive mode -x, saying: + // -x Send DNS traffic preferably over this interface + // whereas e.g. openresolv sends DNS traffix _exclusively_ over that interface, + // or not at all (in case of another exclusive-mode request later in time). + // + // Moreover, resolvconf may be installed but unused, in which case we should + // not use it either, lest we clobber existing configuration. + // + // To handle all the above correctly, we scan the comments in /etc/resolv.conf + // to ensure that it was generated by a resolvconf implementation. + _, err := exec.LookPath("resolvconf") + if err != nil { + return false + } + + f, err := os.Open("/etc/resolv.conf") + if err != nil { + return false + } + defer f.Close() + + scanner := bufio.NewScanner(f) + for scanner.Scan() { + line := scanner.Bytes() + // Look for the word "resolvconf" until comments end. + if len(line) > 0 && line[0] != '#' { + return false + } + if bytes.Contains(line, []byte("resolvconf")) { + return true + } + } + return false +} + +// resolvconfImpl enumerates supported implementations of the resolvconf CLI. +type resolvconfImpl uint8 + +const ( + // resolvconfOpenresolv is the implementation packaged as "openresolv" on Ubuntu. + // It supports exclusive mode and interface metrics. + resolvconfOpenresolv resolvconfImpl = iota + // resolvconfLegacy is the implementation by Thomas Hood packaged as "resolvconf" on Ubuntu. + // It does not support exclusive mode or interface metrics. + resolvconfLegacy +) + +func (impl resolvconfImpl) String() string { + switch impl { + case resolvconfOpenresolv: + return "openresolv" + case resolvconfLegacy: + return "legacy" + default: + return "unknown" + } +} + +// getResolvconfImpl returns the implementation of resolvconf that appears to be in use. +func getResolvconfImpl() resolvconfImpl { + err := exec.Command("resolvconf", "-v").Run() + if err != nil { + if exitErr, ok := err.(*exec.ExitError); ok { + // Thomas Hood's resolvconf has a minimal flag set + // and exits with code 99 when passed an unknown flag. + if exitErr.ExitCode() == 99 { + return resolvconfLegacy + } + } + } + return resolvconfOpenresolv +} + +type resolvconfManager struct { + impl resolvconfImpl +} + +func newResolvconfManager(mconfig ManagerConfig) managerImpl { + impl := getResolvconfImpl() + mconfig.Logf("resolvconf implementation is %s", impl) + + return resolvconfManager{ + impl: impl, + } +} + +// resolvconfConfigName is the name of the config submitted to resolvconf. +// It has this form to match the "tun*" rule in interface-order +// when running resolvconfLegacy, hopefully placing our config first. +const resolvconfConfigName = "tun-tailscale.inet" + +// Up implements managerImpl. +func (m resolvconfManager) Up(config Config) error { + stdin := new(bytes.Buffer) + writeResolvConf(stdin, config.Nameservers, config.Domains) // dns_direct.go + + var cmd *exec.Cmd + switch m.impl { + case resolvconfOpenresolv: + // Request maximal priority (metric 0) and exclusive mode. + cmd = exec.Command("resolvconf", "-m", "0", "-x", "-a", resolvconfConfigName) + case resolvconfLegacy: + // This does not quite give us the desired behavior (queries leak), + // but there is nothing else we can do without messing with other interfaces' settings. + cmd = exec.Command("resolvconf", "-a", resolvconfConfigName) + } + cmd.Stdin = stdin + out, err := cmd.CombinedOutput() + if err != nil { + return fmt.Errorf("running %s: %s", cmd, out) + } + + return nil +} + +// Down implements managerImpl. +func (m resolvconfManager) Down() error { + var cmd *exec.Cmd + switch m.impl { + case resolvconfOpenresolv: + cmd = exec.Command("resolvconf", "-f", "-d", resolvconfConfigName) + case resolvconfLegacy: + // resolvconfLegacy lacks the -f flag. + // Instead, it succeeds even when the config does not exist. + cmd = exec.Command("resolvconf", "-d", resolvconfConfigName) + } + out, err := cmd.CombinedOutput() + if err != nil { + return fmt.Errorf("running %s: %s", cmd, out) + } + + return nil +} diff --git a/net/dns/resolved.go b/net/dns/resolved.go new file mode 100644 index 000000000..9d8c40d90 --- /dev/null +++ b/net/dns/resolved.go @@ -0,0 +1,188 @@ +// Copyright (c) 2020 Tailscale Inc & AUTHORS All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// +build linux + +package dns + +import ( + "context" + "errors" + "fmt" + "os/exec" + + "github.com/godbus/dbus/v5" + "golang.org/x/sys/unix" + "inet.af/netaddr" + "tailscale.com/net/interfaces" +) + +// resolvedListenAddr is the listen address of the resolved stub resolver. +// +// We only consider resolved to be the system resolver if the stub resolver is; +// that is, if this address is the sole nameserver in /etc/resolved.conf. +// In other cases, resolved may be managing the system DNS configuration directly. +// Then the nameserver list will be a concatenation of those for all +// the interfaces that register their interest in being a default resolver with +// SetLinkDomains([]{{"~.", true}, ...}) +// which includes at least the interface with the default route, i.e. not us. +// This does not work for us: there is a possibility of getting NXDOMAIN +// from the other nameservers before we are asked or get a chance to respond. +// We consider this case as lacking resolved support and fall through to dnsDirect. +// +// While it may seem that we need to read a config option to get at this, +// this address is, in fact, hard-coded into resolved. +var resolvedListenAddr = netaddr.IPv4(127, 0, 0, 53) + +var errNotReady = errors.New("interface not ready") + +type resolvedLinkNameserver struct { + Family int32 + Address []byte +} + +type resolvedLinkDomain struct { + Domain string + RoutingOnly bool +} + +// isResolvedActive determines if resolved is currently managing system DNS settings. +func isResolvedActive() bool { + // systemd-resolved is never installed without systemd. + _, err := exec.LookPath("systemctl") + if err != nil { + return false + } + + // is-active exits with code 3 if the service is not active. + err = exec.Command("systemctl", "is-active", "systemd-resolved").Run() + if err != nil { + return false + } + + config, err := readResolvConf() + if err != nil { + return false + } + + // The sole nameserver must be the systemd-resolved stub. + if len(config.Nameservers) == 1 && config.Nameservers[0] == resolvedListenAddr { + return true + } + + return false +} + +// resolvedManager uses the systemd-resolved DBus API. +type resolvedManager struct{} + +func newResolvedManager(mconfig ManagerConfig) managerImpl { + return resolvedManager{} +} + +// Up implements managerImpl. +func (m resolvedManager) Up(config Config) error { + ctx, cancel := context.WithTimeout(context.Background(), reconfigTimeout) + defer cancel() + + // conn is a shared connection whose lifecycle is managed by the dbus package. + // We should not interfere with that by closing it. + conn, err := dbus.SystemBus() + if err != nil { + return fmt.Errorf("connecting to system bus: %w", err) + } + + resolved := conn.Object( + "org.freedesktop.resolve1", + dbus.ObjectPath("/org/freedesktop/resolve1"), + ) + + // In principle, we could persist this in the manager struct + // if we knew that interface indices are persistent. This does not seem to be the case. + _, iface, err := interfaces.Tailscale() + if err != nil { + return fmt.Errorf("getting interface index: %w", err) + } + if iface == nil { + return errNotReady + } + + var linkNameservers = make([]resolvedLinkNameserver, len(config.Nameservers)) + for i, server := range config.Nameservers { + ip := server.As16() + if server.Is4() { + linkNameservers[i] = resolvedLinkNameserver{ + Family: unix.AF_INET, + Address: ip[12:], + } + } else { + linkNameservers[i] = resolvedLinkNameserver{ + Family: unix.AF_INET6, + Address: ip[:], + } + } + } + + err = resolved.CallWithContext( + ctx, "org.freedesktop.resolve1.Manager.SetLinkDNS", 0, + iface.Index, linkNameservers, + ).Store() + if err != nil { + return fmt.Errorf("setLinkDNS: %w", err) + } + + var linkDomains = make([]resolvedLinkDomain, len(config.Domains)) + for i, domain := range config.Domains { + linkDomains[i] = resolvedLinkDomain{ + Domain: domain, + RoutingOnly: false, + } + } + + err = resolved.CallWithContext( + ctx, "org.freedesktop.resolve1.Manager.SetLinkDomains", 0, + iface.Index, linkDomains, + ).Store() + if err != nil { + return fmt.Errorf("setLinkDomains: %w", err) + } + + return nil +} + +// Down implements managerImpl. +func (m resolvedManager) Down() error { + ctx, cancel := context.WithTimeout(context.Background(), reconfigTimeout) + defer cancel() + + // conn is a shared connection whose lifecycle is managed by the dbus package. + // We should not interfere with that by closing it. + conn, err := dbus.SystemBus() + if err != nil { + return fmt.Errorf("connecting to system bus: %w", err) + } + + resolved := conn.Object( + "org.freedesktop.resolve1", + dbus.ObjectPath("/org/freedesktop/resolve1"), + ) + + _, iface, err := interfaces.Tailscale() + if err != nil { + return fmt.Errorf("getting interface index: %w", err) + } + if iface == nil { + return errNotReady + } + + err = resolved.CallWithContext( + ctx, "org.freedesktop.resolve1.Manager.RevertLink", 0, + iface.Index, + ).Store() + if err != nil { + return fmt.Errorf("RevertLink: %w", err) + } + + return nil +} diff --git a/net/dns/tsdns.go b/net/dns/tsdns.go new file mode 100644 index 000000000..2b530b81e --- /dev/null +++ b/net/dns/tsdns.go @@ -0,0 +1,662 @@ +// Copyright (c) 2020 Tailscale Inc & AUTHORS All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// Package dns provides a Resolver capable of resolving +// domains on a Tailscale network. +package dns + +import ( + "encoding/hex" + "errors" + "net" + "strings" + "sync" + "time" + + dns "golang.org/x/net/dns/dnsmessage" + "inet.af/netaddr" + "tailscale.com/net/interfaces" + "tailscale.com/types/logger" + "tailscale.com/util/dnsname" + "tailscale.com/wgengine/monitor" +) + +// maxResponseBytes is the maximum size of a response from a Resolver. +const maxResponseBytes = 512 + +// queueSize is the maximal number of DNS requests that can await polling. +// If EnqueueRequest is called when this many requests are already pending, +// the request will be dropped to avoid blocking the caller. +const queueSize = 64 + +// defaultTTL is the TTL of all responses from Resolver. +const defaultTTL = 600 * time.Second + +// ErrClosed indicates that the resolver has been closed and readers should exit. +var ErrClosed = errors.New("closed") + +var ( + errFullQueue = errors.New("request queue full") + errMapNotSet = errors.New("domain map not set") + errNotForwarding = errors.New("forwarding disabled") + errNotImplemented = errors.New("query type not implemented") + errNotQuery = errors.New("not a DNS query") + errNotOurName = errors.New("not a Tailscale DNS name") +) + +// Packet represents a DNS payload together with the address of its origin. +type Packet struct { + // Payload is the application layer DNS payload. + // Resolver assumes ownership of the request payload when it is enqueued + // and cedes ownership of the response payload when it is returned from NextResponse. + Payload []byte + // Addr is the source address for a request and the destination address for a response. + Addr netaddr.IPPort +} + +// Resolver is a DNS resolver for nodes on the Tailscale network, +// associating them with domain names of the form <mynode>.<mydomain>.<root>. +// If it is asked to resolve a domain that is not of that form, +// it delegates to upstream nameservers if any are set. +type Resolver struct { + logf logger.Logf + linkMon *monitor.Mon // or nil + unregLinkMon func() // or nil + // forwarder forwards requests to upstream nameservers. + forwarder *forwarder + + // queue is a buffered channel holding DNS requests queued for resolution. + queue chan Packet + // responses is an unbuffered channel to which responses are returned. + responses chan Packet + // errors is an unbuffered channel to which errors are returned. + errors chan error + // closed signals all goroutines to stop. + closed chan struct{} + // wg signals when all goroutines have stopped. + wg sync.WaitGroup + + // mu guards the following fields from being updated while used. + mu sync.Mutex + // dnsMap is the map most recently received from the control server. + dnsMap *Map +} + +// ResolverConfig is the set of configuration options for a Resolver. +type ResolverConfig struct { + // Logf is the logger to use throughout the Resolver. + Logf logger.Logf + // Forward determines whether the resolver will forward packets to + // nameservers set with SetUpstreams if the domain name is not of a Tailscale node. + Forward bool + // LinkMonitor optionally provides a link monitor to use to rebind + // connections on link changes. + // If nil, rebinds are not performend. + LinkMonitor *monitor.Mon +} + +// NewResolver constructs a resolver associated with the given root domain. +// The root domain must be in canonical form (with a trailing period). +func NewResolver(config ResolverConfig) *Resolver { + r := &Resolver{ + logf: logger.WithPrefix(config.Logf, "dns: "), + linkMon: config.LinkMonitor, + queue: make(chan Packet, queueSize), + responses: make(chan Packet), + errors: make(chan error), + closed: make(chan struct{}), + } + + if config.Forward { + r.forwarder = newForwarder(r.logf, r.responses) + } + if r.linkMon != nil { + r.unregLinkMon = r.linkMon.RegisterChangeCallback(r.onLinkMonitorChange) + } + + return r +} + +func (r *Resolver) Start() error { + if r.forwarder != nil { + if err := r.forwarder.Start(); err != nil { + return err + } + } + + r.wg.Add(1) + go r.poll() + + return nil +} + +// Close shuts down the resolver and ensures poll goroutines have exited. +// The Resolver cannot be used again after Close is called. +func (r *Resolver) Close() { + select { + case <-r.closed: + return + default: + // continue + } + close(r.closed) + + if r.unregLinkMon != nil { + r.unregLinkMon() + } + + if r.forwarder != nil { + r.forwarder.Close() + } + + r.wg.Wait() +} + +func (r *Resolver) onLinkMonitorChange(changed bool, state *interfaces.State) { + if !changed { + return + } + if r.forwarder != nil { + r.forwarder.rebindFromNetworkChange() + } +} + +// SetMap sets the resolver's DNS map, taking ownership of it. +func (r *Resolver) SetMap(m *Map) { + r.mu.Lock() + oldMap := r.dnsMap + r.dnsMap = m + r.mu.Unlock() + r.logf("map diff:\n%s", m.PrettyDiffFrom(oldMap)) +} + +// SetUpstreams sets the addresses of the resolver's +// upstream nameservers, taking ownership of the argument. +func (r *Resolver) SetUpstreams(upstreams []net.Addr) { + if r.forwarder != nil { + r.forwarder.setUpstreams(upstreams) + } + r.logf("set upstreams: %v", upstreams) +} + +// EnqueueRequest places the given DNS request in the resolver's queue. +// It takes ownership of the payload and does not block. +// If the queue is full, the request will be dropped and an error will be returned. +func (r *Resolver) EnqueueRequest(request Packet) error { + select { + case <-r.closed: + return ErrClosed + case r.queue <- request: + return nil + default: + return errFullQueue + } +} + +// NextResponse returns a DNS response to a previously enqueued request. +// It blocks until a response is available and gives up ownership of the response payload. +func (r *Resolver) NextResponse() (Packet, error) { + select { + case <-r.closed: + return Packet{}, ErrClosed + case resp := <-r.responses: + return resp, nil + case err := <-r.errors: + return Packet{}, err + } +} + +// Resolve maps a given domain name to the IP address of the host that owns it, +// if the IP address conforms to the DNS resource type given by tp (one of A, AAAA, ALL). +// The domain name must be in canonical form (with a trailing period). +func (r *Resolver) Resolve(domain string, tp dns.Type) (netaddr.IP, dns.RCode, error) { + r.mu.Lock() + dnsMap := r.dnsMap + r.mu.Unlock() + + if dnsMap == nil { + return netaddr.IP{}, dns.RCodeServerFailure, errMapNotSet + } + + // Reject .onion domains per RFC 7686. + if dnsname.HasSuffix(domain, ".onion") { + return netaddr.IP{}, dns.RCodeNameError, nil + } + + anyHasSuffix := false + for _, suffix := range dnsMap.rootDomains { + if dnsname.HasSuffix(domain, suffix) { + anyHasSuffix = true + break + } + } + addr, found := dnsMap.nameToIP[domain] + if !found { + if !anyHasSuffix { + return netaddr.IP{}, dns.RCodeRefused, nil + } + return netaddr.IP{}, dns.RCodeNameError, nil + } + + // Refactoring note: this must happen after we check suffixes, + // otherwise we will respond with NOTIMP to requests that should be forwarded. + switch tp { + case dns.TypeA: + if !addr.Is4() { + return netaddr.IP{}, dns.RCodeSuccess, nil + } + return addr, dns.RCodeSuccess, nil + case dns.TypeAAAA: + if !addr.Is6() { + return netaddr.IP{}, dns.RCodeSuccess, nil + } + return addr, dns.RCodeSuccess, nil + case dns.TypeALL: + // Answer with whatever we've got. + // It could be IPv4, IPv6, or a zero addr. + // TODO: Return all available resolutions (A and AAAA, if we have them). + return addr, dns.RCodeSuccess, nil + + // Leave some some record types explicitly unimplemented. + // These types relate to recursive resolution or special + // DNS sematics and might be implemented in the future. + case dns.TypeNS, dns.TypeSOA, dns.TypeAXFR, dns.TypeHINFO: + return netaddr.IP{}, dns.RCodeNotImplemented, errNotImplemented + + // For everything except for the few types above that are explictly not implemented, return no records. + // This is what other DNS systems do: always return NOERROR + // without any records whenever the requested record type is unknown. + // You can try this with: + // dig -t TYPE9824 example.com + // and note that NOERROR is returned, despite that record type being made up. + default: + // no records exist of this type + return netaddr.IP{}, dns.RCodeSuccess, nil + } +} + +// ResolveReverse returns the unique domain name that maps to the given address. +// The returned domain name is in canonical form (with a trailing period). +func (r *Resolver) ResolveReverse(ip netaddr.IP) (string, dns.RCode, error) { + r.mu.Lock() + dnsMap := r.dnsMap + r.mu.Unlock() + + if dnsMap == nil { + return "", dns.RCodeServerFailure, errMapNotSet + } + name, found := dnsMap.ipToName[ip] + if !found { + return "", dns.RCodeNameError, nil + } + return name, dns.RCodeSuccess, nil +} + +func (r *Resolver) poll() { + defer r.wg.Done() + + var packet Packet + for { + select { + case <-r.closed: + return + case packet = <-r.queue: + // continue + } + + out, err := r.respond(packet.Payload) + + if err == errNotOurName { + if r.forwarder != nil { + err = r.forwarder.forward(packet) + if err == nil { + // forward will send response into r.responses, nothing to do. + continue + } + } else { + err = errNotForwarding + } + } + + if err != nil { + select { + case <-r.closed: + return + case r.errors <- err: + // continue + } + } else { + packet.Payload = out + select { + case <-r.closed: + return + case r.responses <- packet: + // continue + } + } + } +} + +type response struct { + Header dns.Header + Question dns.Question + // Name is the response to a PTR query. + Name string + // IP is the response to an A, AAAA, or ALL query. + IP netaddr.IP +} + +// parseQuery parses the query in given packet into a response struct. +func parseQuery(query []byte, resp *response) error { + var parser dns.Parser + var err error + + resp.Header, err = parser.Start(query) + if err != nil { + return err + } + + if resp.Header.Response { + return errNotQuery + } + + resp.Question, err = parser.Question() + if err != nil { + return err + } + + return nil +} + +// marshalARecord serializes an A record into an active builder. +// The caller may continue using the builder following the call. +func marshalARecord(name dns.Name, ip netaddr.IP, builder *dns.Builder) error { + var answer dns.AResource + + answerHeader := dns.ResourceHeader{ + Name: name, + Type: dns.TypeA, + Class: dns.ClassINET, + TTL: uint32(defaultTTL / time.Second), + } + ipbytes := ip.As4() + copy(answer.A[:], ipbytes[:]) + return builder.AResource(answerHeader, answer) +} + +// marshalAAAARecord serializes an AAAA record into an active builder. +// The caller may continue using the builder following the call. +func marshalAAAARecord(name dns.Name, ip netaddr.IP, builder *dns.Builder) error { + var answer dns.AAAAResource + + answerHeader := dns.ResourceHeader{ + Name: name, + Type: dns.TypeAAAA, + Class: dns.ClassINET, + TTL: uint32(defaultTTL / time.Second), + } + ipbytes := ip.As16() + copy(answer.AAAA[:], ipbytes[:]) + return builder.AAAAResource(answerHeader, answer) +} + +// marshalPTRRecord serializes a PTR record into an active builder. +// The caller may continue using the builder following the call. +func marshalPTRRecord(queryName dns.Name, name string, builder *dns.Builder) error { + var answer dns.PTRResource + var err error + + answerHeader := dns.ResourceHeader{ + Name: queryName, + Type: dns.TypePTR, + Class: dns.ClassINET, + TTL: uint32(defaultTTL / time.Second), + } + answer.PTR, err = dns.NewName(name) + if err != nil { + return err + } + return builder.PTRResource(answerHeader, answer) +} + +// marshalResponse serializes the DNS response into a new buffer. +func marshalResponse(resp *response) ([]byte, error) { + resp.Header.Response = true + resp.Header.Authoritative = true + if resp.Header.RecursionDesired { + resp.Header.RecursionAvailable = true + } + + builder := dns.NewBuilder(nil, resp.Header) + + isSuccess := resp.Header.RCode == dns.RCodeSuccess + + if resp.Question.Type != 0 || isSuccess { + err := builder.StartQuestions() + if err != nil { + return nil, err + } + + err = builder.Question(resp.Question) + if err != nil { + return nil, err + } + } + + // Only successful responses contain answers. + if !isSuccess { + return builder.Finish() + } + + err := builder.StartAnswers() + if err != nil { + return nil, err + } + + switch resp.Question.Type { + case dns.TypeA, dns.TypeAAAA, dns.TypeALL: + if resp.IP.Is4() { + err = marshalARecord(resp.Question.Name, resp.IP, &builder) + } else if resp.IP.Is6() { + err = marshalAAAARecord(resp.Question.Name, resp.IP, &builder) + } + case dns.TypePTR: + err = marshalPTRRecord(resp.Question.Name, resp.Name, &builder) + } + if err != nil { + return nil, err + } + + return builder.Finish() +} + +const ( + rdnsv4Suffix = ".in-addr.arpa." + rdnsv6Suffix = ".ip6.arpa." +) + +// hasRDNSBonjourPrefix reports whether name has a Bonjour Service Prefix.. +// +// https://tools.ietf.org/html/rfc6763 lists +// "five special RR names" for Bonjour service discovery: +// +// b._dns-sd._udp.<domain>. +// db._dns-sd._udp.<domain>. +// r._dns-sd._udp.<domain>. +// dr._dns-sd._udp.<domain>. +// lb._dns-sd._udp.<domain>. +func hasRDNSBonjourPrefix(s string) bool { + // Even the shortest name containing a Bonjour prefix is long, + // so check length (cheap) and bail early if possible. + if len(s) < len("*._dns-sd._udp.0.0.0.0.in-addr.arpa.") { + return false + } + dot := strings.IndexByte(s, '.') + if dot == -1 { + return false // shouldn't happen + } + switch s[:dot] { + case "b", "db", "r", "dr", "lb": + default: + return false + } + + return strings.HasPrefix(s[dot:], "._dns-sd._udp.") +} + +// rawNameToLower converts a raw DNS name to a string, lowercasing it. +func rawNameToLower(name []byte) string { + var sb strings.Builder + sb.Grow(len(name)) + + for _, b := range name { + if 'A' <= b && b <= 'Z' { + b = b - 'A' + 'a' + } + sb.WriteByte(b) + } + + return sb.String() +} + +// ptrNameToIPv4 transforms a PTR name representing an IPv4 address to said address. +// Such names are IPv4 labels in reverse order followed by .in-addr.arpa. +// For example, +// 4.3.2.1.in-addr.arpa +// is transformed to +// 1.2.3.4 +func rdnsNameToIPv4(name string) (ip netaddr.IP, ok bool) { + name = strings.TrimSuffix(name, rdnsv4Suffix) + ip, err := netaddr.ParseIP(string(name)) + if err != nil { + return netaddr.IP{}, false + } + if !ip.Is4() { + return netaddr.IP{}, false + } + b := ip.As4() + return netaddr.IPv4(b[3], b[2], b[1], b[0]), true +} + +// ptrNameToIPv6 transforms a PTR name representing an IPv6 address to said address. +// Such names are dot-separated nibbles in reverse order followed by .ip6.arpa. +// For example, +// b.a.9.8.7.6.5.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.8.b.d.0.1.0.0.2.ip6.arpa. +// is transformed to +// 2001:db8::567:89ab +func rdnsNameToIPv6(name string) (ip netaddr.IP, ok bool) { + var b [32]byte + var ipb [16]byte + + name = strings.TrimSuffix(name, rdnsv6Suffix) + // 32 nibbles and 31 dots between them. + if len(name) != 63 { + return netaddr.IP{}, false + } + + // Dots and hex digits alternate. + prevDot := true + // i ranges over name backward; j ranges over b forward. + for i, j := len(name)-1, 0; i >= 0; i-- { + thisDot := (name[i] == '.') + if prevDot == thisDot { + return netaddr.IP{}, false + } + prevDot = thisDot + + if !thisDot { + // This is safe assuming alternation. + // We do not check that non-dots are hex digits: hex.Decode below will do that. + b[j] = name[i] + j++ + } + } + + _, err := hex.Decode(ipb[:], b[:]) + if err != nil { + return netaddr.IP{}, false + } + + return netaddr.IPFrom16(ipb), true +} + +// respondReverse returns a DNS response to a PTR query. +// It is assumed that resp.Question is populated by respond before this is called. +func (r *Resolver) respondReverse(query []byte, name string, resp *response) ([]byte, error) { + if hasRDNSBonjourPrefix(name) { + return nil, errNotOurName + } + + var ip netaddr.IP + var ok bool + switch { + case strings.HasSuffix(name, rdnsv4Suffix): + ip, ok = rdnsNameToIPv4(name) + case strings.HasSuffix(name, rdnsv6Suffix): + ip, ok = rdnsNameToIPv6(name) + default: + return nil, errNotOurName + } + + // It is more likely that we failed in parsing the name than that it is actually malformed. + // To avoid frustrating users, just log and delegate. + if !ok { + r.logf("parsing rdns: malformed name: %s", name) + return nil, errNotOurName + } + + var err error + resp.Name, resp.Header.RCode, err = r.ResolveReverse(ip) + if err != nil { + r.logf("resolving rdns: %v", ip, err) + } + if resp.Header.RCode == dns.RCodeNameError { + return nil, errNotOurName + } + + return marshalResponse(resp) +} + +// respond returns a DNS response to query if it can be resolved locally. +// Otherwise, it returns errNotOurName. +func (r *Resolver) respond(query []byte) ([]byte, error) { + resp := new(response) + + // ParseQuery is sufficiently fast to run on every DNS packet. + // This is considerably simpler than extracting the name by hand + // to shave off microseconds in case of delegation. + err := parseQuery(query, resp) + // We will not return this error: it is the sender's fault. + if err != nil { + if errors.Is(err, dns.ErrSectionDone) { + r.logf("parseQuery(%02x): no DNS questions", query) + } else { + r.logf("parseQuery(%02x): %v", query, err) + } + resp.Header.RCode = dns.RCodeFormatError + return marshalResponse(resp) + } + rawName := resp.Question.Name.Data[:resp.Question.Name.Length] + name := rawNameToLower(rawName) + + // Always try to handle reverse lookups; delegate inside when not found. + // This way, queries for existent nodes do not leak, + // but we behave gracefully if non-Tailscale nodes exist in CGNATRange. + if resp.Question.Type == dns.TypePTR { + return r.respondReverse(query, name, resp) + } + + resp.IP, resp.Header.RCode, err = r.Resolve(name, resp.Question.Type) + // This return code is special: it requests forwarding. + if resp.Header.RCode == dns.RCodeRefused { + return nil, errNotOurName + } + + // We will not return this error: it is the sender's fault. + if err != nil { + r.logf("resolving: %v", err) + } + + return marshalResponse(resp) +} diff --git a/net/dns/tsdns_server_test.go b/net/dns/tsdns_server_test.go new file mode 100644 index 000000000..95544ba18 --- /dev/null +++ b/net/dns/tsdns_server_test.go @@ -0,0 +1,95 @@ +// Copyright (c) 2020 Tailscale Inc & AUTHORS All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package dns + +import ( + "log" + "testing" + + "github.com/miekg/dns" + "inet.af/netaddr" +) + +// This file exists to isolate the test infrastructure +// that depends on github.com/miekg/dns +// from the rest, which only depends on dnsmessage. + +var dnsHandleFunc = dns.HandleFunc + +// resolveToIP returns a handler function which responds +// to queries of type A it receives with an A record containing ipv4, +// to queries of type AAAA with an AAAA record containing ipv6, +// to queries of type NS with an NS record containg name. +func resolveToIP(ipv4, ipv6 netaddr.IP, ns string) dns.HandlerFunc { + return func(w dns.ResponseWriter, req *dns.Msg) { + m := new(dns.Msg) + m.SetReply(req) + + if len(req.Question) != 1 { + panic("not a single-question request") + } + question := req.Question[0] + + var ans dns.RR + switch question.Qtype { + case dns.TypeA: + ans = &dns.A{ + Hdr: dns.RR_Header{ + Name: question.Name, + Rrtype: dns.TypeA, + Class: dns.ClassINET, + }, + A: ipv4.IPAddr().IP, + } + case dns.TypeAAAA: + ans = &dns.AAAA{ + Hdr: dns.RR_Header{ + Name: question.Name, + Rrtype: dns.TypeAAAA, + Class: dns.ClassINET, + }, + AAAA: ipv6.IPAddr().IP, + } + case dns.TypeNS: + ans = &dns.NS{ + Hdr: dns.RR_Header{ + Name: question.Name, + Rrtype: dns.TypeNS, + Class: dns.ClassINET, + }, + Ns: ns, + } + } + + m.Answer = append(m.Answer, ans) + w.WriteMsg(m) + } +} + +func resolveToNXDOMAIN(w dns.ResponseWriter, req *dns.Msg) { + m := new(dns.Msg) + m.SetRcode(req, dns.RcodeNameError) + w.WriteMsg(m) +} + +func serveDNS(tb testing.TB, addr string) (*dns.Server, chan error) { + server := &dns.Server{Addr: addr, Net: "udp"} + + waitch := make(chan struct{}) + server.NotifyStartedFunc = func() { close(waitch) } + + errch := make(chan error, 1) + go func() { + err := server.ListenAndServe() + if err != nil { + log.Printf("ListenAndServe(%q): %v", addr, err) + } + errch <- err + close(errch) + }() + + <-waitch + return server, errch +} diff --git a/net/dns/tsdns_test.go b/net/dns/tsdns_test.go new file mode 100644 index 000000000..59bcd8ec1 --- /dev/null +++ b/net/dns/tsdns_test.go @@ -0,0 +1,816 @@ +// Copyright (c) 2020 Tailscale Inc & AUTHORS All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package dns + +import ( + "bytes" + "errors" + "net" + "sync" + "testing" + + dns "golang.org/x/net/dns/dnsmessage" + "inet.af/netaddr" + "tailscale.com/tstest" +) + +var testipv4 = netaddr.IPv4(1, 2, 3, 4) +var testipv6 = netaddr.IPv6Raw([16]byte{ + 0x00, 0x01, 0x02, 0x03, + 0x04, 0x05, 0x06, 0x07, + 0x08, 0x09, 0x0a, 0x0b, + 0x0c, 0x0d, 0x0e, 0x0f, +}) + +var dnsMap = NewMap( + map[string]netaddr.IP{ + "test1.ipn.dev.": testipv4, + "test2.ipn.dev.": testipv6, + }, + []string{"ipn.dev."}, +) + +func dnspacket(domain string, tp dns.Type) []byte { + var dnsHeader dns.Header + question := dns.Question{ + Name: dns.MustNewName(domain), + Type: tp, + Class: dns.ClassINET, + } + + builder := dns.NewBuilder(nil, dnsHeader) + builder.StartQuestions() + builder.Question(question) + payload, _ := builder.Finish() + + return payload +} + +type dnsResponse struct { + ip netaddr.IP + name string + rcode dns.RCode +} + +func unpackResponse(payload []byte) (dnsResponse, error) { + var response dnsResponse + var parser dns.Parser + + h, err := parser.Start(payload) + if err != nil { + return response, err + } + + if !h.Response { + return response, errors.New("not a response") + } + + response.rcode = h.RCode + if response.rcode != dns.RCodeSuccess { + return response, nil + } + + err = parser.SkipAllQuestions() + if err != nil { + return response, err + } + + ah, err := parser.AnswerHeader() + if err != nil { + return response, err + } + + switch ah.Type { + case dns.TypeA: + res, err := parser.AResource() + if err != nil { + return response, err + } + response.ip = netaddr.IPv4(res.A[0], res.A[1], res.A[2], res.A[3]) + case dns.TypeAAAA: + res, err := parser.AAAAResource() + if err != nil { + return response, err + } + response.ip = netaddr.IPv6Raw(res.AAAA) + case dns.TypeNS: + res, err := parser.NSResource() + if err != nil { + return response, err + } + response.name = res.NS.String() + default: + return response, errors.New("type not in {A, AAAA, NS}") + } + + return response, nil +} + +func syncRespond(r *Resolver, query []byte) ([]byte, error) { + request := Packet{Payload: query} + r.EnqueueRequest(request) + resp, err := r.NextResponse() + return resp.Payload, err +} + +func mustIP(str string) netaddr.IP { + ip, err := netaddr.ParseIP(str) + if err != nil { + panic(err) + } + return ip +} + +func TestRDNSNameToIPv4(t *testing.T) { + tests := []struct { + name string + input string + wantIP netaddr.IP + wantOK bool + }{ + {"valid", "4.123.24.1.in-addr.arpa.", netaddr.IPv4(1, 24, 123, 4), true}, + {"double_dot", "1..2.3.in-addr.arpa.", netaddr.IP{}, false}, + {"overflow", "1.256.3.4.in-addr.arpa.", netaddr.IP{}, false}, + {"not_ip", "sub.do.ma.in.in-addr.arpa.", netaddr.IP{}, false}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + ip, ok := rdnsNameToIPv4(tt.input) + if ok != tt.wantOK { + t.Errorf("ok = %v; want %v", ok, tt.wantOK) + } else if ok && ip != tt.wantIP { + t.Errorf("ip = %v; want %v", ip, tt.wantIP) + } + }) + } +} + +func TestRDNSNameToIPv6(t *testing.T) { + tests := []struct { + name string + input string + wantIP netaddr.IP + wantOK bool + }{ + { + "valid", + "b.a.9.8.7.6.5.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.8.b.d.0.1.0.0.2.ip6.arpa.", + mustIP("2001:db8::567:89ab"), + true, + }, + { + "double_dot", + "b..9.8.7.6.5.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.8.b.d.0.1.0.0.2.ip6.arpa.", + netaddr.IP{}, + false, + }, + { + "double_hex", + "b.a.98.0.7.6.5.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.8.b.d.0.1.0.0.2.ip6.arpa.", + netaddr.IP{}, + false, + }, + { + "not_hex", + "b.a.g.0.7.6.5.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.8.b.d.0.1.0.0.2.ip6.arpa.", + netaddr.IP{}, + false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + ip, ok := rdnsNameToIPv6(tt.input) + if ok != tt.wantOK { + t.Errorf("ok = %v; want %v", ok, tt.wantOK) + } else if ok && ip != tt.wantIP { + t.Errorf("ip = %v; want %v", ip, tt.wantIP) + } + }) + } +} + +func TestResolve(t *testing.T) { + r := NewResolver(ResolverConfig{Logf: t.Logf, Forward: false}) + r.SetMap(dnsMap) + + if err := r.Start(); err != nil { + t.Fatalf("start: %v", err) + } + defer r.Close() + + tests := []struct { + name string + qname string + qtype dns.Type + ip netaddr.IP + code dns.RCode + }{ + {"ipv4", "test1.ipn.dev.", dns.TypeA, testipv4, dns.RCodeSuccess}, + {"ipv6", "test2.ipn.dev.", dns.TypeAAAA, testipv6, dns.RCodeSuccess}, + {"no-ipv6", "test1.ipn.dev.", dns.TypeAAAA, netaddr.IP{}, dns.RCodeSuccess}, + {"nxdomain", "test3.ipn.dev.", dns.TypeA, netaddr.IP{}, dns.RCodeNameError}, + {"foreign domain", "google.com.", dns.TypeA, netaddr.IP{}, dns.RCodeRefused}, + {"all", "test1.ipn.dev.", dns.TypeA, testipv4, dns.RCodeSuccess}, + {"mx-ipv4", "test1.ipn.dev.", dns.TypeMX, netaddr.IP{}, dns.RCodeSuccess}, + {"mx-ipv6", "test2.ipn.dev.", dns.TypeMX, netaddr.IP{}, dns.RCodeSuccess}, + {"mx-nxdomain", "test3.ipn.dev.", dns.TypeMX, netaddr.IP{}, dns.RCodeNameError}, + {"ns-nxdomain", "test3.ipn.dev.", dns.TypeNS, netaddr.IP{}, dns.RCodeNameError}, + {"onion-domain", "footest.onion.", dns.TypeA, netaddr.IP{}, dns.RCodeNameError}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + ip, code, err := r.Resolve(tt.qname, tt.qtype) + if err != nil { + t.Errorf("err = %v; want nil", err) + } + if code != tt.code { + t.Errorf("code = %v; want %v", code, tt.code) + } + // Only check ip for non-err + if ip != tt.ip { + t.Errorf("ip = %v; want %v", ip, tt.ip) + } + }) + } +} + +func TestResolveReverse(t *testing.T) { + r := NewResolver(ResolverConfig{Logf: t.Logf, Forward: false}) + r.SetMap(dnsMap) + + if err := r.Start(); err != nil { + t.Fatalf("start: %v", err) + } + defer r.Close() + + tests := []struct { + name string + ip netaddr.IP + want string + code dns.RCode + }{ + {"ipv4", testipv4, "test1.ipn.dev.", dns.RCodeSuccess}, + {"ipv6", testipv6, "test2.ipn.dev.", dns.RCodeSuccess}, + {"nxdomain", netaddr.IPv4(4, 3, 2, 1), "", dns.RCodeNameError}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + name, code, err := r.ResolveReverse(tt.ip) + if err != nil { + t.Errorf("err = %v; want nil", err) + } + if code != tt.code { + t.Errorf("code = %v; want %v", code, tt.code) + } + if name != tt.want { + t.Errorf("ip = %v; want %v", name, tt.want) + } + }) + } +} + +func ipv6Works() bool { + c, err := net.Listen("tcp", "[::1]:0") + if err != nil { + return false + } + c.Close() + return true +} + +func TestDelegate(t *testing.T) { + tstest.ResourceCheck(t) + + if !ipv6Works() { + t.Skip("skipping test that requires localhost IPv6") + } + + dnsHandleFunc("test.site.", resolveToIP(testipv4, testipv6, "dns.test.site.")) + dnsHandleFunc("nxdomain.site.", resolveToNXDOMAIN) + + v4server, v4errch := serveDNS(t, "127.0.0.1:0") + v6server, v6errch := serveDNS(t, "[::1]:0") + + defer func() { + if err := <-v4errch; err != nil { + t.Errorf("v4 server error: %v", err) + } + if err := <-v6errch; err != nil { + t.Errorf("v6 server error: %v", err) + } + }() + if v4server != nil { + defer v4server.Shutdown() + } + if v6server != nil { + defer v6server.Shutdown() + } + + if v4server == nil || v6server == nil { + // There is an error in at least one of the channels + // and we cannot proceed; return to see it. + return + } + + r := NewResolver(ResolverConfig{Logf: t.Logf, Forward: true}) + r.SetMap(dnsMap) + r.SetUpstreams([]net.Addr{ + v4server.PacketConn.LocalAddr(), + v6server.PacketConn.LocalAddr(), + }) + + if err := r.Start(); err != nil { + t.Fatalf("start: %v", err) + } + defer r.Close() + + tests := []struct { + title string + query []byte + response dnsResponse + }{ + { + "ipv4", + dnspacket("test.site.", dns.TypeA), + dnsResponse{ip: testipv4, rcode: dns.RCodeSuccess}, + }, + { + "ipv6", + dnspacket("test.site.", dns.TypeAAAA), + dnsResponse{ip: testipv6, rcode: dns.RCodeSuccess}, + }, + { + "ns", + dnspacket("test.site.", dns.TypeNS), + dnsResponse{name: "dns.test.site.", rcode: dns.RCodeSuccess}, + }, + { + "nxdomain", + dnspacket("nxdomain.site.", dns.TypeA), + dnsResponse{rcode: dns.RCodeNameError}, + }, + } + + for _, tt := range tests { + t.Run(tt.title, func(t *testing.T) { + payload, err := syncRespond(r, tt.query) + if err != nil { + t.Errorf("err = %v; want nil", err) + return + } + response, err := unpackResponse(payload) + if err != nil { + t.Errorf("extract: err = %v; want nil (in %x)", err, payload) + return + } + if response.rcode != tt.response.rcode { + t.Errorf("rcode = %v; want %v", response.rcode, tt.response.rcode) + } + if response.ip != tt.response.ip { + t.Errorf("ip = %v; want %v", response.ip, tt.response.ip) + } + if response.name != tt.response.name { + t.Errorf("name = %v; want %v", response.name, tt.response.name) + } + }) + } +} + +func TestDelegateCollision(t *testing.T) { + dnsHandleFunc("test.site.", resolveToIP(testipv4, testipv6, "dns.test.site.")) + + server, errch := serveDNS(t, "127.0.0.1:0") + defer func() { + if err := <-errch; err != nil { + t.Errorf("server error: %v", err) + } + }() + + if server == nil { + return + } + defer server.Shutdown() + + r := NewResolver(ResolverConfig{Logf: t.Logf, Forward: true}) + r.SetMap(dnsMap) + r.SetUpstreams([]net.Addr{server.PacketConn.LocalAddr()}) + + if err := r.Start(); err != nil { + t.Fatalf("start: %v", err) + } + defer r.Close() + + packets := []struct { + qname string + qtype dns.Type + addr netaddr.IPPort + }{ + {"test.site.", dns.TypeA, netaddr.IPPort{IP: netaddr.IPv4(1, 1, 1, 1), Port: 1001}}, + {"test.site.", dns.TypeAAAA, netaddr.IPPort{IP: netaddr.IPv4(1, 1, 1, 1), Port: 1002}}, + } + + // packets will have the same dns txid. + for _, p := range packets { + payload := dnspacket(p.qname, p.qtype) + req := Packet{Payload: payload, Addr: p.addr} + err := r.EnqueueRequest(req) + if err != nil { + t.Error(err) + } + } + + // Despite the txid collision, the answer(s) should still match the query. + resp, err := r.NextResponse() + if err != nil { + t.Error(err) + } + + var p dns.Parser + _, err = p.Start(resp.Payload) + if err != nil { + t.Error(err) + } + err = p.SkipAllQuestions() + if err != nil { + t.Error(err) + } + ans, err := p.AllAnswers() + if err != nil { + t.Error(err) + } + + var wantType dns.Type + switch ans[0].Body.(type) { + case *dns.AResource: + wantType = dns.TypeA + case *dns.AAAAResource: + wantType = dns.TypeAAAA + default: + t.Errorf("unexpected answer type: %T", ans[0].Body) + } + + for _, p := range packets { + if p.qtype == wantType && p.addr != resp.Addr { + t.Errorf("addr = %v; want %v", resp.Addr, p.addr) + } + } +} + +func TestConcurrentSetMap(t *testing.T) { + r := NewResolver(ResolverConfig{Logf: t.Logf, Forward: false}) + + if err := r.Start(); err != nil { + t.Fatalf("start: %v", err) + } + defer r.Close() + + // This is purely to ensure that Resolve does not race with SetMap. + var wg sync.WaitGroup + wg.Add(2) + go func() { + defer wg.Done() + r.SetMap(dnsMap) + }() + go func() { + defer wg.Done() + r.Resolve("test1.ipn.dev", dns.TypeA) + }() + wg.Wait() +} + +func TestConcurrentSetUpstreams(t *testing.T) { + dnsHandleFunc("test.site.", resolveToIP(testipv4, testipv6, "dns.test.site.")) + + server, errch := serveDNS(t, "127.0.0.1:0") + defer func() { + if err := <-errch; err != nil { + t.Errorf("server error: %v", err) + } + }() + + if server == nil { + return + } + defer server.Shutdown() + + r := NewResolver(ResolverConfig{Logf: t.Logf, Forward: true}) + r.SetMap(dnsMap) + + if err := r.Start(); err != nil { + t.Fatalf("start: %v", err) + } + defer r.Close() + + packet := dnspacket("test.site.", dns.TypeA) + // This is purely to ensure that delegation does not race with SetUpstreams. + var wg sync.WaitGroup + wg.Add(2) + go func() { + defer wg.Done() + r.SetUpstreams([]net.Addr{server.PacketConn.LocalAddr()}) + }() + go func() { + defer wg.Done() + syncRespond(r, packet) + }() + wg.Wait() +} + +var allResponse = []byte{ + 0x00, 0x00, // transaction id: 0 + 0x84, 0x00, // flags: response, authoritative, no error + 0x00, 0x01, // one question + 0x00, 0x01, // one answer + 0x00, 0x00, 0x00, 0x00, // no authority or additional RRs + // Question: + 0x05, 0x74, 0x65, 0x73, 0x74, 0x31, 0x03, 0x69, 0x70, 0x6e, 0x03, 0x64, 0x65, 0x76, 0x00, // name + 0x00, 0xff, 0x00, 0x01, // type ALL, class IN + // Answer: + 0x05, 0x74, 0x65, 0x73, 0x74, 0x31, 0x03, 0x69, 0x70, 0x6e, 0x03, 0x64, 0x65, 0x76, 0x00, // name + 0x00, 0x01, 0x00, 0x01, // type A, class IN + 0x00, 0x00, 0x02, 0x58, // TTL: 600 + 0x00, 0x04, // length: 4 bytes + 0x01, 0x02, 0x03, 0x04, // A: 1.2.3.4 +} + +var ipv4Response = []byte{ + 0x00, 0x00, // transaction id: 0 + 0x84, 0x00, // flags: response, authoritative, no error + 0x00, 0x01, // one question + 0x00, 0x01, // one answer + 0x00, 0x00, 0x00, 0x00, // no authority or additional RRs + // Question: + 0x05, 0x74, 0x65, 0x73, 0x74, 0x31, 0x03, 0x69, 0x70, 0x6e, 0x03, 0x64, 0x65, 0x76, 0x00, // name + 0x00, 0x01, 0x00, 0x01, // type A, class IN + // Answer: + 0x05, 0x74, 0x65, 0x73, 0x74, 0x31, 0x03, 0x69, 0x70, 0x6e, 0x03, 0x64, 0x65, 0x76, 0x00, // name + 0x00, 0x01, 0x00, 0x01, // type A, class IN + 0x00, 0x00, 0x02, 0x58, // TTL: 600 + 0x00, 0x04, // length: 4 bytes + 0x01, 0x02, 0x03, 0x04, // A: 1.2.3.4 +} + +var ipv6Response = []byte{ + 0x00, 0x00, // transaction id: 0 + 0x84, 0x00, // flags: response, authoritative, no error + 0x00, 0x01, // one question + 0x00, 0x01, // one answer + 0x00, 0x00, 0x00, 0x00, // no authority or additional RRs + // Question: + 0x05, 0x74, 0x65, 0x73, 0x74, 0x32, 0x03, 0x69, 0x70, 0x6e, 0x03, 0x64, 0x65, 0x76, 0x00, // name + 0x00, 0x1c, 0x00, 0x01, // type AAAA, class IN + // Answer: + 0x05, 0x74, 0x65, 0x73, 0x74, 0x32, 0x03, 0x69, 0x70, 0x6e, 0x03, 0x64, 0x65, 0x76, 0x00, // name + 0x00, 0x1c, 0x00, 0x01, // type AAAA, class IN + 0x00, 0x00, 0x02, 0x58, // TTL: 600 + 0x00, 0x10, // length: 16 bytes + // AAAA: 0001:0203:0405:0607:0809:0A0B:0C0D:0E0F + 0x00, 0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07, 0x08, 0x09, 0x0a, 0xb, 0xc, 0xd, 0xe, 0xf, +} + +var ipv4UppercaseResponse = []byte{ + 0x00, 0x00, // transaction id: 0 + 0x84, 0x00, // flags: response, authoritative, no error + 0x00, 0x01, // one question + 0x00, 0x01, // one answer + 0x00, 0x00, 0x00, 0x00, // no authority or additional RRs + // Question: + 0x05, 0x54, 0x45, 0x53, 0x54, 0x31, 0x03, 0x49, 0x50, 0x4e, 0x03, 0x44, 0x45, 0x56, 0x00, // name + 0x00, 0x01, 0x00, 0x01, // type A, class IN + // Answer: + 0x05, 0x54, 0x45, 0x53, 0x54, 0x31, 0x03, 0x49, 0x50, 0x4e, 0x03, 0x44, 0x45, 0x56, 0x00, // name + 0x00, 0x01, 0x00, 0x01, // type A, class IN + 0x00, 0x00, 0x02, 0x58, // TTL: 600 + 0x00, 0x04, // length: 4 bytes + 0x01, 0x02, 0x03, 0x04, // A: 1.2.3.4 +} + +var ptrResponse = []byte{ + 0x00, 0x00, // transaction id: 0 + 0x84, 0x00, // flags: response, authoritative, no error + 0x00, 0x01, // one question + 0x00, 0x01, // one answer + 0x00, 0x00, 0x00, 0x00, // no authority or additional RRs + // Question: 4.3.2.1.in-addr.arpa + 0x01, 0x34, 0x01, 0x33, 0x01, 0x32, 0x01, 0x31, 0x07, + 0x69, 0x6e, 0x2d, 0x61, 0x64, 0x64, 0x72, 0x04, 0x61, 0x72, 0x70, 0x61, 0x00, + 0x00, 0x0c, 0x00, 0x01, // type PTR, class IN + // Answer: 4.3.2.1.in-addr.arpa + 0x01, 0x34, 0x01, 0x33, 0x01, 0x32, 0x01, 0x31, 0x07, + 0x69, 0x6e, 0x2d, 0x61, 0x64, 0x64, 0x72, 0x04, 0x61, 0x72, 0x70, 0x61, 0x00, + 0x00, 0x0c, 0x00, 0x01, // type PTR, class IN + 0x00, 0x00, 0x02, 0x58, // TTL: 600 + 0x00, 0x0f, // length: 15 bytes + // PTR: test1.ipn.dev + 0x05, 0x74, 0x65, 0x73, 0x74, 0x31, 0x03, 0x69, 0x70, 0x6e, 0x03, 0x64, 0x65, 0x76, 0x00, +} + +var ptrResponse6 = []byte{ + 0x00, 0x00, // transaction id: 0 + 0x84, 0x00, // flags: response, authoritative, no error + 0x00, 0x01, // one question + 0x00, 0x01, // one answer + 0x00, 0x00, 0x00, 0x00, // no authority or additional RRs + // Question: f.0.e.0.d.0.c.0.b.0.a.0.9.0.8.0.7.0.6.0.5.0.4.0.3.0.2.0.1.0.0.0.ip6.arpa + 0x01, 0x66, 0x01, 0x30, 0x01, 0x65, 0x01, 0x30, + 0x01, 0x64, 0x01, 0x30, 0x01, 0x63, 0x01, 0x30, + 0x01, 0x62, 0x01, 0x30, 0x01, 0x61, 0x01, 0x30, + 0x01, 0x39, 0x01, 0x30, 0x01, 0x38, 0x01, 0x30, + 0x01, 0x37, 0x01, 0x30, 0x01, 0x36, 0x01, 0x30, + 0x01, 0x35, 0x01, 0x30, 0x01, 0x34, 0x01, 0x30, + 0x01, 0x33, 0x01, 0x30, 0x01, 0x32, 0x01, 0x30, + 0x01, 0x31, 0x01, 0x30, 0x01, 0x30, 0x01, 0x30, + 0x03, 0x69, 0x70, 0x36, + 0x04, 0x61, 0x72, 0x70, 0x61, 0x00, + 0x00, 0x0c, 0x00, 0x01, // type PTR, class IN6 + // Answer: f.0.e.0.d.0.c.0.b.0.a.0.9.0.8.0.7.0.6.0.5.0.4.0.3.0.2.0.1.0.0.0.ip6.arpa + 0x01, 0x66, 0x01, 0x30, 0x01, 0x65, 0x01, 0x30, + 0x01, 0x64, 0x01, 0x30, 0x01, 0x63, 0x01, 0x30, + 0x01, 0x62, 0x01, 0x30, 0x01, 0x61, 0x01, 0x30, + 0x01, 0x39, 0x01, 0x30, 0x01, 0x38, 0x01, 0x30, + 0x01, 0x37, 0x01, 0x30, 0x01, 0x36, 0x01, 0x30, + 0x01, 0x35, 0x01, 0x30, 0x01, 0x34, 0x01, 0x30, + 0x01, 0x33, 0x01, 0x30, 0x01, 0x32, 0x01, 0x30, + 0x01, 0x31, 0x01, 0x30, 0x01, 0x30, 0x01, 0x30, + 0x03, 0x69, 0x70, 0x36, + 0x04, 0x61, 0x72, 0x70, 0x61, 0x00, + 0x00, 0x0c, 0x00, 0x01, // type PTR, class IN + 0x00, 0x00, 0x02, 0x58, // TTL: 600 + 0x00, 0x0f, // length: 15 bytes + // PTR: test2.ipn.dev + 0x05, 0x74, 0x65, 0x73, 0x74, 0x32, 0x03, 0x69, 0x70, 0x6e, 0x03, 0x64, 0x65, 0x76, 0x00, +} + +var nxdomainResponse = []byte{ + 0x00, 0x00, // transaction id: 0 + 0x84, 0x03, // flags: response, authoritative, error: nxdomain + 0x00, 0x01, // one question + 0x00, 0x00, // no answers + 0x00, 0x00, 0x00, 0x00, // no authority or additional RRs + // Question: + 0x05, 0x74, 0x65, 0x73, 0x74, 0x33, 0x03, 0x69, 0x70, 0x6e, 0x03, 0x64, 0x65, 0x76, 0x00, // name + 0x00, 0x01, 0x00, 0x01, // type A, class IN +} + +var emptyResponse = []byte{ + 0x00, 0x00, // transaction id: 0 + 0x84, 0x00, // flags: response, authoritative, no error + 0x00, 0x01, // one question + 0x00, 0x00, // no answers + 0x00, 0x00, 0x00, 0x00, // no authority or additional RRs + // Question: + 0x05, 0x74, 0x65, 0x73, 0x74, 0x31, 0x03, 0x69, 0x70, 0x6e, 0x03, 0x64, 0x65, 0x76, 0x00, // name + 0x00, 0x1c, 0x00, 0x01, // type AAAA, class IN +} + +func TestFull(t *testing.T) { + r := NewResolver(ResolverConfig{Logf: t.Logf, Forward: false}) + r.SetMap(dnsMap) + + if err := r.Start(); err != nil { + t.Fatalf("start: %v", err) + } + defer r.Close() + + // One full packet and one error packet + tests := []struct { + name string + request []byte + response []byte + }{ + {"all", dnspacket("test1.ipn.dev.", dns.TypeALL), allResponse}, + {"ipv4", dnspacket("test1.ipn.dev.", dns.TypeA), ipv4Response}, + {"ipv6", dnspacket("test2.ipn.dev.", dns.TypeAAAA), ipv6Response}, + {"no-ipv6", dnspacket("test1.ipn.dev.", dns.TypeAAAA), emptyResponse}, + {"upper", dnspacket("TEST1.IPN.DEV.", dns.TypeA), ipv4UppercaseResponse}, + {"ptr", dnspacket("4.3.2.1.in-addr.arpa.", dns.TypePTR), ptrResponse}, + {"ptr", dnspacket("f.0.e.0.d.0.c.0.b.0.a.0.9.0.8.0.7.0.6.0.5.0.4.0.3.0.2.0.1.0.0.0.ip6.arpa.", + dns.TypePTR), ptrResponse6}, + {"nxdomain", dnspacket("test3.ipn.dev.", dns.TypeA), nxdomainResponse}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + response, err := syncRespond(r, tt.request) + if err != nil { + t.Errorf("err = %v; want nil", err) + } + if !bytes.Equal(response, tt.response) { + t.Errorf("response = %x; want %x", response, tt.response) + } + }) + } +} + +func TestAllocs(t *testing.T) { + r := NewResolver(ResolverConfig{Logf: t.Logf, Forward: false}) + r.SetMap(dnsMap) + + if err := r.Start(); err != nil { + t.Fatalf("start: %v", err) + } + defer r.Close() + + // It is seemingly pointless to test allocs in the delegate path, + // as dialer.Dial -> Read -> Write alone comprise 12 allocs. + tests := []struct { + name string + query []byte + want int + }{ + // Name lowercasing and response slice created by dns.NewBuilder. + {"forward", dnspacket("test1.ipn.dev.", dns.TypeA), 2}, + // 3 extra allocs in rdnsNameToIPv4 and one in marshalPTRRecord (dns.NewName). + {"reverse", dnspacket("4.3.2.1.in-addr.arpa.", dns.TypePTR), 5}, + } + + for _, tt := range tests { + allocs := testing.AllocsPerRun(100, func() { + syncRespond(r, tt.query) + }) + if int(allocs) > tt.want { + t.Errorf("%s: allocs = %v; want %v", tt.name, allocs, tt.want) + } + } +} + +func TestTrimRDNSBonjourPrefix(t *testing.T) { + tests := []struct { + in string + want bool + }{ + {"b._dns-sd._udp.0.10.20.172.in-addr.arpa.", true}, + {"db._dns-sd._udp.0.10.20.172.in-addr.arpa.", true}, + {"r._dns-sd._udp.0.10.20.172.in-addr.arpa.", true}, + {"dr._dns-sd._udp.0.10.20.172.in-addr.arpa.", true}, + {"lb._dns-sd._udp.0.10.20.172.in-addr.arpa.", true}, + {"qq._dns-sd._udp.0.10.20.172.in-addr.arpa.", false}, + {"0.10.20.172.in-addr.arpa.", false}, + {"i-have-no-dot", false}, + } + + for _, test := range tests { + got := hasRDNSBonjourPrefix(test.in) + if got != test.want { + t.Errorf("trimRDNSBonjourPrefix(%q) = %v, want %v", test.in, got, test.want) + } + } +} + +func BenchmarkFull(b *testing.B) { + dnsHandleFunc("test.site.", resolveToIP(testipv4, testipv6, "dns.test.site.")) + + server, errch := serveDNS(b, "127.0.0.1:0") + defer func() { + if err := <-errch; err != nil { + b.Errorf("server error: %v", err) + } + }() + + if server == nil { + return + } + defer server.Shutdown() + + r := NewResolver(ResolverConfig{Logf: b.Logf, Forward: true}) + r.SetMap(dnsMap) + r.SetUpstreams([]net.Addr{server.PacketConn.LocalAddr()}) + + if err := r.Start(); err != nil { + b.Fatalf("start: %v", err) + } + defer r.Close() + + tests := []struct { + name string + request []byte + }{ + {"forward", dnspacket("test1.ipn.dev.", dns.TypeA)}, + {"reverse", dnspacket("4.3.2.1.in-addr.arpa.", dns.TypePTR)}, + {"delegated", dnspacket("test.site.", dns.TypeA)}, + } + + for _, tt := range tests { + b.Run(tt.name, func(b *testing.B) { + for i := 0; i < b.N; i++ { + syncRespond(r, tt.request) + } + }) + } +} + +func TestMarshalResponseFormatError(t *testing.T) { + resp := new(response) + resp.Header.RCode = dns.RCodeFormatError + v, err := marshalResponse(resp) + if err != nil { + t.Errorf("marshal error: %v", err) + } + t.Logf("response: %q", v) +} diff --git a/net/flowtrack/flowtrack.go b/net/flowtrack/flowtrack.go index 5fcd4ab20..8387145d3 100644 --- a/net/flowtrack/flowtrack.go +++ b/net/flowtrack/flowtrack.go @@ -15,16 +15,18 @@ import ( "fmt" "inet.af/netaddr" + "tailscale.com/types/ipproto" ) -// Tuple is a 4-tuple of source and destination IP and port. +// Tuple is a 5-tuple of proto, source and destination IP and port. type Tuple struct { - Src netaddr.IPPort - Dst netaddr.IPPort + Proto ipproto.Proto + Src netaddr.IPPort + Dst netaddr.IPPort } func (t Tuple) String() string { - return fmt.Sprintf("(%v => %v)", t.Src, t.Dst) + return fmt.Sprintf("(%v %v => %v)", t.Proto, t.Src, t.Dst) } // Cache is an LRU cache keyed by Tuple. diff --git a/net/interfaces/interfaces.go b/net/interfaces/interfaces.go index 3a0ffeb0b..f2988af34 100644 --- a/net/interfaces/interfaces.go +++ b/net/interfaces/interfaces.go @@ -6,10 +6,10 @@ package interfaces import ( + "bytes" "fmt" "net" "net/http" - "reflect" "runtime" "sort" "strings" @@ -190,6 +190,9 @@ func ForeachInterface(fn func(Interface, []netaddr.IPPrefix)) error { } } } + sort.Slice(pfxs, func(i, j int) bool { + return pfxs[i].IP.Less(pfxs[j].IP) + }) fn(Interface{iface}, pfxs) } return nil @@ -204,7 +207,7 @@ type State struct { // IPPrefix, where the IP is the interface IP address and Bits is // the subnet mask. InterfaceIPs map[string][]netaddr.IPPrefix - InterfaceUp map[string]bool + Interface map[string]Interface // HaveV6Global is whether this machine has an IPv6 global address // on some non-Tailscale interface that's up. @@ -235,14 +238,14 @@ type State struct { func (s *State) String() string { var sb strings.Builder fmt.Fprintf(&sb, "interfaces.State{defaultRoute=%v ifs={", s.DefaultRouteInterface) - ifs := make([]string, 0, len(s.InterfaceUp)) - for k := range s.InterfaceUp { + ifs := make([]string, 0, len(s.Interface)) + for k := range s.Interface { if anyInterestingIP(s.InterfaceIPs[k]) { ifs = append(ifs, k) } } sort.Slice(ifs, func(i, j int) bool { - upi, upj := s.InterfaceUp[ifs[i]], s.InterfaceUp[ifs[j]] + upi, upj := s.Interface[ifs[i]].IsUp(), s.Interface[ifs[j]].IsUp() if upi != upj { // Up sorts before down. return upi @@ -253,7 +256,7 @@ func (s *State) String() string { if i > 0 { sb.WriteString(" ") } - if s.InterfaceUp[ifName] { + if s.Interface[ifName].IsUp() { fmt.Fprintf(&sb, "%s:[", ifName) needSpace := false for _, pfx := range s.InterfaceIPs[ifName] { @@ -286,50 +289,76 @@ func (s *State) String() string { return sb.String() } -func (s *State) Equal(s2 *State) bool { - return reflect.DeepEqual(s, s2) -} - -func (s *State) HasPAC() bool { return s != nil && s.PAC != "" } - -// AnyInterfaceUp reports whether any interface seems like it has Internet access. -func (s *State) AnyInterfaceUp() bool { - return s != nil && (s.HaveV4 || s.HaveV6Global) -} - -// RemoveUninterestingInterfacesAndAddresses removes uninteresting IPs -// from InterfaceIPs, also removing from both the InterfaceIPs and -// InterfaceUp map any interfaces that don't have any interesting IPs. -func (s *State) RemoveUninterestingInterfacesAndAddresses() { - for ifName := range s.InterfaceUp { - ips := s.InterfaceIPs[ifName] - keep := ips[:0] - for _, pfx := range ips { - if isInterestingIP(pfx.IP) { - keep = append(keep, pfx) - } - } - if len(keep) == 0 { - delete(s.InterfaceUp, ifName) - delete(s.InterfaceIPs, ifName) +// EqualFiltered reports whether s and s2 are equal, +// considering only interfaces in s for which filter returns true. +func (s *State) EqualFiltered(s2 *State, filter func(i Interface, ips []netaddr.IPPrefix) bool) bool { + if s == nil && s2 == nil { + return true + } + if s == nil || s2 == nil { + return false + } + if s.HaveV6Global != s2.HaveV6Global || + s.HaveV4 != s2.HaveV4 || + s.IsExpensive != s2.IsExpensive || + s.DefaultRouteInterface != s2.DefaultRouteInterface || + s.HTTPProxy != s2.HTTPProxy || + s.PAC != s2.PAC { + return false + } + for iname, i := range s.Interface { + ips := s.InterfaceIPs[iname] + if !filter(i, ips) { continue } - if len(keep) < len(ips) { - s.InterfaceIPs[ifName] = keep + i2, ok := s2.Interface[iname] + if !ok { + return false + } + ips2, ok := s2.InterfaceIPs[iname] + if !ok { + return false + } + if !interfacesEqual(i, i2) || !prefixesEqual(ips, ips2) { + return false } } + return true +} + +func interfacesEqual(a, b Interface) bool { + return a.Index == b.Index && + a.MTU == b.MTU && + a.Name == b.Name && + a.Flags == b.Flags && + bytes.Equal([]byte(a.HardwareAddr), []byte(b.HardwareAddr)) } -// RemoveTailscaleInterfaces modifes s to remove any interfaces that -// are owned by this process. (TODO: make this true; currently it -// uses some heuristics) -func (s *State) RemoveTailscaleInterfaces() { - for name, pfxs := range s.InterfaceIPs { - if isTailscaleInterface(name, pfxs) { - delete(s.InterfaceIPs, name) - delete(s.InterfaceUp, name) +func prefixesEqual(a, b []netaddr.IPPrefix) bool { + if len(a) != len(b) { + return false + } + for i, v := range a { + if b[i] != v { + return false } } + return true +} + +// FilterInteresting reports whether i is an interesting non-Tailscale interface. +func FilterInteresting(i Interface, ips []netaddr.IPPrefix) bool { + return !isTailscaleInterface(i.Name, ips) && anyInterestingIP(ips) +} + +// FilterAll always returns true, to use EqualFiltered against all interfaces. +func FilterAll(i Interface, ips []netaddr.IPPrefix) bool { return true } + +func (s *State) HasPAC() bool { return s != nil && s.PAC != "" } + +// AnyInterfaceUp reports whether any interface seems like it has Internet access. +func (s *State) AnyInterfaceUp() bool { + return s != nil && (s.HaveV4 || s.HaveV6Global) } func hasTailscaleIP(pfxs []netaddr.IPPrefix) bool { @@ -364,11 +393,11 @@ var getPAC func() string func GetState() (*State, error) { s := &State{ InterfaceIPs: make(map[string][]netaddr.IPPrefix), - InterfaceUp: make(map[string]bool), + Interface: make(map[string]Interface), } if err := ForeachInterface(func(ni Interface, pfxs []netaddr.IPPrefix) { ifUp := ni.IsUp() - s.InterfaceUp[ni.Name] = ifUp + s.Interface[ni.Name] = ni s.InterfaceIPs[ni.Name] = append(s.InterfaceIPs[ni.Name], pfxs...) if !ifUp || isTailscaleInterface(ni.Name, pfxs) { return diff --git a/net/interfaces/interfaces_test.go b/net/interfaces/interfaces_test.go index 88948b579..ab0ee3734 100644 --- a/net/interfaces/interfaces_test.go +++ b/net/interfaces/interfaces_test.go @@ -5,6 +5,7 @@ package interfaces import ( + "encoding/json" "testing" ) @@ -13,7 +14,11 @@ func TestGetState(t *testing.T) { if err != nil { t.Fatal(err) } - t.Logf("Got: %#v", st) + j, err := json.MarshalIndent(st, "", "\t") + if err != nil { + t.Errorf("JSON: %v", err) + } + t.Logf("Got: %s", j) t.Logf("As string: %s", st) st2, err := GetState() @@ -21,14 +26,13 @@ func TestGetState(t *testing.T) { t.Fatal(err) } - if !st.Equal(st2) { + if !st.EqualFiltered(st2, FilterAll) { // let's assume nobody was changing the system network interfaces between // the two GetState calls. t.Fatal("two States back-to-back were not equal") } - st.RemoveTailscaleInterfaces() - t.Logf("As string without Tailscale:\n\t%s", st) + t.Logf("As string:\n\t%s", st) } func TestLikelyHomeRouterIP(t *testing.T) { diff --git a/net/interfaces/interfaces_windows.go b/net/interfaces/interfaces_windows.go index 19e9b48b4..91f679c97 100644 --- a/net/interfaces/interfaces_windows.go +++ b/net/interfaces/interfaces_windows.go @@ -7,17 +7,19 @@ package interfaces import ( "fmt" "log" + "net" "net/url" - "os/exec" "syscall" "unsafe" - "go4.org/mem" "golang.org/x/sys/windows" "golang.zx2c4.com/wireguard/windows/tunnel/winipcfg" "inet.af/netaddr" "tailscale.com/tsconst" - "tailscale.com/util/lineread" +) + +const ( + fallbackInterfaceMetric = uint32(0) // Used if we cannot get the actual interface metric ) func init() { @@ -25,58 +27,76 @@ func init() { getPAC = getPACWindows } -/* -Parse out 10.0.0.1 from: - -Z:\>route print -4 -=========================================================================== -Interface List - 15...aa 15 48 ff 1c 72 ......Red Hat VirtIO Ethernet Adapter - 5...........................Tailscale Tunnel - 1...........................Software Loopback Interface 1 -=========================================================================== - -IPv4 Route Table -=========================================================================== -Active Routes: -Network Destination Netmask Gateway Interface Metric - 0.0.0.0 0.0.0.0 10.0.0.1 10.0.28.63 5 - 10.0.0.0 255.255.0.0 On-link 10.0.28.63 261 - 10.0.28.63 255.255.255.255 On-link 10.0.28.63 261 - 10.0.42.0 255.255.255.0 100.103.42.106 100.103.42.106 5 - 10.0.255.255 255.255.255.255 On-link 10.0.28.63 261 - 34.193.248.174 255.255.255.255 100.103.42.106 100.103.42.106 5 - -*/ func likelyHomeRouterIPWindows() (ret netaddr.IP, ok bool) { - cmd := exec.Command("route", "print", "-4") - cmd.SysProcAttr = &syscall.SysProcAttr{HideWindow: true} - stdout, err := cmd.StdoutPipe() + rs, err := winipcfg.GetIPForwardTable2(windows.AF_INET) if err != nil { + log.Printf("routerIP/GetIPForwardTable2 error: %v", err) return } - if err := cmd.Start(); err != nil { + + var ifaceMetricCache map[winipcfg.LUID]uint32 + + getIfaceMetric := func(luid winipcfg.LUID) (metric uint32) { + if ifaceMetricCache == nil { + ifaceMetricCache = make(map[winipcfg.LUID]uint32) + } else if m, ok := ifaceMetricCache[luid]; ok { + return m + } + + if iface, err := luid.IPInterface(windows.AF_INET); err == nil { + metric = iface.Metric + } else { + log.Printf("routerIP/luid.IPInterface error: %v", err) + metric = fallbackInterfaceMetric + } + + ifaceMetricCache[luid] = metric return } - defer cmd.Wait() - var f []mem.RO - lineread.Reader(stdout, func(lineb []byte) error { - line := mem.B(lineb) - if !mem.Contains(line, mem.S("0.0.0.0")) { - return nil + unspec := net.IPv4(0, 0, 0, 0) + var best *winipcfg.MibIPforwardRow2 // best (lowest metric) found so far, or nil + + for i := range rs { + r := &rs[i] + if r.Loopback || r.DestinationPrefix.PrefixLength != 0 || !r.DestinationPrefix.Prefix.IP().Equal(unspec) { + // Not a default route, so skip + continue } - f = mem.AppendFields(f[:0], line) - if len(f) < 3 || !f[0].EqualString("0.0.0.0") || !f[1].EqualString("0.0.0.0") { - return nil + + ip, ok := netaddr.FromStdIP(r.NextHop.IP()) + if !ok { + // Not a valid gateway, so skip (won't happen though) + continue } - ipm := f[2] - ip, err := netaddr.ParseIP(string(mem.Append(nil, ipm))) - if err == nil && isPrivateIP(ip) { + + if best == nil { + best = r ret = ip + continue } - return nil - }) + + // We can get here only if there are multiple default gateways defined (rare case), + // in which case we need to calculate the effective metric. + // Effective metric is sum of interface metric and route metric offset + if ifaceMetricCache == nil { + // If we're here it means that previous route still isn't updated, so update it + best.Metric += getIfaceMetric(best.InterfaceLUID) + } + r.Metric += getIfaceMetric(r.InterfaceLUID) + + if best.Metric > r.Metric || best.Metric == r.Metric && ret.Compare(ip) > 0 { + // Pick the route with lower metric, or lower IP if metrics are equal + best = r + ret = ip + } + } + + if !ret.IsZero() && !isPrivateIP(ret) { + // Default route has a non-private gateway + return netaddr.IP{}, false + } + return ret, !ret.IsZero() } diff --git a/net/packet/header.go b/net/packet/header.go index 86680a5a7..5cf4ef650 100644 --- a/net/packet/header.go +++ b/net/packet/header.go @@ -10,6 +10,7 @@ import ( ) const tcpHeaderLength = 20 +const sctpHeaderLength = 12 // maxPacketLength is the largest length that all headers support. // IPv4 headers using uint16 for this forces an upper bound of 64KB. diff --git a/net/packet/icmp4.go b/net/packet/icmp4.go index 8a1568114..da4774887 100644 --- a/net/packet/icmp4.go +++ b/net/packet/icmp4.go @@ -4,7 +4,11 @@ package packet -import "encoding/binary" +import ( + "encoding/binary" + + "tailscale.com/types/ipproto" +) // icmp4HeaderLength is the size of the ICMPv4 packet header, not // including the outer IP layer or the variable "response data" @@ -66,7 +70,7 @@ func (h ICMP4Header) Marshal(buf []byte) error { return errLargePacket } // The caller does not need to set this. - h.IPProto = ICMPv4 + h.IPProto = ipproto.ICMPv4 buf[20] = uint8(h.Type) buf[21] = uint8(h.Code) diff --git a/net/packet/ip.go b/net/packet/ip.go deleted file mode 100644 index 34194f344..000000000 --- a/net/packet/ip.go +++ /dev/null @@ -1,66 +0,0 @@ -// Copyright (c) 2020 Tailscale Inc & AUTHORS All rights reserved. -// Use of this source code is governed by a BSD-style -// license that can be found in the LICENSE file. - -package packet - -// IPProto is an IP subprotocol as defined by the IANA protocol -// numbers list -// (https://www.iana.org/assignments/protocol-numbers/protocol-numbers.xhtml), -// or the special values Unknown or Fragment. -type IPProto uint8 - -const ( - // Unknown represents an unknown or unsupported protocol; it's - // deliberately the zero value. Strictly speaking the zero - // value is IPv6 hop-by-hop extensions, but we don't support - // those, so this is still technically correct. - Unknown IPProto = 0x00 - - // Values from the IANA registry. - ICMPv4 IPProto = 0x01 - IGMP IPProto = 0x02 - ICMPv6 IPProto = 0x3a - TCP IPProto = 0x06 - UDP IPProto = 0x11 - - // TSMP is the Tailscale Message Protocol (our ICMP-ish - // thing), an IP protocol used only between Tailscale nodes - // (still encrypted by WireGuard) that communicates why things - // failed, etc. - // - // Proto number 99 is reserved for "any private encryption - // scheme". We never accept these from the host OS stack nor - // send them to the host network stack. It's only used between - // nodes. - TSMP IPProto = 99 - - // Fragment represents any non-first IP fragment, for which we - // don't have the sub-protocol header (and therefore can't - // figure out what the sub-protocol is). - // - // 0xFF is reserved in the IANA registry, so we steal it for - // internal use. - Fragment IPProto = 0xFF -) - -func (p IPProto) String() string { - switch p { - case Fragment: - return "Frag" - case ICMPv4: - return "ICMPv4" - case IGMP: - return "IGMP" - case ICMPv6: - return "ICMPv6" - case UDP: - return "UDP" - case TCP: - return "TCP" - case TSMP: - return "TSMP" - default: - return "Unknown" - } -} diff --git a/net/packet/ip4.go b/net/packet/ip4.go index 0240abaa1..2c090d9f1 100644 --- a/net/packet/ip4.go +++ b/net/packet/ip4.go @@ -9,6 +9,7 @@ import ( "errors" "inet.af/netaddr" + "tailscale.com/types/ipproto" ) // ip4HeaderLength is the length of an IPv4 header with no IP options. @@ -16,7 +17,7 @@ const ip4HeaderLength = 20 // IP4Header represents an IPv4 packet header. type IP4Header struct { - IPProto IPProto + IPProto ipproto.Proto IPID uint16 Src netaddr.IP Dst netaddr.IP diff --git a/net/packet/ip6.go b/net/packet/ip6.go index 59f605b32..e181f1dde 100644 --- a/net/packet/ip6.go +++ b/net/packet/ip6.go @@ -8,6 +8,7 @@ import ( "encoding/binary" "inet.af/netaddr" + "tailscale.com/types/ipproto" ) // ip6HeaderLength is the length of an IPv6 header with no IP options. @@ -15,7 +16,7 @@ const ip6HeaderLength = 40 // IP6Header represents an IPv6 packet header. type IP6Header struct { - IPProto IPProto + IPProto ipproto.Proto IPID uint32 // only lower 20 bits used Src netaddr.IP Dst netaddr.IP diff --git a/net/packet/packet.go b/net/packet/packet.go index a88a1af7a..05c4a382f 100644 --- a/net/packet/packet.go +++ b/net/packet/packet.go @@ -11,9 +11,12 @@ import ( "strings" "inet.af/netaddr" + "tailscale.com/types/ipproto" "tailscale.com/types/strbuilder" ) +const unknown = ipproto.Unknown + // RFC1858: prevent overlapping fragment attacks. const minFrag = 60 + 20 // max IPv4 header + basic TCP header @@ -44,7 +47,7 @@ type Parsed struct { // 6), or 0 if the packet doesn't look like IPv4 or IPv6. IPVersion uint8 // IPProto is the IP subprotocol (UDP, TCP, etc.). Valid iff IPVersion != 0. - IPProto IPProto + IPProto ipproto.Proto // SrcIP4 is the source address. Family matches IPVersion. Port is // valid iff IPProto == TCP || IPProto == UDP. Src netaddr.IPPort @@ -100,7 +103,7 @@ func (q *Parsed) Decode(b []byte) { if len(b) < 1 { q.IPVersion = 0 - q.IPProto = Unknown + q.IPProto = unknown return } @@ -112,7 +115,7 @@ func (q *Parsed) Decode(b []byte) { q.decode6(b) default: q.IPVersion = 0 - q.IPProto = Unknown + q.IPProto = unknown } } @@ -125,16 +128,16 @@ func (q *Parsed) StuffForTesting(len int) { func (q *Parsed) decode4(b []byte) { if len(b) < ip4HeaderLength { q.IPVersion = 0 - q.IPProto = Unknown + q.IPProto = unknown return } // Check that it's IPv4. - q.IPProto = IPProto(b[9]) + q.IPProto = ipproto.Proto(b[9]) q.length = int(binary.BigEndian.Uint16(b[2:4])) if len(b) < q.length { // Packet was cut off before full IPv4 length. - q.IPProto = Unknown + q.IPProto = unknown return } @@ -145,7 +148,7 @@ func (q *Parsed) decode4(b []byte) { q.subofs = int((b[0] & 0x0F) << 2) if q.subofs > q.length { // next-proto starts beyond end of packet. - q.IPProto = Unknown + q.IPProto = unknown return } sub := b[q.subofs:] @@ -170,29 +173,29 @@ func (q *Parsed) decode4(b []byte) { // This is the first fragment if moreFrags && len(sub) < minFrag { // Suspiciously short first fragment, dump it. - q.IPProto = Unknown + q.IPProto = unknown return } // otherwise, this is either non-fragmented (the usual case) // or a big enough initial fragment that we can read the // whole subprotocol header. switch q.IPProto { - case ICMPv4: + case ipproto.ICMPv4: if len(sub) < icmp4HeaderLength { - q.IPProto = Unknown + q.IPProto = unknown return } q.Src.Port = 0 q.Dst.Port = 0 q.dataofs = q.subofs + icmp4HeaderLength return - case IGMP: + case ipproto.IGMP: // Keep IPProto, but don't parse anything else // out. return - case TCP: + case ipproto.TCP: if len(sub) < tcpHeaderLength { - q.IPProto = Unknown + q.IPProto = unknown return } q.Src.Port = binary.BigEndian.Uint16(sub[0:2]) @@ -201,21 +204,29 @@ func (q *Parsed) decode4(b []byte) { headerLength := (sub[12] & 0xF0) >> 2 q.dataofs = q.subofs + int(headerLength) return - case UDP: + case ipproto.UDP: if len(sub) < udpHeaderLength { - q.IPProto = Unknown + q.IPProto = unknown return } q.Src.Port = binary.BigEndian.Uint16(sub[0:2]) q.Dst.Port = binary.BigEndian.Uint16(sub[2:4]) q.dataofs = q.subofs + udpHeaderLength return - case TSMP: + case ipproto.SCTP: + if len(sub) < sctpHeaderLength { + q.IPProto = unknown + return + } + q.Src.Port = binary.BigEndian.Uint16(sub[0:2]) + q.Dst.Port = binary.BigEndian.Uint16(sub[2:4]) + return + case ipproto.TSMP: // Inter-tailscale messages. q.dataofs = q.subofs return default: - q.IPProto = Unknown + q.IPProto = unknown return } } else { @@ -223,7 +234,7 @@ func (q *Parsed) decode4(b []byte) { if fragOfs < minFrag { // First frag was suspiciously short, so we can't // trust the followup either. - q.IPProto = Unknown + q.IPProto = unknown return } // otherwise, we have to permit the fragment to slide through. @@ -232,7 +243,7 @@ func (q *Parsed) decode4(b []byte) { // but that would require statefulness. Anyway, receivers' // kernels know to drop fragments where the initial fragment // doesn't arrive. - q.IPProto = Fragment + q.IPProto = ipproto.Fragment return } } @@ -240,15 +251,15 @@ func (q *Parsed) decode4(b []byte) { func (q *Parsed) decode6(b []byte) { if len(b) < ip6HeaderLength { q.IPVersion = 0 - q.IPProto = Unknown + q.IPProto = unknown return } - q.IPProto = IPProto(b[6]) + q.IPProto = ipproto.Proto(b[6]) q.length = int(binary.BigEndian.Uint16(b[4:6])) + ip6HeaderLength if len(b) < q.length { // Packet was cut off before the full IPv6 length. - q.IPProto = Unknown + q.IPProto = unknown return } @@ -274,17 +285,17 @@ func (q *Parsed) decode6(b []byte) { sub = sub[:len(sub):len(sub)] // help the compiler do bounds check elimination switch q.IPProto { - case ICMPv6: + case ipproto.ICMPv6: if len(sub) < icmp6HeaderLength { - q.IPProto = Unknown + q.IPProto = unknown return } q.Src.Port = 0 q.Dst.Port = 0 q.dataofs = q.subofs + icmp6HeaderLength - case TCP: + case ipproto.TCP: if len(sub) < tcpHeaderLength { - q.IPProto = Unknown + q.IPProto = unknown return } q.Src.Port = binary.BigEndian.Uint16(sub[0:2]) @@ -293,20 +304,28 @@ func (q *Parsed) decode6(b []byte) { headerLength := (sub[12] & 0xF0) >> 2 q.dataofs = q.subofs + int(headerLength) return - case UDP: + case ipproto.UDP: if len(sub) < udpHeaderLength { - q.IPProto = Unknown + q.IPProto = unknown return } q.Src.Port = binary.BigEndian.Uint16(sub[0:2]) q.Dst.Port = binary.BigEndian.Uint16(sub[2:4]) q.dataofs = q.subofs + udpHeaderLength - case TSMP: + case ipproto.SCTP: + if len(sub) < sctpHeaderLength { + q.IPProto = unknown + return + } + q.Src.Port = binary.BigEndian.Uint16(sub[0:2]) + q.Dst.Port = binary.BigEndian.Uint16(sub[2:4]) + return + case ipproto.TSMP: // Inter-tailscale messages. q.dataofs = q.subofs return default: - q.IPProto = Unknown + q.IPProto = unknown return } } @@ -324,6 +343,19 @@ func (q *Parsed) IP4Header() IP4Header { } } +func (q *Parsed) IP6Header() IP6Header { + if q.IPVersion != 6 { + panic("IP6Header called on non-IPv6 Parsed") + } + ipid := (binary.BigEndian.Uint32(q.b[:4]) << 12) >> 12 + return IP6Header{ + IPID: ipid, + IPProto: q.IPProto, + Src: q.Src.IP, + Dst: q.Dst.IP, + } +} + func (q *Parsed) ICMP4Header() ICMP4Header { if q.IPVersion != 4 { panic("IP4Header called on non-IPv4 Parsed") @@ -367,13 +399,13 @@ func (q *Parsed) IsTCPSyn() bool { // IsError reports whether q is an ICMP "Error" packet. func (q *Parsed) IsError() bool { switch q.IPProto { - case ICMPv4: + case ipproto.ICMPv4: if len(q.b) < q.subofs+8 { return false } t := ICMP4Type(q.b[q.subofs]) return t == ICMP4Unreachable || t == ICMP4TimeExceeded - case ICMPv6: + case ipproto.ICMPv6: if len(q.b) < q.subofs+8 { return false } @@ -387,9 +419,9 @@ func (q *Parsed) IsError() bool { // IsEchoRequest reports whether q is an ICMP Echo Request. func (q *Parsed) IsEchoRequest() bool { switch q.IPProto { - case ICMPv4: + case ipproto.ICMPv4: return len(q.b) >= q.subofs+8 && ICMP4Type(q.b[q.subofs]) == ICMP4EchoRequest && ICMP4Code(q.b[q.subofs+1]) == ICMP4NoCode - case ICMPv6: + case ipproto.ICMPv6: return len(q.b) >= q.subofs+8 && ICMP6Type(q.b[q.subofs]) == ICMP6EchoRequest && ICMP6Code(q.b[q.subofs+1]) == ICMP6NoCode default: return false @@ -399,9 +431,9 @@ func (q *Parsed) IsEchoRequest() bool { // IsEchoRequest reports whether q is an IPv4 ICMP Echo Response. func (q *Parsed) IsEchoResponse() bool { switch q.IPProto { - case ICMPv4: + case ipproto.ICMPv4: return len(q.b) >= q.subofs+8 && ICMP4Type(q.b[q.subofs]) == ICMP4EchoReply && ICMP4Code(q.b[q.subofs+1]) == ICMP4NoCode - case ICMPv6: + case ipproto.ICMPv6: return len(q.b) >= q.subofs+8 && ICMP6Type(q.b[q.subofs]) == ICMP6EchoReply && ICMP6Code(q.b[q.subofs+1]) == ICMP6NoCode default: return false diff --git a/net/packet/packet_test.go b/net/packet/packet_test.go index 8bac5db4a..ac4fa33f3 100644 --- a/net/packet/packet_test.go +++ b/net/packet/packet_test.go @@ -10,6 +10,19 @@ import ( "testing" "inet.af/netaddr" + "tailscale.com/types/ipproto" +) + +const ( + Unknown = ipproto.Unknown + TCP = ipproto.TCP + UDP = ipproto.UDP + SCTP = ipproto.SCTP + IGMP = ipproto.IGMP + ICMPv4 = ipproto.ICMPv4 + ICMPv6 = ipproto.ICMPv6 + TSMP = ipproto.TSMP + Fragment = ipproto.Fragment ) func mustIPPort(s string) netaddr.IPPort { @@ -305,6 +318,39 @@ var ipv4TSMPDecode = Parsed{ Dst: mustIPPort("100.74.70.3:0"), } +// IPv4 SCTP +var sctpBuffer = []byte{ + // IPv4 header: + 0x45, 0x00, + 0x00, 0x20, // 20 + 12 bytes total + 0x00, 0x00, // ID + 0x00, 0x00, // Fragment + 0x40, // TTL + byte(SCTP), + // Checksum, unchecked: + 1, 2, + // source IP: + 0x64, 0x5e, 0x0c, 0x0e, + // dest IP: + 0x64, 0x4a, 0x46, 0x03, + // Src Port, Dest Port: + 0x00, 0x7b, 0x01, 0xc8, + // Verification tag: + 1, 2, 3, 4, + // Checksum: (unchecked) + 5, 6, 7, 8, +} + +var sctpDecode = Parsed{ + b: sctpBuffer, + subofs: 20, + length: 20 + 12, + IPVersion: 4, + IPProto: SCTP, + Src: mustIPPort("100.94.12.14:123"), + Dst: mustIPPort("100.74.70.3:456"), +} + func TestParsedString(t *testing.T) { tests := []struct { name string @@ -320,6 +366,7 @@ func TestParsedString(t *testing.T) { {"igmp", igmpPacketDecode, "IGMP{192.168.1.82:0 > 224.0.0.251:0}"}, {"unknown", unknownPacketDecode, "Unknown{???}"}, {"ipv4_tsmp", ipv4TSMPDecode, "TSMP{100.94.12.14:0 > 100.74.70.3:0}"}, + {"sctp", sctpDecode, "SCTP{100.94.12.14:123 > 100.74.70.3:456}"}, } for _, tt := range tests { @@ -357,6 +404,7 @@ func TestDecode(t *testing.T) { {"unknown", unknownPacketBuffer, unknownPacketDecode}, {"invalid4", invalid4RequestBuffer, invalid4RequestDecode}, {"ipv4_tsmp", ipv4TSMPBuffer, ipv4TSMPDecode}, + {"ipv4_sctp", sctpBuffer, sctpDecode}, } for _, tt := range tests { diff --git a/net/packet/tsmp.go b/net/packet/tsmp.go index 2346c9419..fb257556c 100644 --- a/net/packet/tsmp.go +++ b/net/packet/tsmp.go @@ -17,6 +17,7 @@ import ( "inet.af/netaddr" "tailscale.com/net/flowtrack" + "tailscale.com/types/ipproto" ) // TailscaleRejectedHeader is a TSMP message that says that one @@ -39,7 +40,7 @@ type TailscaleRejectedHeader struct { IPDst netaddr.IP // IPv4 or IPv6 header's dst IP Src netaddr.IPPort // rejected flow's src Dst netaddr.IPPort // rejected flow's dst - Proto IPProto // proto that was rejected (TCP or UDP) + Proto ipproto.Proto // proto that was rejected (TCP or UDP) Reason TailscaleRejectReason // why the connection was rejected // MaybeBroken is whether the rejection is non-terminal (the @@ -57,7 +58,7 @@ type TailscaleRejectedHeader struct { const rejectFlagBitMaybeBroken = 0x1 func (rh TailscaleRejectedHeader) Flow() flowtrack.Tuple { - return flowtrack.Tuple{Src: rh.Src, Dst: rh.Dst} + return flowtrack.Tuple{Proto: rh.Proto, Src: rh.Src, Dst: rh.Dst} } func (rh TailscaleRejectedHeader) String() string { @@ -69,6 +70,12 @@ type TSMPType uint8 const ( // TSMPTypeRejectedConn is the type byte for a TailscaleRejectedHeader. TSMPTypeRejectedConn TSMPType = '!' + + // TSMPTypePing is the type byte for a TailscalePingRequest. + TSMPTypePing TSMPType = 'p' + + // TSMPTypePong is the type byte for a TailscalePongResponse. + TSMPTypePong TSMPType = 'o' ) type TailscaleRejectReason byte @@ -138,7 +145,7 @@ func (h TailscaleRejectedHeader) Marshal(buf []byte) error { } if h.Src.IP.Is4() { iph := IP4Header{ - IPProto: TSMP, + IPProto: ipproto.TSMP, Src: h.IPSrc, Dst: h.IPDst, } @@ -146,7 +153,7 @@ func (h TailscaleRejectedHeader) Marshal(buf []byte) error { buf = buf[ip4HeaderLength:] } else if h.Src.IP.Is6() { iph := IP6Header{ - IPProto: TSMP, + IPProto: ipproto.TSMP, Src: h.IPSrc, Dst: h.IPDst, } @@ -181,7 +188,7 @@ func (pp *Parsed) AsTailscaleRejectedHeader() (h TailscaleRejectedHeader, ok boo return } h = TailscaleRejectedHeader{ - Proto: IPProto(p[1]), + Proto: ipproto.Proto(p[1]), Reason: TailscaleRejectReason(p[2]), IPSrc: pp.Src.IP, IPDst: pp.Dst.IP, @@ -194,3 +201,58 @@ func (pp *Parsed) AsTailscaleRejectedHeader() (h TailscaleRejectedHeader, ok boo } return h, true } + +// TSMPPingRequest is a TSMP message that's like an ICMP ping request. +// +// On the wire, after the IP header, it's currently 9 bytes: +// * 'p' (TSMPTypePing) +// * 8 opaque ping bytes to copy back in the response +type TSMPPingRequest struct { + Data [8]byte +} + +func (pp *Parsed) AsTSMPPing() (h TSMPPingRequest, ok bool) { + if pp.IPProto != ipproto.TSMP { + return + } + p := pp.Payload() + if len(p) < 9 || p[0] != byte(TSMPTypePing) { + return + } + copy(h.Data[:], p[1:]) + return h, true +} + +type TSMPPongReply struct { + IPHeader Header + Data [8]byte +} + +func (pp *Parsed) AsTSMPPong() (data [8]byte, ok bool) { + if pp.IPProto != ipproto.TSMP { + return + } + p := pp.Payload() + if len(p) < 9 || p[0] != byte(TSMPTypePong) { + return + } + copy(data[:], p[1:]) + return data, true +} + +func (h TSMPPongReply) Len() int { + return h.IPHeader.Len() + 9 +} + +func (h TSMPPongReply) Marshal(buf []byte) error { + if len(buf) < h.Len() { + return errSmallBuffer + } + if err := h.IPHeader.Marshal(buf); err != nil { + return err + } + buf = buf[h.IPHeader.Len():] + buf[0] = byte(TSMPTypePong) + copy(buf[1:], h.Data[:]) + return nil +} diff --git a/net/packet/udp4.go b/net/packet/udp4.go index 82aa30179..ce179f89d 100644 --- a/net/packet/udp4.go +++ b/net/packet/udp4.go @@ -4,7 +4,11 @@ package packet -import "encoding/binary" +import ( + "encoding/binary" + + "tailscale.com/types/ipproto" +) // udpHeaderLength is the size of the UDP packet header, not including // the outer IP header. @@ -31,7 +35,7 @@ func (h UDP4Header) Marshal(buf []byte) error { return errLargePacket } // The caller does not need to set this. - h.IPProto = UDP + h.IPProto = ipproto.UDP length := len(buf) - h.IP4Header.Len() binary.BigEndian.PutUint16(buf[20:22], h.SrcPort) diff --git a/net/packet/udp6.go b/net/packet/udp6.go index 0450eae9e..18213c1fb 100644 --- a/net/packet/udp6.go +++ b/net/packet/udp6.go @@ -4,7 +4,11 @@ package packet -import "encoding/binary" +import ( + "encoding/binary" + + "tailscale.com/types/ipproto" +) // UDP6Header is an IPv6+UDP header. type UDP6Header struct { @@ -27,7 +31,7 @@ func (h UDP6Header) Marshal(buf []byte) error { return errLargePacket } // The caller does not need to set this. - h.IPProto = UDP + h.IPProto = ipproto.UDP length := len(buf) - h.IP6Header.Len() binary.BigEndian.PutUint16(buf[40:42], h.SrcPort) diff --git a/net/tsaddr/tsaddr.go b/net/tsaddr/tsaddr.go index 44cf5cf23..9bf81326e 100644 --- a/net/tsaddr/tsaddr.go +++ b/net/tsaddr/tsaddr.go @@ -37,7 +37,7 @@ var ( ) // TailscaleServiceIP returns the listen address of services -// provided by Tailscale itself such as the Magic DNS proxy. +// provided by Tailscale itself such as the MagicDNS proxy. func TailscaleServiceIP() netaddr.IP { serviceIP.Do(func() { mustIP(&serviceIP.v, "100.100.100.100") }) return serviceIP.v diff --git a/net/tstun/fake.go b/net/tstun/fake.go new file mode 100644 index 000000000..f9c3a9d6f --- /dev/null +++ b/net/tstun/fake.go @@ -0,0 +1,54 @@ +// Copyright (c) 2020 Tailscale Inc & AUTHORS All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package tstun + +import ( + "io" + "os" + + "github.com/tailscale/wireguard-go/tun" +) + +type fakeTUN struct { + evchan chan tun.Event + closechan chan struct{} +} + +// NewFake returns a tun.Device that does nothing. +func NewFake() tun.Device { + return &fakeTUN{ + evchan: make(chan tun.Event), + closechan: make(chan struct{}), + } +} + +func (t *fakeTUN) File() *os.File { + panic("fakeTUN.File() called, which makes no sense") +} + +func (t *fakeTUN) Close() error { + close(t.closechan) + close(t.evchan) + return nil +} + +func (t *fakeTUN) Read(out []byte, offset int) (int, error) { + <-t.closechan + return 0, io.EOF +} + +func (t *fakeTUN) Write(b []byte, n int) (int, error) { + select { + case <-t.closechan: + return 0, ErrClosed + default: + } + return len(b), nil +} + +func (t *fakeTUN) Flush() error { return nil } +func (t *fakeTUN) MTU() (int, error) { return 1500, nil } +func (t *fakeTUN) Name() (string, error) { return "FakeTUN", nil } +func (t *fakeTUN) Events() chan tun.Event { return t.evchan } diff --git a/net/tstun/ifstatus_noop.go b/net/tstun/ifstatus_noop.go new file mode 100644 index 000000000..223be7949 --- /dev/null +++ b/net/tstun/ifstatus_noop.go @@ -0,0 +1,19 @@ +// Copyright (c) 2021 Tailscale Inc & AUTHORS All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// +build !windows + +package tstun + +import ( + "time" + + "github.com/tailscale/wireguard-go/tun" + "tailscale.com/types/logger" +) + +// Dummy implementation that does nothing. +func waitInterfaceUp(iface tun.Device, timeout time.Duration, logf logger.Logf) error { + return nil +} diff --git a/net/tstun/ifstatus_windows.go b/net/tstun/ifstatus_windows.go new file mode 100644 index 000000000..840e50f4d --- /dev/null +++ b/net/tstun/ifstatus_windows.go @@ -0,0 +1,111 @@ +// Copyright (c) 2021 Tailscale Inc & AUTHORS All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package tstun + +import ( + "fmt" + "sync" + "time" + + "github.com/tailscale/wireguard-go/tun" + "golang.zx2c4.com/wireguard/windows/tunnel/winipcfg" + "tailscale.com/types/logger" +) + +// ifaceWatcher waits for an interface to be up. +type ifaceWatcher struct { + logf logger.Logf + luid winipcfg.LUID + + mu sync.Mutex // guards following + done bool + sig chan bool +} + +// callback is the callback we register with Windows to call when IP interface changes. +func (iw *ifaceWatcher) callback(notificationType winipcfg.MibNotificationType, iface *winipcfg.MibIPInterfaceRow) { + // Probably should check only when MibParameterNotification, but just in case included MibAddInstance also. + if notificationType == winipcfg.MibParameterNotification || notificationType == winipcfg.MibAddInstance { + // Out of paranoia, start a goroutine to finish our work, to return to Windows out of this callback. + go iw.isUp() + } +} + +func (iw *ifaceWatcher) isUp() bool { + iw.mu.Lock() + defer iw.mu.Unlock() + + if iw.done { + // We already know that it's up + return true + } + + if iw.getOperStatus() != winipcfg.IfOperStatusUp { + return false + } + + iw.done = true + iw.sig <- true + return true +} + +func (iw *ifaceWatcher) getOperStatus() winipcfg.IfOperStatus { + ifc, err := iw.luid.Interface() + if err != nil { + iw.logf("iw.luid.Interface error: %v", err) + return 0 + } + return ifc.OperStatus +} + +func waitInterfaceUp(iface tun.Device, timeout time.Duration, logf logger.Logf) error { + iw := &ifaceWatcher{ + luid: winipcfg.LUID(iface.(*tun.NativeTun).LUID()), + logf: logger.WithPrefix(logf, "waitInterfaceUp: "), + } + + // Just in case check the status first + if iw.getOperStatus() == winipcfg.IfOperStatusUp { + iw.logf("TUN interface already up; no need to wait") + return nil + } + + iw.sig = make(chan bool, 1) + cb, err := winipcfg.RegisterInterfaceChangeCallback(iw.callback) + if err != nil { + iw.logf("RegisterInterfaceChangeCallback error: %v", err) + return err + } + defer cb.Unregister() + + t0 := time.Now() + expires := t0.Add(timeout) + ticker := time.NewTicker(10 * time.Second) + defer ticker.Stop() + + for { + iw.logf("waiting for TUN interface to come up...") + + select { + case <-iw.sig: + iw.logf("TUN interface is up after %v", time.Since(t0)) + return nil + case <-ticker.C: + break + } + + if iw.isUp() { + // Very unlikely to happen - either NotifyIpInterfaceChange doesn't work + // or it came up in the same moment as tick. Indicate this in the log message. + iw.logf("TUN interface is up after %v (on poll, without notification)", time.Since(t0)) + return nil + } + + if expires.Before(time.Now()) { + iw.logf("timeout waiting %v for TUN interface to come up", timeout) + return fmt.Errorf("timeout waiting for TUN interface to come up") + } + } +} diff --git a/net/tstun/tun.go b/net/tstun/tun.go new file mode 100644 index 000000000..d480d3244 --- /dev/null +++ b/net/tstun/tun.go @@ -0,0 +1,129 @@ +// Copyright (c) 2021 Tailscale Inc & AUTHORS All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// Package tun creates a tuntap device, working around OS-specific +// quirks if necessary. +package tstun + +import ( + "bytes" + "os" + "os/exec" + "runtime" + "time" + + "github.com/tailscale/wireguard-go/tun" + "tailscale.com/types/logger" + "tailscale.com/version/distro" +) + +// minimalMTU is the MTU we set on tailscale's TUN +// interface. wireguard-go defaults to 1420 bytes, which only works if +// the "outer" MTU is 1500 bytes. This breaks on DSL connections +// (typically 1492 MTU) and on GCE (1460 MTU?!). +// +// 1280 is the smallest MTU allowed for IPv6, which is a sensible +// "probably works everywhere" setting until we develop proper PMTU +// discovery. +const minimalMTU = 1280 + +// New returns a tun.Device for the requested device name. +func New(logf logger.Logf, tunName string) (tun.Device, error) { + dev, err := tun.CreateTUN(tunName, minimalMTU) + if err != nil { + return nil, err + } + if err := waitInterfaceUp(dev, 90*time.Second, logf); err != nil { + return nil, err + } + return dev, nil +} + +// Diagnose tries to explain a tuntap device creation failure. +// It pokes around the system and logs some diagnostic info that might +// help debug why tun creation failed. Because device creation has +// already failed and the program's about to end, log a lot. +func Diagnose(logf logger.Logf, tunName string) { + switch runtime.GOOS { + case "linux": + diagnoseLinuxTUNFailure(tunName, logf) + case "darwin": + diagnoseDarwinTUNFailure(tunName, logf) + default: + logf("no TUN failure diagnostics for OS %q", runtime.GOOS) + } +} + +func diagnoseDarwinTUNFailure(tunName string, logf logger.Logf) { + if os.Getuid() != 0 { + logf("failed to create TUN device as non-root user; use 'sudo tailscaled', or run under launchd with 'sudo tailscaled install-system-daemon'") + } + if tunName != "utun" { + logf("failed to create TUN device %q; try using tun device \"utun\" instead for automatic selection", tunName) + } +} + +func diagnoseLinuxTUNFailure(tunName string, logf logger.Logf) { + kernel, err := exec.Command("uname", "-r").Output() + kernel = bytes.TrimSpace(kernel) + if err != nil { + logf("no TUN, and failed to look up kernel version: %v", err) + return + } + logf("Linux kernel version: %s", kernel) + + modprobeOut, err := exec.Command("/sbin/modprobe", "tun").CombinedOutput() + if err == nil { + logf("'modprobe tun' successful") + // Either tun is currently loaded, or it's statically + // compiled into the kernel (which modprobe checks + // with /lib/modules/$(uname -r)/modules.builtin) + // + // So if there's a problem at this point, it's + // probably because /dev/net/tun doesn't exist. + const dev = "/dev/net/tun" + if fi, err := os.Stat(dev); err != nil { + logf("tun module loaded in kernel, but %s does not exist", dev) + } else { + logf("%s: %v", dev, fi.Mode()) + } + + // We failed to find why it failed. Just let our + // caller report the error it got from wireguard-go. + return + } + logf("is CONFIG_TUN enabled in your kernel? `modprobe tun` failed with: %s", modprobeOut) + + switch distro.Get() { + case distro.Debian: + dpkgOut, err := exec.Command("dpkg", "-S", "kernel/drivers/net/tun.ko").CombinedOutput() + if len(bytes.TrimSpace(dpkgOut)) == 0 || err != nil { + logf("tun module not loaded nor found on disk") + return + } + if !bytes.Contains(dpkgOut, kernel) { + logf("kernel/drivers/net/tun.ko found on disk, but not for current kernel; are you in middle of a system update and haven't rebooted? found: %s", dpkgOut) + } + case distro.Arch: + findOut, err := exec.Command("find", "/lib/modules/", "-path", "*/net/tun.ko*").CombinedOutput() + if len(bytes.TrimSpace(findOut)) == 0 || err != nil { + logf("tun module not loaded nor found on disk") + return + } + if !bytes.Contains(findOut, kernel) { + logf("kernel/drivers/net/tun.ko found on disk, but not for current kernel; are you in middle of a system update and haven't rebooted? found: %s", findOut) + } + case distro.OpenWrt: + out, err := exec.Command("opkg", "list-installed").CombinedOutput() + if err != nil { + logf("error querying OpenWrt installed packages: %s", out) + return + } + for _, pkg := range []string{"kmod-tun", "ca-bundle"} { + if !bytes.Contains(out, []byte(pkg+" - ")) { + logf("Missing required package %s; run: opkg install %s", pkg, pkg) + } + } + } +} diff --git a/net/tstun/tun_windows.go b/net/tstun/tun_windows.go new file mode 100644 index 000000000..dc5fc2d79 --- /dev/null +++ b/net/tstun/tun_windows.go @@ -0,0 +1,24 @@ +// Copyright (c) 2020 Tailscale Inc & AUTHORS All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package tstun + +import ( + "github.com/tailscale/wireguard-go/tun" + "github.com/tailscale/wireguard-go/tun/wintun" + "golang.org/x/sys/windows" +) + +func init() { + var err error + tun.WintunPool, err = wintun.MakePool("Tailscale") + if err != nil { + panic(err) + } + guid, err := windows.GUIDFromString("{37217669-42da-4657-a55b-0d995d328250}") + if err != nil { + panic(err) + } + tun.WintunStaticRequestedGUID = &guid +} diff --git a/net/tstun/wrap.go b/net/tstun/wrap.go new file mode 100644 index 000000000..70225c52e --- /dev/null +++ b/net/tstun/wrap.go @@ -0,0 +1,498 @@ +// Copyright (c) 2020 Tailscale Inc & AUTHORS All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// Package tstun provides a TUN struct implementing the tun.Device interface +// with additional features as required by wgengine. +package tstun + +import ( + "errors" + "io" + "os" + "sync" + "sync/atomic" + "time" + + "github.com/tailscale/wireguard-go/device" + "github.com/tailscale/wireguard-go/tun" + "inet.af/netaddr" + "tailscale.com/net/packet" + "tailscale.com/types/ipproto" + "tailscale.com/types/logger" + "tailscale.com/wgengine/filter" +) + +const maxBufferSize = device.MaxMessageSize + +// PacketStartOffset is the minimal amount of leading space that must exist +// before &packet[offset] in a packet passed to Read, Write, or InjectInboundDirect. +// This is necessary to avoid reallocation in wireguard-go internals. +const PacketStartOffset = device.MessageTransportHeaderSize + +// MaxPacketSize is the maximum size (in bytes) +// of a packet that can be injected into a tstun.Wrapper. +const MaxPacketSize = device.MaxContentSize + +var ( + // ErrClosed is returned when attempting an operation on a closed Wrapper. + ErrClosed = errors.New("device closed") + // ErrFiltered is returned when the acted-on packet is rejected by a filter. + ErrFiltered = errors.New("packet dropped by filter") +) + +var ( + errPacketTooBig = errors.New("packet too big") + errOffsetTooBig = errors.New("offset larger than buffer length") + errOffsetTooSmall = errors.New("offset smaller than PacketStartOffset") +) + +// parsedPacketPool holds a pool of Parsed structs for use in filtering. +// This is needed because escape analysis cannot see that parsed packets +// do not escape through {Pre,Post}Filter{In,Out}. +var parsedPacketPool = sync.Pool{New: func() interface{} { return new(packet.Parsed) }} + +// FilterFunc is a packet-filtering function with access to the Wrapper device. +// It must not hold onto the packet struct, as its backing storage will be reused. +type FilterFunc func(*packet.Parsed, *Wrapper) filter.Response + +// Wrapper augments a tun.Device with packet filtering and injection. +type Wrapper struct { + logf logger.Logf + // tdev is the underlying Wrapper device. + tdev tun.Device + + closeOnce sync.Once + + lastActivityAtomic int64 // unix seconds of last send or receive + + destIPActivity atomic.Value // of map[netaddr.IP]func() + + // buffer stores the oldest unconsumed packet from tdev. + // It is made a static buffer in order to avoid allocations. + buffer [maxBufferSize]byte + // bufferConsumed synchronizes access to buffer (shared by Read and poll). + bufferConsumed chan struct{} + + // closed signals poll (by closing) when the device is closed. + closed chan struct{} + // errors is the error queue populated by poll. + errors chan error + // outbound is the queue by which packets leave the TUN device. + // + // The directions are relative to the network, not the device: + // inbound packets arrive via UDP and are written into the TUN device; + // outbound packets are read from the TUN device and sent out via UDP. + // This queue is needed because although inbound writes are synchronous, + // the other direction must wait on a Wireguard goroutine to poll it. + // + // Empty reads are skipped by Wireguard, so it is always legal + // to discard an empty packet instead of sending it through t.outbound. + outbound chan []byte + + // fitler stores the currently active package filter + filter atomic.Value // of *filter.Filter + // filterFlags control the verbosity of logging packet drops/accepts. + filterFlags filter.RunFlags + + // PreFilterIn is the inbound filter function that runs before the main filter + // and therefore sees the packets that may be later dropped by it. + PreFilterIn FilterFunc + // PostFilterIn is the inbound filter function that runs after the main filter. + PostFilterIn FilterFunc + // PreFilterOut is the outbound filter function that runs before the main filter + // and therefore sees the packets that may be later dropped by it. + PreFilterOut FilterFunc + // PostFilterOut is the outbound filter function that runs after the main filter. + PostFilterOut FilterFunc + + // OnTSMPPongReceived, if non-nil, is called whenever a TSMP pong arrives. + OnTSMPPongReceived func(data [8]byte) + + // disableFilter disables all filtering when set. This should only be used in tests. + disableFilter bool +} + +func Wrap(logf logger.Logf, tdev tun.Device) *Wrapper { + tun := &Wrapper{ + logf: logger.WithPrefix(logf, "tstun: "), + tdev: tdev, + // bufferConsumed is conceptually a condition variable: + // a goroutine should not block when setting it, even with no listeners. + bufferConsumed: make(chan struct{}, 1), + closed: make(chan struct{}), + errors: make(chan error), + outbound: make(chan []byte), + // TODO(dmytro): (highly rate-limited) hexdumps should happen on unknown packets. + filterFlags: filter.LogAccepts | filter.LogDrops, + } + + go tun.poll() + // The buffer starts out consumed. + tun.bufferConsumed <- struct{}{} + + return tun +} + +// SetDestIPActivityFuncs sets a map of funcs to run per packet +// destination (the map keys). +// +// The map ownership passes to the Wrapper. It must be non-nil. +func (t *Wrapper) SetDestIPActivityFuncs(m map[netaddr.IP]func()) { + t.destIPActivity.Store(m) +} + +func (t *Wrapper) Close() error { + var err error + t.closeOnce.Do(func() { + // Other channels need not be closed: poll will exit gracefully after this. + close(t.closed) + + err = t.tdev.Close() + }) + return err +} + +func (t *Wrapper) Events() chan tun.Event { + return t.tdev.Events() +} + +func (t *Wrapper) File() *os.File { + return t.tdev.File() +} + +func (t *Wrapper) Flush() error { + return t.tdev.Flush() +} + +func (t *Wrapper) MTU() (int, error) { + return t.tdev.MTU() +} + +func (t *Wrapper) Name() (string, error) { + return t.tdev.Name() +} + +// poll polls t.tdev.Read, placing the oldest unconsumed packet into t.buffer. +// This is needed because t.tdev.Read in general may block (it does on Windows), +// so packets may be stuck in t.outbound if t.Read called t.tdev.Read directly. +func (t *Wrapper) poll() { + for { + select { + case <-t.closed: + return + case <-t.bufferConsumed: + // continue + } + + // Read may use memory in t.buffer before PacketStartOffset for mandatory headers. + // This is the rationale behind the tun.Wrapper.{Read,Write} interfaces + // and the reason t.buffer has size MaxMessageSize and not MaxContentSize. + n, err := t.tdev.Read(t.buffer[:], PacketStartOffset) + if err != nil { + select { + case <-t.closed: + return + case t.errors <- err: + // In principle, read errors are not fatal (but wireguard-go disagrees). + t.bufferConsumed <- struct{}{} + } + continue + } + + // Wireguard will skip an empty read, + // so we might as well do it here to avoid the send through t.outbound. + if n == 0 { + t.bufferConsumed <- struct{}{} + continue + } + + select { + case <-t.closed: + return + case t.outbound <- t.buffer[PacketStartOffset : PacketStartOffset+n]: + // continue + } + } +} + +var magicDNSIPPort = netaddr.MustParseIPPort("100.100.100.100:0") + +func (t *Wrapper) filterOut(p *packet.Parsed) filter.Response { + // Fake ICMP echo responses to MagicDNS (100.100.100.100). + if p.IsEchoRequest() && p.Dst == magicDNSIPPort { + header := p.ICMP4Header() + header.ToResponse() + outp := packet.Generate(&header, p.Payload()) + t.InjectInboundCopy(outp) + return filter.DropSilently // don't pass on to OS; already handled + } + + if t.PreFilterOut != nil { + if res := t.PreFilterOut(p, t); res.IsDrop() { + return res + } + } + + filt, _ := t.filter.Load().(*filter.Filter) + + if filt == nil { + return filter.Drop + } + + if filt.RunOut(p, t.filterFlags) != filter.Accept { + return filter.Drop + } + + if t.PostFilterOut != nil { + if res := t.PostFilterOut(p, t); res.IsDrop() { + return res + } + } + + return filter.Accept +} + +// noteActivity records that there was a read or write at the current time. +func (t *Wrapper) noteActivity() { + atomic.StoreInt64(&t.lastActivityAtomic, time.Now().Unix()) +} + +// IdleDuration reports how long it's been since the last read or write to this device. +// +// Its value is only accurate to roughly second granularity. +// If there's never been activity, the duration is since 1970. +func (t *Wrapper) IdleDuration() time.Duration { + sec := atomic.LoadInt64(&t.lastActivityAtomic) + return time.Since(time.Unix(sec, 0)) +} + +func (t *Wrapper) Read(buf []byte, offset int) (int, error) { + var n int + + wasInjectedPacket := false + + select { + case <-t.closed: + return 0, io.EOF + case err := <-t.errors: + return 0, err + case pkt := <-t.outbound: + n = copy(buf[offset:], pkt) + // t.buffer has a fixed location in memory, + // so this is the easiest way to tell when it has been consumed. + // &pkt[0] can be used because empty packets do not reach t.outbound. + if &pkt[0] == &t.buffer[PacketStartOffset] { + t.bufferConsumed <- struct{}{} + } else { + // If the packet is not from t.buffer, then it is an injected packet. + wasInjectedPacket = true + } + } + + p := parsedPacketPool.Get().(*packet.Parsed) + defer parsedPacketPool.Put(p) + p.Decode(buf[offset : offset+n]) + + if m, ok := t.destIPActivity.Load().(map[netaddr.IP]func()); ok { + if fn := m[p.Dst.IP]; fn != nil { + fn() + } + } + + // For injected packets, we return early to bypass filtering. + if wasInjectedPacket { + t.noteActivity() + return n, nil + } + + if !t.disableFilter { + response := t.filterOut(p) + if response != filter.Accept { + // Wireguard considers read errors fatal; pretend nothing was read + return 0, nil + } + } + + t.noteActivity() + return n, nil +} + +func (t *Wrapper) filterIn(buf []byte) filter.Response { + p := parsedPacketPool.Get().(*packet.Parsed) + defer parsedPacketPool.Put(p) + p.Decode(buf) + + if p.IPProto == ipproto.TSMP { + if pingReq, ok := p.AsTSMPPing(); ok { + t.noteActivity() + t.injectOutboundPong(p, pingReq) + return filter.DropSilently + } else if data, ok := p.AsTSMPPong(); ok { + if f := t.OnTSMPPongReceived; f != nil { + f(data) + } + } + } + + if t.PreFilterIn != nil { + if res := t.PreFilterIn(p, t); res.IsDrop() { + return res + } + } + + filt, _ := t.filter.Load().(*filter.Filter) + + if filt == nil { + return filter.Drop + } + + if filt.RunIn(p, t.filterFlags) != filter.Accept { + + // Tell them, via TSMP, we're dropping them due to the ACL. + // Their host networking stack can translate this into ICMP + // or whatnot as required. But notably, their GUI or tailscale CLI + // can show them a rejection history with reasons. + if p.IPVersion == 4 && p.IPProto == ipproto.TCP && p.TCPFlags&packet.TCPSyn != 0 { + rj := packet.TailscaleRejectedHeader{ + IPSrc: p.Dst.IP, + IPDst: p.Src.IP, + Src: p.Src, + Dst: p.Dst, + Proto: p.IPProto, + Reason: packet.RejectedDueToACLs, + } + if filt.ShieldsUp() { + rj.Reason = packet.RejectedDueToShieldsUp + } + pkt := packet.Generate(rj, nil) + t.InjectOutbound(pkt) + + // TODO(bradfitz): also send a TCP RST, after the TSMP message. + } + + return filter.Drop + } + + if t.PostFilterIn != nil { + if res := t.PostFilterIn(p, t); res.IsDrop() { + return res + } + } + + return filter.Accept +} + +// Write accepts an incoming packet. The packet begins at buf[offset:], +// like wireguard-go/tun.Device.Write. +func (t *Wrapper) Write(buf []byte, offset int) (int, error) { + if !t.disableFilter { + res := t.filterIn(buf[offset:]) + if res == filter.DropSilently { + return len(buf), nil + } + if res != filter.Accept { + return 0, ErrFiltered + } + } + + t.noteActivity() + return t.tdev.Write(buf, offset) +} + +func (t *Wrapper) GetFilter() *filter.Filter { + filt, _ := t.filter.Load().(*filter.Filter) + return filt +} + +func (t *Wrapper) SetFilter(filt *filter.Filter) { + t.filter.Store(filt) +} + +// InjectInboundDirect makes the Wrapper device behave as if a packet +// with the given contents was received from the network. +// It blocks and does not take ownership of the packet. +// The injected packet will not pass through inbound filters. +// +// The packet contents are to start at &buf[offset]. +// offset must be greater or equal to PacketStartOffset. +// The space before &buf[offset] will be used by Wireguard. +func (t *Wrapper) InjectInboundDirect(buf []byte, offset int) error { + if len(buf) > MaxPacketSize { + return errPacketTooBig + } + if len(buf) < offset { + return errOffsetTooBig + } + if offset < PacketStartOffset { + return errOffsetTooSmall + } + + // Write to the underlying device to skip filters. + _, err := t.tdev.Write(buf, offset) + return err +} + +// InjectInboundCopy takes a packet without leading space, +// reallocates it to conform to the InjectInboundDirect interface +// and calls InjectInboundDirect on it. Injecting a nil packet is a no-op. +func (t *Wrapper) InjectInboundCopy(packet []byte) error { + // We duplicate this check from InjectInboundDirect here + // to avoid wasting an allocation on an oversized packet. + if len(packet) > MaxPacketSize { + return errPacketTooBig + } + if len(packet) == 0 { + return nil + } + + buf := make([]byte, PacketStartOffset+len(packet)) + copy(buf[PacketStartOffset:], packet) + + return t.InjectInboundDirect(buf, PacketStartOffset) +} + +func (t *Wrapper) injectOutboundPong(pp *packet.Parsed, req packet.TSMPPingRequest) { + pong := packet.TSMPPongReply{ + Data: req.Data, + } + switch pp.IPVersion { + case 4: + h4 := pp.IP4Header() + h4.ToResponse() + pong.IPHeader = h4 + case 6: + h6 := pp.IP6Header() + h6.ToResponse() + pong.IPHeader = h6 + default: + return + } + + t.InjectOutbound(packet.Generate(pong, nil)) +} + +// InjectOutbound makes the Wrapper device behave as if a packet +// with the given contents was sent to the network. +// It does not block, but takes ownership of the packet. +// The injected packet will not pass through outbound filters. +// Injecting an empty packet is a no-op. +func (t *Wrapper) InjectOutbound(packet []byte) error { + if len(packet) > MaxPacketSize { + return errPacketTooBig + } + if len(packet) == 0 { + return nil + } + select { + case <-t.closed: + return ErrClosed + case t.outbound <- packet: + return nil + } +} + +// Unwrap returns the underlying tun.Device. +func (t *Wrapper) Unwrap() tun.Device { + return t.tdev +} diff --git a/net/tstun/wrap_test.go b/net/tstun/wrap_test.go new file mode 100644 index 000000000..4032b168b --- /dev/null +++ b/net/tstun/wrap_test.go @@ -0,0 +1,387 @@ +// Copyright (c) 2020 Tailscale Inc & AUTHORS All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package tstun + +import ( + "bytes" + "fmt" + "strconv" + "strings" + "sync/atomic" + "testing" + "unsafe" + + "github.com/tailscale/wireguard-go/tun/tuntest" + "inet.af/netaddr" + "tailscale.com/net/packet" + "tailscale.com/types/ipproto" + "tailscale.com/types/logger" + "tailscale.com/wgengine/filter" +) + +func udp4(src, dst string, sport, dport uint16) []byte { + sip, err := netaddr.ParseIP(src) + if err != nil { + panic(err) + } + dip, err := netaddr.ParseIP(dst) + if err != nil { + panic(err) + } + header := &packet.UDP4Header{ + IP4Header: packet.IP4Header{ + Src: sip, + Dst: dip, + IPID: 0, + }, + SrcPort: sport, + DstPort: dport, + } + return packet.Generate(header, []byte("udp_payload")) +} + +func nets(nets ...string) (ret []netaddr.IPPrefix) { + for _, s := range nets { + if i := strings.IndexByte(s, '/'); i == -1 { + ip, err := netaddr.ParseIP(s) + if err != nil { + panic(err) + } + bits := uint8(32) + if ip.Is6() { + bits = 128 + } + ret = append(ret, netaddr.IPPrefix{IP: ip, Bits: bits}) + } else { + pfx, err := netaddr.ParseIPPrefix(s) + if err != nil { + panic(err) + } + ret = append(ret, pfx) + } + } + return ret +} + +func ports(s string) filter.PortRange { + if s == "*" { + return filter.PortRange{First: 0, Last: 65535} + } + + var fs, ls string + i := strings.IndexByte(s, '-') + if i == -1 { + fs = s + ls = fs + } else { + fs = s[:i] + ls = s[i+1:] + } + first, err := strconv.ParseInt(fs, 10, 16) + if err != nil { + panic(fmt.Sprintf("invalid NetPortRange %q", s)) + } + last, err := strconv.ParseInt(ls, 10, 16) + if err != nil { + panic(fmt.Sprintf("invalid NetPortRange %q", s)) + } + return filter.PortRange{First: uint16(first), Last: uint16(last)} +} + +func netports(netPorts ...string) (ret []filter.NetPortRange) { + for _, s := range netPorts { + i := strings.LastIndexByte(s, ':') + if i == -1 { + panic(fmt.Sprintf("invalid NetPortRange %q", s)) + } + + npr := filter.NetPortRange{ + Net: nets(s[:i])[0], + Ports: ports(s[i+1:]), + } + ret = append(ret, npr) + } + return ret +} + +func setfilter(logf logger.Logf, tun *Wrapper) { + protos := []ipproto.Proto{ + ipproto.TCP, + ipproto.UDP, + } + matches := []filter.Match{ + {IPProto: protos, Srcs: nets("5.6.7.8"), Dsts: netports("1.2.3.4:89-90")}, + {IPProto: protos, Srcs: nets("1.2.3.4"), Dsts: netports("5.6.7.8:98")}, + } + var sb netaddr.IPSetBuilder + sb.AddPrefix(netaddr.MustParseIPPrefix("1.2.0.0/16")) + tun.SetFilter(filter.New(matches, sb.IPSet(), sb.IPSet(), nil, logf)) +} + +func newChannelTUN(logf logger.Logf, secure bool) (*tuntest.ChannelTUN, *Wrapper) { + chtun := tuntest.NewChannelTUN() + tun := Wrap(logf, chtun.TUN()) + if secure { + setfilter(logf, tun) + } else { + tun.disableFilter = true + } + return chtun, tun +} + +func newFakeTUN(logf logger.Logf, secure bool) (*fakeTUN, *Wrapper) { + ftun := NewFake() + tun := Wrap(logf, ftun) + if secure { + setfilter(logf, tun) + } else { + tun.disableFilter = true + } + return ftun.(*fakeTUN), tun +} + +func TestReadAndInject(t *testing.T) { + chtun, tun := newChannelTUN(t.Logf, false) + defer tun.Close() + + const size = 2 // all payloads have this size + written := []string{"w0", "w1"} + injected := []string{"i0", "i1"} + + go func() { + for _, packet := range written { + payload := []byte(packet) + chtun.Outbound <- payload + } + }() + + for _, packet := range injected { + go func(packet string) { + payload := []byte(packet) + err := tun.InjectOutbound(payload) + if err != nil { + t.Errorf("%s: error: %v", packet, err) + } + }(packet) + } + + var buf [MaxPacketSize]byte + var seen = make(map[string]bool) + // We expect the same packets back, in no particular order. + for i := 0; i < len(written)+len(injected); i++ { + n, err := tun.Read(buf[:], 0) + if err != nil { + t.Errorf("read %d: error: %v", i, err) + } + if n != size { + t.Errorf("read %d: got size %d; want %d", i, n, size) + } + got := string(buf[:n]) + t.Logf("read %d: got %s", i, got) + seen[got] = true + } + + for _, packet := range written { + if !seen[packet] { + t.Errorf("%s not received", packet) + } + } + for _, packet := range injected { + if !seen[packet] { + t.Errorf("%s not received", packet) + } + } +} + +func TestWriteAndInject(t *testing.T) { + chtun, tun := newChannelTUN(t.Logf, false) + defer tun.Close() + + const size = 2 // all payloads have this size + written := []string{"w0", "w1"} + injected := []string{"i0", "i1"} + + go func() { + for _, packet := range written { + payload := []byte(packet) + n, err := tun.Write(payload, 0) + if err != nil { + t.Errorf("%s: error: %v", packet, err) + } + if n != size { + t.Errorf("%s: got size %d; want %d", packet, n, size) + } + } + }() + + for _, packet := range injected { + go func(packet string) { + payload := []byte(packet) + err := tun.InjectInboundCopy(payload) + if err != nil { + t.Errorf("%s: error: %v", packet, err) + } + }(packet) + } + + seen := make(map[string]bool) + // We expect the same packets back, in no particular order. + for i := 0; i < len(written)+len(injected); i++ { + packet := <-chtun.Inbound + got := string(packet) + t.Logf("read %d: got %s", i, got) + seen[got] = true + } + + for _, packet := range written { + if !seen[packet] { + t.Errorf("%s not received", packet) + } + } + for _, packet := range injected { + if !seen[packet] { + t.Errorf("%s not received", packet) + } + } +} + +func TestFilter(t *testing.T) { + chtun, tun := newChannelTUN(t.Logf, true) + defer tun.Close() + + type direction int + + const ( + in direction = iota + out + ) + + tests := []struct { + name string + dir direction + drop bool + data []byte + }{ + {"junk_in", in, true, []byte("\x45not a valid IPv4 packet")}, + {"junk_out", out, true, []byte("\x45not a valid IPv4 packet")}, + {"bad_port_in", in, true, udp4("5.6.7.8", "1.2.3.4", 22, 22)}, + {"bad_port_out", out, false, udp4("1.2.3.4", "5.6.7.8", 22, 22)}, + {"bad_ip_in", in, true, udp4("8.1.1.1", "1.2.3.4", 89, 89)}, + {"bad_ip_out", out, false, udp4("1.2.3.4", "8.1.1.1", 98, 98)}, + {"good_packet_in", in, false, udp4("5.6.7.8", "1.2.3.4", 89, 89)}, + {"good_packet_out", out, false, udp4("1.2.3.4", "5.6.7.8", 98, 98)}, + } + + // A reader on the other end of the tun. + go func() { + var recvbuf []byte + for { + select { + case <-tun.closed: + return + case recvbuf = <-chtun.Inbound: + // continue + } + for _, tt := range tests { + if tt.drop && bytes.Equal(recvbuf, tt.data) { + t.Errorf("did not drop %s", tt.name) + } + } + } + }() + + var buf [MaxPacketSize]byte + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + var n int + var err error + var filtered bool + + if tt.dir == in { + _, err = tun.Write(tt.data, 0) + if err == ErrFiltered { + filtered = true + err = nil + } + } else { + chtun.Outbound <- tt.data + n, err = tun.Read(buf[:], 0) + // In the read direction, errors are fatal, so we return n = 0 instead. + filtered = (n == 0) + } + + if err != nil { + t.Errorf("got err %v; want nil", err) + } + + if filtered { + if !tt.drop { + t.Errorf("got drop; want accept") + } + } else { + if tt.drop { + t.Errorf("got accept; want drop") + } + } + }) + } +} + +func TestAllocs(t *testing.T) { + ftun, tun := newFakeTUN(t.Logf, false) + defer tun.Close() + + buf := []byte{0x00} + allocs := testing.AllocsPerRun(100, func() { + _, err := ftun.Write(buf, 0) + if err != nil { + t.Errorf("write: error: %v", err) + return + } + }) + + if allocs > 0 { + t.Errorf("read allocs = %v; want 0", allocs) + } +} + +func TestClose(t *testing.T) { + ftun, tun := newFakeTUN(t.Logf, false) + + data := udp4("1.2.3.4", "5.6.7.8", 98, 98) + _, err := ftun.Write(data, 0) + if err != nil { + t.Error(err) + } + + tun.Close() + _, err = ftun.Write(data, 0) + if err == nil { + t.Error("Expected error from ftun.Write() after Close()") + } +} + +func BenchmarkWrite(b *testing.B) { + ftun, tun := newFakeTUN(b.Logf, true) + defer tun.Close() + + packet := udp4("5.6.7.8", "1.2.3.4", 89, 89) + for i := 0; i < b.N; i++ { + _, err := ftun.Write(packet, 0) + if err != nil { + b.Errorf("err = %v; want nil", err) + } + } +} + +func TestAtomic64Alignment(t *testing.T) { + off := unsafe.Offsetof(Wrapper{}.lastActivityAtomic) + if off%8 != 0 { + t.Errorf("offset %v not 8-byte aligned", off) + } + + c := new(Wrapper) + atomic.StoreInt64(&c.lastActivityAtomic, 123) +} |
