summaryrefslogtreecommitdiffhomepage
path: root/ipn
diff options
context:
space:
mode:
Diffstat (limited to 'ipn')
-rw-r--r--ipn/localapi/localapi.go26
-rw-r--r--ipn/localapi/localapi_test.go66
2 files changed, 88 insertions, 4 deletions
diff --git a/ipn/localapi/localapi.go b/ipn/localapi/localapi.go
index 5eec66e64..c5ae3f846 100644
--- a/ipn/localapi/localapi.go
+++ b/ipn/localapi/localapi.go
@@ -1168,16 +1168,34 @@ func (h *Handler) serveDial(w http.ResponseWriter, r *http.Request) {
http.Error(w, "missing Dial-Host or Dial-Port header", http.StatusBadRequest)
return
}
+ network := cmp.Or(r.Header.Get("Dial-Network"), "tcp")
+
+ addr := net.JoinHostPort(hostStr, portStr)
+
+ // Check whether the resolved address is a Tailscale route.
+ // If not, tell the client to dial it directly so the connection
+ // comes from the calling user's UID rather than our root-owned daemon.
+ ipp, viaTailscale, err := h.b.Dialer().UserDialPlan(r.Context(), network, addr)
+ if err != nil {
+ http.Error(w, "resolve failure: "+err.Error(), http.StatusBadGateway)
+ return
+ }
+ if !viaTailscale {
+ w.Header().Set("Dial-Self", "true")
+ w.Header().Set("Dial-Addr", ipp.String())
+ w.WriteHeader(http.StatusOK)
+ return
+ }
+
hijacker, ok := w.(http.Hijacker)
if !ok {
http.Error(w, "make request over HTTP/1", http.StatusBadRequest)
return
}
- network := cmp.Or(r.Header.Get("Dial-Network"), "tcp")
-
- addr := net.JoinHostPort(hostStr, portStr)
- outConn, err := h.b.Dialer().UserDial(r.Context(), network, addr)
+ // Dial via Tailscale using the resolved IP:port to avoid a TOCTOU
+ // race with DNS re-resolution.
+ outConn, err := h.b.Dialer().UserDial(r.Context(), network, ipp.String())
if err != nil {
http.Error(w, "dial failure: "+err.Error(), http.StatusBadGateway)
return
diff --git a/ipn/localapi/localapi_test.go b/ipn/localapi/localapi_test.go
index 47e334571..a755221bf 100644
--- a/ipn/localapi/localapi_test.go
+++ b/ipn/localapi/localapi_test.go
@@ -500,3 +500,69 @@ func TestServeWithUnhealthyState(t *testing.T) {
})
}
}
+
+func TestServeDialSelf(t *testing.T) {
+ h := handlerForTest(t, &Handler{
+ PermitRead: true,
+ PermitWrite: true,
+ b: newTestLocalBackend(t),
+ })
+
+ tests := []struct {
+ name string
+ host string
+ port string
+ wantSelf bool
+ wantAddr string
+ wantStatus int
+ }{
+ {
+ name: "loopback_v4",
+ host: "127.0.0.1",
+ port: "8080",
+ wantSelf: true,
+ wantAddr: "127.0.0.1:8080",
+ wantStatus: http.StatusOK,
+ },
+ {
+ name: "loopback_v6",
+ host: "::1",
+ port: "8080",
+ wantSelf: true,
+ wantAddr: "[::1]:8080",
+ wantStatus: http.StatusOK,
+ },
+ {
+ name: "localhost",
+ host: "localhost",
+ port: "3000",
+ wantSelf: true,
+ wantStatus: http.StatusOK,
+ },
+ }
+
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ req := httptest.NewRequest("POST", "http://local-tailscaled.sock/localapi/v0/dial", nil)
+ req.Header.Set("Connection", "upgrade")
+ req.Header.Set("Upgrade", "ts-dial")
+ req.Header.Set("Dial-Host", tt.host)
+ req.Header.Set("Dial-Port", tt.port)
+ resp := httptest.NewRecorder()
+ h.serveDial(resp, req)
+
+ if resp.Code != tt.wantStatus {
+ t.Fatalf("status = %d, want %d; body: %s", resp.Code, tt.wantStatus, resp.Body.String())
+ }
+ gotSelf := resp.Header().Get("Dial-Self") == "true"
+ if gotSelf != tt.wantSelf {
+ t.Errorf("Dial-Self = %v, want %v", gotSelf, tt.wantSelf)
+ }
+ if tt.wantAddr != "" {
+ if got := resp.Header().Get("Dial-Addr"); got != tt.wantAddr {
+ t.Errorf("Dial-Addr = %q, want %q", got, tt.wantAddr)
+ }
+ }
+ })
+ }
+}