summaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorBrad Fitzpatrick <bradfitz@tailscale.com>2025-06-08 18:51:41 -0700
committerBrad Fitzpatrick <brad@danga.com>2025-06-18 14:20:39 -0700
commite92eb6b17bb59cd66cd78c90db3b285015ed5e11 (patch)
tree059585ca0a791e03a16dfcb1e223ca646e19d999
parent4979ce7a94cd023db5cd03cbb556934d9652dfd2 (diff)
downloadtailscale-e92eb6b17bb59cd66cd78c90db3b285015ed5e11.tar.xz
tailscale-e92eb6b17bb59cd66cd78c90db3b285015ed5e11.zip
net/tlsdial: fix TLS cert validation of HTTPS proxies
If you had HTTPS_PROXY=https://some-valid-cert.example.com running a CONNECT proxy, we should've been able to do a TLS CONNECT request to e.g. controlplane.tailscale.com:443 through that, and I'm pretty sure it used to work, but refactorings and lack of integration tests made it regress. It probably regressed when we added the baked-in LetsEncrypt root cert validation fallback code, which was testing against the wrong hostname (the ultimate one, not the one which we were being asked to validate) Fixes #16222 Change-Id: If014e395f830e2f87f056f588edacad5c15e91bc Signed-off-by: Brad Fitzpatrick <bradfitz@tailscale.com>
-rw-r--r--cmd/proxy-test-server/proxy-test-server.go81
-rw-r--r--control/controlclient/controlclient_test.go225
-rw-r--r--control/controlclient/direct.go7
-rw-r--r--control/controlhttp/client.go2
-rw-r--r--derp/derphttp/derphttp_client.go3
-rw-r--r--derp/derphttp/derphttp_test.go34
-rw-r--r--logpolicy/logpolicy.go4
-rw-r--r--net/bakedroots/bakedroots.go5
-rw-r--r--net/connectproxy/connectproxy.go93
-rw-r--r--net/dnscache/dnscache.go13
-rw-r--r--net/dnsfallback/dnsfallback.go2
-rw-r--r--net/tlsdial/tlsdial.go68
-rw-r--r--net/tlsdial/tlsdial_test.go2
-rw-r--r--tstest/tlstest/testdata/controlplane.tstest.key5
-rw-r--r--tstest/tlstest/testdata/proxy.tstest.key5
-rw-r--r--tstest/tlstest/testdata/root-ca.key5
-rw-r--r--tstest/tlstest/tlstest.go167
17 files changed, 672 insertions, 49 deletions
diff --git a/cmd/proxy-test-server/proxy-test-server.go b/cmd/proxy-test-server/proxy-test-server.go
new file mode 100644
index 000000000..9f8c94a38
--- /dev/null
+++ b/cmd/proxy-test-server/proxy-test-server.go
@@ -0,0 +1,81 @@
+// Copyright (c) Tailscale Inc & AUTHORS
+// SPDX-License-Identifier: BSD-3-Clause
+
+// The proxy-test-server command is a simple HTTP proxy server for testing
+// Tailscale's client proxy functionality.
+package main
+
+import (
+ "crypto/tls"
+ "flag"
+ "fmt"
+ "log"
+ "net"
+ "net/http"
+ "os"
+ "strings"
+
+ "golang.org/x/crypto/acme/autocert"
+ "tailscale.com/net/connectproxy"
+ "tailscale.com/tempfork/acme"
+)
+
+var (
+ listen = flag.String("listen", ":8080", "Address to listen on for HTTPS proxy requests")
+ hostname = flag.String("hostname", "localhost", "Hostname for the proxy server")
+ tailscaleOnly = flag.Bool("tailscale-only", true, "Restrict proxy to Tailscale targets only")
+ extraAllowedHosts = flag.String("allow-hosts", "", "Comma-separated list of allowed target hosts to additionally allow if --tailscale-only is true")
+)
+
+func main() {
+ flag.Parse()
+
+ am := &autocert.Manager{
+ HostPolicy: autocert.HostWhitelist(*hostname),
+ Prompt: autocert.AcceptTOS,
+ Cache: autocert.DirCache(os.ExpandEnv("$HOME/.cache/autocert/proxy-test-server")),
+ }
+ var allowTarget func(hostPort string) error
+ if *tailscaleOnly {
+ allowTarget = func(hostPort string) error {
+ host, port, err := net.SplitHostPort(hostPort)
+ if err != nil {
+ return fmt.Errorf("invalid target %q: %v", hostPort, err)
+ }
+ if port != "443" {
+ return fmt.Errorf("target %q must use port 443", hostPort)
+ }
+ for allowed := range strings.SplitSeq(*extraAllowedHosts, ",") {
+ if host == allowed {
+ return nil // explicitly allowed target
+ }
+ }
+ if !strings.HasSuffix(host, ".tailscale.com") {
+ return fmt.Errorf("target %q is not a Tailscale host", hostPort)
+ }
+ return nil // valid Tailscale target
+ }
+ }
+
+ go func() {
+ if err := http.ListenAndServe(":http", am.HTTPHandler(nil)); err != nil {
+ log.Fatalf("autocert HTTP server failed: %v", err)
+ }
+ }()
+ hs := &http.Server{
+ Addr: *listen,
+ Handler: &connectproxy.Handler{
+ Check: allowTarget,
+ Logf: log.Printf,
+ },
+ TLSConfig: &tls.Config{
+ GetCertificate: am.GetCertificate,
+ NextProtos: []string{
+ "http/1.1", // enable HTTP/2
+ acme.ALPNProto, // enable tls-alpn ACME challenges
+ },
+ },
+ }
+ log.Printf("Starting proxy-test-server on %s (hostname: %q)\n", *listen, *hostname)
+ log.Fatal(hs.ListenAndServeTLS("", "")) // cert and key are provided by autocert
+}
diff --git a/control/controlclient/controlclient_test.go b/control/controlclient/controlclient_test.go
index f8882a4e7..1107f76a4 100644
--- a/control/controlclient/controlclient_test.go
+++ b/control/controlclient/controlclient_test.go
@@ -4,13 +4,35 @@
package controlclient
import (
+ "context"
+ "crypto/tls"
"errors"
+ "flag"
"fmt"
"io"
+ "net"
+ "net/http"
+ "net/netip"
+ "net/url"
"reflect"
"slices"
+ "sync/atomic"
"testing"
+ "time"
+ "tailscale.com/control/controlknobs"
+ "tailscale.com/health"
+ "tailscale.com/net/bakedroots"
+ "tailscale.com/net/connectproxy"
+ "tailscale.com/net/netmon"
+ "tailscale.com/net/tsdial"
+ "tailscale.com/tailcfg"
+ "tailscale.com/tstest"
+ "tailscale.com/tstest/integration/testcontrol"
+ "tailscale.com/tstest/tlstest"
+ "tailscale.com/tstime"
+ "tailscale.com/types/key"
+ "tailscale.com/types/logger"
"tailscale.com/types/netmap"
"tailscale.com/types/persist"
)
@@ -188,3 +210,206 @@ func isRetryableErrorForTest(err error) bool {
}
return false
}
+
+var liveNetworkTest = flag.Bool("live-network-test", false, "run live network tests")
+
+func TestDirectProxyManual(t *testing.T) {
+ if !*liveNetworkTest {
+ t.Skip("skipping without --live-network-test")
+ }
+
+ dialer := &tsdial.Dialer{}
+ dialer.SetNetMon(netmon.NewStatic())
+
+ opts := Options{
+ Persist: persist.Persist{},
+ GetMachinePrivateKey: func() (key.MachinePrivate, error) {
+ return key.NewMachine(), nil
+ },
+ ServerURL: "https://controlplane.tailscale.com",
+ Clock: tstime.StdClock{},
+ Hostinfo: &tailcfg.Hostinfo{
+ BackendLogID: "test-backend-log-id",
+ },
+ DiscoPublicKey: key.NewDisco().Public(),
+ Logf: t.Logf,
+ HealthTracker: &health.Tracker{},
+ PopBrowserURL: func(url string) {
+ t.Logf("PopBrowserURL: %q", url)
+ },
+ Dialer: dialer,
+ ControlKnobs: &controlknobs.Knobs{},
+ }
+ d, err := NewDirect(opts)
+ if err != nil {
+ t.Fatalf("NewDirect: %v", err)
+ }
+ ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
+ defer cancel()
+
+ url, err := d.TryLogin(ctx, LoginEphemeral)
+ if err != nil {
+ t.Fatalf("TryLogin: %v", err)
+ }
+ t.Logf("URL: %q", url)
+}
+
+func TestHTTPSNoProxy(t *testing.T) { testHTTPS(t, false) }
+
+// TestTLSWithProxy verifies we can connect to the control plane via
+// an HTTPS proxy.
+func TestHTTPSWithProxy(t *testing.T) { testHTTPS(t, true) }
+
+func testHTTPS(t *testing.T, withProxy bool) {
+ bakedroots.ResetForTest(t, tlstest.TestRootCA())
+
+ controlLn, err := tls.Listen("tcp", "127.0.0.1:0", tlstest.ControlPlaneKeyPair.ServerTLSConfig())
+ if err != nil {
+ t.Fatal(err)
+ }
+ defer controlLn.Close()
+
+ proxyLn, err := tls.Listen("tcp", "127.0.0.1:0", tlstest.ProxyServerKeyPair.ServerTLSConfig())
+ if err != nil {
+ t.Fatal(err)
+ }
+ defer proxyLn.Close()
+
+ const requiredAuthKey = "hunter2"
+ const someUsername = "testuser"
+ const somePassword = "testpass"
+
+ testControl := &testcontrol.Server{
+ Logf: tstest.WhileTestRunningLogger(t),
+ RequireAuthKey: requiredAuthKey,
+ }
+ controlSrv := &http.Server{
+ Handler: testControl,
+ ErrorLog: logger.StdLogger(t.Logf),
+ }
+ go controlSrv.Serve(controlLn)
+
+ const fakeControlIP = "1.2.3.4"
+ const fakeProxyIP = "5.6.7.8"
+
+ dialer := &tsdial.Dialer{}
+ dialer.SetNetMon(netmon.NewStatic())
+ dialer.SetSystemDialerForTest(func(ctx context.Context, network, addr string) (net.Conn, error) {
+ host, _, err := net.SplitHostPort(addr)
+ if err != nil {
+ return nil, fmt.Errorf("SplitHostPort(%q): %v", addr, err)
+ }
+ var d net.Dialer
+ if host == fakeControlIP {
+ return d.DialContext(ctx, network, controlLn.Addr().String())
+ }
+ if host == fakeProxyIP {
+ return d.DialContext(ctx, network, proxyLn.Addr().String())
+ }
+ return nil, fmt.Errorf("unexpected dial to %q", addr)
+ })
+
+ opts := Options{
+ Persist: persist.Persist{},
+ GetMachinePrivateKey: func() (key.MachinePrivate, error) {
+ return key.NewMachine(), nil
+ },
+ AuthKey: requiredAuthKey,
+ ServerURL: "https://controlplane.tstest",
+ Clock: tstime.StdClock{},
+ Hostinfo: &tailcfg.Hostinfo{
+ BackendLogID: "test-backend-log-id",
+ },
+ DiscoPublicKey: key.NewDisco().Public(),
+ Logf: t.Logf,
+ HealthTracker: &health.Tracker{},
+ PopBrowserURL: func(url string) {
+ t.Logf("PopBrowserURL: %q", url)
+ },
+ Dialer: dialer,
+ }
+ d, err := NewDirect(opts)
+ if err != nil {
+ t.Fatalf("NewDirect: %v", err)
+ }
+
+ d.dnsCache.LookupIPForTest = func(ctx context.Context, host string) ([]netip.Addr, error) {
+ switch host {
+ case "controlplane.tstest":
+ return []netip.Addr{netip.MustParseAddr(fakeControlIP)}, nil
+ case "proxy.tstest":
+ if !withProxy {
+ t.Errorf("unexpected DNS lookup for %q with proxy disabled", host)
+ return nil, fmt.Errorf("unexpected DNS lookup for %q", host)
+ }
+ return []netip.Addr{netip.MustParseAddr(fakeProxyIP)}, nil
+ }
+ t.Errorf("unexpected DNS query for %q", host)
+ return []netip.Addr{}, nil
+ }
+
+ var proxyReqs atomic.Int64
+ if withProxy {
+ d.httpc.Transport.(*http.Transport).Proxy = func(req *http.Request) (*url.URL, error) {
+ t.Logf("using proxy for %q", req.URL)
+ u := &url.URL{
+ Scheme: "https",
+ Host: "proxy.tstest:443",
+ User: url.UserPassword(someUsername, somePassword),
+ }
+ return u, nil
+ }
+
+ connectProxy := &http.Server{
+ Handler: connectProxyTo(t, "controlplane.tstest:443", controlLn.Addr().String(), &proxyReqs),
+ }
+ go connectProxy.Serve(proxyLn)
+ }
+
+ ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
+ defer cancel()
+
+ url, err := d.TryLogin(ctx, LoginEphemeral)
+ if err != nil {
+ t.Fatalf("TryLogin: %v", err)
+ }
+ if url != "" {
+ t.Errorf("got URL %q, want empty", url)
+ }
+
+ if withProxy {
+ if got, want := proxyReqs.Load(), int64(1); got != want {
+ t.Errorf("proxy CONNECT requests = %d; want %d", got, want)
+ }
+ }
+}
+
+func connectProxyTo(t testing.TB, target, backendAddrPort string, reqs *atomic.Int64) http.Handler {
+ return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ if r.RequestURI != target {
+ t.Errorf("invalid CONNECT request to %q; want %q", r.RequestURI, target)
+ http.Error(w, "bad target", http.StatusBadRequest)
+ return
+ }
+
+ r.Header.Set("Authorization", r.Header.Get("Proxy-Authorization")) // for the BasicAuth method. kinda trashy.
+ user, pass, ok := r.BasicAuth()
+ if !ok || user != "testuser" || pass != "testpass" {
+ t.Errorf("invalid CONNECT auth %q:%q; want %q:%q", user, pass, "testuser", "testpass")
+ http.Error(w, "bad auth", http.StatusUnauthorized)
+ return
+ }
+
+ (&connectproxy.Handler{
+ Dial: func(ctx context.Context, network, addr string) (net.Conn, error) {
+ var d net.Dialer
+ c, err := d.DialContext(ctx, network, backendAddrPort)
+ if err == nil {
+ reqs.Add(1)
+ }
+ return c, err
+ },
+ Logf: t.Logf,
+ }).ServeHTTP(w, r)
+ })
+}
diff --git a/control/controlclient/direct.go b/control/controlclient/direct.go
index 2d6dc6e36..4c9b04ce9 100644
--- a/control/controlclient/direct.go
+++ b/control/controlclient/direct.go
@@ -16,7 +16,6 @@ import (
"net"
"net/http"
"net/netip"
- "net/url"
"os"
"reflect"
"runtime"
@@ -240,10 +239,6 @@ func NewDirect(opts Options) (*Direct, error) {
opts.ControlKnobs = &controlknobs.Knobs{}
}
opts.ServerURL = strings.TrimRight(opts.ServerURL, "/")
- serverURL, err := url.Parse(opts.ServerURL)
- if err != nil {
- return nil, err
- }
if opts.Clock == nil {
opts.Clock = tstime.StdClock{}
}
@@ -273,7 +268,7 @@ func NewDirect(opts Options) (*Direct, error) {
tr := http.DefaultTransport.(*http.Transport).Clone()
tr.Proxy = tshttpproxy.ProxyFromEnvironment
tshttpproxy.SetTransportGetProxyConnectHeader(tr)
- tr.TLSClientConfig = tlsdial.Config(serverURL.Hostname(), opts.HealthTracker, tr.TLSClientConfig)
+ tr.TLSClientConfig = tlsdial.Config(opts.HealthTracker, tr.TLSClientConfig)
var dialFunc netx.DialFunc
dialFunc, interceptedDial = makeScreenTimeDetectingDialFunc(opts.Dialer.SystemDial)
tr.DialContext = dnscache.Dialer(dialFunc, dnsCache)
diff --git a/control/controlhttp/client.go b/control/controlhttp/client.go
index 869bcb599..1bb60d672 100644
--- a/control/controlhttp/client.go
+++ b/control/controlhttp/client.go
@@ -534,7 +534,7 @@ func (a *Dialer) tryURLUpgrade(ctx context.Context, u *url.URL, optAddr netip.Ad
// Disable HTTP2, since h2 can't do protocol switching.
tr.TLSClientConfig.NextProtos = []string{}
tr.TLSNextProto = map[string]func(string, *tls.Conn) http.RoundTripper{}
- tr.TLSClientConfig = tlsdial.Config(a.Hostname, a.HealthTracker, tr.TLSClientConfig)
+ tr.TLSClientConfig = tlsdial.Config(a.HealthTracker, tr.TLSClientConfig)
if !tr.TLSClientConfig.InsecureSkipVerify {
panic("unexpected") // should be set by tlsdial.Config
}
diff --git a/derp/derphttp/derphttp_client.go b/derp/derphttp/derphttp_client.go
index 8c42e9070..7385f0ad1 100644
--- a/derp/derphttp/derphttp_client.go
+++ b/derp/derphttp/derphttp_client.go
@@ -647,12 +647,13 @@ func (c *Client) dialRegion(ctx context.Context, reg *tailcfg.DERPRegion) (net.C
}
func (c *Client) tlsClient(nc net.Conn, node *tailcfg.DERPNode) *tls.Conn {
- tlsConf := tlsdial.Config(c.tlsServerName(node), c.HealthTracker, c.TLSConfig)
+ tlsConf := tlsdial.Config(c.HealthTracker, c.TLSConfig)
if node != nil {
if node.InsecureForTests {
tlsConf.InsecureSkipVerify = true
tlsConf.VerifyConnection = nil
}
+ tlsConf.ServerName = c.tlsServerName(node)
if node.CertName != "" {
if suf, ok := strings.CutPrefix(node.CertName, "sha256-raw:"); ok {
tlsdial.SetConfigExpectedCertHash(tlsConf, suf)
diff --git a/derp/derphttp/derphttp_test.go b/derp/derphttp/derphttp_test.go
index 252549660..7f0a7e333 100644
--- a/derp/derphttp/derphttp_test.go
+++ b/derp/derphttp/derphttp_test.go
@@ -7,10 +7,14 @@ import (
"bytes"
"context"
"crypto/tls"
+ "encoding/json"
+ "flag"
"fmt"
+ "maps"
"net"
"net/http"
"net/http/httptest"
+ "slices"
"strings"
"sync"
"testing"
@@ -19,6 +23,7 @@ import (
"tailscale.com/derp"
"tailscale.com/net/netmon"
"tailscale.com/net/netx"
+ "tailscale.com/tailcfg"
"tailscale.com/types/key"
)
@@ -556,3 +561,32 @@ func TestNotifyError(t *testing.T) {
t.Fatalf("context done before receiving error: %v", ctx.Err())
}
}
+
+var liveNetworkTest = flag.Bool("live-net-tests", false, "run live network tests")
+
+func TestManualDial(t *testing.T) {
+ if !*liveNetworkTest {
+ t.Skip("skipping live network test without --live-net-tests")
+ }
+ dm := &tailcfg.DERPMap{}
+ res, err := http.Get("https://controlplane.tailscale.com/derpmap/default")
+ if err != nil {
+ t.Fatalf("fetching DERPMap: %v", err)
+ }
+ defer res.Body.Close()
+ if err := json.NewDecoder(res.Body).Decode(dm); err != nil {
+ t.Fatalf("decoding DERPMap: %v", err)
+ }
+
+ region := slices.Sorted(maps.Keys(dm.Regions))[0]
+
+ netMon := netmon.NewStatic()
+ rc := NewRegionClient(key.NewNode(), t.Logf, netMon, func() *tailcfg.DERPRegion {
+ return dm.Regions[region]
+ })
+ defer rc.Close()
+
+ if err := rc.Connect(context.Background()); err != nil {
+ t.Fatalf("rc.Connect: %v", err)
+ }
+}
diff --git a/logpolicy/logpolicy.go b/logpolicy/logpolicy.go
index fc259a417..b84528d7b 100644
--- a/logpolicy/logpolicy.go
+++ b/logpolicy/logpolicy.go
@@ -9,7 +9,6 @@ package logpolicy
import (
"bufio"
"bytes"
- "cmp"
"context"
"crypto/tls"
"encoding/json"
@@ -911,8 +910,7 @@ func (opts TransportOptions) New() http.RoundTripper {
tr.TLSNextProto = map[string]func(authority string, c *tls.Conn) http.RoundTripper{}
}
- host := cmp.Or(opts.Host, logtail.DefaultHost)
- tr.TLSClientConfig = tlsdial.Config(host, opts.Health, tr.TLSClientConfig)
+ tr.TLSClientConfig = tlsdial.Config(opts.Health, tr.TLSClientConfig)
// Force TLS 1.3 since we know log.tailscale.com supports it.
tr.TLSClientConfig.MinVersion = tls.VersionTLS13
diff --git a/net/bakedroots/bakedroots.go b/net/bakedroots/bakedroots.go
index 42e70c0dd..8787b4a6d 100644
--- a/net/bakedroots/bakedroots.go
+++ b/net/bakedroots/bakedroots.go
@@ -7,6 +7,7 @@ package bakedroots
import (
"crypto/x509"
+ "fmt"
"sync"
"tailscale.com/util/testenv"
@@ -14,7 +15,7 @@ import (
// Get returns the baked-in roots.
//
-// As of 2025-01-21, this includes only the LetsEncrypt ISRG Root X1 root.
+// As of 2025-01-21, this includes only the LetsEncrypt ISRG Root X1 & X2 roots.
func Get() *x509.CertPool {
roots.once.Do(func() {
roots.parsePEM(append(
@@ -56,7 +57,7 @@ type rootsOnce struct {
func (r *rootsOnce) parsePEM(caPEM []byte) {
p := x509.NewCertPool()
if !p.AppendCertsFromPEM(caPEM) {
- panic("bogus PEM")
+ panic(fmt.Sprintf("bogus PEM: %q", caPEM))
}
r.p = p
}
diff --git a/net/connectproxy/connectproxy.go b/net/connectproxy/connectproxy.go
new file mode 100644
index 000000000..4bf687502
--- /dev/null
+++ b/net/connectproxy/connectproxy.go
@@ -0,0 +1,93 @@
+// Copyright (c) Tailscale Inc & AUTHORS
+// SPDX-License-Identifier: BSD-3-Clause
+
+// Package connectproxy contains some CONNECT proxy code.
+package connectproxy
+
+import (
+ "context"
+ "io"
+ "log"
+ "net"
+ "net/http"
+ "time"
+
+ "tailscale.com/net/netx"
+ "tailscale.com/types/logger"
+)
+
+// Handler is an HTTP CONNECT proxy handler.
+type Handler struct {
+ // Dial, if non-nil, is an alternate dialer to use
+ // instead of the default dialer.
+ Dial netx.DialFunc
+
+ // Logf, if non-nil, is an alterate logger to
+ // use instead of log.Printf.
+ Logf logger.Logf
+
+ // Check, if non-nil, validates the CONNECT target.
+ Check func(hostPort string) error
+}
+
+func (h *Handler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
+ ctx := r.Context()
+ if r.Method != "CONNECT" {
+ http.Error(w, "Method not allowed", http.StatusMethodNotAllowed)
+ return
+ }
+
+ dial := h.Dial
+ if dial == nil {
+ var d net.Dialer
+ dial = d.DialContext
+ }
+ logf := h.Logf
+ if logf == nil {
+ logf = log.Printf
+ }
+
+ hostPort := r.RequestURI
+ if h.Check != nil {
+ if err := h.Check(hostPort); err != nil {
+ logf("CONNECT target %q not allowed: %v", hostPort, err)
+ http.Error(w, "Invalid CONNECT target", http.StatusForbidden)
+ return
+ }
+ }
+
+ ctx, cancel := context.WithTimeout(ctx, 15*time.Second)
+ defer cancel()
+ back, err := dial(ctx, "tcp", hostPort)
+ if err != nil {
+ logf("error CONNECT dialing %v: %v", hostPort, err)
+ http.Error(w, "Connect failure", http.StatusBadGateway)
+ return
+ }
+ defer back.Close()
+
+ hj, ok := w.(http.Hijacker)
+ if !ok {
+ http.Error(w, "CONNECT hijack unavailable", http.StatusInternalServerError)
+ return
+ }
+ c, br, err := hj.Hijack()
+ if err != nil {
+ logf("CONNECT hijack: %v", err)
+ return
+ }
+ defer c.Close()
+
+ io.WriteString(c, "HTTP/1.1 200 OK\r\n\r\n")
+
+ errc := make(chan error, 2)
+ go func() {
+ _, err := io.Copy(c, back)
+ errc <- err
+ }()
+ go func() {
+ _, err := io.Copy(back, br)
+ errc <- err
+ }()
+ <-errc
+}
diff --git a/net/dnscache/dnscache.go b/net/dnscache/dnscache.go
index 96550cbb1..d60e92f0b 100644
--- a/net/dnscache/dnscache.go
+++ b/net/dnscache/dnscache.go
@@ -24,6 +24,7 @@ import (
"tailscale.com/util/cloudenv"
"tailscale.com/util/singleflight"
"tailscale.com/util/slicesx"
+ "tailscale.com/util/testenv"
)
var zaddr netip.Addr
@@ -63,6 +64,10 @@ type Resolver struct {
// If nil, net.DefaultResolver is used.
Forward *net.Resolver
+ // LookupIPForTest, if non-nil and in tests, handles requests instead
+ // of the usual mechanisms.
+ LookupIPForTest func(ctx context.Context, host string) ([]netip.Addr, error)
+
// LookupIPFallback optionally provides a backup DNS mechanism
// to use if Forward returns an error or no results.
LookupIPFallback func(ctx context.Context, host string) ([]netip.Addr, error)
@@ -284,7 +289,13 @@ func (r *Resolver) lookupIP(ctx context.Context, host string) (ip, ip6 netip.Add
lookupCtx, lookupCancel := context.WithTimeout(ctx, r.lookupTimeoutForHost(host))
defer lookupCancel()
- ips, err := r.fwd().LookupNetIP(lookupCtx, "ip", host)
+
+ var ips []netip.Addr
+ if r.LookupIPForTest != nil && testenv.InTest() {
+ ips, err = r.LookupIPForTest(ctx, host)
+ } else {
+ ips, err = r.fwd().LookupNetIP(lookupCtx, "ip", host)
+ }
if err != nil || len(ips) == 0 {
if resolver, ok := r.cloudHostResolver(); ok {
r.dlogf("resolving %q via cloud resolver", host)
diff --git a/net/dnsfallback/dnsfallback.go b/net/dnsfallback/dnsfallback.go
index 4c5d5fa2f..8e53c3b29 100644
--- a/net/dnsfallback/dnsfallback.go
+++ b/net/dnsfallback/dnsfallback.go
@@ -286,7 +286,7 @@ func bootstrapDNSMap(ctx context.Context, serverName string, serverIP netip.Addr
tr.DialContext = func(ctx context.Context, netw, addr string) (net.Conn, error) {
return dialer.DialContext(ctx, "tcp", net.JoinHostPort(serverIP.String(), "443"))
}
- tr.TLSClientConfig = tlsdial.Config(serverName, ht, tr.TLSClientConfig)
+ tr.TLSClientConfig = tlsdial.Config(ht, tr.TLSClientConfig)
c := &http.Client{Transport: tr}
req, err := http.NewRequestWithContext(ctx, "GET", "https://"+serverName+"/bootstrap-dns?q="+url.QueryEscape(queryName), nil)
if err != nil {
diff --git a/net/tlsdial/tlsdial.go b/net/tlsdial/tlsdial.go
index 1bd2450aa..80f3bfc06 100644
--- a/net/tlsdial/tlsdial.go
+++ b/net/tlsdial/tlsdial.go
@@ -59,18 +59,26 @@ var mitmBlockWarnable = health.Register(&health.Warnable{
ImpactsConnectivity: true,
})
-// Config returns a tls.Config for connecting to a server.
+// Config returns a tls.Config for connecting to a server that
+// uses system roots for validation but, if those fail, also tries
+// the baked-in LetsEncrypt roots as a fallback validation method.
+//
// If base is non-nil, it's cloned as the base config before
// being configured and returned.
// If ht is non-nil, it's used to report health errors.
-func Config(host string, ht *health.Tracker, base *tls.Config) *tls.Config {
+func Config(ht *health.Tracker, base *tls.Config) *tls.Config {
var conf *tls.Config
if base == nil {
conf = new(tls.Config)
} else {
conf = base.Clone()
}
- conf.ServerName = host
+
+ // Note: we do NOT set conf.ServerName here (as we accidentally did
+ // previously), as this path is also used when dialing an HTTPS proxy server
+ // (through which we'll send a CONNECT request to get a TCP connection to do
+ // the real TCP connection) because host is the ultimate hostname, but this
+ // tls.Config is used for both the proxy and the ultimate target.
if n := sslKeyLogFile; n != "" {
f, err := os.OpenFile(n, os.O_CREATE|os.O_APPEND|os.O_WRONLY, 0600)
@@ -93,7 +101,9 @@ func Config(host string, ht *health.Tracker, base *tls.Config) *tls.Config {
// (with the baked-in fallback root) in the VerifyConnection hook.
conf.InsecureSkipVerify = true
conf.VerifyConnection = func(cs tls.ConnectionState) (retErr error) {
- if host == "log.tailscale.com" && hostinfo.IsNATLabGuestVM() {
+ dialedHost := cs.ServerName
+
+ if dialedHost == "log.tailscale.com" && hostinfo.IsNATLabGuestVM() {
// Allow log.tailscale.com TLS MITM for integration tests when
// the client's running within a NATLab VM.
return nil
@@ -116,7 +126,7 @@ func Config(host string, ht *health.Tracker, base *tls.Config) *tls.Config {
// Show a dedicated warning.
m, ok := blockblame.VerifyCertificate(cert)
if ok {
- log.Printf("tlsdial: server cert for %q looks like %q equipment (could be blocking Tailscale)", host, m.Name)
+ log.Printf("tlsdial: server cert seen while dialing %q looks like %q equipment (could be blocking Tailscale)", dialedHost, m.Name)
ht.SetUnhealthy(mitmBlockWarnable, health.Args{"manufacturer": m.Name})
} else {
ht.SetHealthy(mitmBlockWarnable)
@@ -135,7 +145,7 @@ func Config(host string, ht *health.Tracker, base *tls.Config) *tls.Config {
ht.SetTLSConnectionError(cs.ServerName, nil)
if selfSignedIssuer != "" {
// Log the self-signed issuer, but don't treat it as an error.
- log.Printf("tlsdial: warning: server cert for %q passed x509 validation but is self-signed by %q", host, selfSignedIssuer)
+ log.Printf("tlsdial: warning: server cert for %q passed x509 validation but is self-signed by %q", dialedHost, selfSignedIssuer)
}
}
}()
@@ -144,7 +154,7 @@ func Config(host string, ht *health.Tracker, base *tls.Config) *tls.Config {
// First try doing x509 verification with the system's
// root CA pool.
opts := x509.VerifyOptions{
- DNSName: cs.ServerName,
+ DNSName: dialedHost,
Intermediates: x509.NewCertPool(),
}
for _, cert := range cs.PeerCertificates[1:] {
@@ -152,7 +162,7 @@ func Config(host string, ht *health.Tracker, base *tls.Config) *tls.Config {
}
_, errSys := cs.PeerCertificates[0].Verify(opts)
if debug() {
- log.Printf("tlsdial(sys %q): %v", host, errSys)
+ log.Printf("tlsdial(sys %q): %v", dialedHost, errSys)
}
// Always verify with our baked-in Let's Encrypt certificate,
@@ -161,13 +171,11 @@ func Config(host string, ht *health.Tracker, base *tls.Config) *tls.Config {
opts.Roots = bakedroots.Get()
_, bakedErr := cs.PeerCertificates[0].Verify(opts)
if debug() {
- log.Printf("tlsdial(bake %q): %v", host, bakedErr)
+ log.Printf("tlsdial(bake %q): %v", dialedHost, bakedErr)
} else if bakedErr != nil {
- if _, loaded := tlsdialWarningPrinted.LoadOrStore(host, true); !loaded {
- if errSys == nil {
- log.Printf("tlsdial: warning: server cert for %q is not a Let's Encrypt cert", host)
- } else {
- log.Printf("tlsdial: error: server cert for %q failed to verify and is not a Let's Encrypt cert", host)
+ if _, loaded := tlsdialWarningPrinted.LoadOrStore(dialedHost, true); !loaded {
+ if errSys != nil {
+ log.Printf("tlsdial: error: server cert for %q failed both system roots & Let's Encrypt root validation", dialedHost)
}
}
}
@@ -202,9 +210,6 @@ func SetConfigExpectedCert(c *tls.Config, certDNSName string) {
c.ServerName = certDNSName
return
}
- if c.VerifyPeerCertificate != nil {
- panic("refusing to override tls.Config.VerifyPeerCertificate")
- }
// Set InsecureSkipVerify to prevent crypto/tls from doing its
// own cert verification, but do the same work that it'd do
// (but using certDNSName) in the VerifyPeerCertificate hook.
@@ -257,29 +262,30 @@ func SetConfigExpectedCertHash(c *tls.Config, wantFullCertSHA256Hex string) {
if c.VerifyPeerCertificate != nil {
panic("refusing to override tls.Config.VerifyPeerCertificate")
}
+
// Set InsecureSkipVerify to prevent crypto/tls from doing its
// own cert verification, but do the same work that it'd do
- // (but using certDNSName) in the VerifyPeerCertificate hook.
+ // (but using certDNSName) in the VerifyConnection hook.
c.InsecureSkipVerify = true
- c.VerifyConnection = nil
- c.VerifyPeerCertificate = func(rawCerts [][]byte, _ [][]*x509.Certificate) error {
+
+ c.VerifyConnection = func(cs tls.ConnectionState) error {
+ dialedHost := cs.ServerName
var sawGoodCert bool
- for _, rawCert := range rawCerts {
- cert, err := x509.ParseCertificate(rawCert)
- if err != nil {
- return fmt.Errorf("ParseCertificate: %w", err)
- }
+
+ for _, cert := range cs.PeerCertificates {
if strings.HasPrefix(cert.Subject.CommonName, derpconst.MetaCertCommonNamePrefix) {
continue
}
if sawGoodCert {
return errors.New("unexpected multiple certs presented")
}
- if fmt.Sprintf("%02x", sha256.Sum256(rawCert)) != wantFullCertSHA256Hex {
+ if fmt.Sprintf("%02x", sha256.Sum256(cert.Raw)) != wantFullCertSHA256Hex {
return fmt.Errorf("cert hash does not match expected cert hash")
}
- if err := cert.VerifyHostname(c.ServerName); err != nil {
- return fmt.Errorf("cert does not match server name %q: %w", c.ServerName, err)
+ if dialedHost != "" { // it's empty when dialing a derper by IP with no hostname
+ if err := cert.VerifyHostname(dialedHost); err != nil {
+ return fmt.Errorf("cert does not match server name %q: %w", dialedHost, err)
+ }
}
now := time.Now()
if now.After(cert.NotAfter) {
@@ -302,12 +308,8 @@ func SetConfigExpectedCertHash(c *tls.Config, wantFullCertSHA256Hex string) {
func NewTransport() *http.Transport {
return &http.Transport{
DialTLSContext: func(ctx context.Context, network, addr string) (net.Conn, error) {
- host, _, err := net.SplitHostPort(addr)
- if err != nil {
- return nil, err
- }
var d tls.Dialer
- d.Config = Config(host, nil, nil)
+ d.Config = Config(nil, nil)
return d.DialContext(ctx, network, addr)
},
}
diff --git a/net/tlsdial/tlsdial_test.go b/net/tlsdial/tlsdial_test.go
index 6723b82e0..e2c4cdd4f 100644
--- a/net/tlsdial/tlsdial_test.go
+++ b/net/tlsdial/tlsdial_test.go
@@ -86,7 +86,7 @@ func TestFallbackRootWorks(t *testing.T) {
DisableKeepAlives: true, // for test cleanup ease
}
ht := new(health.Tracker)
- tr.TLSClientConfig = Config("tlsdial.test", ht, tr.TLSClientConfig)
+ tr.TLSClientConfig = Config(ht, tr.TLSClientConfig)
c := &http.Client{Transport: tr}
ctr0 := atomic.LoadInt32(&counterFallbackOK)
diff --git a/tstest/tlstest/testdata/controlplane.tstest.key b/tstest/tlstest/testdata/controlplane.tstest.key
new file mode 100644
index 000000000..dbe5ede34
--- /dev/null
+++ b/tstest/tlstest/testdata/controlplane.tstest.key
@@ -0,0 +1,5 @@
+-----BEGIN EC PRIVATE KEY-----
+MHcCAQEEIHcxOQNVyqvBSSlu7c93QW6OsyccjL+R1evW4acd32MWoAoGCCqGSM49
+AwEHoUQDQgAEIOY5/CQ8CMuKYPLf+r6OEneqfzQ5RfgPnLdkL22qhm8xb69ZCXxz
+UecawU0KEDfHLYbUYXSuhAFxxuPh9I3x5Q==
+-----END EC PRIVATE KEY-----
diff --git a/tstest/tlstest/testdata/proxy.tstest.key b/tstest/tlstest/testdata/proxy.tstest.key
new file mode 100644
index 000000000..067279089
--- /dev/null
+++ b/tstest/tlstest/testdata/proxy.tstest.key
@@ -0,0 +1,5 @@
+-----BEGIN EC PRIVATE KEY-----
+MHcCAQEEING1XBDWFXQjqBmLjhp20hXOf2rk/I0N6W7muv9RVvk3oAoGCCqGSM49
+AwEHoUQDQgAE8lxnEEeLqYikwmXbXSsIQSw20R0oLA831s960KQZEgt0P9SbWcJc
+QTk98rdfYT/QDdHn157Oh4FPcDtxmdQ4vw==
+-----END EC PRIVATE KEY-----
diff --git a/tstest/tlstest/testdata/root-ca.key b/tstest/tlstest/testdata/root-ca.key
new file mode 100644
index 000000000..ece23ddf9
--- /dev/null
+++ b/tstest/tlstest/testdata/root-ca.key
@@ -0,0 +1,5 @@
+-----BEGIN EC PRIVATE KEY-----
+MHcCAQEEIMl3xjqt1dnXBpYJSEqevirAcnSJ79I2tucdRazlrDG9oAoGCCqGSM49
+AwEHoUQDQgAEQ/+Jme+16hgO7TtPSIFHVV0Yt969ltVlARVcNUZmWc0upQaq7uiJ
+Aur5KtzwxU3YI4bhNK0593OK2TLvEEWIdw==
+-----END EC PRIVATE KEY-----
diff --git a/tstest/tlstest/tlstest.go b/tstest/tlstest/tlstest.go
new file mode 100644
index 000000000..f65c261e8
--- /dev/null
+++ b/tstest/tlstest/tlstest.go
@@ -0,0 +1,167 @@
+// Copyright (c) Tailscale Inc & AUTHORS
+// SPDX-License-Identifier: BSD-3-Clause
+
+// Package tlstest contains code to help test Tailscale's client proxy support.
+package tlstest
+
+import (
+ "bytes"
+ "crypto/ecdsa"
+ "crypto/rand"
+ "crypto/tls"
+ "crypto/x509"
+ "crypto/x509/pkix"
+ _ "embed"
+ "encoding/pem"
+ "fmt"
+ "math/big"
+ "sync"
+ "time"
+)
+
+// Some baked-in ECDSA keys to speed up tests, not having to burn CPU to
+// generate them each time. We only make the certs (which have expiry times)
+// at runtime.
+//
+// They were made with:
+//
+// openssl ecparam -name prime256v1 -genkey -noout -out root-ca.key
+var (
+ //go:embed testdata/root-ca.key
+ rootCAKeyPEM []byte
+
+ // TestProxyServerKey is the PEM private key for [TestProxyServerCert].
+ //
+ //go:embed testdata/proxy.tstest.key
+ TestProxyServerKey []byte
+
+ // TestControlPlaneKey is the PEM private key for [TestControlPlaneCert].
+ //
+ //go:embed testdata/controlplane.tstest.key
+ TestControlPlaneKey []byte
+)
+
+// TestRootCA returns a self-signed ECDSA root CA certificate (as PEM) for
+// testing purposes.
+func TestRootCA() []byte {
+ return bytes.Clone(testRootCAOncer())
+}
+
+var testRootCAOncer = sync.OnceValue(func() []byte {
+ key := rootCAKey()
+ now := time.Now().Add(-time.Hour)
+ tpl := &x509.Certificate{
+ SerialNumber: big.NewInt(1),
+ Subject: pkix.Name{
+ CommonName: "Tailscale Unit Test ECDSA Root",
+ Organization: []string{"Tailscale Test Org"},
+ },
+ NotBefore: now,
+ NotAfter: now.AddDate(5, 0, 0),
+
+ IsCA: true,
+ BasicConstraintsValid: true,
+ KeyUsage: x509.KeyUsageCertSign | x509.KeyUsageCRLSign,
+ SubjectKeyId: mustSKID(&key.PublicKey),
+ }
+
+ der, err := x509.CreateCertificate(rand.Reader, tpl, tpl, &key.PublicKey, key)
+ if err != nil {
+ panic(err)
+ }
+ return pemCert(der)
+})
+
+func pemCert(der []byte) []byte {
+ var buf bytes.Buffer
+ if err := pem.Encode(&buf, &pem.Block{Type: "CERTIFICATE", Bytes: der}); err != nil {
+ panic(fmt.Sprintf("failed to encode PEM: %v", err))
+ }
+ return buf.Bytes()
+}
+
+var rootCAKey = sync.OnceValue(func() *ecdsa.PrivateKey {
+ return mustParsePEM(rootCAKeyPEM, x509.ParseECPrivateKey)
+})
+
+func mustParsePEM[T any](pemBytes []byte, parse func([]byte) (T, error)) T {
+ block, rest := pem.Decode(pemBytes)
+ if block == nil || len(rest) > 0 {
+ panic("invalid PEM")
+ }
+ v, err := parse(block.Bytes)
+ if err != nil {
+ panic(fmt.Sprintf("invalid PEM: %v", err))
+ }
+ return v
+}
+
+// KeyPair is a simple struct to hold a certificate and its private key.
+type KeyPair struct {
+ Domain string
+ KeyPEM []byte // PEM-encoded private key
+}
+
+// ServerTLSConfig returns a TLS configuration suitable for a server
+// using the KeyPair's certificate and private key.
+func (p KeyPair) ServerTLSConfig() *tls.Config {
+ cert, err := tls.X509KeyPair(p.CertPEM(), p.KeyPEM)
+ if err != nil {
+ panic("invalid TLS key pair: " + err.Error())
+ }
+ return &tls.Config{
+ Certificates: []tls.Certificate{cert},
+ }
+}
+
+// ProxyServerKeyPair is a KeyPair for a test control plane server
+// with domain name "proxy.tstest".
+var ProxyServerKeyPair = KeyPair{
+ Domain: "proxy.tstest",
+ KeyPEM: TestProxyServerKey,
+}
+
+// ControlPlaneKeyPair is a KeyPair for a test control plane server
+// with domain name "controlplane.tstest".
+var ControlPlaneKeyPair = KeyPair{
+ Domain: "controlplane.tstest",
+ KeyPEM: TestControlPlaneKey,
+}
+
+func (p KeyPair) CertPEM() []byte {
+ caCert := mustParsePEM(TestRootCA(), x509.ParseCertificate)
+ caPriv := mustParsePEM(rootCAKeyPEM, x509.ParseECPrivateKey)
+ leafKey := mustParsePEM(p.KeyPEM, x509.ParseECPrivateKey)
+
+ serial, err := rand.Int(rand.Reader, big.NewInt(0).Lsh(big.NewInt(1), 128))
+ if err != nil {
+ panic(err)
+ }
+
+ now := time.Now().Add(-time.Hour)
+ tpl := &x509.Certificate{
+ SerialNumber: serial,
+ Subject: pkix.Name{CommonName: p.Domain},
+ NotBefore: now,
+ NotAfter: now.AddDate(2, 0, 0),
+
+ KeyUsage: x509.KeyUsageDigitalSignature | x509.KeyUsageKeyEncipherment,
+ ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageServerAuth},
+ BasicConstraintsValid: true,
+ DNSNames: []string{p.Domain},
+ }
+
+ der, err := x509.CreateCertificate(rand.Reader, tpl, caCert, &leafKey.PublicKey, caPriv)
+ if err != nil {
+ panic(err)
+ }
+ return pemCert(der)
+}
+
+func mustSKID(pub *ecdsa.PublicKey) []byte {
+ skid, err := x509.MarshalPKIXPublicKey(pub)
+ if err != nil {
+ panic(err)
+ }
+ return skid[:20] // same as x509 library
+}