1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
|
// Copyright (c) Tailscale Inc & contributors
// SPDX-License-Identifier: BSD-3-Clause
//go:build !plan9
package main
import (
"encoding/json"
"fmt"
"io"
"net/http"
"net/http/httptest"
"os"
"path/filepath"
"testing"
"go.uber.org/zap"
"golang.org/x/oauth2"
)
func TestNewStaticClient(t *testing.T) {
const (
clientIDFile = "client-id"
clientSecretFile = "client-secret"
)
tmp := t.TempDir()
clientIDPath := filepath.Join(tmp, clientIDFile)
if err := os.WriteFile(clientIDPath, []byte("test-client-id"), 0600); err != nil {
t.Fatalf("error writing test file %q: %v", clientIDPath, err)
}
clientSecretPath := filepath.Join(tmp, clientSecretFile)
if err := os.WriteFile(clientSecretPath, []byte("test-client-secret"), 0600); err != nil {
t.Fatalf("error writing test file %q: %v", clientSecretPath, err)
}
srv := testAPI(t, 3600)
cl, err := newTSClient(zap.NewNop().Sugar(), "", clientIDPath, clientSecretPath, srv.URL)
if err != nil {
t.Fatalf("error creating Tailscale client: %v", err)
}
resp, err := cl.HTTPClient.Get(srv.URL)
if err != nil {
t.Fatalf("error making test API call: %v", err)
}
defer resp.Body.Close()
got, err := io.ReadAll(resp.Body)
if err != nil {
t.Fatalf("error reading response body: %v", err)
}
want := "Bearer " + testToken("/api/v2/oauth/token", "test-client-id", "test-client-secret", "")
if string(got) != want {
t.Errorf("got %q; want %q", got, want)
}
}
func TestNewWorkloadIdentityClient(t *testing.T) {
// 5 seconds is within expiryDelta leeway, so the access token will
// immediately be considered expired and get refreshed on each access.
srv := testAPI(t, 5)
cl, err := newTSClient(zap.NewNop().Sugar(), "test-client-id", "", "", srv.URL)
if err != nil {
t.Fatalf("error creating Tailscale client: %v", err)
}
// Modify the path where the JWT will be read from.
oauth2Transport, ok := cl.HTTPClient.Transport.(*oauth2.Transport)
if !ok {
t.Fatalf("expected oauth2.Transport, got %T", cl.HTTPClient.Transport)
}
jwtTokenSource, ok := oauth2Transport.Source.(*jwtTokenSource)
if !ok {
t.Fatalf("expected jwtTokenSource, got %T", oauth2Transport.Source)
}
tmp := t.TempDir()
jwtPath := filepath.Join(tmp, "token")
jwtTokenSource.jwtPath = jwtPath
for _, jwt := range []string{"test-jwt", "updated-test-jwt"} {
if err := os.WriteFile(jwtPath, []byte(jwt), 0600); err != nil {
t.Fatalf("error writing test file %q: %v", jwtPath, err)
}
resp, err := cl.HTTPClient.Get(srv.URL)
if err != nil {
t.Fatalf("error making test API call: %v", err)
}
defer resp.Body.Close()
got, err := io.ReadAll(resp.Body)
if err != nil {
t.Fatalf("error reading response body: %v", err)
}
if want := "Bearer " + testToken("/api/v2/oauth/token-exchange", "test-client-id", "", jwt); string(got) != want {
t.Errorf("got %q; want %q", got, want)
}
}
}
func testAPI(t *testing.T, expirationSeconds int) *httptest.Server {
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
t.Logf("test server got request: %s %s", r.Method, r.URL.Path)
switch r.URL.Path {
case "/api/v2/oauth/token", "/api/v2/oauth/token-exchange":
id, secret, ok := r.BasicAuth()
if !ok {
t.Fatal("missing or invalid basic auth")
}
w.Header().Set("Content-Type", "application/json")
if err := json.NewEncoder(w).Encode(map[string]any{
"access_token": testToken(r.URL.Path, id, secret, r.FormValue("jwt")),
"token_type": "Bearer",
"expires_in": expirationSeconds,
}); err != nil {
t.Fatalf("error writing response: %v", err)
}
case "/":
// Echo back the authz header for test assertions.
_, err := w.Write([]byte(r.Header.Get("Authorization")))
if err != nil {
t.Fatalf("error writing response: %v", err)
}
default:
w.WriteHeader(http.StatusNotFound)
}
}))
t.Cleanup(srv.Close)
return srv
}
func testToken(path, id, secret, jwt string) string {
return fmt.Sprintf("%s|%s|%s|%s", path, id, secret, jwt)
}
|