summaryrefslogtreecommitdiffhomepage
path: root/cmd/tailproxy/proxy.go
blob: 1bcaa785970a2e4a0024c29ce8bafb3e9bdd94e7 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
package main

import (
	"context"
	"flag"
	"fmt"
	"log"
	"net"
	"slices"
	"strings"
	"sync"

	"github.com/inetaf/tcpproxy"
	"tailscale.com/client/tailscale"
	"tailscale.com/net/netutil"
	"tailscale.com/tailcfg"
	"tailscale.com/tsnet"
	"tailscale.com/types/logger"
	"tailscale.com/util/dnsname"
	"tailscale.com/util/must"
)

type proxyGrantRule struct {
	AllowedHosts []dnsname.FQDN
}

func handleConn(ctx context.Context, c net.Conn, lc *tailscale.LocalClient, dialCtx func(context.Context, string, string) (net.Conn, error)) {
	addrPortStr := c.LocalAddr().String()
	_, port, err := net.SplitHostPort(addrPortStr)
	if err != nil {
		log.Printf("tcpSNIHandler.Handle: bogus addrPort %q", addrPortStr)
		c.Close()
		return
	}
	who, err := lc.WhoIs(ctx, c.RemoteAddr().String())
	if err != nil {
		c.Close()
		log.Printf("tcpSNIHandler.Handle: WhoIs: %v", err)
		return
	}
	rules, err := tailcfg.UnmarshalCapJSON[proxyGrantRule](who.CapMap, "maisem.com/tailproxy")
	if err != nil {
		c.Close()
		log.Printf("tcpSNIHandler.Handle: UnmarshalCapJSON: %v", err)
		return
	}

	var p tcpproxy.Proxy
	p.ListenFunc = func(net, laddr string) (net.Listener, error) {
		return netutil.NewOneConnListener(c, nil), nil
	}
	p.AddSNIRouteFunc(addrPortStr, func(ctx context.Context, sniName string) (t tcpproxy.Target, ok bool) {
		sniFQDN, err := dnsname.ToFQDN(sniName)
		if err != nil {
			log.Printf("tcpSNIHandler.Handle: ToFQDN: %v", err)
			return nil, false
		}
		for _, rule := range rules {
			if slices.ContainsFunc(rule.AllowedHosts, func(fqdn dnsname.FQDN) bool {
				return fqdn == "*" || fqdn.Contains(sniFQDN)
			}) {
				log.Printf("tcpSNIHandler.Handle: %s is allowed", sniName)
				return &tcpproxy.DialProxy{
					Addr:        net.JoinHostPort(sniName, port),
					DialContext: dialCtx,
				}, true
			}
		}
		log.Printf("tcpSNIHandler.Handle: %s is not allowed", sniName)
		return nil, false
	})
	p.Start()
}

func main() {
	var (
		ports    = flag.String("ports", "443", "comma-separated list of ports to proxy")
		hostname = flag.String("hostname", "", "Hostname to register the service under")
	)
	flag.Parse()

	ctx := context.Background()
	s := &tsnet.Server{
		Hostname: *hostname,
		Logf:     logger.Discard,
	}
	must.Get(s.Up(ctx))
	var wg sync.WaitGroup
	log.Printf("Listening on ports: %s", *ports)
	for _, port := range strings.Split(*ports, ",") {
		wg.Add(1)
		ln := must.Get(s.Listen("tcp", ":"+port))
		lc := must.Get(s.LocalClient())
		go func() {
			defer wg.Done()
			for {
				c, err := ln.Accept()
				if err != nil {
					continue
				}
				fmt.Println("Accepted connection")
				go handleConn(ctx, c, lc, s.Dial)
			}
		}()
	}
	wg.Wait()
}