diff options
| author | Nick Khyl <nickk@tailscale.com> | 2024-12-05 13:16:48 -0600 |
|---|---|---|
| committer | Nick Khyl <nickk@tailscale.com> | 2024-12-05 13:16:48 -0600 |
| commit | 0267fe83b200f1702a2fa0a395442c02a053fadb (patch) | |
| tree | 63654c55225eeb834de59a5a0bc8d19033c6145b /cmd/sniproxy | |
| parent | 87546a5edf6b6503a87eeb2d666baba57398a066 (diff) | |
| download | tailscale-1.78.0.tar.xz tailscale-1.78.0.zip | |
VERSION.txt: this is v1.78.0v1.78.0
Signed-off-by: Nick Khyl <nickk@tailscale.com>
Diffstat (limited to 'cmd/sniproxy')
| -rw-r--r-- | cmd/sniproxy/.gitignore | 2 | ||||
| -rw-r--r-- | cmd/sniproxy/handlers_test.go | 318 | ||||
| -rw-r--r-- | cmd/sniproxy/server.go | 654 | ||||
| -rw-r--r-- | cmd/sniproxy/server_test.go | 190 | ||||
| -rw-r--r-- | cmd/sniproxy/sniproxy.go | 582 |
5 files changed, 873 insertions, 873 deletions
diff --git a/cmd/sniproxy/.gitignore b/cmd/sniproxy/.gitignore index b1399c881..0bca33912 100644 --- a/cmd/sniproxy/.gitignore +++ b/cmd/sniproxy/.gitignore @@ -1 +1 @@ -sniproxy +sniproxy
diff --git a/cmd/sniproxy/handlers_test.go b/cmd/sniproxy/handlers_test.go index 4f9fc6a34..8ec5b097c 100644 --- a/cmd/sniproxy/handlers_test.go +++ b/cmd/sniproxy/handlers_test.go @@ -1,159 +1,159 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -package main - -import ( - "bytes" - "context" - "encoding/hex" - "io" - "net" - "net/netip" - "strings" - "testing" - - "tailscale.com/net/memnet" -) - -func echoConnOnce(conn net.Conn) { - defer conn.Close() - - b := make([]byte, 256) - n, err := conn.Read(b) - if err != nil { - return - } - - if _, err := conn.Write(b[:n]); err != nil { - return - } -} - -func TestTCPRoundRobinHandler(t *testing.T) { - h := tcpRoundRobinHandler{ - To: []string{"yeet.com"}, - DialContext: func(ctx context.Context, network, addr string) (net.Conn, error) { - if network != "tcp" { - t.Errorf("network = %s, want %s", network, "tcp") - } - if addr != "yeet.com:22" { - t.Errorf("addr = %s, want %s", addr, "yeet.com:22") - } - - c, s := memnet.NewConn("outbound", 1024) - go echoConnOnce(s) - return c, nil - }, - } - - cSock, sSock := memnet.NewTCPConn(netip.MustParseAddrPort("10.64.1.2:22"), netip.MustParseAddrPort("10.64.1.2:22"), 1024) - h.Handle(sSock) - - // Test data write and read, the other end will echo back - // a single stanza - want := "hello" - if _, err := io.WriteString(cSock, want); err != nil { - t.Fatal(err) - } - got := make([]byte, len(want)) - if _, err := io.ReadAtLeast(cSock, got, len(got)); err != nil { - t.Fatal(err) - } - if string(got) != want { - t.Errorf("got %q, want %q", got, want) - } - - // The other end closed the socket after the first echo, so - // any following read should error. - io.WriteString(cSock, "deadass heres some data on god fr") - if _, err := io.ReadAtLeast(cSock, got, len(got)); err == nil { - t.Error("read succeeded on closed socket") - } -} - -// Capture of first TCP data segment for a connection to https://pkgs.tailscale.com -const tlsStart = `45000239ff1840004006f9f5c0a801f2 -c726b5efcf9e01bbe803b21394e3b752 -801801f641dc00000101080ade3474f2 -2fb93ee71603010200010001fc030303 -c3acbd19d2624765bb19af4bce03365e -1d197f5bb939cdadeff26b0f8e7a0620 -295b04127b82bae46aac4ff58cffef25 -eba75a4b7a6de729532c411bd9dd0d2c -00203a3a130113021303c02bc02fc02c -c030cca9cca8c013c014009c009d002f -003501000193caca0000000a000a0008 -1a1a001d001700180010000e000c0268 -3208687474702f312e31002b0007062a -2a03040303ff01000100000d00120010 -04030804040105030805050108060601 -000b00020100002300000033002b0029 -1a1a000100001d0020d3c76bef062979 -a812ce935cfb4dbe6b3a84dc5ba9226f -23b0f34af9d1d03b4a001b0003020002 -00120000446900050003026832000000 -170015000012706b67732e7461696c73 -63616c652e636f6d002d000201010005 -00050100000000001700003a3a000100 -0015002d000000000000000000000000 -00000000000000000000000000000000 -00000000000000000000000000000000 -0000290094006f0069e76f2016f963ad -38c8632d1f240cd75e00e25fdef295d4 -7042b26f3a9a543b1c7dc74939d77803 -20527d423ff996997bda2c6383a14f49 -219eeef8a053e90a32228df37ddbe126 -eccf6b085c93890d08341d819aea6111 -0d909f4cd6b071d9ea40618e74588a33 -90d494bbb5c3002120d5a164a16c9724 -c9ef5e540d8d6f007789a7acf9f5f16f -bf6a1907a6782ed02b` - -func fakeSNIHeader() []byte { - b, err := hex.DecodeString(strings.Replace(tlsStart, "\n", "", -1)) - if err != nil { - panic(err) - } - return b[0x34:] // trim IP + TCP header -} - -func TestTCPSNIHandler(t *testing.T) { - h := tcpSNIHandler{ - Allowlist: []string{"pkgs.tailscale.com"}, - DialContext: func(ctx context.Context, network, addr string) (net.Conn, error) { - if network != "tcp" { - t.Errorf("network = %s, want %s", network, "tcp") - } - if addr != "pkgs.tailscale.com:443" { - t.Errorf("addr = %s, want %s", addr, "pkgs.tailscale.com:443") - } - - c, s := memnet.NewConn("outbound", 1024) - go echoConnOnce(s) - return c, nil - }, - } - - cSock, sSock := memnet.NewTCPConn(netip.MustParseAddrPort("10.64.1.2:22"), netip.MustParseAddrPort("10.64.1.2:443"), 1024) - h.Handle(sSock) - - // Fake a TLS handshake record with an SNI in it. - if _, err := cSock.Write(fakeSNIHeader()); err != nil { - t.Fatal(err) - } - - // Test read, the other end will echo back - // a single stanza, which is at least the beginning of the SNI header. - want := fakeSNIHeader()[:5] - if _, err := cSock.Write(want); err != nil { - t.Fatal(err) - } - got := make([]byte, len(want)) - if _, err := io.ReadAtLeast(cSock, got, len(got)); err != nil { - t.Fatal(err) - } - if !bytes.Equal(got, want) { - t.Errorf("got %q, want %q", got, want) - } -} +// Copyright (c) Tailscale Inc & AUTHORS
+// SPDX-License-Identifier: BSD-3-Clause
+
+package main
+
+import (
+ "bytes"
+ "context"
+ "encoding/hex"
+ "io"
+ "net"
+ "net/netip"
+ "strings"
+ "testing"
+
+ "tailscale.com/net/memnet"
+)
+
+func echoConnOnce(conn net.Conn) {
+ defer conn.Close()
+
+ b := make([]byte, 256)
+ n, err := conn.Read(b)
+ if err != nil {
+ return
+ }
+
+ if _, err := conn.Write(b[:n]); err != nil {
+ return
+ }
+}
+
+func TestTCPRoundRobinHandler(t *testing.T) {
+ h := tcpRoundRobinHandler{
+ To: []string{"yeet.com"},
+ DialContext: func(ctx context.Context, network, addr string) (net.Conn, error) {
+ if network != "tcp" {
+ t.Errorf("network = %s, want %s", network, "tcp")
+ }
+ if addr != "yeet.com:22" {
+ t.Errorf("addr = %s, want %s", addr, "yeet.com:22")
+ }
+
+ c, s := memnet.NewConn("outbound", 1024)
+ go echoConnOnce(s)
+ return c, nil
+ },
+ }
+
+ cSock, sSock := memnet.NewTCPConn(netip.MustParseAddrPort("10.64.1.2:22"), netip.MustParseAddrPort("10.64.1.2:22"), 1024)
+ h.Handle(sSock)
+
+ // Test data write and read, the other end will echo back
+ // a single stanza
+ want := "hello"
+ if _, err := io.WriteString(cSock, want); err != nil {
+ t.Fatal(err)
+ }
+ got := make([]byte, len(want))
+ if _, err := io.ReadAtLeast(cSock, got, len(got)); err != nil {
+ t.Fatal(err)
+ }
+ if string(got) != want {
+ t.Errorf("got %q, want %q", got, want)
+ }
+
+ // The other end closed the socket after the first echo, so
+ // any following read should error.
+ io.WriteString(cSock, "deadass heres some data on god fr")
+ if _, err := io.ReadAtLeast(cSock, got, len(got)); err == nil {
+ t.Error("read succeeded on closed socket")
+ }
+}
+
+// Capture of first TCP data segment for a connection to https://pkgs.tailscale.com
+const tlsStart = `45000239ff1840004006f9f5c0a801f2
+c726b5efcf9e01bbe803b21394e3b752
+801801f641dc00000101080ade3474f2
+2fb93ee71603010200010001fc030303
+c3acbd19d2624765bb19af4bce03365e
+1d197f5bb939cdadeff26b0f8e7a0620
+295b04127b82bae46aac4ff58cffef25
+eba75a4b7a6de729532c411bd9dd0d2c
+00203a3a130113021303c02bc02fc02c
+c030cca9cca8c013c014009c009d002f
+003501000193caca0000000a000a0008
+1a1a001d001700180010000e000c0268
+3208687474702f312e31002b0007062a
+2a03040303ff01000100000d00120010
+04030804040105030805050108060601
+000b00020100002300000033002b0029
+1a1a000100001d0020d3c76bef062979
+a812ce935cfb4dbe6b3a84dc5ba9226f
+23b0f34af9d1d03b4a001b0003020002
+00120000446900050003026832000000
+170015000012706b67732e7461696c73
+63616c652e636f6d002d000201010005
+00050100000000001700003a3a000100
+0015002d000000000000000000000000
+00000000000000000000000000000000
+00000000000000000000000000000000
+0000290094006f0069e76f2016f963ad
+38c8632d1f240cd75e00e25fdef295d4
+7042b26f3a9a543b1c7dc74939d77803
+20527d423ff996997bda2c6383a14f49
+219eeef8a053e90a32228df37ddbe126
+eccf6b085c93890d08341d819aea6111
+0d909f4cd6b071d9ea40618e74588a33
+90d494bbb5c3002120d5a164a16c9724
+c9ef5e540d8d6f007789a7acf9f5f16f
+bf6a1907a6782ed02b`
+
+func fakeSNIHeader() []byte {
+ b, err := hex.DecodeString(strings.Replace(tlsStart, "\n", "", -1))
+ if err != nil {
+ panic(err)
+ }
+ return b[0x34:] // trim IP + TCP header
+}
+
+func TestTCPSNIHandler(t *testing.T) {
+ h := tcpSNIHandler{
+ Allowlist: []string{"pkgs.tailscale.com"},
+ DialContext: func(ctx context.Context, network, addr string) (net.Conn, error) {
+ if network != "tcp" {
+ t.Errorf("network = %s, want %s", network, "tcp")
+ }
+ if addr != "pkgs.tailscale.com:443" {
+ t.Errorf("addr = %s, want %s", addr, "pkgs.tailscale.com:443")
+ }
+
+ c, s := memnet.NewConn("outbound", 1024)
+ go echoConnOnce(s)
+ return c, nil
+ },
+ }
+
+ cSock, sSock := memnet.NewTCPConn(netip.MustParseAddrPort("10.64.1.2:22"), netip.MustParseAddrPort("10.64.1.2:443"), 1024)
+ h.Handle(sSock)
+
+ // Fake a TLS handshake record with an SNI in it.
+ if _, err := cSock.Write(fakeSNIHeader()); err != nil {
+ t.Fatal(err)
+ }
+
+ // Test read, the other end will echo back
+ // a single stanza, which is at least the beginning of the SNI header.
+ want := fakeSNIHeader()[:5]
+ if _, err := cSock.Write(want); err != nil {
+ t.Fatal(err)
+ }
+ got := make([]byte, len(want))
+ if _, err := io.ReadAtLeast(cSock, got, len(got)); err != nil {
+ t.Fatal(err)
+ }
+ if !bytes.Equal(got, want) {
+ t.Errorf("got %q, want %q", got, want)
+ }
+}
diff --git a/cmd/sniproxy/server.go b/cmd/sniproxy/server.go index b322b6f4b..c89420661 100644 --- a/cmd/sniproxy/server.go +++ b/cmd/sniproxy/server.go @@ -1,327 +1,327 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -package main - -import ( - "expvar" - "log" - "net" - "net/netip" - "sync" - "time" - - "golang.org/x/net/dns/dnsmessage" - "tailscale.com/metrics" - "tailscale.com/tailcfg" - "tailscale.com/types/appctype" - "tailscale.com/types/ipproto" - "tailscale.com/types/nettype" - "tailscale.com/util/clientmetric" - "tailscale.com/util/mak" -) - -var tsMBox = dnsmessage.MustNewName("support.tailscale.com.") - -// target describes the predicates which route some inbound -// traffic to the app connector to a specific handler. -type target struct { - Dest netip.Prefix - Matching tailcfg.ProtoPortRange -} - -// Server implements an App Connector as expressed in sniproxy. -type Server struct { - mu sync.RWMutex // mu guards following fields - connectors map[appctype.ConfigID]connector -} - -type appcMetrics struct { - dnsResponses expvar.Int - dnsFailures expvar.Int - tcpConns expvar.Int - sniConns expvar.Int - unhandledConns expvar.Int -} - -var getMetrics = sync.OnceValue[*appcMetrics](func() *appcMetrics { - m := appcMetrics{} - - stats := new(metrics.Set) - stats.Set("tls_sessions", &m.sniConns) - clientmetric.NewCounterFunc("sniproxy_tls_sessions", m.sniConns.Value) - stats.Set("tcp_sessions", &m.tcpConns) - clientmetric.NewCounterFunc("sniproxy_tcp_sessions", m.tcpConns.Value) - stats.Set("dns_responses", &m.dnsResponses) - clientmetric.NewCounterFunc("sniproxy_dns_responses", m.dnsResponses.Value) - stats.Set("dns_failed", &m.dnsFailures) - clientmetric.NewCounterFunc("sniproxy_dns_failed", m.dnsFailures.Value) - expvar.Publish("sniproxy", stats) - - return &m -}) - -// Configure applies the provided configuration to the app connector. -func (s *Server) Configure(cfg *appctype.AppConnectorConfig) { - s.mu.Lock() - defer s.mu.Unlock() - s.connectors = makeConnectorsFromConfig(cfg) - log.Printf("installed app connector config: %+v", s.connectors) -} - -// HandleTCPFlow implements tsnet.FallbackTCPHandler. -func (s *Server) HandleTCPFlow(src, dst netip.AddrPort) (handler func(net.Conn), intercept bool) { - m := getMetrics() - s.mu.RLock() - defer s.mu.RUnlock() - - for _, c := range s.connectors { - if handler, intercept := c.handleTCPFlow(src, dst, m); intercept { - return handler, intercept - } - } - - return nil, false -} - -// HandleDNS handles a DNS request to the app connector. -func (s *Server) HandleDNS(c nettype.ConnPacketConn) { - defer c.Close() - c.SetReadDeadline(time.Now().Add(5 * time.Second)) - m := getMetrics() - - buf := make([]byte, 1500) - n, err := c.Read(buf) - if err != nil { - log.Printf("HandleDNS: read failed: %v\n ", err) - m.dnsFailures.Add(1) - return - } - - addrPortStr := c.LocalAddr().String() - host, _, err := net.SplitHostPort(addrPortStr) - if err != nil { - log.Printf("HandleDNS: bogus addrPort %q", addrPortStr) - m.dnsFailures.Add(1) - return - } - localAddr, err := netip.ParseAddr(host) - if err != nil { - log.Printf("HandleDNS: bogus local address %q", host) - m.dnsFailures.Add(1) - return - } - - var msg dnsmessage.Message - err = msg.Unpack(buf[:n]) - if err != nil { - log.Printf("HandleDNS: dnsmessage unpack failed: %v\n ", err) - m.dnsFailures.Add(1) - return - } - - s.mu.RLock() - defer s.mu.RUnlock() - for _, connector := range s.connectors { - resp, err := connector.handleDNS(&msg, localAddr) - if err != nil { - log.Printf("HandleDNS: connector handling failed: %v\n", err) - m.dnsFailures.Add(1) - return - } - if len(resp) > 0 { - // This connector handled the DNS request - _, err = c.Write(resp) - if err != nil { - log.Printf("HandleDNS: write failed: %v\n", err) - m.dnsFailures.Add(1) - return - } - - m.dnsResponses.Add(1) - return - } - } -} - -// connector describes a logical collection of -// services which need to be proxied. -type connector struct { - Handlers map[target]handler -} - -// handleTCPFlow implements tsnet.FallbackTCPHandler. -func (c *connector) handleTCPFlow(src, dst netip.AddrPort, m *appcMetrics) (handler func(net.Conn), intercept bool) { - for t, h := range c.Handlers { - if t.Matching.Proto != 0 && t.Matching.Proto != int(ipproto.TCP) { - continue - } - if !t.Dest.Contains(dst.Addr()) { - continue - } - if !t.Matching.Ports.Contains(dst.Port()) { - continue - } - - switch h.(type) { - case *tcpSNIHandler: - m.sniConns.Add(1) - case *tcpRoundRobinHandler: - m.tcpConns.Add(1) - default: - log.Printf("handleTCPFlow: unhandled handler type %T", h) - } - - return h.Handle, true - } - - m.unhandledConns.Add(1) - return nil, false -} - -// handleDNS returns the DNS response to the given query. If this -// connector is unable to handle the request, nil is returned. -func (c *connector) handleDNS(req *dnsmessage.Message, localAddr netip.Addr) (response []byte, err error) { - for t, h := range c.Handlers { - if t.Dest.Contains(localAddr) { - return makeDNSResponse(req, h.ReachableOn()) - } - } - - // Did not match, signal 'not handled' to caller - return nil, nil -} - -func makeDNSResponse(req *dnsmessage.Message, reachableIPs []netip.Addr) (response []byte, err error) { - resp := dnsmessage.NewBuilder(response, - dnsmessage.Header{ - ID: req.Header.ID, - Response: true, - Authoritative: true, - }) - resp.EnableCompression() - - if len(req.Questions) == 0 { - response, _ = resp.Finish() - return response, nil - } - q := req.Questions[0] - err = resp.StartQuestions() - if err != nil { - return - } - resp.Question(q) - - err = resp.StartAnswers() - if err != nil { - return - } - - switch q.Type { - case dnsmessage.TypeAAAA: - for _, ip := range reachableIPs { - if ip.Is6() { - err = resp.AAAAResource( - dnsmessage.ResourceHeader{Name: q.Name, Class: q.Class, TTL: 120}, - dnsmessage.AAAAResource{AAAA: ip.As16()}, - ) - } - } - - case dnsmessage.TypeA: - for _, ip := range reachableIPs { - if ip.Is4() { - err = resp.AResource( - dnsmessage.ResourceHeader{Name: q.Name, Class: q.Class, TTL: 120}, - dnsmessage.AResource{A: ip.As4()}, - ) - } - } - - case dnsmessage.TypeSOA: - err = resp.SOAResource( - dnsmessage.ResourceHeader{Name: q.Name, Class: q.Class, TTL: 120}, - dnsmessage.SOAResource{NS: q.Name, MBox: tsMBox, Serial: 2023030600, - Refresh: 120, Retry: 120, Expire: 120, MinTTL: 60}, - ) - case dnsmessage.TypeNS: - err = resp.NSResource( - dnsmessage.ResourceHeader{Name: q.Name, Class: q.Class, TTL: 120}, - dnsmessage.NSResource{NS: tsMBox}, - ) - } - - if err != nil { - return nil, err - } - return resp.Finish() -} - -type handler interface { - // Handle handles the given socket. - Handle(c net.Conn) - - // ReachableOn returns the IP addresses this handler is reachable on. - ReachableOn() []netip.Addr -} - -func installDNATHandler(d *appctype.DNATConfig, out *connector) { - // These handlers don't actually do DNAT, they just - // proxy the data over the connection. - var dialer net.Dialer - dialer.Timeout = 5 * time.Second - h := tcpRoundRobinHandler{ - To: d.To, - DialContext: dialer.DialContext, - ReachableIPs: d.Addrs, - } - - for _, addr := range d.Addrs { - for _, protoPort := range d.IP { - t := target{ - Dest: netip.PrefixFrom(addr, addr.BitLen()), - Matching: protoPort, - } - - mak.Set(&out.Handlers, t, handler(&h)) - } - } -} - -func installSNIHandler(c *appctype.SNIProxyConfig, out *connector) { - var dialer net.Dialer - dialer.Timeout = 5 * time.Second - h := tcpSNIHandler{ - Allowlist: c.AllowedDomains, - DialContext: dialer.DialContext, - ReachableIPs: c.Addrs, - } - - for _, addr := range c.Addrs { - for _, protoPort := range c.IP { - t := target{ - Dest: netip.PrefixFrom(addr, addr.BitLen()), - Matching: protoPort, - } - - mak.Set(&out.Handlers, t, handler(&h)) - } - } -} - -func makeConnectorsFromConfig(cfg *appctype.AppConnectorConfig) map[appctype.ConfigID]connector { - var connectors map[appctype.ConfigID]connector - - for cID, d := range cfg.DNAT { - c := connectors[cID] - installDNATHandler(&d, &c) - mak.Set(&connectors, cID, c) - } - for cID, d := range cfg.SNIProxy { - c := connectors[cID] - installSNIHandler(&d, &c) - mak.Set(&connectors, cID, c) - } - - return connectors -} +// Copyright (c) Tailscale Inc & AUTHORS
+// SPDX-License-Identifier: BSD-3-Clause
+
+package main
+
+import (
+ "expvar"
+ "log"
+ "net"
+ "net/netip"
+ "sync"
+ "time"
+
+ "golang.org/x/net/dns/dnsmessage"
+ "tailscale.com/metrics"
+ "tailscale.com/tailcfg"
+ "tailscale.com/types/appctype"
+ "tailscale.com/types/ipproto"
+ "tailscale.com/types/nettype"
+ "tailscale.com/util/clientmetric"
+ "tailscale.com/util/mak"
+)
+
+var tsMBox = dnsmessage.MustNewName("support.tailscale.com.")
+
+// target describes the predicates which route some inbound
+// traffic to the app connector to a specific handler.
+type target struct {
+ Dest netip.Prefix
+ Matching tailcfg.ProtoPortRange
+}
+
+// Server implements an App Connector as expressed in sniproxy.
+type Server struct {
+ mu sync.RWMutex // mu guards following fields
+ connectors map[appctype.ConfigID]connector
+}
+
+type appcMetrics struct {
+ dnsResponses expvar.Int
+ dnsFailures expvar.Int
+ tcpConns expvar.Int
+ sniConns expvar.Int
+ unhandledConns expvar.Int
+}
+
+var getMetrics = sync.OnceValue[*appcMetrics](func() *appcMetrics {
+ m := appcMetrics{}
+
+ stats := new(metrics.Set)
+ stats.Set("tls_sessions", &m.sniConns)
+ clientmetric.NewCounterFunc("sniproxy_tls_sessions", m.sniConns.Value)
+ stats.Set("tcp_sessions", &m.tcpConns)
+ clientmetric.NewCounterFunc("sniproxy_tcp_sessions", m.tcpConns.Value)
+ stats.Set("dns_responses", &m.dnsResponses)
+ clientmetric.NewCounterFunc("sniproxy_dns_responses", m.dnsResponses.Value)
+ stats.Set("dns_failed", &m.dnsFailures)
+ clientmetric.NewCounterFunc("sniproxy_dns_failed", m.dnsFailures.Value)
+ expvar.Publish("sniproxy", stats)
+
+ return &m
+})
+
+// Configure applies the provided configuration to the app connector.
+func (s *Server) Configure(cfg *appctype.AppConnectorConfig) {
+ s.mu.Lock()
+ defer s.mu.Unlock()
+ s.connectors = makeConnectorsFromConfig(cfg)
+ log.Printf("installed app connector config: %+v", s.connectors)
+}
+
+// HandleTCPFlow implements tsnet.FallbackTCPHandler.
+func (s *Server) HandleTCPFlow(src, dst netip.AddrPort) (handler func(net.Conn), intercept bool) {
+ m := getMetrics()
+ s.mu.RLock()
+ defer s.mu.RUnlock()
+
+ for _, c := range s.connectors {
+ if handler, intercept := c.handleTCPFlow(src, dst, m); intercept {
+ return handler, intercept
+ }
+ }
+
+ return nil, false
+}
+
+// HandleDNS handles a DNS request to the app connector.
+func (s *Server) HandleDNS(c nettype.ConnPacketConn) {
+ defer c.Close()
+ c.SetReadDeadline(time.Now().Add(5 * time.Second))
+ m := getMetrics()
+
+ buf := make([]byte, 1500)
+ n, err := c.Read(buf)
+ if err != nil {
+ log.Printf("HandleDNS: read failed: %v\n ", err)
+ m.dnsFailures.Add(1)
+ return
+ }
+
+ addrPortStr := c.LocalAddr().String()
+ host, _, err := net.SplitHostPort(addrPortStr)
+ if err != nil {
+ log.Printf("HandleDNS: bogus addrPort %q", addrPortStr)
+ m.dnsFailures.Add(1)
+ return
+ }
+ localAddr, err := netip.ParseAddr(host)
+ if err != nil {
+ log.Printf("HandleDNS: bogus local address %q", host)
+ m.dnsFailures.Add(1)
+ return
+ }
+
+ var msg dnsmessage.Message
+ err = msg.Unpack(buf[:n])
+ if err != nil {
+ log.Printf("HandleDNS: dnsmessage unpack failed: %v\n ", err)
+ m.dnsFailures.Add(1)
+ return
+ }
+
+ s.mu.RLock()
+ defer s.mu.RUnlock()
+ for _, connector := range s.connectors {
+ resp, err := connector.handleDNS(&msg, localAddr)
+ if err != nil {
+ log.Printf("HandleDNS: connector handling failed: %v\n", err)
+ m.dnsFailures.Add(1)
+ return
+ }
+ if len(resp) > 0 {
+ // This connector handled the DNS request
+ _, err = c.Write(resp)
+ if err != nil {
+ log.Printf("HandleDNS: write failed: %v\n", err)
+ m.dnsFailures.Add(1)
+ return
+ }
+
+ m.dnsResponses.Add(1)
+ return
+ }
+ }
+}
+
+// connector describes a logical collection of
+// services which need to be proxied.
+type connector struct {
+ Handlers map[target]handler
+}
+
+// handleTCPFlow implements tsnet.FallbackTCPHandler.
+func (c *connector) handleTCPFlow(src, dst netip.AddrPort, m *appcMetrics) (handler func(net.Conn), intercept bool) {
+ for t, h := range c.Handlers {
+ if t.Matching.Proto != 0 && t.Matching.Proto != int(ipproto.TCP) {
+ continue
+ }
+ if !t.Dest.Contains(dst.Addr()) {
+ continue
+ }
+ if !t.Matching.Ports.Contains(dst.Port()) {
+ continue
+ }
+
+ switch h.(type) {
+ case *tcpSNIHandler:
+ m.sniConns.Add(1)
+ case *tcpRoundRobinHandler:
+ m.tcpConns.Add(1)
+ default:
+ log.Printf("handleTCPFlow: unhandled handler type %T", h)
+ }
+
+ return h.Handle, true
+ }
+
+ m.unhandledConns.Add(1)
+ return nil, false
+}
+
+// handleDNS returns the DNS response to the given query. If this
+// connector is unable to handle the request, nil is returned.
+func (c *connector) handleDNS(req *dnsmessage.Message, localAddr netip.Addr) (response []byte, err error) {
+ for t, h := range c.Handlers {
+ if t.Dest.Contains(localAddr) {
+ return makeDNSResponse(req, h.ReachableOn())
+ }
+ }
+
+ // Did not match, signal 'not handled' to caller
+ return nil, nil
+}
+
+func makeDNSResponse(req *dnsmessage.Message, reachableIPs []netip.Addr) (response []byte, err error) {
+ resp := dnsmessage.NewBuilder(response,
+ dnsmessage.Header{
+ ID: req.Header.ID,
+ Response: true,
+ Authoritative: true,
+ })
+ resp.EnableCompression()
+
+ if len(req.Questions) == 0 {
+ response, _ = resp.Finish()
+ return response, nil
+ }
+ q := req.Questions[0]
+ err = resp.StartQuestions()
+ if err != nil {
+ return
+ }
+ resp.Question(q)
+
+ err = resp.StartAnswers()
+ if err != nil {
+ return
+ }
+
+ switch q.Type {
+ case dnsmessage.TypeAAAA:
+ for _, ip := range reachableIPs {
+ if ip.Is6() {
+ err = resp.AAAAResource(
+ dnsmessage.ResourceHeader{Name: q.Name, Class: q.Class, TTL: 120},
+ dnsmessage.AAAAResource{AAAA: ip.As16()},
+ )
+ }
+ }
+
+ case dnsmessage.TypeA:
+ for _, ip := range reachableIPs {
+ if ip.Is4() {
+ err = resp.AResource(
+ dnsmessage.ResourceHeader{Name: q.Name, Class: q.Class, TTL: 120},
+ dnsmessage.AResource{A: ip.As4()},
+ )
+ }
+ }
+
+ case dnsmessage.TypeSOA:
+ err = resp.SOAResource(
+ dnsmessage.ResourceHeader{Name: q.Name, Class: q.Class, TTL: 120},
+ dnsmessage.SOAResource{NS: q.Name, MBox: tsMBox, Serial: 2023030600,
+ Refresh: 120, Retry: 120, Expire: 120, MinTTL: 60},
+ )
+ case dnsmessage.TypeNS:
+ err = resp.NSResource(
+ dnsmessage.ResourceHeader{Name: q.Name, Class: q.Class, TTL: 120},
+ dnsmessage.NSResource{NS: tsMBox},
+ )
+ }
+
+ if err != nil {
+ return nil, err
+ }
+ return resp.Finish()
+}
+
+type handler interface {
+ // Handle handles the given socket.
+ Handle(c net.Conn)
+
+ // ReachableOn returns the IP addresses this handler is reachable on.
+ ReachableOn() []netip.Addr
+}
+
+func installDNATHandler(d *appctype.DNATConfig, out *connector) {
+ // These handlers don't actually do DNAT, they just
+ // proxy the data over the connection.
+ var dialer net.Dialer
+ dialer.Timeout = 5 * time.Second
+ h := tcpRoundRobinHandler{
+ To: d.To,
+ DialContext: dialer.DialContext,
+ ReachableIPs: d.Addrs,
+ }
+
+ for _, addr := range d.Addrs {
+ for _, protoPort := range d.IP {
+ t := target{
+ Dest: netip.PrefixFrom(addr, addr.BitLen()),
+ Matching: protoPort,
+ }
+
+ mak.Set(&out.Handlers, t, handler(&h))
+ }
+ }
+}
+
+func installSNIHandler(c *appctype.SNIProxyConfig, out *connector) {
+ var dialer net.Dialer
+ dialer.Timeout = 5 * time.Second
+ h := tcpSNIHandler{
+ Allowlist: c.AllowedDomains,
+ DialContext: dialer.DialContext,
+ ReachableIPs: c.Addrs,
+ }
+
+ for _, addr := range c.Addrs {
+ for _, protoPort := range c.IP {
+ t := target{
+ Dest: netip.PrefixFrom(addr, addr.BitLen()),
+ Matching: protoPort,
+ }
+
+ mak.Set(&out.Handlers, t, handler(&h))
+ }
+ }
+}
+
+func makeConnectorsFromConfig(cfg *appctype.AppConnectorConfig) map[appctype.ConfigID]connector {
+ var connectors map[appctype.ConfigID]connector
+
+ for cID, d := range cfg.DNAT {
+ c := connectors[cID]
+ installDNATHandler(&d, &c)
+ mak.Set(&connectors, cID, c)
+ }
+ for cID, d := range cfg.SNIProxy {
+ c := connectors[cID]
+ installSNIHandler(&d, &c)
+ mak.Set(&connectors, cID, c)
+ }
+
+ return connectors
+}
diff --git a/cmd/sniproxy/server_test.go b/cmd/sniproxy/server_test.go index d56f2aa75..2a51c874c 100644 --- a/cmd/sniproxy/server_test.go +++ b/cmd/sniproxy/server_test.go @@ -1,95 +1,95 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -package main - -import ( - "net/netip" - "testing" - - "github.com/google/go-cmp/cmp" - "github.com/google/go-cmp/cmp/cmpopts" - "tailscale.com/tailcfg" - "tailscale.com/types/appctype" -) - -func TestMakeConnectorsFromConfig(t *testing.T) { - tcs := []struct { - name string - input *appctype.AppConnectorConfig - want map[appctype.ConfigID]connector - }{ - { - "empty", - &appctype.AppConnectorConfig{}, - nil, - }, - { - "DNAT", - &appctype.AppConnectorConfig{ - DNAT: map[appctype.ConfigID]appctype.DNATConfig{ - "swiggity_swooty": { - Addrs: []netip.Addr{netip.MustParseAddr("100.64.0.1"), netip.MustParseAddr("fd7a:115c:a1e0::1")}, - To: []string{"example.org"}, - IP: []tailcfg.ProtoPortRange{{Proto: 0, Ports: tailcfg.PortRange{First: 0, Last: 65535}}}, - }, - }, - }, - map[appctype.ConfigID]connector{ - "swiggity_swooty": { - Handlers: map[target]handler{ - { - Dest: netip.MustParsePrefix("100.64.0.1/32"), - Matching: tailcfg.ProtoPortRange{Proto: 0, Ports: tailcfg.PortRange{First: 0, Last: 65535}}, - }: &tcpRoundRobinHandler{To: []string{"example.org"}, ReachableIPs: []netip.Addr{netip.MustParseAddr("100.64.0.1"), netip.MustParseAddr("fd7a:115c:a1e0::1")}}, - { - Dest: netip.MustParsePrefix("fd7a:115c:a1e0::1/128"), - Matching: tailcfg.ProtoPortRange{Proto: 0, Ports: tailcfg.PortRange{First: 0, Last: 65535}}, - }: &tcpRoundRobinHandler{To: []string{"example.org"}, ReachableIPs: []netip.Addr{netip.MustParseAddr("100.64.0.1"), netip.MustParseAddr("fd7a:115c:a1e0::1")}}, - }, - }, - }, - }, - { - "SNIProxy", - &appctype.AppConnectorConfig{ - SNIProxy: map[appctype.ConfigID]appctype.SNIProxyConfig{ - "swiggity_swooty": { - Addrs: []netip.Addr{netip.MustParseAddr("100.64.0.1"), netip.MustParseAddr("fd7a:115c:a1e0::1")}, - AllowedDomains: []string{"example.org"}, - IP: []tailcfg.ProtoPortRange{{Proto: 0, Ports: tailcfg.PortRange{First: 0, Last: 65535}}}, - }, - }, - }, - map[appctype.ConfigID]connector{ - "swiggity_swooty": { - Handlers: map[target]handler{ - { - Dest: netip.MustParsePrefix("100.64.0.1/32"), - Matching: tailcfg.ProtoPortRange{Proto: 0, Ports: tailcfg.PortRange{First: 0, Last: 65535}}, - }: &tcpSNIHandler{Allowlist: []string{"example.org"}, ReachableIPs: []netip.Addr{netip.MustParseAddr("100.64.0.1"), netip.MustParseAddr("fd7a:115c:a1e0::1")}}, - { - Dest: netip.MustParsePrefix("fd7a:115c:a1e0::1/128"), - Matching: tailcfg.ProtoPortRange{Proto: 0, Ports: tailcfg.PortRange{First: 0, Last: 65535}}, - }: &tcpSNIHandler{Allowlist: []string{"example.org"}, ReachableIPs: []netip.Addr{netip.MustParseAddr("100.64.0.1"), netip.MustParseAddr("fd7a:115c:a1e0::1")}}, - }, - }, - }, - }, - } - - for _, tc := range tcs { - t.Run(tc.name, func(t *testing.T) { - connectors := makeConnectorsFromConfig(tc.input) - - if diff := cmp.Diff(connectors, tc.want, - cmpopts.IgnoreFields(tcpRoundRobinHandler{}, "DialContext"), - cmpopts.IgnoreFields(tcpSNIHandler{}, "DialContext"), - cmp.Comparer(func(x, y netip.Addr) bool { - return x == y - })); diff != "" { - t.Fatalf("mismatch (-want +got):\n%s", diff) - } - }) - } -} +// Copyright (c) Tailscale Inc & AUTHORS
+// SPDX-License-Identifier: BSD-3-Clause
+
+package main
+
+import (
+ "net/netip"
+ "testing"
+
+ "github.com/google/go-cmp/cmp"
+ "github.com/google/go-cmp/cmp/cmpopts"
+ "tailscale.com/tailcfg"
+ "tailscale.com/types/appctype"
+)
+
+func TestMakeConnectorsFromConfig(t *testing.T) {
+ tcs := []struct {
+ name string
+ input *appctype.AppConnectorConfig
+ want map[appctype.ConfigID]connector
+ }{
+ {
+ "empty",
+ &appctype.AppConnectorConfig{},
+ nil,
+ },
+ {
+ "DNAT",
+ &appctype.AppConnectorConfig{
+ DNAT: map[appctype.ConfigID]appctype.DNATConfig{
+ "swiggity_swooty": {
+ Addrs: []netip.Addr{netip.MustParseAddr("100.64.0.1"), netip.MustParseAddr("fd7a:115c:a1e0::1")},
+ To: []string{"example.org"},
+ IP: []tailcfg.ProtoPortRange{{Proto: 0, Ports: tailcfg.PortRange{First: 0, Last: 65535}}},
+ },
+ },
+ },
+ map[appctype.ConfigID]connector{
+ "swiggity_swooty": {
+ Handlers: map[target]handler{
+ {
+ Dest: netip.MustParsePrefix("100.64.0.1/32"),
+ Matching: tailcfg.ProtoPortRange{Proto: 0, Ports: tailcfg.PortRange{First: 0, Last: 65535}},
+ }: &tcpRoundRobinHandler{To: []string{"example.org"}, ReachableIPs: []netip.Addr{netip.MustParseAddr("100.64.0.1"), netip.MustParseAddr("fd7a:115c:a1e0::1")}},
+ {
+ Dest: netip.MustParsePrefix("fd7a:115c:a1e0::1/128"),
+ Matching: tailcfg.ProtoPortRange{Proto: 0, Ports: tailcfg.PortRange{First: 0, Last: 65535}},
+ }: &tcpRoundRobinHandler{To: []string{"example.org"}, ReachableIPs: []netip.Addr{netip.MustParseAddr("100.64.0.1"), netip.MustParseAddr("fd7a:115c:a1e0::1")}},
+ },
+ },
+ },
+ },
+ {
+ "SNIProxy",
+ &appctype.AppConnectorConfig{
+ SNIProxy: map[appctype.ConfigID]appctype.SNIProxyConfig{
+ "swiggity_swooty": {
+ Addrs: []netip.Addr{netip.MustParseAddr("100.64.0.1"), netip.MustParseAddr("fd7a:115c:a1e0::1")},
+ AllowedDomains: []string{"example.org"},
+ IP: []tailcfg.ProtoPortRange{{Proto: 0, Ports: tailcfg.PortRange{First: 0, Last: 65535}}},
+ },
+ },
+ },
+ map[appctype.ConfigID]connector{
+ "swiggity_swooty": {
+ Handlers: map[target]handler{
+ {
+ Dest: netip.MustParsePrefix("100.64.0.1/32"),
+ Matching: tailcfg.ProtoPortRange{Proto: 0, Ports: tailcfg.PortRange{First: 0, Last: 65535}},
+ }: &tcpSNIHandler{Allowlist: []string{"example.org"}, ReachableIPs: []netip.Addr{netip.MustParseAddr("100.64.0.1"), netip.MustParseAddr("fd7a:115c:a1e0::1")}},
+ {
+ Dest: netip.MustParsePrefix("fd7a:115c:a1e0::1/128"),
+ Matching: tailcfg.ProtoPortRange{Proto: 0, Ports: tailcfg.PortRange{First: 0, Last: 65535}},
+ }: &tcpSNIHandler{Allowlist: []string{"example.org"}, ReachableIPs: []netip.Addr{netip.MustParseAddr("100.64.0.1"), netip.MustParseAddr("fd7a:115c:a1e0::1")}},
+ },
+ },
+ },
+ },
+ }
+
+ for _, tc := range tcs {
+ t.Run(tc.name, func(t *testing.T) {
+ connectors := makeConnectorsFromConfig(tc.input)
+
+ if diff := cmp.Diff(connectors, tc.want,
+ cmpopts.IgnoreFields(tcpRoundRobinHandler{}, "DialContext"),
+ cmpopts.IgnoreFields(tcpSNIHandler{}, "DialContext"),
+ cmp.Comparer(func(x, y netip.Addr) bool {
+ return x == y
+ })); diff != "" {
+ t.Fatalf("mismatch (-want +got):\n%s", diff)
+ }
+ })
+ }
+}
diff --git a/cmd/sniproxy/sniproxy.go b/cmd/sniproxy/sniproxy.go index fa83aaf4a..c048c8e7e 100644 --- a/cmd/sniproxy/sniproxy.go +++ b/cmd/sniproxy/sniproxy.go @@ -1,291 +1,291 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -// The sniproxy is an outbound SNI proxy. It receives TLS connections over -// Tailscale on one or more TCP ports and sends them out to the same SNI -// hostname & port on the internet. It can optionally forward one or more -// TCP ports to a specific destination. It only does TCP. -package main - -import ( - "context" - "errors" - "flag" - "fmt" - "log" - "net" - "net/http" - "net/netip" - "os" - "sort" - "strconv" - "strings" - - "github.com/peterbourgon/ff/v3" - "tailscale.com/client/tailscale" - "tailscale.com/hostinfo" - "tailscale.com/ipn" - "tailscale.com/tailcfg" - "tailscale.com/tsnet" - "tailscale.com/tsweb" - "tailscale.com/types/appctype" - "tailscale.com/types/ipproto" - "tailscale.com/types/nettype" - "tailscale.com/util/mak" -) - -const configCapKey = "tailscale.com/sniproxy" - -// portForward is the state for a single port forwarding entry, as passed to the --forward flag. -type portForward struct { - Port int - Proto string - Destination string -} - -// parseForward takes a proto/port/destination tuple as an input, as would be passed -// to the --forward command line flag, and returns a *portForward struct of those parameters. -func parseForward(value string) (*portForward, error) { - parts := strings.Split(value, "/") - if len(parts) != 3 { - return nil, errors.New("cannot parse: " + value) - } - - proto := parts[0] - if proto != "tcp" { - return nil, errors.New("unsupported forwarding protocol: " + proto) - } - port, err := strconv.ParseUint(parts[1], 10, 16) - if err != nil { - return nil, errors.New("bad forwarding port: " + parts[1]) - } - host := parts[2] - if host == "" { - return nil, errors.New("bad destination: " + value) - } - - return &portForward{Port: int(port), Proto: proto, Destination: host}, nil -} - -func main() { - // Parse flags - fs := flag.NewFlagSet("sniproxy", flag.ContinueOnError) - var ( - ports = fs.String("ports", "443", "comma-separated list of ports to proxy") - forwards = fs.String("forwards", "", "comma-separated list of ports to transparently forward, protocol/number/destination. For example, --forwards=tcp/22/github.com,tcp/5432/sql.example.com") - wgPort = fs.Int("wg-listen-port", 0, "UDP port to listen on for WireGuard and peer-to-peer traffic; 0 means automatically select") - promoteHTTPS = fs.Bool("promote-https", true, "promote HTTP to HTTPS") - debugPort = fs.Int("debug-port", 8893, "Listening port for debug/metrics endpoint") - hostname = fs.String("hostname", "", "Hostname to register the service under") - ) - err := ff.Parse(fs, os.Args[1:], ff.WithEnvVarPrefix("TS_APPC")) - if err != nil { - log.Fatal("ff.Parse") - } - - var ts tsnet.Server - defer ts.Close() - - ctx, cancel := context.WithCancel(context.Background()) - defer cancel() - run(ctx, &ts, *wgPort, *hostname, *promoteHTTPS, *debugPort, *ports, *forwards) -} - -// run actually runs the sniproxy. Its separate from main() to assist in testing. -func run(ctx context.Context, ts *tsnet.Server, wgPort int, hostname string, promoteHTTPS bool, debugPort int, ports, forwards string) { - // Wire up Tailscale node + app connector server - hostinfo.SetApp("sniproxy") - var s sniproxy - s.ts = ts - - s.ts.Port = uint16(wgPort) - s.ts.Hostname = hostname - - lc, err := s.ts.LocalClient() - if err != nil { - log.Fatalf("LocalClient() failed: %v", err) - } - s.lc = lc - s.ts.RegisterFallbackTCPHandler(s.srv.HandleTCPFlow) - - // Start special-purpose listeners: dns, http promotion, debug server - ln, err := s.ts.Listen("udp", ":53") - if err != nil { - log.Fatalf("failed listening on port 53: %v", err) - } - defer ln.Close() - go s.serveDNS(ln) - if promoteHTTPS { - ln, err := s.ts.Listen("tcp", ":80") - if err != nil { - log.Fatalf("failed listening on port 80: %v", err) - } - defer ln.Close() - log.Printf("Promoting HTTP to HTTPS ...") - go s.promoteHTTPS(ln) - } - if debugPort != 0 { - mux := http.NewServeMux() - tsweb.Debugger(mux) - dln, err := s.ts.Listen("tcp", fmt.Sprintf(":%d", debugPort)) - if err != nil { - log.Fatalf("failed listening on debug port: %v", err) - } - defer dln.Close() - go func() { - log.Fatalf("debug serve: %v", http.Serve(dln, mux)) - }() - } - - // Finally, start mainloop to configure app connector based on information - // in the netmap. - // We set the NotifyInitialNetMap flag so we will always get woken with the - // current netmap, before only being woken on changes. - bus, err := lc.WatchIPNBus(ctx, ipn.NotifyWatchEngineUpdates|ipn.NotifyInitialNetMap|ipn.NotifyNoPrivateKeys) - if err != nil { - log.Fatalf("watching IPN bus: %v", err) - } - defer bus.Close() - for { - msg, err := bus.Next() - if err != nil { - if errors.Is(err, context.Canceled) { - return - } - log.Fatalf("reading IPN bus: %v", err) - } - - // NetMap contains app-connector configuration - if nm := msg.NetMap; nm != nil && nm.SelfNode.Valid() { - sn := nm.SelfNode.AsStruct() - - var c appctype.AppConnectorConfig - nmConf, err := tailcfg.UnmarshalNodeCapJSON[appctype.AppConnectorConfig](sn.CapMap, configCapKey) - if err != nil { - log.Printf("failed to read app connector configuration from coordination server: %v", err) - } else if len(nmConf) > 0 { - c = nmConf[0] - } - - if c.AdvertiseRoutes { - if err := s.advertiseRoutesFromConfig(ctx, &c); err != nil { - log.Printf("failed to advertise routes: %v", err) - } - } - - // Backwards compatibility: combine any configuration from control with flags specified - // on the command line. This is intentionally done after we advertise any routes - // because its never correct to advertise the nodes native IP addresses. - s.mergeConfigFromFlags(&c, ports, forwards) - s.srv.Configure(&c) - } - } -} - -type sniproxy struct { - srv Server - ts *tsnet.Server - lc *tailscale.LocalClient -} - -func (s *sniproxy) advertiseRoutesFromConfig(ctx context.Context, c *appctype.AppConnectorConfig) error { - // Collect the set of addresses to advertise, using a map - // to avoid duplicate entries. - addrs := map[netip.Addr]struct{}{} - for _, c := range c.SNIProxy { - for _, ip := range c.Addrs { - addrs[ip] = struct{}{} - } - } - for _, c := range c.DNAT { - for _, ip := range c.Addrs { - addrs[ip] = struct{}{} - } - } - - var routes []netip.Prefix - for a := range addrs { - routes = append(routes, netip.PrefixFrom(a, a.BitLen())) - } - sort.SliceStable(routes, func(i, j int) bool { - return routes[i].Addr().Less(routes[j].Addr()) // determinism r us - }) - - _, err := s.lc.EditPrefs(ctx, &ipn.MaskedPrefs{ - Prefs: ipn.Prefs{ - AdvertiseRoutes: routes, - }, - AdvertiseRoutesSet: true, - }) - return err -} - -func (s *sniproxy) mergeConfigFromFlags(out *appctype.AppConnectorConfig, ports, forwards string) { - ip4, ip6 := s.ts.TailscaleIPs() - - sniConfigFromFlags := appctype.SNIProxyConfig{ - Addrs: []netip.Addr{ip4, ip6}, - } - if ports != "" { - for _, portStr := range strings.Split(ports, ",") { - port, err := strconv.ParseUint(portStr, 10, 16) - if err != nil { - log.Fatalf("invalid port: %s", portStr) - } - sniConfigFromFlags.IP = append(sniConfigFromFlags.IP, tailcfg.ProtoPortRange{ - Proto: int(ipproto.TCP), - Ports: tailcfg.PortRange{First: uint16(port), Last: uint16(port)}, - }) - } - } - - var forwardConfigFromFlags []appctype.DNATConfig - for _, forwStr := range strings.Split(forwards, ",") { - if forwStr == "" { - continue - } - forw, err := parseForward(forwStr) - if err != nil { - log.Printf("invalid forwarding spec: %v", err) - continue - } - - forwardConfigFromFlags = append(forwardConfigFromFlags, appctype.DNATConfig{ - Addrs: []netip.Addr{ip4, ip6}, - To: []string{forw.Destination}, - IP: []tailcfg.ProtoPortRange{ - { - Proto: int(ipproto.TCP), - Ports: tailcfg.PortRange{First: uint16(forw.Port), Last: uint16(forw.Port)}, - }, - }, - }) - } - - if len(forwardConfigFromFlags) == 0 && len(sniConfigFromFlags.IP) == 0 { - return // no config specified on the command line - } - - mak.Set(&out.SNIProxy, "flags", sniConfigFromFlags) - for i, forward := range forwardConfigFromFlags { - mak.Set(&out.DNAT, appctype.ConfigID(fmt.Sprintf("flags_%d", i)), forward) - } -} - -func (s *sniproxy) serveDNS(ln net.Listener) { - for { - c, err := ln.Accept() - if err != nil { - log.Printf("serveDNS accept: %v", err) - return - } - go s.srv.HandleDNS(c.(nettype.ConnPacketConn)) - } -} - -func (s *sniproxy) promoteHTTPS(ln net.Listener) { - err := http.Serve(ln, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - http.Redirect(w, r, "https://"+r.Host+r.RequestURI, http.StatusFound) - })) - log.Fatalf("promoteHTTPS http.Serve: %v", err) -} +// Copyright (c) Tailscale Inc & AUTHORS
+// SPDX-License-Identifier: BSD-3-Clause
+
+// The sniproxy is an outbound SNI proxy. It receives TLS connections over
+// Tailscale on one or more TCP ports and sends them out to the same SNI
+// hostname & port on the internet. It can optionally forward one or more
+// TCP ports to a specific destination. It only does TCP.
+package main
+
+import (
+ "context"
+ "errors"
+ "flag"
+ "fmt"
+ "log"
+ "net"
+ "net/http"
+ "net/netip"
+ "os"
+ "sort"
+ "strconv"
+ "strings"
+
+ "github.com/peterbourgon/ff/v3"
+ "tailscale.com/client/tailscale"
+ "tailscale.com/hostinfo"
+ "tailscale.com/ipn"
+ "tailscale.com/tailcfg"
+ "tailscale.com/tsnet"
+ "tailscale.com/tsweb"
+ "tailscale.com/types/appctype"
+ "tailscale.com/types/ipproto"
+ "tailscale.com/types/nettype"
+ "tailscale.com/util/mak"
+)
+
+const configCapKey = "tailscale.com/sniproxy"
+
+// portForward is the state for a single port forwarding entry, as passed to the --forward flag.
+type portForward struct {
+ Port int
+ Proto string
+ Destination string
+}
+
+// parseForward takes a proto/port/destination tuple as an input, as would be passed
+// to the --forward command line flag, and returns a *portForward struct of those parameters.
+func parseForward(value string) (*portForward, error) {
+ parts := strings.Split(value, "/")
+ if len(parts) != 3 {
+ return nil, errors.New("cannot parse: " + value)
+ }
+
+ proto := parts[0]
+ if proto != "tcp" {
+ return nil, errors.New("unsupported forwarding protocol: " + proto)
+ }
+ port, err := strconv.ParseUint(parts[1], 10, 16)
+ if err != nil {
+ return nil, errors.New("bad forwarding port: " + parts[1])
+ }
+ host := parts[2]
+ if host == "" {
+ return nil, errors.New("bad destination: " + value)
+ }
+
+ return &portForward{Port: int(port), Proto: proto, Destination: host}, nil
+}
+
+func main() {
+ // Parse flags
+ fs := flag.NewFlagSet("sniproxy", flag.ContinueOnError)
+ var (
+ ports = fs.String("ports", "443", "comma-separated list of ports to proxy")
+ forwards = fs.String("forwards", "", "comma-separated list of ports to transparently forward, protocol/number/destination. For example, --forwards=tcp/22/github.com,tcp/5432/sql.example.com")
+ wgPort = fs.Int("wg-listen-port", 0, "UDP port to listen on for WireGuard and peer-to-peer traffic; 0 means automatically select")
+ promoteHTTPS = fs.Bool("promote-https", true, "promote HTTP to HTTPS")
+ debugPort = fs.Int("debug-port", 8893, "Listening port for debug/metrics endpoint")
+ hostname = fs.String("hostname", "", "Hostname to register the service under")
+ )
+ err := ff.Parse(fs, os.Args[1:], ff.WithEnvVarPrefix("TS_APPC"))
+ if err != nil {
+ log.Fatal("ff.Parse")
+ }
+
+ var ts tsnet.Server
+ defer ts.Close()
+
+ ctx, cancel := context.WithCancel(context.Background())
+ defer cancel()
+ run(ctx, &ts, *wgPort, *hostname, *promoteHTTPS, *debugPort, *ports, *forwards)
+}
+
+// run actually runs the sniproxy. Its separate from main() to assist in testing.
+func run(ctx context.Context, ts *tsnet.Server, wgPort int, hostname string, promoteHTTPS bool, debugPort int, ports, forwards string) {
+ // Wire up Tailscale node + app connector server
+ hostinfo.SetApp("sniproxy")
+ var s sniproxy
+ s.ts = ts
+
+ s.ts.Port = uint16(wgPort)
+ s.ts.Hostname = hostname
+
+ lc, err := s.ts.LocalClient()
+ if err != nil {
+ log.Fatalf("LocalClient() failed: %v", err)
+ }
+ s.lc = lc
+ s.ts.RegisterFallbackTCPHandler(s.srv.HandleTCPFlow)
+
+ // Start special-purpose listeners: dns, http promotion, debug server
+ ln, err := s.ts.Listen("udp", ":53")
+ if err != nil {
+ log.Fatalf("failed listening on port 53: %v", err)
+ }
+ defer ln.Close()
+ go s.serveDNS(ln)
+ if promoteHTTPS {
+ ln, err := s.ts.Listen("tcp", ":80")
+ if err != nil {
+ log.Fatalf("failed listening on port 80: %v", err)
+ }
+ defer ln.Close()
+ log.Printf("Promoting HTTP to HTTPS ...")
+ go s.promoteHTTPS(ln)
+ }
+ if debugPort != 0 {
+ mux := http.NewServeMux()
+ tsweb.Debugger(mux)
+ dln, err := s.ts.Listen("tcp", fmt.Sprintf(":%d", debugPort))
+ if err != nil {
+ log.Fatalf("failed listening on debug port: %v", err)
+ }
+ defer dln.Close()
+ go func() {
+ log.Fatalf("debug serve: %v", http.Serve(dln, mux))
+ }()
+ }
+
+ // Finally, start mainloop to configure app connector based on information
+ // in the netmap.
+ // We set the NotifyInitialNetMap flag so we will always get woken with the
+ // current netmap, before only being woken on changes.
+ bus, err := lc.WatchIPNBus(ctx, ipn.NotifyWatchEngineUpdates|ipn.NotifyInitialNetMap|ipn.NotifyNoPrivateKeys)
+ if err != nil {
+ log.Fatalf("watching IPN bus: %v", err)
+ }
+ defer bus.Close()
+ for {
+ msg, err := bus.Next()
+ if err != nil {
+ if errors.Is(err, context.Canceled) {
+ return
+ }
+ log.Fatalf("reading IPN bus: %v", err)
+ }
+
+ // NetMap contains app-connector configuration
+ if nm := msg.NetMap; nm != nil && nm.SelfNode.Valid() {
+ sn := nm.SelfNode.AsStruct()
+
+ var c appctype.AppConnectorConfig
+ nmConf, err := tailcfg.UnmarshalNodeCapJSON[appctype.AppConnectorConfig](sn.CapMap, configCapKey)
+ if err != nil {
+ log.Printf("failed to read app connector configuration from coordination server: %v", err)
+ } else if len(nmConf) > 0 {
+ c = nmConf[0]
+ }
+
+ if c.AdvertiseRoutes {
+ if err := s.advertiseRoutesFromConfig(ctx, &c); err != nil {
+ log.Printf("failed to advertise routes: %v", err)
+ }
+ }
+
+ // Backwards compatibility: combine any configuration from control with flags specified
+ // on the command line. This is intentionally done after we advertise any routes
+ // because its never correct to advertise the nodes native IP addresses.
+ s.mergeConfigFromFlags(&c, ports, forwards)
+ s.srv.Configure(&c)
+ }
+ }
+}
+
+type sniproxy struct {
+ srv Server
+ ts *tsnet.Server
+ lc *tailscale.LocalClient
+}
+
+func (s *sniproxy) advertiseRoutesFromConfig(ctx context.Context, c *appctype.AppConnectorConfig) error {
+ // Collect the set of addresses to advertise, using a map
+ // to avoid duplicate entries.
+ addrs := map[netip.Addr]struct{}{}
+ for _, c := range c.SNIProxy {
+ for _, ip := range c.Addrs {
+ addrs[ip] = struct{}{}
+ }
+ }
+ for _, c := range c.DNAT {
+ for _, ip := range c.Addrs {
+ addrs[ip] = struct{}{}
+ }
+ }
+
+ var routes []netip.Prefix
+ for a := range addrs {
+ routes = append(routes, netip.PrefixFrom(a, a.BitLen()))
+ }
+ sort.SliceStable(routes, func(i, j int) bool {
+ return routes[i].Addr().Less(routes[j].Addr()) // determinism r us
+ })
+
+ _, err := s.lc.EditPrefs(ctx, &ipn.MaskedPrefs{
+ Prefs: ipn.Prefs{
+ AdvertiseRoutes: routes,
+ },
+ AdvertiseRoutesSet: true,
+ })
+ return err
+}
+
+func (s *sniproxy) mergeConfigFromFlags(out *appctype.AppConnectorConfig, ports, forwards string) {
+ ip4, ip6 := s.ts.TailscaleIPs()
+
+ sniConfigFromFlags := appctype.SNIProxyConfig{
+ Addrs: []netip.Addr{ip4, ip6},
+ }
+ if ports != "" {
+ for _, portStr := range strings.Split(ports, ",") {
+ port, err := strconv.ParseUint(portStr, 10, 16)
+ if err != nil {
+ log.Fatalf("invalid port: %s", portStr)
+ }
+ sniConfigFromFlags.IP = append(sniConfigFromFlags.IP, tailcfg.ProtoPortRange{
+ Proto: int(ipproto.TCP),
+ Ports: tailcfg.PortRange{First: uint16(port), Last: uint16(port)},
+ })
+ }
+ }
+
+ var forwardConfigFromFlags []appctype.DNATConfig
+ for _, forwStr := range strings.Split(forwards, ",") {
+ if forwStr == "" {
+ continue
+ }
+ forw, err := parseForward(forwStr)
+ if err != nil {
+ log.Printf("invalid forwarding spec: %v", err)
+ continue
+ }
+
+ forwardConfigFromFlags = append(forwardConfigFromFlags, appctype.DNATConfig{
+ Addrs: []netip.Addr{ip4, ip6},
+ To: []string{forw.Destination},
+ IP: []tailcfg.ProtoPortRange{
+ {
+ Proto: int(ipproto.TCP),
+ Ports: tailcfg.PortRange{First: uint16(forw.Port), Last: uint16(forw.Port)},
+ },
+ },
+ })
+ }
+
+ if len(forwardConfigFromFlags) == 0 && len(sniConfigFromFlags.IP) == 0 {
+ return // no config specified on the command line
+ }
+
+ mak.Set(&out.SNIProxy, "flags", sniConfigFromFlags)
+ for i, forward := range forwardConfigFromFlags {
+ mak.Set(&out.DNAT, appctype.ConfigID(fmt.Sprintf("flags_%d", i)), forward)
+ }
+}
+
+func (s *sniproxy) serveDNS(ln net.Listener) {
+ for {
+ c, err := ln.Accept()
+ if err != nil {
+ log.Printf("serveDNS accept: %v", err)
+ return
+ }
+ go s.srv.HandleDNS(c.(nettype.ConnPacketConn))
+ }
+}
+
+func (s *sniproxy) promoteHTTPS(ln net.Listener) {
+ err := http.Serve(ln, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ http.Redirect(w, r, "https://"+r.Host+r.RequestURI, http.StatusFound)
+ }))
+ log.Fatalf("promoteHTTPS http.Serve: %v", err)
+}
|
