summaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
-rw-r--r--client/web/web.go53
-rw-r--r--client/web/web_test.go82
2 files changed, 112 insertions, 23 deletions
diff --git a/client/web/web.go b/client/web/web.go
index 6203b4c18..e9810ccd0 100644
--- a/client/web/web.go
+++ b/client/web/web.go
@@ -203,15 +203,25 @@ func NewServer(opts ServerOpts) (s *Server, err error) {
}
s.assetsHandler, s.assetsCleanup = assetsHandler(s.devMode)
- var metric string // clientmetric to report on startup
+ var metric string
+ s.apiHandler, metric = s.modeAPIHandler(s.mode)
+ s.apiHandler = s.withCSRF(s.apiHandler)
- // Create handler for "/api" requests with CSRF protection.
- // We don't require secure cookies, since the web client is regularly used
- // on network appliances that are served on local non-https URLs.
- // The client is secured by limiting the interface it listens on,
- // or by authenticating requests before they reach the web client.
+ // Don't block startup on reporting metric.
+ // Report in separate go routine with 5 second timeout.
+ go func() {
+ ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
+ defer cancel()
+ s.lc.IncrementCounter(ctx, metric, 1)
+ }()
+
+ return s, nil
+}
+
+func (s *Server) withCSRF(h http.Handler) http.Handler {
csrfProtect := csrf.Protect(s.csrfKey(), csrf.Secure(false))
+ // ref https://github.com/tailscale/tailscale/pull/14822
// signal to the CSRF middleware that the request is being served over
// plaintext HTTP to skip TLS-only header checks.
withSetPlaintext := func(h http.Handler) http.Handler {
@@ -221,27 +231,24 @@ func NewServer(opts ServerOpts) (s *Server, err error) {
})
}
- switch s.mode {
+ // NB: the order of the withSetPlaintext and csrfProtect calls is important
+ // to ensure that we signal to the CSRF middleware that the request is being
+ // served over plaintext HTTP and not over TLS as it presumes by default.
+ return withSetPlaintext(csrfProtect(h))
+}
+
+func (s *Server) modeAPIHandler(mode ServerMode) (http.Handler, string) {
+ switch mode {
case LoginServerMode:
- s.apiHandler = csrfProtect(withSetPlaintext(http.HandlerFunc(s.serveLoginAPI)))
- metric = "web_login_client_initialization"
+ return http.HandlerFunc(s.serveLoginAPI), "web_login_client_initialization"
case ReadOnlyServerMode:
- s.apiHandler = csrfProtect(withSetPlaintext(http.HandlerFunc(s.serveLoginAPI)))
- metric = "web_readonly_client_initialization"
+ return http.HandlerFunc(s.serveLoginAPI), "web_readonly_client_initialization"
case ManageServerMode:
- s.apiHandler = csrfProtect(withSetPlaintext(http.HandlerFunc(s.serveAPI)))
- metric = "web_client_initialization"
+ return http.HandlerFunc(s.serveAPI), "web_client_initialization"
+ default: // invalid mode
+ log.Fatalf("invalid mode: %v", mode)
}
-
- // Don't block startup on reporting metric.
- // Report in separate go routine with 5 second timeout.
- go func() {
- ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
- defer cancel()
- s.lc.IncrementCounter(ctx, metric, 1)
- }()
-
- return s, nil
+ return nil, ""
}
func (s *Server) Shutdown() {
diff --git a/client/web/web_test.go b/client/web/web_test.go
index b9242f6ac..291356260 100644
--- a/client/web/web_test.go
+++ b/client/web/web_test.go
@@ -11,6 +11,7 @@ import (
"fmt"
"io"
"net/http"
+ "net/http/cookiejar"
"net/http/httptest"
"net/netip"
"net/url"
@@ -20,6 +21,7 @@ import (
"time"
"github.com/google/go-cmp/cmp"
+ "github.com/gorilla/csrf"
"tailscale.com/client/local"
"tailscale.com/client/tailscale/apitype"
"tailscale.com/ipn"
@@ -1477,3 +1479,83 @@ func mockWaitAuthURL(_ context.Context, id string, src tailcfg.NodeID) (*tailcfg
return nil, errors.New("unknown id")
}
}
+
+func TestCSRFProtect(t *testing.T) {
+ s := &Server{}
+
+ mux := http.NewServeMux()
+ mux.HandleFunc("GET /test/csrf-token", func(w http.ResponseWriter, r *http.Request) {
+ token := csrf.Token(r)
+ _, err := io.WriteString(w, token)
+ if err != nil {
+ t.Fatal(err)
+ }
+ })
+ mux.HandleFunc("POST /test/csrf-protected", func(w http.ResponseWriter, r *http.Request) {
+ _, err := io.WriteString(w, "ok")
+ if err != nil {
+ t.Fatal(err)
+ }
+ })
+ h := s.withCSRF(mux)
+ ser := httptest.NewServer(h)
+ defer ser.Close()
+
+ jar, err := cookiejar.New(nil)
+ if err != nil {
+ t.Fatalf("unable to construct cookie jar: %v", err)
+ }
+
+ client := ser.Client()
+ client.Jar = jar
+
+ // make GET request to populate cookie jar
+ resp, err := client.Get(ser.URL + "/test/csrf-token")
+ if err != nil {
+ t.Fatalf("unable to make request: %v", err)
+ }
+ defer resp.Body.Close()
+ if resp.StatusCode != http.StatusOK {
+ t.Fatalf("unexpected status: %v", resp.Status)
+ }
+ tokenBytes, err := io.ReadAll(resp.Body)
+ if err != nil {
+ t.Fatalf("unable to read body: %v", err)
+ }
+
+ csrfToken := strings.TrimSpace(string(tokenBytes))
+ if csrfToken == "" {
+ t.Fatal("empty csrf token")
+ }
+
+ // make a POST request without the CSRF header; ensure it fails
+ resp, err = client.Post(ser.URL+"/test/csrf-protected", "text/plain", nil)
+ if err != nil {
+ t.Fatalf("unable to make request: %v", err)
+ }
+ if resp.StatusCode != http.StatusForbidden {
+ t.Fatalf("unexpected status: %v", resp.Status)
+ }
+
+ // make a POST request with the CSRF header; ensure it succeeds
+ req, err := http.NewRequest("POST", ser.URL+"/test/csrf-protected", nil)
+ if err != nil {
+ t.Fatalf("error building request: %v", err)
+ }
+ req.Header.Set("X-CSRF-Token", csrfToken)
+ resp, err = client.Do(req)
+ if err != nil {
+ t.Fatalf("unable to make request: %v", err)
+ }
+ if resp.StatusCode != http.StatusOK {
+ t.Fatalf("unexpected status: %v", resp.Status)
+ }
+ defer resp.Body.Close()
+ out, err := io.ReadAll(resp.Body)
+ if err != nil {
+ t.Fatalf("unable to read body: %v", err)
+ }
+ if string(out) != "ok" {
+ t.Fatalf("unexpected body: %q", out)
+ }
+}