summaryrefslogtreecommitdiffhomepage
path: root/control/controlbase/conn_test.go
blob: ed4642d3b179c42bcbace341ed1acc22305e8122 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
// Copyright (c) Tailscale Inc & AUTHORS
// SPDX-License-Identifier: BSD-3-Clause

package controlbase

import (
	"bufio"
	"bytes"
	"context"
	"encoding/binary"
	"fmt"
	"io"
	"net"
	"runtime"
	"strings"
	"sync"
	"testing"
	"testing/iotest"
	"time"

	chp "golang.org/x/crypto/chacha20poly1305"
	"golang.org/x/net/nettest"
	"tailscale.com/net/memnet"
	"tailscale.com/types/key"
)

const testProtocolVersion = 1

func TestMessageSize(t *testing.T) {
	// This test is a regression guard against someone looking at
	// maxCiphertextSize, going "huh, we could be more efficient if it
	// were larger, and accidentally violating the Noise spec. Do not
	// change this max value, it's a deliberate limitation of the
	// cryptographic protocol we use (see Section 3 "Message Format"
	// of the Noise spec).
	const max = 65535
	if maxCiphertextSize > max {
		t.Fatalf("max ciphertext size is %d, which is larger than the maximum noise message size %d", maxCiphertextSize, max)
	}
}

func TestConnBasic(t *testing.T) {
	client, server := pair(t)

	sb := sinkReads(server)

	want := "test"
	if _, err := io.WriteString(client, want); err != nil {
		t.Fatalf("client write failed: %v", err)
	}
	client.Close()

	if got := sb.String(4); got != want {
		t.Fatalf("wrong content received: got %q, want %q", got, want)
	}
	if err := sb.Error(); err != io.EOF {
		t.Fatal("client close wasn't seen by server")
	}
	if sb.Total() != 4 {
		t.Fatalf("wrong amount of bytes received: got %d, want 4", sb.Total())
	}
}

// bufferedWriteConn wraps a net.Conn and gives control over how
// Writes get batched out.
type bufferedWriteConn struct {
	net.Conn
	w           *bufio.Writer
	manualFlush bool
}

func (c *bufferedWriteConn) Write(bs []byte) (int, error) {
	n, err := c.w.Write(bs)
	if err == nil && !c.manualFlush {
		err = c.w.Flush()
	}
	return n, err
}

// TestFastPath exercises the Read codepath that can receive multiple
// Noise frames at once and decode each in turn without making another
// syscall.
func TestFastPath(t *testing.T) {
	s1, s2 := memnet.NewConn("noise", 128000)
	b := &bufferedWriteConn{s1, bufio.NewWriterSize(s1, 10000), false}
	client, server := pairWithConns(t, b, s2)

	b.manualFlush = true

	sb := sinkReads(server)

	const packets = 10
	s := "test"
	for range packets {
		// Many separate writes, to force separate Noise frames that
		// all get buffered up and then all sent as a single slice to
		// the server.
		if _, err := io.WriteString(client, s); err != nil {
			t.Fatalf("client write1 failed: %v", err)
		}
	}
	if err := b.w.Flush(); err != nil {
		t.Fatalf("client flush failed: %v", err)
	}
	client.Close()

	want := strings.Repeat(s, packets)
	if got := sb.String(len(want)); got != want {
		t.Fatalf("wrong content received: got %q, want %q", got, want)
	}
	if err := sb.Error(); err != io.EOF {
		t.Fatalf("client close wasn't seen by server")
	}
}

// Writes things larger than a single Noise frame, to check the
// chunking on the encoder and decoder.
func TestBigData(t *testing.T) {
	client, server := pair(t)

	serverReads := sinkReads(server)
	clientReads := sinkReads(client)

	const sz = 15 * 1024 // 15KiB
	clientStr := strings.Repeat("abcde", sz/5)
	serverStr := strings.Repeat("fghij", sz/5*2)

	if _, err := io.WriteString(client, clientStr); err != nil {
		t.Fatalf("writing client>server: %v", err)
	}
	if _, err := io.WriteString(server, serverStr); err != nil {
		t.Fatalf("writing server>client: %v", err)
	}

	if serverGot := serverReads.String(sz); serverGot != clientStr {
		t.Error("server didn't receive what client sent")
	}
	if clientGot := clientReads.String(2 * sz); clientGot != serverStr {
		t.Error("client didn't receive what server sent")
	}

	getNonce := func(n [chp.NonceSize]byte) uint64 {
		if binary.BigEndian.Uint32(n[:4]) != 0 {
			panic("unexpected nonce")
		}
		return binary.BigEndian.Uint64(n[4:])
	}

	// Reach into the Conns and verify the cipher nonces advanced as
	// expected.
	if getNonce(client.tx.nonce) != getNonce(server.rx.nonce) {
		t.Error("desynchronized client tx nonce")
	}
	if getNonce(server.tx.nonce) != getNonce(client.rx.nonce) {
		t.Error("desynchronized server tx nonce")
	}
	if n := getNonce(client.tx.nonce); n != 4 {
		t.Errorf("wrong client tx nonce, got %d want 4", n)
	}
	if n := getNonce(server.tx.nonce); n != 8 {
		t.Errorf("wrong client tx nonce, got %d want 8", n)
	}
}

// readerConn wraps a net.Conn and routes its Reads through a separate
// io.Reader.
type readerConn struct {
	net.Conn
	r io.Reader
}

func (c readerConn) Read(bs []byte) (int, error) { return c.r.Read(bs) }

// Check that the receiver can handle not being able to read an entire
// frame in a single syscall.
func TestDataTrickle(t *testing.T) {
	s1, s2 := memnet.NewConn("noise", 128000)
	client, server := pairWithConns(t, s1, readerConn{s2, iotest.OneByteReader(s2)})
	serverReads := sinkReads(server)

	const sz = 10000
	clientStr := strings.Repeat("abcde", sz/5)
	if _, err := io.WriteString(client, clientStr); err != nil {
		t.Fatalf("writing client>server: %v", err)
	}

	serverGot := serverReads.String(sz)
	if serverGot != clientStr {
		t.Error("server didn't receive what client sent")
	}
}

func TestConnStd(t *testing.T) {
	// You can run this test manually, and noise.Conn should pass all
	// of them except for TestConn/PastTimeout,
	// TestConn/FutureTimeout, TestConn/ConcurrentMethods, because
	// those tests assume that write errors are recoverable, and
	// they're not on our Conn due to cipher security.
	t.Skip("not all tests can pass on this Conn, see https://github.com/golang/go/issues/46977")
	nettest.TestConn(t, func() (c1 net.Conn, c2 net.Conn, stop func(), err error) {
		s1, s2 := memnet.NewConn("noise", 4096)
		controlKey := key.NewMachine()
		machineKey := key.NewMachine()
		serverErr := make(chan error, 1)
		go func() {
			var err error
			c2, err = Server(context.Background(), s2, controlKey, nil)
			serverErr <- err
		}()
		c1, err = Client(context.Background(), s1, machineKey, controlKey.Public(), testProtocolVersion)
		if err != nil {
			s1.Close()
			s2.Close()
			return nil, nil, nil, fmt.Errorf("connecting client: %w", err)
		}
		if err := <-serverErr; err != nil {
			c1.Close()
			s1.Close()
			s2.Close()
			return nil, nil, nil, fmt.Errorf("connecting server: %w", err)
		}
		return c1, c2, func() {
			c1.Close()
			c2.Close()
		}, nil
	})
}

// tests that the idle memory overhead of a Conn blocked in a read is
// reasonable (under 2K). It was previously over 8KB with two 4KB
// buffers for rx/tx. This make sure we don't regress. Hopefully it
// doesn't turn into a flaky test. If so, const max can be adjusted,
// or it can be deleted or reworked.
func TestConnMemoryOverhead(t *testing.T) {
	num := 1000
	if testing.Short() {
		num = 100
	}
	ng0 := runtime.NumGoroutine()

	runtime.GC()
	var ms0 runtime.MemStats
	runtime.ReadMemStats(&ms0)

	var closers []io.Closer
	closeAll := func() {
		for _, c := range closers {
			c.Close()
		}
		closers = nil
	}
	defer closeAll()

	for range num {
		client, server := pair(t)
		closers = append(closers, client, server)
		go func() {
			var buf [1]byte
			client.Read(buf[:])
		}()
	}

	t0 := time.Now()
	deadline := t0.Add(3 * time.Second)
	var ngo int
	for time.Now().Before(deadline) {
		runtime.GC()
		ngo = runtime.NumGoroutine()
		if ngo >= num {
			break
		}
		time.Sleep(10 * time.Millisecond)
	}
	if ngo < num {
		t.Fatalf("only %v goroutines; expected %v+", ngo, num)
	}
	runtime.GC()
	var ms runtime.MemStats
	runtime.ReadMemStats(&ms)
	growthTotal := int64(ms.HeapAlloc) - int64(ms0.HeapAlloc)
	growthEach := float64(growthTotal) / float64(num)
	t.Logf("Alloced %v bytes, %.2f B/each", growthTotal, growthEach)
	const max = 2048
	if growthEach > max {
		t.Errorf("allocated more than expected; want max %v bytes/each", max)
	}

	closeAll()

	// And make sure our goroutines go away too.
	deadline = time.Now().Add(3 * time.Second)
	for time.Now().Before(deadline) {
		ngo = runtime.NumGoroutine()
		if ngo < ng0+num/10 {
			break
		}
		time.Sleep(10 * time.Millisecond)
	}
	if ngo >= ng0+num/10 {
		t.Errorf("goroutines didn't go back down; started at %v, now %v", ng0, ngo)
	}
}

type readSink struct {
	r io.Reader

	cond *sync.Cond
	sync.Mutex
	bs  bytes.Buffer
	err error
}

func sinkReads(r io.Reader) *readSink {
	ret := &readSink{
		r: r,
	}
	ret.cond = sync.NewCond(&ret.Mutex)
	go func() {
		var buf [4096]byte
		for {
			n, err := r.Read(buf[:])
			ret.Lock()
			ret.bs.Write(buf[:n])
			if err != nil {
				ret.err = err
			}
			ret.cond.Broadcast()
			ret.Unlock()
			if err != nil {
				return
			}
		}
	}()
	return ret
}

func (s *readSink) String(total int) string {
	s.Lock()
	defer s.Unlock()
	for s.bs.Len() < total && s.err == nil {
		s.cond.Wait()
	}
	if s.err != nil {
		total = s.bs.Len()
	}
	return string(s.bs.Bytes()[:total])
}

func (s *readSink) Error() error {
	s.Lock()
	defer s.Unlock()
	for s.err == nil {
		s.cond.Wait()
	}
	return s.err
}

func (s *readSink) Total() int {
	s.Lock()
	defer s.Unlock()
	return s.bs.Len()
}

func pairWithConns(t *testing.T, clientConn, serverConn net.Conn) (*Conn, *Conn) {
	var (
		controlKey = key.NewMachine()
		machineKey = key.NewMachine()
		server     *Conn
		serverErr  = make(chan error, 1)
	)
	go func() {
		var err error
		server, err = Server(context.Background(), serverConn, controlKey, nil)
		serverErr <- err
	}()

	client, err := Client(context.Background(), clientConn, machineKey, controlKey.Public(), testProtocolVersion)
	if err != nil {
		t.Fatalf("client connection failed: %v", err)
	}
	if err := <-serverErr; err != nil {
		t.Fatalf("server connection failed: %v", err)
	}
	return client, server
}

func pair(t *testing.T) (*Conn, *Conn) {
	s1, s2 := memnet.NewConn("noise", 128000)
	return pairWithConns(t, s1, s2)
}