summaryrefslogtreecommitdiffhomepage
path: root/net/memnet/memnet.go
blob: 25b1062a19cec8c60a2d4ee4737262899ffa9398 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
// Copyright (c) Tailscale Inc & contributors
// SPDX-License-Identifier: BSD-3-Clause

// Package memnet implements an in-memory network implementation.
// It is useful for dialing and listening on in-memory addresses
// in tests and other situations where you don't want to use the
// network.
package memnet

import (
	"context"
	"fmt"
	"net"
	"net/netip"

	"tailscale.com/net/netx"
	"tailscale.com/syncs"
)

var _ netx.Network = (*Network)(nil)

// Network implements [Network] using an in-memory network, usually
// used for testing.
//
// As of 2025-04-08, it only supports TCP.
//
// Its zero value is a valid [netx.Network] implementation.
type Network struct {
	mu  syncs.Mutex
	lns map[string]*Listener // address -> listener
}

func (m *Network) Listen(network, address string) (net.Listener, error) {
	if network != "tcp" && network != "tcp4" && network != "tcp6" {
		return nil, fmt.Errorf("memNetwork: Listen called with unsupported network %q", network)
	}
	ap, err := netip.ParseAddrPort(address)
	if err != nil {
		return nil, fmt.Errorf("memNetwork: Listen called with invalid address %q: %w", address, err)
	}

	m.mu.Lock()
	defer m.mu.Unlock()

	if m.lns == nil {
		m.lns = make(map[string]*Listener)
	}
	port := ap.Port()
	for {
		if port == 0 {
			port = 33000
		}
		key := net.JoinHostPort(ap.Addr().String(), fmt.Sprint(port))
		_, ok := m.lns[key]
		if ok {
			if ap.Port() != 0 {
				return nil, fmt.Errorf("memNetwork: Listen called with duplicate address %q", address)
			}
			port++
			continue
		}
		ln := Listen(key)
		m.lns[key] = ln
		ln.onClose = func() {
			m.mu.Lock()
			delete(m.lns, key)
			m.mu.Unlock()
		}
		return ln, nil
	}
}

func (m *Network) NewLocalTCPListener() net.Listener {
	ln, err := m.Listen("tcp", "127.0.0.1:0")
	if err != nil {
		panic(fmt.Sprintf("memNetwork: failed to create local TCP listener: %v", err))
	}
	return ln
}

func (m *Network) Dial(ctx context.Context, network, address string) (net.Conn, error) {
	if network != "tcp" && network != "tcp4" && network != "tcp6" {
		return nil, fmt.Errorf("memNetwork: Dial called with unsupported network %q", network)
	}
	m.mu.Lock()
	ln, ok := m.lns[address]
	m.mu.Unlock()
	if !ok {
		return nil, fmt.Errorf("memNetwork: Dial called on unknown address %q", address)
	}
	return ln.Dial(ctx, network, address)
}