summaryrefslogtreecommitdiffhomepage
path: root/cmd/sniproxy
diff options
context:
space:
mode:
Diffstat (limited to 'cmd/sniproxy')
-rw-r--r--cmd/sniproxy/.gitignore2
-rw-r--r--cmd/sniproxy/handlers_test.go318
-rw-r--r--cmd/sniproxy/server.go654
-rw-r--r--cmd/sniproxy/server_test.go190
-rw-r--r--cmd/sniproxy/sniproxy.go582
5 files changed, 873 insertions, 873 deletions
diff --git a/cmd/sniproxy/.gitignore b/cmd/sniproxy/.gitignore
index b1399c881..0bca33912 100644
--- a/cmd/sniproxy/.gitignore
+++ b/cmd/sniproxy/.gitignore
@@ -1 +1 @@
-sniproxy
+sniproxy
diff --git a/cmd/sniproxy/handlers_test.go b/cmd/sniproxy/handlers_test.go
index 4f9fc6a34..8ec5b097c 100644
--- a/cmd/sniproxy/handlers_test.go
+++ b/cmd/sniproxy/handlers_test.go
@@ -1,159 +1,159 @@
-// Copyright (c) Tailscale Inc & AUTHORS
-// SPDX-License-Identifier: BSD-3-Clause
-
-package main
-
-import (
- "bytes"
- "context"
- "encoding/hex"
- "io"
- "net"
- "net/netip"
- "strings"
- "testing"
-
- "tailscale.com/net/memnet"
-)
-
-func echoConnOnce(conn net.Conn) {
- defer conn.Close()
-
- b := make([]byte, 256)
- n, err := conn.Read(b)
- if err != nil {
- return
- }
-
- if _, err := conn.Write(b[:n]); err != nil {
- return
- }
-}
-
-func TestTCPRoundRobinHandler(t *testing.T) {
- h := tcpRoundRobinHandler{
- To: []string{"yeet.com"},
- DialContext: func(ctx context.Context, network, addr string) (net.Conn, error) {
- if network != "tcp" {
- t.Errorf("network = %s, want %s", network, "tcp")
- }
- if addr != "yeet.com:22" {
- t.Errorf("addr = %s, want %s", addr, "yeet.com:22")
- }
-
- c, s := memnet.NewConn("outbound", 1024)
- go echoConnOnce(s)
- return c, nil
- },
- }
-
- cSock, sSock := memnet.NewTCPConn(netip.MustParseAddrPort("10.64.1.2:22"), netip.MustParseAddrPort("10.64.1.2:22"), 1024)
- h.Handle(sSock)
-
- // Test data write and read, the other end will echo back
- // a single stanza
- want := "hello"
- if _, err := io.WriteString(cSock, want); err != nil {
- t.Fatal(err)
- }
- got := make([]byte, len(want))
- if _, err := io.ReadAtLeast(cSock, got, len(got)); err != nil {
- t.Fatal(err)
- }
- if string(got) != want {
- t.Errorf("got %q, want %q", got, want)
- }
-
- // The other end closed the socket after the first echo, so
- // any following read should error.
- io.WriteString(cSock, "deadass heres some data on god fr")
- if _, err := io.ReadAtLeast(cSock, got, len(got)); err == nil {
- t.Error("read succeeded on closed socket")
- }
-}
-
-// Capture of first TCP data segment for a connection to https://pkgs.tailscale.com
-const tlsStart = `45000239ff1840004006f9f5c0a801f2
-c726b5efcf9e01bbe803b21394e3b752
-801801f641dc00000101080ade3474f2
-2fb93ee71603010200010001fc030303
-c3acbd19d2624765bb19af4bce03365e
-1d197f5bb939cdadeff26b0f8e7a0620
-295b04127b82bae46aac4ff58cffef25
-eba75a4b7a6de729532c411bd9dd0d2c
-00203a3a130113021303c02bc02fc02c
-c030cca9cca8c013c014009c009d002f
-003501000193caca0000000a000a0008
-1a1a001d001700180010000e000c0268
-3208687474702f312e31002b0007062a
-2a03040303ff01000100000d00120010
-04030804040105030805050108060601
-000b00020100002300000033002b0029
-1a1a000100001d0020d3c76bef062979
-a812ce935cfb4dbe6b3a84dc5ba9226f
-23b0f34af9d1d03b4a001b0003020002
-00120000446900050003026832000000
-170015000012706b67732e7461696c73
-63616c652e636f6d002d000201010005
-00050100000000001700003a3a000100
-0015002d000000000000000000000000
-00000000000000000000000000000000
-00000000000000000000000000000000
-0000290094006f0069e76f2016f963ad
-38c8632d1f240cd75e00e25fdef295d4
-7042b26f3a9a543b1c7dc74939d77803
-20527d423ff996997bda2c6383a14f49
-219eeef8a053e90a32228df37ddbe126
-eccf6b085c93890d08341d819aea6111
-0d909f4cd6b071d9ea40618e74588a33
-90d494bbb5c3002120d5a164a16c9724
-c9ef5e540d8d6f007789a7acf9f5f16f
-bf6a1907a6782ed02b`
-
-func fakeSNIHeader() []byte {
- b, err := hex.DecodeString(strings.Replace(tlsStart, "\n", "", -1))
- if err != nil {
- panic(err)
- }
- return b[0x34:] // trim IP + TCP header
-}
-
-func TestTCPSNIHandler(t *testing.T) {
- h := tcpSNIHandler{
- Allowlist: []string{"pkgs.tailscale.com"},
- DialContext: func(ctx context.Context, network, addr string) (net.Conn, error) {
- if network != "tcp" {
- t.Errorf("network = %s, want %s", network, "tcp")
- }
- if addr != "pkgs.tailscale.com:443" {
- t.Errorf("addr = %s, want %s", addr, "pkgs.tailscale.com:443")
- }
-
- c, s := memnet.NewConn("outbound", 1024)
- go echoConnOnce(s)
- return c, nil
- },
- }
-
- cSock, sSock := memnet.NewTCPConn(netip.MustParseAddrPort("10.64.1.2:22"), netip.MustParseAddrPort("10.64.1.2:443"), 1024)
- h.Handle(sSock)
-
- // Fake a TLS handshake record with an SNI in it.
- if _, err := cSock.Write(fakeSNIHeader()); err != nil {
- t.Fatal(err)
- }
-
- // Test read, the other end will echo back
- // a single stanza, which is at least the beginning of the SNI header.
- want := fakeSNIHeader()[:5]
- if _, err := cSock.Write(want); err != nil {
- t.Fatal(err)
- }
- got := make([]byte, len(want))
- if _, err := io.ReadAtLeast(cSock, got, len(got)); err != nil {
- t.Fatal(err)
- }
- if !bytes.Equal(got, want) {
- t.Errorf("got %q, want %q", got, want)
- }
-}
+// Copyright (c) Tailscale Inc & AUTHORS
+// SPDX-License-Identifier: BSD-3-Clause
+
+package main
+
+import (
+ "bytes"
+ "context"
+ "encoding/hex"
+ "io"
+ "net"
+ "net/netip"
+ "strings"
+ "testing"
+
+ "tailscale.com/net/memnet"
+)
+
+func echoConnOnce(conn net.Conn) {
+ defer conn.Close()
+
+ b := make([]byte, 256)
+ n, err := conn.Read(b)
+ if err != nil {
+ return
+ }
+
+ if _, err := conn.Write(b[:n]); err != nil {
+ return
+ }
+}
+
+func TestTCPRoundRobinHandler(t *testing.T) {
+ h := tcpRoundRobinHandler{
+ To: []string{"yeet.com"},
+ DialContext: func(ctx context.Context, network, addr string) (net.Conn, error) {
+ if network != "tcp" {
+ t.Errorf("network = %s, want %s", network, "tcp")
+ }
+ if addr != "yeet.com:22" {
+ t.Errorf("addr = %s, want %s", addr, "yeet.com:22")
+ }
+
+ c, s := memnet.NewConn("outbound", 1024)
+ go echoConnOnce(s)
+ return c, nil
+ },
+ }
+
+ cSock, sSock := memnet.NewTCPConn(netip.MustParseAddrPort("10.64.1.2:22"), netip.MustParseAddrPort("10.64.1.2:22"), 1024)
+ h.Handle(sSock)
+
+ // Test data write and read, the other end will echo back
+ // a single stanza
+ want := "hello"
+ if _, err := io.WriteString(cSock, want); err != nil {
+ t.Fatal(err)
+ }
+ got := make([]byte, len(want))
+ if _, err := io.ReadAtLeast(cSock, got, len(got)); err != nil {
+ t.Fatal(err)
+ }
+ if string(got) != want {
+ t.Errorf("got %q, want %q", got, want)
+ }
+
+ // The other end closed the socket after the first echo, so
+ // any following read should error.
+ io.WriteString(cSock, "deadass heres some data on god fr")
+ if _, err := io.ReadAtLeast(cSock, got, len(got)); err == nil {
+ t.Error("read succeeded on closed socket")
+ }
+}
+
+// Capture of first TCP data segment for a connection to https://pkgs.tailscale.com
+const tlsStart = `45000239ff1840004006f9f5c0a801f2
+c726b5efcf9e01bbe803b21394e3b752
+801801f641dc00000101080ade3474f2
+2fb93ee71603010200010001fc030303
+c3acbd19d2624765bb19af4bce03365e
+1d197f5bb939cdadeff26b0f8e7a0620
+295b04127b82bae46aac4ff58cffef25
+eba75a4b7a6de729532c411bd9dd0d2c
+00203a3a130113021303c02bc02fc02c
+c030cca9cca8c013c014009c009d002f
+003501000193caca0000000a000a0008
+1a1a001d001700180010000e000c0268
+3208687474702f312e31002b0007062a
+2a03040303ff01000100000d00120010
+04030804040105030805050108060601
+000b00020100002300000033002b0029
+1a1a000100001d0020d3c76bef062979
+a812ce935cfb4dbe6b3a84dc5ba9226f
+23b0f34af9d1d03b4a001b0003020002
+00120000446900050003026832000000
+170015000012706b67732e7461696c73
+63616c652e636f6d002d000201010005
+00050100000000001700003a3a000100
+0015002d000000000000000000000000
+00000000000000000000000000000000
+00000000000000000000000000000000
+0000290094006f0069e76f2016f963ad
+38c8632d1f240cd75e00e25fdef295d4
+7042b26f3a9a543b1c7dc74939d77803
+20527d423ff996997bda2c6383a14f49
+219eeef8a053e90a32228df37ddbe126
+eccf6b085c93890d08341d819aea6111
+0d909f4cd6b071d9ea40618e74588a33
+90d494bbb5c3002120d5a164a16c9724
+c9ef5e540d8d6f007789a7acf9f5f16f
+bf6a1907a6782ed02b`
+
+func fakeSNIHeader() []byte {
+ b, err := hex.DecodeString(strings.Replace(tlsStart, "\n", "", -1))
+ if err != nil {
+ panic(err)
+ }
+ return b[0x34:] // trim IP + TCP header
+}
+
+func TestTCPSNIHandler(t *testing.T) {
+ h := tcpSNIHandler{
+ Allowlist: []string{"pkgs.tailscale.com"},
+ DialContext: func(ctx context.Context, network, addr string) (net.Conn, error) {
+ if network != "tcp" {
+ t.Errorf("network = %s, want %s", network, "tcp")
+ }
+ if addr != "pkgs.tailscale.com:443" {
+ t.Errorf("addr = %s, want %s", addr, "pkgs.tailscale.com:443")
+ }
+
+ c, s := memnet.NewConn("outbound", 1024)
+ go echoConnOnce(s)
+ return c, nil
+ },
+ }
+
+ cSock, sSock := memnet.NewTCPConn(netip.MustParseAddrPort("10.64.1.2:22"), netip.MustParseAddrPort("10.64.1.2:443"), 1024)
+ h.Handle(sSock)
+
+ // Fake a TLS handshake record with an SNI in it.
+ if _, err := cSock.Write(fakeSNIHeader()); err != nil {
+ t.Fatal(err)
+ }
+
+ // Test read, the other end will echo back
+ // a single stanza, which is at least the beginning of the SNI header.
+ want := fakeSNIHeader()[:5]
+ if _, err := cSock.Write(want); err != nil {
+ t.Fatal(err)
+ }
+ got := make([]byte, len(want))
+ if _, err := io.ReadAtLeast(cSock, got, len(got)); err != nil {
+ t.Fatal(err)
+ }
+ if !bytes.Equal(got, want) {
+ t.Errorf("got %q, want %q", got, want)
+ }
+}
diff --git a/cmd/sniproxy/server.go b/cmd/sniproxy/server.go
index b322b6f4b..c89420661 100644
--- a/cmd/sniproxy/server.go
+++ b/cmd/sniproxy/server.go
@@ -1,327 +1,327 @@
-// Copyright (c) Tailscale Inc & AUTHORS
-// SPDX-License-Identifier: BSD-3-Clause
-
-package main
-
-import (
- "expvar"
- "log"
- "net"
- "net/netip"
- "sync"
- "time"
-
- "golang.org/x/net/dns/dnsmessage"
- "tailscale.com/metrics"
- "tailscale.com/tailcfg"
- "tailscale.com/types/appctype"
- "tailscale.com/types/ipproto"
- "tailscale.com/types/nettype"
- "tailscale.com/util/clientmetric"
- "tailscale.com/util/mak"
-)
-
-var tsMBox = dnsmessage.MustNewName("support.tailscale.com.")
-
-// target describes the predicates which route some inbound
-// traffic to the app connector to a specific handler.
-type target struct {
- Dest netip.Prefix
- Matching tailcfg.ProtoPortRange
-}
-
-// Server implements an App Connector as expressed in sniproxy.
-type Server struct {
- mu sync.RWMutex // mu guards following fields
- connectors map[appctype.ConfigID]connector
-}
-
-type appcMetrics struct {
- dnsResponses expvar.Int
- dnsFailures expvar.Int
- tcpConns expvar.Int
- sniConns expvar.Int
- unhandledConns expvar.Int
-}
-
-var getMetrics = sync.OnceValue[*appcMetrics](func() *appcMetrics {
- m := appcMetrics{}
-
- stats := new(metrics.Set)
- stats.Set("tls_sessions", &m.sniConns)
- clientmetric.NewCounterFunc("sniproxy_tls_sessions", m.sniConns.Value)
- stats.Set("tcp_sessions", &m.tcpConns)
- clientmetric.NewCounterFunc("sniproxy_tcp_sessions", m.tcpConns.Value)
- stats.Set("dns_responses", &m.dnsResponses)
- clientmetric.NewCounterFunc("sniproxy_dns_responses", m.dnsResponses.Value)
- stats.Set("dns_failed", &m.dnsFailures)
- clientmetric.NewCounterFunc("sniproxy_dns_failed", m.dnsFailures.Value)
- expvar.Publish("sniproxy", stats)
-
- return &m
-})
-
-// Configure applies the provided configuration to the app connector.
-func (s *Server) Configure(cfg *appctype.AppConnectorConfig) {
- s.mu.Lock()
- defer s.mu.Unlock()
- s.connectors = makeConnectorsFromConfig(cfg)
- log.Printf("installed app connector config: %+v", s.connectors)
-}
-
-// HandleTCPFlow implements tsnet.FallbackTCPHandler.
-func (s *Server) HandleTCPFlow(src, dst netip.AddrPort) (handler func(net.Conn), intercept bool) {
- m := getMetrics()
- s.mu.RLock()
- defer s.mu.RUnlock()
-
- for _, c := range s.connectors {
- if handler, intercept := c.handleTCPFlow(src, dst, m); intercept {
- return handler, intercept
- }
- }
-
- return nil, false
-}
-
-// HandleDNS handles a DNS request to the app connector.
-func (s *Server) HandleDNS(c nettype.ConnPacketConn) {
- defer c.Close()
- c.SetReadDeadline(time.Now().Add(5 * time.Second))
- m := getMetrics()
-
- buf := make([]byte, 1500)
- n, err := c.Read(buf)
- if err != nil {
- log.Printf("HandleDNS: read failed: %v\n ", err)
- m.dnsFailures.Add(1)
- return
- }
-
- addrPortStr := c.LocalAddr().String()
- host, _, err := net.SplitHostPort(addrPortStr)
- if err != nil {
- log.Printf("HandleDNS: bogus addrPort %q", addrPortStr)
- m.dnsFailures.Add(1)
- return
- }
- localAddr, err := netip.ParseAddr(host)
- if err != nil {
- log.Printf("HandleDNS: bogus local address %q", host)
- m.dnsFailures.Add(1)
- return
- }
-
- var msg dnsmessage.Message
- err = msg.Unpack(buf[:n])
- if err != nil {
- log.Printf("HandleDNS: dnsmessage unpack failed: %v\n ", err)
- m.dnsFailures.Add(1)
- return
- }
-
- s.mu.RLock()
- defer s.mu.RUnlock()
- for _, connector := range s.connectors {
- resp, err := connector.handleDNS(&msg, localAddr)
- if err != nil {
- log.Printf("HandleDNS: connector handling failed: %v\n", err)
- m.dnsFailures.Add(1)
- return
- }
- if len(resp) > 0 {
- // This connector handled the DNS request
- _, err = c.Write(resp)
- if err != nil {
- log.Printf("HandleDNS: write failed: %v\n", err)
- m.dnsFailures.Add(1)
- return
- }
-
- m.dnsResponses.Add(1)
- return
- }
- }
-}
-
-// connector describes a logical collection of
-// services which need to be proxied.
-type connector struct {
- Handlers map[target]handler
-}
-
-// handleTCPFlow implements tsnet.FallbackTCPHandler.
-func (c *connector) handleTCPFlow(src, dst netip.AddrPort, m *appcMetrics) (handler func(net.Conn), intercept bool) {
- for t, h := range c.Handlers {
- if t.Matching.Proto != 0 && t.Matching.Proto != int(ipproto.TCP) {
- continue
- }
- if !t.Dest.Contains(dst.Addr()) {
- continue
- }
- if !t.Matching.Ports.Contains(dst.Port()) {
- continue
- }
-
- switch h.(type) {
- case *tcpSNIHandler:
- m.sniConns.Add(1)
- case *tcpRoundRobinHandler:
- m.tcpConns.Add(1)
- default:
- log.Printf("handleTCPFlow: unhandled handler type %T", h)
- }
-
- return h.Handle, true
- }
-
- m.unhandledConns.Add(1)
- return nil, false
-}
-
-// handleDNS returns the DNS response to the given query. If this
-// connector is unable to handle the request, nil is returned.
-func (c *connector) handleDNS(req *dnsmessage.Message, localAddr netip.Addr) (response []byte, err error) {
- for t, h := range c.Handlers {
- if t.Dest.Contains(localAddr) {
- return makeDNSResponse(req, h.ReachableOn())
- }
- }
-
- // Did not match, signal 'not handled' to caller
- return nil, nil
-}
-
-func makeDNSResponse(req *dnsmessage.Message, reachableIPs []netip.Addr) (response []byte, err error) {
- resp := dnsmessage.NewBuilder(response,
- dnsmessage.Header{
- ID: req.Header.ID,
- Response: true,
- Authoritative: true,
- })
- resp.EnableCompression()
-
- if len(req.Questions) == 0 {
- response, _ = resp.Finish()
- return response, nil
- }
- q := req.Questions[0]
- err = resp.StartQuestions()
- if err != nil {
- return
- }
- resp.Question(q)
-
- err = resp.StartAnswers()
- if err != nil {
- return
- }
-
- switch q.Type {
- case dnsmessage.TypeAAAA:
- for _, ip := range reachableIPs {
- if ip.Is6() {
- err = resp.AAAAResource(
- dnsmessage.ResourceHeader{Name: q.Name, Class: q.Class, TTL: 120},
- dnsmessage.AAAAResource{AAAA: ip.As16()},
- )
- }
- }
-
- case dnsmessage.TypeA:
- for _, ip := range reachableIPs {
- if ip.Is4() {
- err = resp.AResource(
- dnsmessage.ResourceHeader{Name: q.Name, Class: q.Class, TTL: 120},
- dnsmessage.AResource{A: ip.As4()},
- )
- }
- }
-
- case dnsmessage.TypeSOA:
- err = resp.SOAResource(
- dnsmessage.ResourceHeader{Name: q.Name, Class: q.Class, TTL: 120},
- dnsmessage.SOAResource{NS: q.Name, MBox: tsMBox, Serial: 2023030600,
- Refresh: 120, Retry: 120, Expire: 120, MinTTL: 60},
- )
- case dnsmessage.TypeNS:
- err = resp.NSResource(
- dnsmessage.ResourceHeader{Name: q.Name, Class: q.Class, TTL: 120},
- dnsmessage.NSResource{NS: tsMBox},
- )
- }
-
- if err != nil {
- return nil, err
- }
- return resp.Finish()
-}
-
-type handler interface {
- // Handle handles the given socket.
- Handle(c net.Conn)
-
- // ReachableOn returns the IP addresses this handler is reachable on.
- ReachableOn() []netip.Addr
-}
-
-func installDNATHandler(d *appctype.DNATConfig, out *connector) {
- // These handlers don't actually do DNAT, they just
- // proxy the data over the connection.
- var dialer net.Dialer
- dialer.Timeout = 5 * time.Second
- h := tcpRoundRobinHandler{
- To: d.To,
- DialContext: dialer.DialContext,
- ReachableIPs: d.Addrs,
- }
-
- for _, addr := range d.Addrs {
- for _, protoPort := range d.IP {
- t := target{
- Dest: netip.PrefixFrom(addr, addr.BitLen()),
- Matching: protoPort,
- }
-
- mak.Set(&out.Handlers, t, handler(&h))
- }
- }
-}
-
-func installSNIHandler(c *appctype.SNIProxyConfig, out *connector) {
- var dialer net.Dialer
- dialer.Timeout = 5 * time.Second
- h := tcpSNIHandler{
- Allowlist: c.AllowedDomains,
- DialContext: dialer.DialContext,
- ReachableIPs: c.Addrs,
- }
-
- for _, addr := range c.Addrs {
- for _, protoPort := range c.IP {
- t := target{
- Dest: netip.PrefixFrom(addr, addr.BitLen()),
- Matching: protoPort,
- }
-
- mak.Set(&out.Handlers, t, handler(&h))
- }
- }
-}
-
-func makeConnectorsFromConfig(cfg *appctype.AppConnectorConfig) map[appctype.ConfigID]connector {
- var connectors map[appctype.ConfigID]connector
-
- for cID, d := range cfg.DNAT {
- c := connectors[cID]
- installDNATHandler(&d, &c)
- mak.Set(&connectors, cID, c)
- }
- for cID, d := range cfg.SNIProxy {
- c := connectors[cID]
- installSNIHandler(&d, &c)
- mak.Set(&connectors, cID, c)
- }
-
- return connectors
-}
+// Copyright (c) Tailscale Inc & AUTHORS
+// SPDX-License-Identifier: BSD-3-Clause
+
+package main
+
+import (
+ "expvar"
+ "log"
+ "net"
+ "net/netip"
+ "sync"
+ "time"
+
+ "golang.org/x/net/dns/dnsmessage"
+ "tailscale.com/metrics"
+ "tailscale.com/tailcfg"
+ "tailscale.com/types/appctype"
+ "tailscale.com/types/ipproto"
+ "tailscale.com/types/nettype"
+ "tailscale.com/util/clientmetric"
+ "tailscale.com/util/mak"
+)
+
+var tsMBox = dnsmessage.MustNewName("support.tailscale.com.")
+
+// target describes the predicates which route some inbound
+// traffic to the app connector to a specific handler.
+type target struct {
+ Dest netip.Prefix
+ Matching tailcfg.ProtoPortRange
+}
+
+// Server implements an App Connector as expressed in sniproxy.
+type Server struct {
+ mu sync.RWMutex // mu guards following fields
+ connectors map[appctype.ConfigID]connector
+}
+
+type appcMetrics struct {
+ dnsResponses expvar.Int
+ dnsFailures expvar.Int
+ tcpConns expvar.Int
+ sniConns expvar.Int
+ unhandledConns expvar.Int
+}
+
+var getMetrics = sync.OnceValue[*appcMetrics](func() *appcMetrics {
+ m := appcMetrics{}
+
+ stats := new(metrics.Set)
+ stats.Set("tls_sessions", &m.sniConns)
+ clientmetric.NewCounterFunc("sniproxy_tls_sessions", m.sniConns.Value)
+ stats.Set("tcp_sessions", &m.tcpConns)
+ clientmetric.NewCounterFunc("sniproxy_tcp_sessions", m.tcpConns.Value)
+ stats.Set("dns_responses", &m.dnsResponses)
+ clientmetric.NewCounterFunc("sniproxy_dns_responses", m.dnsResponses.Value)
+ stats.Set("dns_failed", &m.dnsFailures)
+ clientmetric.NewCounterFunc("sniproxy_dns_failed", m.dnsFailures.Value)
+ expvar.Publish("sniproxy", stats)
+
+ return &m
+})
+
+// Configure applies the provided configuration to the app connector.
+func (s *Server) Configure(cfg *appctype.AppConnectorConfig) {
+ s.mu.Lock()
+ defer s.mu.Unlock()
+ s.connectors = makeConnectorsFromConfig(cfg)
+ log.Printf("installed app connector config: %+v", s.connectors)
+}
+
+// HandleTCPFlow implements tsnet.FallbackTCPHandler.
+func (s *Server) HandleTCPFlow(src, dst netip.AddrPort) (handler func(net.Conn), intercept bool) {
+ m := getMetrics()
+ s.mu.RLock()
+ defer s.mu.RUnlock()
+
+ for _, c := range s.connectors {
+ if handler, intercept := c.handleTCPFlow(src, dst, m); intercept {
+ return handler, intercept
+ }
+ }
+
+ return nil, false
+}
+
+// HandleDNS handles a DNS request to the app connector.
+func (s *Server) HandleDNS(c nettype.ConnPacketConn) {
+ defer c.Close()
+ c.SetReadDeadline(time.Now().Add(5 * time.Second))
+ m := getMetrics()
+
+ buf := make([]byte, 1500)
+ n, err := c.Read(buf)
+ if err != nil {
+ log.Printf("HandleDNS: read failed: %v\n ", err)
+ m.dnsFailures.Add(1)
+ return
+ }
+
+ addrPortStr := c.LocalAddr().String()
+ host, _, err := net.SplitHostPort(addrPortStr)
+ if err != nil {
+ log.Printf("HandleDNS: bogus addrPort %q", addrPortStr)
+ m.dnsFailures.Add(1)
+ return
+ }
+ localAddr, err := netip.ParseAddr(host)
+ if err != nil {
+ log.Printf("HandleDNS: bogus local address %q", host)
+ m.dnsFailures.Add(1)
+ return
+ }
+
+ var msg dnsmessage.Message
+ err = msg.Unpack(buf[:n])
+ if err != nil {
+ log.Printf("HandleDNS: dnsmessage unpack failed: %v\n ", err)
+ m.dnsFailures.Add(1)
+ return
+ }
+
+ s.mu.RLock()
+ defer s.mu.RUnlock()
+ for _, connector := range s.connectors {
+ resp, err := connector.handleDNS(&msg, localAddr)
+ if err != nil {
+ log.Printf("HandleDNS: connector handling failed: %v\n", err)
+ m.dnsFailures.Add(1)
+ return
+ }
+ if len(resp) > 0 {
+ // This connector handled the DNS request
+ _, err = c.Write(resp)
+ if err != nil {
+ log.Printf("HandleDNS: write failed: %v\n", err)
+ m.dnsFailures.Add(1)
+ return
+ }
+
+ m.dnsResponses.Add(1)
+ return
+ }
+ }
+}
+
+// connector describes a logical collection of
+// services which need to be proxied.
+type connector struct {
+ Handlers map[target]handler
+}
+
+// handleTCPFlow implements tsnet.FallbackTCPHandler.
+func (c *connector) handleTCPFlow(src, dst netip.AddrPort, m *appcMetrics) (handler func(net.Conn), intercept bool) {
+ for t, h := range c.Handlers {
+ if t.Matching.Proto != 0 && t.Matching.Proto != int(ipproto.TCP) {
+ continue
+ }
+ if !t.Dest.Contains(dst.Addr()) {
+ continue
+ }
+ if !t.Matching.Ports.Contains(dst.Port()) {
+ continue
+ }
+
+ switch h.(type) {
+ case *tcpSNIHandler:
+ m.sniConns.Add(1)
+ case *tcpRoundRobinHandler:
+ m.tcpConns.Add(1)
+ default:
+ log.Printf("handleTCPFlow: unhandled handler type %T", h)
+ }
+
+ return h.Handle, true
+ }
+
+ m.unhandledConns.Add(1)
+ return nil, false
+}
+
+// handleDNS returns the DNS response to the given query. If this
+// connector is unable to handle the request, nil is returned.
+func (c *connector) handleDNS(req *dnsmessage.Message, localAddr netip.Addr) (response []byte, err error) {
+ for t, h := range c.Handlers {
+ if t.Dest.Contains(localAddr) {
+ return makeDNSResponse(req, h.ReachableOn())
+ }
+ }
+
+ // Did not match, signal 'not handled' to caller
+ return nil, nil
+}
+
+func makeDNSResponse(req *dnsmessage.Message, reachableIPs []netip.Addr) (response []byte, err error) {
+ resp := dnsmessage.NewBuilder(response,
+ dnsmessage.Header{
+ ID: req.Header.ID,
+ Response: true,
+ Authoritative: true,
+ })
+ resp.EnableCompression()
+
+ if len(req.Questions) == 0 {
+ response, _ = resp.Finish()
+ return response, nil
+ }
+ q := req.Questions[0]
+ err = resp.StartQuestions()
+ if err != nil {
+ return
+ }
+ resp.Question(q)
+
+ err = resp.StartAnswers()
+ if err != nil {
+ return
+ }
+
+ switch q.Type {
+ case dnsmessage.TypeAAAA:
+ for _, ip := range reachableIPs {
+ if ip.Is6() {
+ err = resp.AAAAResource(
+ dnsmessage.ResourceHeader{Name: q.Name, Class: q.Class, TTL: 120},
+ dnsmessage.AAAAResource{AAAA: ip.As16()},
+ )
+ }
+ }
+
+ case dnsmessage.TypeA:
+ for _, ip := range reachableIPs {
+ if ip.Is4() {
+ err = resp.AResource(
+ dnsmessage.ResourceHeader{Name: q.Name, Class: q.Class, TTL: 120},
+ dnsmessage.AResource{A: ip.As4()},
+ )
+ }
+ }
+
+ case dnsmessage.TypeSOA:
+ err = resp.SOAResource(
+ dnsmessage.ResourceHeader{Name: q.Name, Class: q.Class, TTL: 120},
+ dnsmessage.SOAResource{NS: q.Name, MBox: tsMBox, Serial: 2023030600,
+ Refresh: 120, Retry: 120, Expire: 120, MinTTL: 60},
+ )
+ case dnsmessage.TypeNS:
+ err = resp.NSResource(
+ dnsmessage.ResourceHeader{Name: q.Name, Class: q.Class, TTL: 120},
+ dnsmessage.NSResource{NS: tsMBox},
+ )
+ }
+
+ if err != nil {
+ return nil, err
+ }
+ return resp.Finish()
+}
+
+type handler interface {
+ // Handle handles the given socket.
+ Handle(c net.Conn)
+
+ // ReachableOn returns the IP addresses this handler is reachable on.
+ ReachableOn() []netip.Addr
+}
+
+func installDNATHandler(d *appctype.DNATConfig, out *connector) {
+ // These handlers don't actually do DNAT, they just
+ // proxy the data over the connection.
+ var dialer net.Dialer
+ dialer.Timeout = 5 * time.Second
+ h := tcpRoundRobinHandler{
+ To: d.To,
+ DialContext: dialer.DialContext,
+ ReachableIPs: d.Addrs,
+ }
+
+ for _, addr := range d.Addrs {
+ for _, protoPort := range d.IP {
+ t := target{
+ Dest: netip.PrefixFrom(addr, addr.BitLen()),
+ Matching: protoPort,
+ }
+
+ mak.Set(&out.Handlers, t, handler(&h))
+ }
+ }
+}
+
+func installSNIHandler(c *appctype.SNIProxyConfig, out *connector) {
+ var dialer net.Dialer
+ dialer.Timeout = 5 * time.Second
+ h := tcpSNIHandler{
+ Allowlist: c.AllowedDomains,
+ DialContext: dialer.DialContext,
+ ReachableIPs: c.Addrs,
+ }
+
+ for _, addr := range c.Addrs {
+ for _, protoPort := range c.IP {
+ t := target{
+ Dest: netip.PrefixFrom(addr, addr.BitLen()),
+ Matching: protoPort,
+ }
+
+ mak.Set(&out.Handlers, t, handler(&h))
+ }
+ }
+}
+
+func makeConnectorsFromConfig(cfg *appctype.AppConnectorConfig) map[appctype.ConfigID]connector {
+ var connectors map[appctype.ConfigID]connector
+
+ for cID, d := range cfg.DNAT {
+ c := connectors[cID]
+ installDNATHandler(&d, &c)
+ mak.Set(&connectors, cID, c)
+ }
+ for cID, d := range cfg.SNIProxy {
+ c := connectors[cID]
+ installSNIHandler(&d, &c)
+ mak.Set(&connectors, cID, c)
+ }
+
+ return connectors
+}
diff --git a/cmd/sniproxy/server_test.go b/cmd/sniproxy/server_test.go
index d56f2aa75..2a51c874c 100644
--- a/cmd/sniproxy/server_test.go
+++ b/cmd/sniproxy/server_test.go
@@ -1,95 +1,95 @@
-// Copyright (c) Tailscale Inc & AUTHORS
-// SPDX-License-Identifier: BSD-3-Clause
-
-package main
-
-import (
- "net/netip"
- "testing"
-
- "github.com/google/go-cmp/cmp"
- "github.com/google/go-cmp/cmp/cmpopts"
- "tailscale.com/tailcfg"
- "tailscale.com/types/appctype"
-)
-
-func TestMakeConnectorsFromConfig(t *testing.T) {
- tcs := []struct {
- name string
- input *appctype.AppConnectorConfig
- want map[appctype.ConfigID]connector
- }{
- {
- "empty",
- &appctype.AppConnectorConfig{},
- nil,
- },
- {
- "DNAT",
- &appctype.AppConnectorConfig{
- DNAT: map[appctype.ConfigID]appctype.DNATConfig{
- "swiggity_swooty": {
- Addrs: []netip.Addr{netip.MustParseAddr("100.64.0.1"), netip.MustParseAddr("fd7a:115c:a1e0::1")},
- To: []string{"example.org"},
- IP: []tailcfg.ProtoPortRange{{Proto: 0, Ports: tailcfg.PortRange{First: 0, Last: 65535}}},
- },
- },
- },
- map[appctype.ConfigID]connector{
- "swiggity_swooty": {
- Handlers: map[target]handler{
- {
- Dest: netip.MustParsePrefix("100.64.0.1/32"),
- Matching: tailcfg.ProtoPortRange{Proto: 0, Ports: tailcfg.PortRange{First: 0, Last: 65535}},
- }: &tcpRoundRobinHandler{To: []string{"example.org"}, ReachableIPs: []netip.Addr{netip.MustParseAddr("100.64.0.1"), netip.MustParseAddr("fd7a:115c:a1e0::1")}},
- {
- Dest: netip.MustParsePrefix("fd7a:115c:a1e0::1/128"),
- Matching: tailcfg.ProtoPortRange{Proto: 0, Ports: tailcfg.PortRange{First: 0, Last: 65535}},
- }: &tcpRoundRobinHandler{To: []string{"example.org"}, ReachableIPs: []netip.Addr{netip.MustParseAddr("100.64.0.1"), netip.MustParseAddr("fd7a:115c:a1e0::1")}},
- },
- },
- },
- },
- {
- "SNIProxy",
- &appctype.AppConnectorConfig{
- SNIProxy: map[appctype.ConfigID]appctype.SNIProxyConfig{
- "swiggity_swooty": {
- Addrs: []netip.Addr{netip.MustParseAddr("100.64.0.1"), netip.MustParseAddr("fd7a:115c:a1e0::1")},
- AllowedDomains: []string{"example.org"},
- IP: []tailcfg.ProtoPortRange{{Proto: 0, Ports: tailcfg.PortRange{First: 0, Last: 65535}}},
- },
- },
- },
- map[appctype.ConfigID]connector{
- "swiggity_swooty": {
- Handlers: map[target]handler{
- {
- Dest: netip.MustParsePrefix("100.64.0.1/32"),
- Matching: tailcfg.ProtoPortRange{Proto: 0, Ports: tailcfg.PortRange{First: 0, Last: 65535}},
- }: &tcpSNIHandler{Allowlist: []string{"example.org"}, ReachableIPs: []netip.Addr{netip.MustParseAddr("100.64.0.1"), netip.MustParseAddr("fd7a:115c:a1e0::1")}},
- {
- Dest: netip.MustParsePrefix("fd7a:115c:a1e0::1/128"),
- Matching: tailcfg.ProtoPortRange{Proto: 0, Ports: tailcfg.PortRange{First: 0, Last: 65535}},
- }: &tcpSNIHandler{Allowlist: []string{"example.org"}, ReachableIPs: []netip.Addr{netip.MustParseAddr("100.64.0.1"), netip.MustParseAddr("fd7a:115c:a1e0::1")}},
- },
- },
- },
- },
- }
-
- for _, tc := range tcs {
- t.Run(tc.name, func(t *testing.T) {
- connectors := makeConnectorsFromConfig(tc.input)
-
- if diff := cmp.Diff(connectors, tc.want,
- cmpopts.IgnoreFields(tcpRoundRobinHandler{}, "DialContext"),
- cmpopts.IgnoreFields(tcpSNIHandler{}, "DialContext"),
- cmp.Comparer(func(x, y netip.Addr) bool {
- return x == y
- })); diff != "" {
- t.Fatalf("mismatch (-want +got):\n%s", diff)
- }
- })
- }
-}
+// Copyright (c) Tailscale Inc & AUTHORS
+// SPDX-License-Identifier: BSD-3-Clause
+
+package main
+
+import (
+ "net/netip"
+ "testing"
+
+ "github.com/google/go-cmp/cmp"
+ "github.com/google/go-cmp/cmp/cmpopts"
+ "tailscale.com/tailcfg"
+ "tailscale.com/types/appctype"
+)
+
+func TestMakeConnectorsFromConfig(t *testing.T) {
+ tcs := []struct {
+ name string
+ input *appctype.AppConnectorConfig
+ want map[appctype.ConfigID]connector
+ }{
+ {
+ "empty",
+ &appctype.AppConnectorConfig{},
+ nil,
+ },
+ {
+ "DNAT",
+ &appctype.AppConnectorConfig{
+ DNAT: map[appctype.ConfigID]appctype.DNATConfig{
+ "swiggity_swooty": {
+ Addrs: []netip.Addr{netip.MustParseAddr("100.64.0.1"), netip.MustParseAddr("fd7a:115c:a1e0::1")},
+ To: []string{"example.org"},
+ IP: []tailcfg.ProtoPortRange{{Proto: 0, Ports: tailcfg.PortRange{First: 0, Last: 65535}}},
+ },
+ },
+ },
+ map[appctype.ConfigID]connector{
+ "swiggity_swooty": {
+ Handlers: map[target]handler{
+ {
+ Dest: netip.MustParsePrefix("100.64.0.1/32"),
+ Matching: tailcfg.ProtoPortRange{Proto: 0, Ports: tailcfg.PortRange{First: 0, Last: 65535}},
+ }: &tcpRoundRobinHandler{To: []string{"example.org"}, ReachableIPs: []netip.Addr{netip.MustParseAddr("100.64.0.1"), netip.MustParseAddr("fd7a:115c:a1e0::1")}},
+ {
+ Dest: netip.MustParsePrefix("fd7a:115c:a1e0::1/128"),
+ Matching: tailcfg.ProtoPortRange{Proto: 0, Ports: tailcfg.PortRange{First: 0, Last: 65535}},
+ }: &tcpRoundRobinHandler{To: []string{"example.org"}, ReachableIPs: []netip.Addr{netip.MustParseAddr("100.64.0.1"), netip.MustParseAddr("fd7a:115c:a1e0::1")}},
+ },
+ },
+ },
+ },
+ {
+ "SNIProxy",
+ &appctype.AppConnectorConfig{
+ SNIProxy: map[appctype.ConfigID]appctype.SNIProxyConfig{
+ "swiggity_swooty": {
+ Addrs: []netip.Addr{netip.MustParseAddr("100.64.0.1"), netip.MustParseAddr("fd7a:115c:a1e0::1")},
+ AllowedDomains: []string{"example.org"},
+ IP: []tailcfg.ProtoPortRange{{Proto: 0, Ports: tailcfg.PortRange{First: 0, Last: 65535}}},
+ },
+ },
+ },
+ map[appctype.ConfigID]connector{
+ "swiggity_swooty": {
+ Handlers: map[target]handler{
+ {
+ Dest: netip.MustParsePrefix("100.64.0.1/32"),
+ Matching: tailcfg.ProtoPortRange{Proto: 0, Ports: tailcfg.PortRange{First: 0, Last: 65535}},
+ }: &tcpSNIHandler{Allowlist: []string{"example.org"}, ReachableIPs: []netip.Addr{netip.MustParseAddr("100.64.0.1"), netip.MustParseAddr("fd7a:115c:a1e0::1")}},
+ {
+ Dest: netip.MustParsePrefix("fd7a:115c:a1e0::1/128"),
+ Matching: tailcfg.ProtoPortRange{Proto: 0, Ports: tailcfg.PortRange{First: 0, Last: 65535}},
+ }: &tcpSNIHandler{Allowlist: []string{"example.org"}, ReachableIPs: []netip.Addr{netip.MustParseAddr("100.64.0.1"), netip.MustParseAddr("fd7a:115c:a1e0::1")}},
+ },
+ },
+ },
+ },
+ }
+
+ for _, tc := range tcs {
+ t.Run(tc.name, func(t *testing.T) {
+ connectors := makeConnectorsFromConfig(tc.input)
+
+ if diff := cmp.Diff(connectors, tc.want,
+ cmpopts.IgnoreFields(tcpRoundRobinHandler{}, "DialContext"),
+ cmpopts.IgnoreFields(tcpSNIHandler{}, "DialContext"),
+ cmp.Comparer(func(x, y netip.Addr) bool {
+ return x == y
+ })); diff != "" {
+ t.Fatalf("mismatch (-want +got):\n%s", diff)
+ }
+ })
+ }
+}
diff --git a/cmd/sniproxy/sniproxy.go b/cmd/sniproxy/sniproxy.go
index fa83aaf4a..c048c8e7e 100644
--- a/cmd/sniproxy/sniproxy.go
+++ b/cmd/sniproxy/sniproxy.go
@@ -1,291 +1,291 @@
-// Copyright (c) Tailscale Inc & AUTHORS
-// SPDX-License-Identifier: BSD-3-Clause
-
-// The sniproxy is an outbound SNI proxy. It receives TLS connections over
-// Tailscale on one or more TCP ports and sends them out to the same SNI
-// hostname & port on the internet. It can optionally forward one or more
-// TCP ports to a specific destination. It only does TCP.
-package main
-
-import (
- "context"
- "errors"
- "flag"
- "fmt"
- "log"
- "net"
- "net/http"
- "net/netip"
- "os"
- "sort"
- "strconv"
- "strings"
-
- "github.com/peterbourgon/ff/v3"
- "tailscale.com/client/tailscale"
- "tailscale.com/hostinfo"
- "tailscale.com/ipn"
- "tailscale.com/tailcfg"
- "tailscale.com/tsnet"
- "tailscale.com/tsweb"
- "tailscale.com/types/appctype"
- "tailscale.com/types/ipproto"
- "tailscale.com/types/nettype"
- "tailscale.com/util/mak"
-)
-
-const configCapKey = "tailscale.com/sniproxy"
-
-// portForward is the state for a single port forwarding entry, as passed to the --forward flag.
-type portForward struct {
- Port int
- Proto string
- Destination string
-}
-
-// parseForward takes a proto/port/destination tuple as an input, as would be passed
-// to the --forward command line flag, and returns a *portForward struct of those parameters.
-func parseForward(value string) (*portForward, error) {
- parts := strings.Split(value, "/")
- if len(parts) != 3 {
- return nil, errors.New("cannot parse: " + value)
- }
-
- proto := parts[0]
- if proto != "tcp" {
- return nil, errors.New("unsupported forwarding protocol: " + proto)
- }
- port, err := strconv.ParseUint(parts[1], 10, 16)
- if err != nil {
- return nil, errors.New("bad forwarding port: " + parts[1])
- }
- host := parts[2]
- if host == "" {
- return nil, errors.New("bad destination: " + value)
- }
-
- return &portForward{Port: int(port), Proto: proto, Destination: host}, nil
-}
-
-func main() {
- // Parse flags
- fs := flag.NewFlagSet("sniproxy", flag.ContinueOnError)
- var (
- ports = fs.String("ports", "443", "comma-separated list of ports to proxy")
- forwards = fs.String("forwards", "", "comma-separated list of ports to transparently forward, protocol/number/destination. For example, --forwards=tcp/22/github.com,tcp/5432/sql.example.com")
- wgPort = fs.Int("wg-listen-port", 0, "UDP port to listen on for WireGuard and peer-to-peer traffic; 0 means automatically select")
- promoteHTTPS = fs.Bool("promote-https", true, "promote HTTP to HTTPS")
- debugPort = fs.Int("debug-port", 8893, "Listening port for debug/metrics endpoint")
- hostname = fs.String("hostname", "", "Hostname to register the service under")
- )
- err := ff.Parse(fs, os.Args[1:], ff.WithEnvVarPrefix("TS_APPC"))
- if err != nil {
- log.Fatal("ff.Parse")
- }
-
- var ts tsnet.Server
- defer ts.Close()
-
- ctx, cancel := context.WithCancel(context.Background())
- defer cancel()
- run(ctx, &ts, *wgPort, *hostname, *promoteHTTPS, *debugPort, *ports, *forwards)
-}
-
-// run actually runs the sniproxy. Its separate from main() to assist in testing.
-func run(ctx context.Context, ts *tsnet.Server, wgPort int, hostname string, promoteHTTPS bool, debugPort int, ports, forwards string) {
- // Wire up Tailscale node + app connector server
- hostinfo.SetApp("sniproxy")
- var s sniproxy
- s.ts = ts
-
- s.ts.Port = uint16(wgPort)
- s.ts.Hostname = hostname
-
- lc, err := s.ts.LocalClient()
- if err != nil {
- log.Fatalf("LocalClient() failed: %v", err)
- }
- s.lc = lc
- s.ts.RegisterFallbackTCPHandler(s.srv.HandleTCPFlow)
-
- // Start special-purpose listeners: dns, http promotion, debug server
- ln, err := s.ts.Listen("udp", ":53")
- if err != nil {
- log.Fatalf("failed listening on port 53: %v", err)
- }
- defer ln.Close()
- go s.serveDNS(ln)
- if promoteHTTPS {
- ln, err := s.ts.Listen("tcp", ":80")
- if err != nil {
- log.Fatalf("failed listening on port 80: %v", err)
- }
- defer ln.Close()
- log.Printf("Promoting HTTP to HTTPS ...")
- go s.promoteHTTPS(ln)
- }
- if debugPort != 0 {
- mux := http.NewServeMux()
- tsweb.Debugger(mux)
- dln, err := s.ts.Listen("tcp", fmt.Sprintf(":%d", debugPort))
- if err != nil {
- log.Fatalf("failed listening on debug port: %v", err)
- }
- defer dln.Close()
- go func() {
- log.Fatalf("debug serve: %v", http.Serve(dln, mux))
- }()
- }
-
- // Finally, start mainloop to configure app connector based on information
- // in the netmap.
- // We set the NotifyInitialNetMap flag so we will always get woken with the
- // current netmap, before only being woken on changes.
- bus, err := lc.WatchIPNBus(ctx, ipn.NotifyWatchEngineUpdates|ipn.NotifyInitialNetMap|ipn.NotifyNoPrivateKeys)
- if err != nil {
- log.Fatalf("watching IPN bus: %v", err)
- }
- defer bus.Close()
- for {
- msg, err := bus.Next()
- if err != nil {
- if errors.Is(err, context.Canceled) {
- return
- }
- log.Fatalf("reading IPN bus: %v", err)
- }
-
- // NetMap contains app-connector configuration
- if nm := msg.NetMap; nm != nil && nm.SelfNode.Valid() {
- sn := nm.SelfNode.AsStruct()
-
- var c appctype.AppConnectorConfig
- nmConf, err := tailcfg.UnmarshalNodeCapJSON[appctype.AppConnectorConfig](sn.CapMap, configCapKey)
- if err != nil {
- log.Printf("failed to read app connector configuration from coordination server: %v", err)
- } else if len(nmConf) > 0 {
- c = nmConf[0]
- }
-
- if c.AdvertiseRoutes {
- if err := s.advertiseRoutesFromConfig(ctx, &c); err != nil {
- log.Printf("failed to advertise routes: %v", err)
- }
- }
-
- // Backwards compatibility: combine any configuration from control with flags specified
- // on the command line. This is intentionally done after we advertise any routes
- // because its never correct to advertise the nodes native IP addresses.
- s.mergeConfigFromFlags(&c, ports, forwards)
- s.srv.Configure(&c)
- }
- }
-}
-
-type sniproxy struct {
- srv Server
- ts *tsnet.Server
- lc *tailscale.LocalClient
-}
-
-func (s *sniproxy) advertiseRoutesFromConfig(ctx context.Context, c *appctype.AppConnectorConfig) error {
- // Collect the set of addresses to advertise, using a map
- // to avoid duplicate entries.
- addrs := map[netip.Addr]struct{}{}
- for _, c := range c.SNIProxy {
- for _, ip := range c.Addrs {
- addrs[ip] = struct{}{}
- }
- }
- for _, c := range c.DNAT {
- for _, ip := range c.Addrs {
- addrs[ip] = struct{}{}
- }
- }
-
- var routes []netip.Prefix
- for a := range addrs {
- routes = append(routes, netip.PrefixFrom(a, a.BitLen()))
- }
- sort.SliceStable(routes, func(i, j int) bool {
- return routes[i].Addr().Less(routes[j].Addr()) // determinism r us
- })
-
- _, err := s.lc.EditPrefs(ctx, &ipn.MaskedPrefs{
- Prefs: ipn.Prefs{
- AdvertiseRoutes: routes,
- },
- AdvertiseRoutesSet: true,
- })
- return err
-}
-
-func (s *sniproxy) mergeConfigFromFlags(out *appctype.AppConnectorConfig, ports, forwards string) {
- ip4, ip6 := s.ts.TailscaleIPs()
-
- sniConfigFromFlags := appctype.SNIProxyConfig{
- Addrs: []netip.Addr{ip4, ip6},
- }
- if ports != "" {
- for _, portStr := range strings.Split(ports, ",") {
- port, err := strconv.ParseUint(portStr, 10, 16)
- if err != nil {
- log.Fatalf("invalid port: %s", portStr)
- }
- sniConfigFromFlags.IP = append(sniConfigFromFlags.IP, tailcfg.ProtoPortRange{
- Proto: int(ipproto.TCP),
- Ports: tailcfg.PortRange{First: uint16(port), Last: uint16(port)},
- })
- }
- }
-
- var forwardConfigFromFlags []appctype.DNATConfig
- for _, forwStr := range strings.Split(forwards, ",") {
- if forwStr == "" {
- continue
- }
- forw, err := parseForward(forwStr)
- if err != nil {
- log.Printf("invalid forwarding spec: %v", err)
- continue
- }
-
- forwardConfigFromFlags = append(forwardConfigFromFlags, appctype.DNATConfig{
- Addrs: []netip.Addr{ip4, ip6},
- To: []string{forw.Destination},
- IP: []tailcfg.ProtoPortRange{
- {
- Proto: int(ipproto.TCP),
- Ports: tailcfg.PortRange{First: uint16(forw.Port), Last: uint16(forw.Port)},
- },
- },
- })
- }
-
- if len(forwardConfigFromFlags) == 0 && len(sniConfigFromFlags.IP) == 0 {
- return // no config specified on the command line
- }
-
- mak.Set(&out.SNIProxy, "flags", sniConfigFromFlags)
- for i, forward := range forwardConfigFromFlags {
- mak.Set(&out.DNAT, appctype.ConfigID(fmt.Sprintf("flags_%d", i)), forward)
- }
-}
-
-func (s *sniproxy) serveDNS(ln net.Listener) {
- for {
- c, err := ln.Accept()
- if err != nil {
- log.Printf("serveDNS accept: %v", err)
- return
- }
- go s.srv.HandleDNS(c.(nettype.ConnPacketConn))
- }
-}
-
-func (s *sniproxy) promoteHTTPS(ln net.Listener) {
- err := http.Serve(ln, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
- http.Redirect(w, r, "https://"+r.Host+r.RequestURI, http.StatusFound)
- }))
- log.Fatalf("promoteHTTPS http.Serve: %v", err)
-}
+// Copyright (c) Tailscale Inc & AUTHORS
+// SPDX-License-Identifier: BSD-3-Clause
+
+// The sniproxy is an outbound SNI proxy. It receives TLS connections over
+// Tailscale on one or more TCP ports and sends them out to the same SNI
+// hostname & port on the internet. It can optionally forward one or more
+// TCP ports to a specific destination. It only does TCP.
+package main
+
+import (
+ "context"
+ "errors"
+ "flag"
+ "fmt"
+ "log"
+ "net"
+ "net/http"
+ "net/netip"
+ "os"
+ "sort"
+ "strconv"
+ "strings"
+
+ "github.com/peterbourgon/ff/v3"
+ "tailscale.com/client/tailscale"
+ "tailscale.com/hostinfo"
+ "tailscale.com/ipn"
+ "tailscale.com/tailcfg"
+ "tailscale.com/tsnet"
+ "tailscale.com/tsweb"
+ "tailscale.com/types/appctype"
+ "tailscale.com/types/ipproto"
+ "tailscale.com/types/nettype"
+ "tailscale.com/util/mak"
+)
+
+const configCapKey = "tailscale.com/sniproxy"
+
+// portForward is the state for a single port forwarding entry, as passed to the --forward flag.
+type portForward struct {
+ Port int
+ Proto string
+ Destination string
+}
+
+// parseForward takes a proto/port/destination tuple as an input, as would be passed
+// to the --forward command line flag, and returns a *portForward struct of those parameters.
+func parseForward(value string) (*portForward, error) {
+ parts := strings.Split(value, "/")
+ if len(parts) != 3 {
+ return nil, errors.New("cannot parse: " + value)
+ }
+
+ proto := parts[0]
+ if proto != "tcp" {
+ return nil, errors.New("unsupported forwarding protocol: " + proto)
+ }
+ port, err := strconv.ParseUint(parts[1], 10, 16)
+ if err != nil {
+ return nil, errors.New("bad forwarding port: " + parts[1])
+ }
+ host := parts[2]
+ if host == "" {
+ return nil, errors.New("bad destination: " + value)
+ }
+
+ return &portForward{Port: int(port), Proto: proto, Destination: host}, nil
+}
+
+func main() {
+ // Parse flags
+ fs := flag.NewFlagSet("sniproxy", flag.ContinueOnError)
+ var (
+ ports = fs.String("ports", "443", "comma-separated list of ports to proxy")
+ forwards = fs.String("forwards", "", "comma-separated list of ports to transparently forward, protocol/number/destination. For example, --forwards=tcp/22/github.com,tcp/5432/sql.example.com")
+ wgPort = fs.Int("wg-listen-port", 0, "UDP port to listen on for WireGuard and peer-to-peer traffic; 0 means automatically select")
+ promoteHTTPS = fs.Bool("promote-https", true, "promote HTTP to HTTPS")
+ debugPort = fs.Int("debug-port", 8893, "Listening port for debug/metrics endpoint")
+ hostname = fs.String("hostname", "", "Hostname to register the service under")
+ )
+ err := ff.Parse(fs, os.Args[1:], ff.WithEnvVarPrefix("TS_APPC"))
+ if err != nil {
+ log.Fatal("ff.Parse")
+ }
+
+ var ts tsnet.Server
+ defer ts.Close()
+
+ ctx, cancel := context.WithCancel(context.Background())
+ defer cancel()
+ run(ctx, &ts, *wgPort, *hostname, *promoteHTTPS, *debugPort, *ports, *forwards)
+}
+
+// run actually runs the sniproxy. Its separate from main() to assist in testing.
+func run(ctx context.Context, ts *tsnet.Server, wgPort int, hostname string, promoteHTTPS bool, debugPort int, ports, forwards string) {
+ // Wire up Tailscale node + app connector server
+ hostinfo.SetApp("sniproxy")
+ var s sniproxy
+ s.ts = ts
+
+ s.ts.Port = uint16(wgPort)
+ s.ts.Hostname = hostname
+
+ lc, err := s.ts.LocalClient()
+ if err != nil {
+ log.Fatalf("LocalClient() failed: %v", err)
+ }
+ s.lc = lc
+ s.ts.RegisterFallbackTCPHandler(s.srv.HandleTCPFlow)
+
+ // Start special-purpose listeners: dns, http promotion, debug server
+ ln, err := s.ts.Listen("udp", ":53")
+ if err != nil {
+ log.Fatalf("failed listening on port 53: %v", err)
+ }
+ defer ln.Close()
+ go s.serveDNS(ln)
+ if promoteHTTPS {
+ ln, err := s.ts.Listen("tcp", ":80")
+ if err != nil {
+ log.Fatalf("failed listening on port 80: %v", err)
+ }
+ defer ln.Close()
+ log.Printf("Promoting HTTP to HTTPS ...")
+ go s.promoteHTTPS(ln)
+ }
+ if debugPort != 0 {
+ mux := http.NewServeMux()
+ tsweb.Debugger(mux)
+ dln, err := s.ts.Listen("tcp", fmt.Sprintf(":%d", debugPort))
+ if err != nil {
+ log.Fatalf("failed listening on debug port: %v", err)
+ }
+ defer dln.Close()
+ go func() {
+ log.Fatalf("debug serve: %v", http.Serve(dln, mux))
+ }()
+ }
+
+ // Finally, start mainloop to configure app connector based on information
+ // in the netmap.
+ // We set the NotifyInitialNetMap flag so we will always get woken with the
+ // current netmap, before only being woken on changes.
+ bus, err := lc.WatchIPNBus(ctx, ipn.NotifyWatchEngineUpdates|ipn.NotifyInitialNetMap|ipn.NotifyNoPrivateKeys)
+ if err != nil {
+ log.Fatalf("watching IPN bus: %v", err)
+ }
+ defer bus.Close()
+ for {
+ msg, err := bus.Next()
+ if err != nil {
+ if errors.Is(err, context.Canceled) {
+ return
+ }
+ log.Fatalf("reading IPN bus: %v", err)
+ }
+
+ // NetMap contains app-connector configuration
+ if nm := msg.NetMap; nm != nil && nm.SelfNode.Valid() {
+ sn := nm.SelfNode.AsStruct()
+
+ var c appctype.AppConnectorConfig
+ nmConf, err := tailcfg.UnmarshalNodeCapJSON[appctype.AppConnectorConfig](sn.CapMap, configCapKey)
+ if err != nil {
+ log.Printf("failed to read app connector configuration from coordination server: %v", err)
+ } else if len(nmConf) > 0 {
+ c = nmConf[0]
+ }
+
+ if c.AdvertiseRoutes {
+ if err := s.advertiseRoutesFromConfig(ctx, &c); err != nil {
+ log.Printf("failed to advertise routes: %v", err)
+ }
+ }
+
+ // Backwards compatibility: combine any configuration from control with flags specified
+ // on the command line. This is intentionally done after we advertise any routes
+ // because its never correct to advertise the nodes native IP addresses.
+ s.mergeConfigFromFlags(&c, ports, forwards)
+ s.srv.Configure(&c)
+ }
+ }
+}
+
+type sniproxy struct {
+ srv Server
+ ts *tsnet.Server
+ lc *tailscale.LocalClient
+}
+
+func (s *sniproxy) advertiseRoutesFromConfig(ctx context.Context, c *appctype.AppConnectorConfig) error {
+ // Collect the set of addresses to advertise, using a map
+ // to avoid duplicate entries.
+ addrs := map[netip.Addr]struct{}{}
+ for _, c := range c.SNIProxy {
+ for _, ip := range c.Addrs {
+ addrs[ip] = struct{}{}
+ }
+ }
+ for _, c := range c.DNAT {
+ for _, ip := range c.Addrs {
+ addrs[ip] = struct{}{}
+ }
+ }
+
+ var routes []netip.Prefix
+ for a := range addrs {
+ routes = append(routes, netip.PrefixFrom(a, a.BitLen()))
+ }
+ sort.SliceStable(routes, func(i, j int) bool {
+ return routes[i].Addr().Less(routes[j].Addr()) // determinism r us
+ })
+
+ _, err := s.lc.EditPrefs(ctx, &ipn.MaskedPrefs{
+ Prefs: ipn.Prefs{
+ AdvertiseRoutes: routes,
+ },
+ AdvertiseRoutesSet: true,
+ })
+ return err
+}
+
+func (s *sniproxy) mergeConfigFromFlags(out *appctype.AppConnectorConfig, ports, forwards string) {
+ ip4, ip6 := s.ts.TailscaleIPs()
+
+ sniConfigFromFlags := appctype.SNIProxyConfig{
+ Addrs: []netip.Addr{ip4, ip6},
+ }
+ if ports != "" {
+ for _, portStr := range strings.Split(ports, ",") {
+ port, err := strconv.ParseUint(portStr, 10, 16)
+ if err != nil {
+ log.Fatalf("invalid port: %s", portStr)
+ }
+ sniConfigFromFlags.IP = append(sniConfigFromFlags.IP, tailcfg.ProtoPortRange{
+ Proto: int(ipproto.TCP),
+ Ports: tailcfg.PortRange{First: uint16(port), Last: uint16(port)},
+ })
+ }
+ }
+
+ var forwardConfigFromFlags []appctype.DNATConfig
+ for _, forwStr := range strings.Split(forwards, ",") {
+ if forwStr == "" {
+ continue
+ }
+ forw, err := parseForward(forwStr)
+ if err != nil {
+ log.Printf("invalid forwarding spec: %v", err)
+ continue
+ }
+
+ forwardConfigFromFlags = append(forwardConfigFromFlags, appctype.DNATConfig{
+ Addrs: []netip.Addr{ip4, ip6},
+ To: []string{forw.Destination},
+ IP: []tailcfg.ProtoPortRange{
+ {
+ Proto: int(ipproto.TCP),
+ Ports: tailcfg.PortRange{First: uint16(forw.Port), Last: uint16(forw.Port)},
+ },
+ },
+ })
+ }
+
+ if len(forwardConfigFromFlags) == 0 && len(sniConfigFromFlags.IP) == 0 {
+ return // no config specified on the command line
+ }
+
+ mak.Set(&out.SNIProxy, "flags", sniConfigFromFlags)
+ for i, forward := range forwardConfigFromFlags {
+ mak.Set(&out.DNAT, appctype.ConfigID(fmt.Sprintf("flags_%d", i)), forward)
+ }
+}
+
+func (s *sniproxy) serveDNS(ln net.Listener) {
+ for {
+ c, err := ln.Accept()
+ if err != nil {
+ log.Printf("serveDNS accept: %v", err)
+ return
+ }
+ go s.srv.HandleDNS(c.(nettype.ConnPacketConn))
+ }
+}
+
+func (s *sniproxy) promoteHTTPS(ln net.Listener) {
+ err := http.Serve(ln, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ http.Redirect(w, r, "https://"+r.Host+r.RequestURI, http.StatusFound)
+ }))
+ log.Fatalf("promoteHTTPS http.Serve: %v", err)
+}