summaryrefslogtreecommitdiffhomepage
path: root/client/web/web_test.go
diff options
context:
space:
mode:
Diffstat (limited to 'client/web/web_test.go')
-rw-r--r--client/web/web_test.go82
1 files changed, 82 insertions, 0 deletions
diff --git a/client/web/web_test.go b/client/web/web_test.go
index b9242f6ac..291356260 100644
--- a/client/web/web_test.go
+++ b/client/web/web_test.go
@@ -11,6 +11,7 @@ import (
"fmt"
"io"
"net/http"
+ "net/http/cookiejar"
"net/http/httptest"
"net/netip"
"net/url"
@@ -20,6 +21,7 @@ import (
"time"
"github.com/google/go-cmp/cmp"
+ "github.com/gorilla/csrf"
"tailscale.com/client/local"
"tailscale.com/client/tailscale/apitype"
"tailscale.com/ipn"
@@ -1477,3 +1479,83 @@ func mockWaitAuthURL(_ context.Context, id string, src tailcfg.NodeID) (*tailcfg
return nil, errors.New("unknown id")
}
}
+
+func TestCSRFProtect(t *testing.T) {
+ s := &Server{}
+
+ mux := http.NewServeMux()
+ mux.HandleFunc("GET /test/csrf-token", func(w http.ResponseWriter, r *http.Request) {
+ token := csrf.Token(r)
+ _, err := io.WriteString(w, token)
+ if err != nil {
+ t.Fatal(err)
+ }
+ })
+ mux.HandleFunc("POST /test/csrf-protected", func(w http.ResponseWriter, r *http.Request) {
+ _, err := io.WriteString(w, "ok")
+ if err != nil {
+ t.Fatal(err)
+ }
+ })
+ h := s.withCSRF(mux)
+ ser := httptest.NewServer(h)
+ defer ser.Close()
+
+ jar, err := cookiejar.New(nil)
+ if err != nil {
+ t.Fatalf("unable to construct cookie jar: %v", err)
+ }
+
+ client := ser.Client()
+ client.Jar = jar
+
+ // make GET request to populate cookie jar
+ resp, err := client.Get(ser.URL + "/test/csrf-token")
+ if err != nil {
+ t.Fatalf("unable to make request: %v", err)
+ }
+ defer resp.Body.Close()
+ if resp.StatusCode != http.StatusOK {
+ t.Fatalf("unexpected status: %v", resp.Status)
+ }
+ tokenBytes, err := io.ReadAll(resp.Body)
+ if err != nil {
+ t.Fatalf("unable to read body: %v", err)
+ }
+
+ csrfToken := strings.TrimSpace(string(tokenBytes))
+ if csrfToken == "" {
+ t.Fatal("empty csrf token")
+ }
+
+ // make a POST request without the CSRF header; ensure it fails
+ resp, err = client.Post(ser.URL+"/test/csrf-protected", "text/plain", nil)
+ if err != nil {
+ t.Fatalf("unable to make request: %v", err)
+ }
+ if resp.StatusCode != http.StatusForbidden {
+ t.Fatalf("unexpected status: %v", resp.Status)
+ }
+
+ // make a POST request with the CSRF header; ensure it succeeds
+ req, err := http.NewRequest("POST", ser.URL+"/test/csrf-protected", nil)
+ if err != nil {
+ t.Fatalf("error building request: %v", err)
+ }
+ req.Header.Set("X-CSRF-Token", csrfToken)
+ resp, err = client.Do(req)
+ if err != nil {
+ t.Fatalf("unable to make request: %v", err)
+ }
+ if resp.StatusCode != http.StatusOK {
+ t.Fatalf("unexpected status: %v", resp.Status)
+ }
+ defer resp.Body.Close()
+ out, err := io.ReadAll(resp.Body)
+ if err != nil {
+ t.Fatalf("unable to read body: %v", err)
+ }
+ if string(out) != "ok" {
+ t.Fatalf("unexpected body: %q", out)
+ }
+}