summaryrefslogtreecommitdiffhomepage
path: root/cmd/k8s-operator/tsclient.go
blob: 702f4cc537240abd5762512efc66e63cafb1d149 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
// Copyright (c) Tailscale Inc & contributors
// SPDX-License-Identifier: BSD-3-Clause

//go:build !plan9

package main

import (
	"context"
	"fmt"
	"net/url"
	"os"
	"sync"
	"time"

	"go.uber.org/zap"
	"golang.org/x/oauth2"
	"golang.org/x/oauth2/clientcredentials"
	"tailscale.com/client/tailscale/v2"

	"tailscale.com/ipn"
)

const (
	oidcJWTPath = "/var/run/secrets/tailscale/serviceaccount/token"
)

func newTSClient(logger *zap.SugaredLogger, clientID, clientIDPath, clientSecretPath, loginServer string) (*tailscale.Client, error) {
	baseURL := ipn.DefaultControlURL
	if loginServer != "" {
		baseURL = loginServer
	}

	base, err := url.Parse(baseURL)
	if err != nil {
		return nil, err
	}

	client := &tailscale.Client{
		UserAgent: "tailscale-k8s-operator",
		BaseURL:   base,
	}

	if clientID == "" {
		// Use static client credentials mounted to disk.
		clientIDBytes, err := os.ReadFile(clientIDPath)
		if err != nil {
			return nil, fmt.Errorf("error reading client ID %q: %w", clientIDPath, err)
		}
		clientSecretBytes, err := os.ReadFile(clientSecretPath)
		if err != nil {
			return nil, fmt.Errorf("reading client secret %q: %w", clientSecretPath, err)
		}

		client.Auth = &tailscale.OAuth{
			ClientID:     string(clientIDBytes),
			ClientSecret: string(clientSecretBytes),
		}
	} else {
		// Use workload identity federation.
		tokenSrc := &jwtTokenSource{
			logger:  logger,
			jwtPath: oidcJWTPath,
			baseCfg: clientcredentials.Config{
				ClientID: clientID,
				TokenURL: fmt.Sprintf("%s%s", baseURL, "/api/v2/oauth/token-exchange"),
			},
		}

		client.Auth = &tailscale.IdentityFederation{
			ClientID: clientID,
			IDTokenFunc: func() (string, error) {
				token, err := tokenSrc.Token()
				if err != nil {
					return "", err
				}

				return token.AccessToken, nil
			},
		}
	}

	return client, nil
}

// jwtTokenSource implements the [oauth2.TokenSource] interface, but with the
// ability to regenerate a fresh underlying token source each time a new value
// of the JWT parameter is needed due to expiration.
type jwtTokenSource struct {
	logger  *zap.SugaredLogger
	jwtPath string                   // Path to the file containing an automatically refreshed JWT.
	baseCfg clientcredentials.Config // Holds config that doesn't change for the lifetime of the process.

	mu         sync.Mutex         // Guards underlying.
	underlying oauth2.TokenSource // The oauth2 client implementation. Does its own separate caching of the access token.
}

func (s *jwtTokenSource) Token() (*oauth2.Token, error) {
	s.mu.Lock()
	defer s.mu.Unlock()

	if s.underlying != nil {
		t, err := s.underlying.Token()
		if err == nil && t != nil && t.Valid() {
			return t, nil
		}
	}

	s.logger.Debugf("Refreshing JWT from %s", s.jwtPath)
	tk, err := os.ReadFile(s.jwtPath)
	if err != nil {
		return nil, fmt.Errorf("error reading JWT from %q: %w", s.jwtPath, err)
	}

	// Shallow copy of the base config.
	credentials := s.baseCfg
	credentials.EndpointParams = map[string][]string{
		"jwt": {string(tk)},
	}

	src := credentials.TokenSource(context.Background())
	s.underlying = oauth2.ReuseTokenSourceWithExpiry(nil, src, time.Minute)
	return s.underlying.Token()
}