summaryrefslogtreecommitdiffhomepage
path: root/net/dns/direct_linux_test.go
blob: 8199b41f3b97356a64a61e0096acb8315d6ec442 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
// Copyright (c) Tailscale Inc & contributors
// SPDX-License-Identifier: BSD-3-Clause

//go:build linux

package dns

import (
	"context"
	"fmt"
	"net/netip"
	"os"
	"path/filepath"
	"testing"
	"testing/synctest"

	"github.com/illarion/gonotify/v3"

	"tailscale.com/util/dnsname"
	"tailscale.com/util/eventbus/eventbustest"
)

func TestDNSTrampleRecovery(t *testing.T) {
	HookWatchFile.Set(watchFile)
	synctest.Test(t, func(t *testing.T) {
		tmp := t.TempDir()
		if err := os.MkdirAll(filepath.Join(tmp, "etc"), 0700); err != nil {
			t.Fatal(err)
		}
		const resolvPath = "/etc/resolv.conf"
		fs := directFS{prefix: tmp}
		readFile := func(t *testing.T, path string) string {
			t.Helper()
			b, err := fs.ReadFile(path)
			if err != nil {
				t.Errorf("Reading DNS config: %v", err)
			}
			return string(b)
		}

		bus := eventbustest.NewBus(t)
		eventbustest.LogAllEvents(t, bus)
		m := newDirectManagerOnFS(t.Logf, nil, bus, fs)
		defer m.Close()

		if err := m.SetDNS(OSConfig{
			Nameservers:   []netip.Addr{netip.MustParseAddr("8.8.8.8"), netip.MustParseAddr("8.8.4.4")},
			SearchDomains: []dnsname.FQDN{"ts.net.", "ts-dns.test."},
			MatchDomains:  []dnsname.FQDN{"ignored."},
		}); err != nil {
			t.Fatal(err)
		}

		const want = `# resolv.conf(5) file generated by tailscale
# For more info, see https://tailscale.com/s/resolvconf-overwrite
# DO NOT EDIT THIS FILE BY HAND -- CHANGES WILL BE OVERWRITTEN

nameserver 8.8.8.8
nameserver 8.8.4.4
search ts.net ts-dns.test
`
		if got := readFile(t, resolvPath); got != want {
			t.Fatalf("resolv.conf:\n%s, want:\n%s", got, want)
		}

		tw := eventbustest.NewWatcher(t, bus)

		const trample = "Hvem er det som tramper på min bro?"
		if err := fs.WriteFile(resolvPath, []byte(trample), 0644); err != nil {
			t.Fatal(err)
		}
		synctest.Wait()

		if err := eventbustest.Expect(tw, eventbustest.Type[TrampleDNS]()); err != nil {
			t.Errorf("did not see trample event: %s", err)
		}
	})
}

// watchFile is generally copied from linuxtrample, but cancels the context
// after the first call to cb() after the first trample to end the test.
func watchFile(ctx context.Context, dir, filename string, cb func()) error {
	ctx, cancel := context.WithCancel(ctx)
	defer cancel()

	const events = gonotify.IN_ATTRIB |
		gonotify.IN_CLOSE_WRITE |
		gonotify.IN_CREATE |
		gonotify.IN_DELETE |
		gonotify.IN_MODIFY |
		gonotify.IN_MOVE

	watcher, err := gonotify.NewDirWatcher(ctx, events, dir)
	if err != nil {
		return fmt.Errorf("NewDirWatcher: %w", err)
	}

	for {
		select {
		case event := <-watcher.C:
			if event.Name == filename {
				cb()
				cancel()
			}
		case <-ctx.Done():
			return ctx.Err()
		}
	}
}