summaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
-rw-r--r--client/local/local.go13
-rw-r--r--client/local/local_test.go51
-rw-r--r--ipn/localapi/localapi.go26
-rw-r--r--ipn/localapi/localapi_test.go66
-rw-r--r--net/tsdial/tsdial.go27
-rw-r--r--net/tsdial/tsdial_test.go97
6 files changed, 276 insertions, 4 deletions
diff --git a/client/local/local.go b/client/local/local.go
index e72589306..f9cf96654 100644
--- a/client/local/local.go
+++ b/client/local/local.go
@@ -972,6 +972,19 @@ func (lc *Client) UserDial(ctx context.Context, network, host string, port uint1
if res.StatusCode != http.StatusSwitchingProtocols {
body, _ := io.ReadAll(res.Body)
res.Body.Close()
+ if res.StatusCode == http.StatusOK && res.Header.Get("Dial-Self") == "true" {
+ // Server told us to dial the address ourselves rather than
+ // proxying through the daemon. This happens for non-Tailscale
+ // addresses where the daemon shouldn't dial as root on the
+ // client's behalf. The server provides the resolved address
+ // to avoid a TOCTOU race with DNS re-resolution.
+ addr := res.Header.Get("Dial-Addr")
+ if addr == "" {
+ return nil, errors.New("server returned Dial-Self without Dial-Addr")
+ }
+ var d net.Dialer
+ return d.DialContext(ctx, network, addr)
+ }
return nil, fmt.Errorf("unexpected HTTP response: %s, %s", res.Status, body)
}
// From here on, the underlying net.Conn is ours to use, but there
diff --git a/client/local/local_test.go b/client/local/local_test.go
index a5377fbd6..58a87b224 100644
--- a/client/local/local_test.go
+++ b/client/local/local_test.go
@@ -61,6 +61,57 @@ func TestWhoIsPeerNotFound(t *testing.T) {
}
}
+func TestUserDialSelf(t *testing.T) {
+ // Start a real TCP listener that the client should dial directly
+ // when the server tells it to dial-self.
+ ln, err := net.Listen("tcp", "127.0.0.1:0")
+ if err != nil {
+ t.Fatal(err)
+ }
+ defer ln.Close()
+ go func() {
+ for {
+ c, err := ln.Accept()
+ if err != nil {
+ return
+ }
+ c.Write([]byte("hello"))
+ c.Close()
+ }
+ }()
+ targetAddr := ln.Addr().(*net.TCPAddr)
+
+ // Mock LocalAPI server that returns Dial-Self response.
+ nw := nettest.GetNetwork(t)
+ ts := nettest.NewHTTPServer(nw, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ w.Header().Set("Dial-Self", "true")
+ w.Header().Set("Dial-Addr", targetAddr.String())
+ w.WriteHeader(http.StatusOK)
+ }))
+ defer ts.Close()
+
+ lc := &Client{
+ Dial: func(ctx context.Context, network, addr string) (net.Conn, error) {
+ return nw.Dial(ctx, network, ts.Listener.Addr().String())
+ },
+ }
+
+ conn, err := lc.UserDial(context.Background(), "tcp", targetAddr.IP.String(), uint16(targetAddr.Port))
+ if err != nil {
+ t.Fatalf("UserDial: %v", err)
+ }
+ defer conn.Close()
+
+ buf := make([]byte, 5)
+ n, err := conn.Read(buf)
+ if err != nil {
+ t.Fatalf("Read: %v", err)
+ }
+ if got := string(buf[:n]); got != "hello" {
+ t.Errorf("got %q, want %q", got, "hello")
+ }
+}
+
func TestDeps(t *testing.T) {
deptest.DepChecker{
BadDeps: map[string]string{
diff --git a/ipn/localapi/localapi.go b/ipn/localapi/localapi.go
index 5eec66e64..c5ae3f846 100644
--- a/ipn/localapi/localapi.go
+++ b/ipn/localapi/localapi.go
@@ -1168,16 +1168,34 @@ func (h *Handler) serveDial(w http.ResponseWriter, r *http.Request) {
http.Error(w, "missing Dial-Host or Dial-Port header", http.StatusBadRequest)
return
}
+ network := cmp.Or(r.Header.Get("Dial-Network"), "tcp")
+
+ addr := net.JoinHostPort(hostStr, portStr)
+
+ // Check whether the resolved address is a Tailscale route.
+ // If not, tell the client to dial it directly so the connection
+ // comes from the calling user's UID rather than our root-owned daemon.
+ ipp, viaTailscale, err := h.b.Dialer().UserDialPlan(r.Context(), network, addr)
+ if err != nil {
+ http.Error(w, "resolve failure: "+err.Error(), http.StatusBadGateway)
+ return
+ }
+ if !viaTailscale {
+ w.Header().Set("Dial-Self", "true")
+ w.Header().Set("Dial-Addr", ipp.String())
+ w.WriteHeader(http.StatusOK)
+ return
+ }
+
hijacker, ok := w.(http.Hijacker)
if !ok {
http.Error(w, "make request over HTTP/1", http.StatusBadRequest)
return
}
- network := cmp.Or(r.Header.Get("Dial-Network"), "tcp")
-
- addr := net.JoinHostPort(hostStr, portStr)
- outConn, err := h.b.Dialer().UserDial(r.Context(), network, addr)
+ // Dial via Tailscale using the resolved IP:port to avoid a TOCTOU
+ // race with DNS re-resolution.
+ outConn, err := h.b.Dialer().UserDial(r.Context(), network, ipp.String())
if err != nil {
http.Error(w, "dial failure: "+err.Error(), http.StatusBadGateway)
return
diff --git a/ipn/localapi/localapi_test.go b/ipn/localapi/localapi_test.go
index 47e334571..a755221bf 100644
--- a/ipn/localapi/localapi_test.go
+++ b/ipn/localapi/localapi_test.go
@@ -500,3 +500,69 @@ func TestServeWithUnhealthyState(t *testing.T) {
})
}
}
+
+func TestServeDialSelf(t *testing.T) {
+ h := handlerForTest(t, &Handler{
+ PermitRead: true,
+ PermitWrite: true,
+ b: newTestLocalBackend(t),
+ })
+
+ tests := []struct {
+ name string
+ host string
+ port string
+ wantSelf bool
+ wantAddr string
+ wantStatus int
+ }{
+ {
+ name: "loopback_v4",
+ host: "127.0.0.1",
+ port: "8080",
+ wantSelf: true,
+ wantAddr: "127.0.0.1:8080",
+ wantStatus: http.StatusOK,
+ },
+ {
+ name: "loopback_v6",
+ host: "::1",
+ port: "8080",
+ wantSelf: true,
+ wantAddr: "[::1]:8080",
+ wantStatus: http.StatusOK,
+ },
+ {
+ name: "localhost",
+ host: "localhost",
+ port: "3000",
+ wantSelf: true,
+ wantStatus: http.StatusOK,
+ },
+ }
+
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ req := httptest.NewRequest("POST", "http://local-tailscaled.sock/localapi/v0/dial", nil)
+ req.Header.Set("Connection", "upgrade")
+ req.Header.Set("Upgrade", "ts-dial")
+ req.Header.Set("Dial-Host", tt.host)
+ req.Header.Set("Dial-Port", tt.port)
+ resp := httptest.NewRecorder()
+ h.serveDial(resp, req)
+
+ if resp.Code != tt.wantStatus {
+ t.Fatalf("status = %d, want %d; body: %s", resp.Code, tt.wantStatus, resp.Body.String())
+ }
+ gotSelf := resp.Header().Get("Dial-Self") == "true"
+ if gotSelf != tt.wantSelf {
+ t.Errorf("Dial-Self = %v, want %v", gotSelf, tt.wantSelf)
+ }
+ if tt.wantAddr != "" {
+ if got := resp.Header().Get("Dial-Addr"); got != tt.wantAddr {
+ t.Errorf("Dial-Addr = %q, want %q", got, tt.wantAddr)
+ }
+ }
+ })
+ }
+}
diff --git a/net/tsdial/tsdial.go b/net/tsdial/tsdial.go
index ebbafa52b..ca08810a3 100644
--- a/net/tsdial/tsdial.go
+++ b/net/tsdial/tsdial.go
@@ -515,6 +515,33 @@ func (d *Dialer) UserDial(ctx context.Context, network, addr string) (net.Conn,
return stdDialer.DialContext(ctx, network, ipp.String())
}
+// UserDialPlan resolves addr and reports whether the dialer would
+// handle it via Tailscale. If viaTailscale is false, the resolved
+// address is not a Tailscale route and the caller may dial it directly.
+//
+// Warning: there is a TOCTOU race if addr contains a DNS name and the
+// caller subsequently passes the same DNS name to [Dialer.UserDial], as DNS
+// may resolve differently the second time. Callers who want to only
+// dial over Tailscale should call [Dialer.UserDial] with the returned
+// ipp.String() (an IP:port) rather than the original DNS name.
+func (d *Dialer) UserDialPlan(ctx context.Context, network, addr string) (ipp netip.AddrPort, viaTailscale bool, err error) {
+ ipp, err = d.userDialResolve(ctx, network, addr)
+ if err != nil {
+ return netip.AddrPort{}, false, err
+ }
+ if d.UseNetstackForIP != nil && d.UseNetstackForIP(ipp.Addr()) {
+ return ipp, true, nil
+ }
+ if routes := d.routes.Load(); routes != nil {
+ isTailscaleRoute, _ := routes.Lookup(ipp.Addr())
+ return ipp, isTailscaleRoute, nil
+ }
+ if version.IsMacGUIVariant() && tsaddr.IsTailscaleIP(ipp.Addr()) {
+ return ipp, true, nil
+ }
+ return ipp, false, nil
+}
+
// dialPeerAPI connects to a Tailscale peer's peerapi over TCP.
//
// network must a "tcp" type, and addr must be an ip:port. Name resolution
diff --git a/net/tsdial/tsdial_test.go b/net/tsdial/tsdial_test.go
new file mode 100644
index 000000000..92960acbe
--- /dev/null
+++ b/net/tsdial/tsdial_test.go
@@ -0,0 +1,97 @@
+// Copyright (c) Tailscale Inc & contributors
+// SPDX-License-Identifier: BSD-3-Clause
+
+package tsdial
+
+import (
+ "context"
+ "net/netip"
+ "testing"
+
+ "github.com/gaissmai/bart"
+)
+
+func TestUserDialPlan(t *testing.T) {
+ tests := []struct {
+ name string
+ addr string
+ routes map[netip.Prefix]bool // nil means no routes configured
+ useNetstackFor func(netip.Addr) bool // nil means not set
+ wantVia bool
+ wantAddr netip.AddrPort
+ }{
+ {
+ name: "loopback_no_routes",
+ addr: "127.0.0.1:8080",
+ wantVia: false,
+ wantAddr: netip.MustParseAddrPort("127.0.0.1:8080"),
+ },
+ {
+ name: "loopback_v6_no_routes",
+ addr: "[::1]:8080",
+ wantVia: false,
+ wantAddr: netip.MustParseAddrPort("[::1]:8080"),
+ },
+ {
+ name: "tailscale_ip_in_routes",
+ addr: "100.64.1.1:22",
+ routes: map[netip.Prefix]bool{
+ netip.MustParsePrefix("100.64.0.0/10"): true,
+ },
+ wantVia: true,
+ wantAddr: netip.MustParseAddrPort("100.64.1.1:22"),
+ },
+ {
+ name: "non_tailscale_ip_in_local_routes",
+ addr: "10.0.0.5:80",
+ routes: map[netip.Prefix]bool{
+ netip.MustParsePrefix("100.64.0.0/10"): true,
+ netip.MustParsePrefix("10.0.0.0/8"): false, // local route
+ },
+ wantVia: false,
+ wantAddr: netip.MustParseAddrPort("10.0.0.5:80"),
+ },
+ {
+ name: "loopback_with_routes_configured",
+ addr: "127.0.0.1:3000",
+ routes: map[netip.Prefix]bool{
+ netip.MustParsePrefix("100.64.0.0/10"): true,
+ },
+ wantVia: false,
+ wantAddr: netip.MustParseAddrPort("127.0.0.1:3000"),
+ },
+ {
+ name: "netstack_for_ip",
+ addr: "100.100.100.100:53",
+ useNetstackFor: func(ip netip.Addr) bool {
+ return ip == netip.MustParseAddr("100.100.100.100")
+ },
+ wantVia: true,
+ wantAddr: netip.MustParseAddrPort("100.100.100.100:53"),
+ },
+ }
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ d := &Dialer{}
+ if tt.routes != nil {
+ rt := &bart.Table[bool]{}
+ for pfx, v := range tt.routes {
+ rt.Insert(pfx, v)
+ }
+ d.routes.Store(rt)
+ }
+ d.UseNetstackForIP = tt.useNetstackFor
+
+ ipp, viaTailscale, err := d.UserDialPlan(context.Background(), "tcp", tt.addr)
+ if err != nil {
+ t.Fatalf("UserDialPlan: %v", err)
+ }
+ if viaTailscale != tt.wantVia {
+ t.Errorf("viaTailscale = %v, want %v", viaTailscale, tt.wantVia)
+ }
+ if ipp != tt.wantAddr {
+ t.Errorf("addr = %v, want %v", ipp, tt.wantAddr)
+ }
+ })
+ }
+}