summaryrefslogtreecommitdiffhomepage
path: root/wif/wif.go
blob: bb2e760f2c7b77b9bd0abf892c2742184e871374 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
// Copyright (c) Tailscale Inc & contributors
// SPDX-License-Identifier: BSD-3-Clause

// Package wif deals with obtaining ID tokens from provider VMs
// to be used as part of Workload Identity Federation
package wif

import (
	"context"
	"encoding/json"
	"errors"
	"fmt"
	"io"
	"net/http"
	"net/url"
	"os"
	"strings"
	"time"

	"github.com/aws/aws-sdk-go-v2/aws"
	"github.com/aws/aws-sdk-go-v2/config"
	"github.com/aws/aws-sdk-go-v2/feature/ec2/imds"
	"github.com/aws/aws-sdk-go-v2/service/sts"
	"github.com/aws/smithy-go"
	"tailscale.com/util/httpm"
)

type Environment string

const (
	EnvGitHub Environment = "github"
	EnvAWS    Environment = "aws"
	EnvGCP    Environment = "gcp"
	EnvNone   Environment = "none"
)

// ObtainProviderToken tries to detect what provider the client is running in
// and then tries to obtain an ID token for the audience that is passed as an argument
// To detect the environment, we do it in the following intentional order:
//  1. GitHub Actions (strongest env signals; may run atop any cloud)
//  2. AWS via IMDSv2 token endpoint (does not require env vars)
//  3. GCP via metadata header semantics
//  4. Azure via metadata endpoint
func ObtainProviderToken(ctx context.Context, audience string) (string, error) {
	env := detectEnvironment(ctx)

	switch env {
	case EnvGitHub:
		return acquireGitHubActionsIDToken(ctx, audience)
	case EnvAWS:
		return acquireAWSWebIdentityToken(ctx, audience)
	case EnvGCP:
		return acquireGCPMetadataIDToken(ctx, audience)
	default:
		return "", errors.New("could not detect environment; provide --id-token explicitly")
	}
}

func detectEnvironment(ctx context.Context) Environment {
	if os.Getenv("ACTIONS_ID_TOKEN_REQUEST_URL") != "" &&
		os.Getenv("ACTIONS_ID_TOKEN_REQUEST_TOKEN") != "" {
		return EnvGitHub
	}

	client := httpClient()
	if detectAWSIMDSv2(ctx, client) {
		return EnvAWS
	}
	if detectGCPMetadata(ctx, client) {
		return EnvGCP
	}
	return EnvNone
}

func httpClient() *http.Client {
	return &http.Client{
		Timeout: time.Second * 5,
	}
}

func detectAWSIMDSv2(ctx context.Context, client *http.Client) bool {
	req, err := http.NewRequestWithContext(ctx, httpm.PUT, "http://169.254.169.254/latest/api/token", nil)
	if err != nil {
		return false
	}
	req.Header.Set("X-aws-ec2-metadata-token-ttl-seconds", "1")

	resp, err := client.Do(req)
	if err != nil {
		return false
	}
	defer resp.Body.Close()

	return resp.StatusCode == http.StatusOK
}

func detectGCPMetadata(ctx context.Context, client *http.Client) bool {
	req, err := http.NewRequestWithContext(ctx, httpm.GET, "http://metadata.google.internal", nil)
	if err != nil {
		return false
	}
	req.Header.Set("Metadata-Flavor", "Google")

	resp, err := client.Do(req)
	if err != nil {
		return false
	}
	defer resp.Body.Close()

	return resp.Header.Get("Metadata-Flavor") == "Google"
}

type githubOIDCResponse struct {
	Value string `json:"value"`
}

func acquireGitHubActionsIDToken(ctx context.Context, audience string) (string, error) {
	reqURL := os.Getenv("ACTIONS_ID_TOKEN_REQUEST_URL")
	reqTok := os.Getenv("ACTIONS_ID_TOKEN_REQUEST_TOKEN")
	if reqURL == "" || reqTok == "" {
		return "", errors.New("missing ACTIONS_ID_TOKEN_REQUEST_URL/TOKEN (ensure workflow has permissions: id-token: write)")
	}

	u, err := url.Parse(reqURL)
	if err != nil {
		return "", fmt.Errorf("parse ACTIONS_ID_TOKEN_REQUEST_URL: %w", err)
	}
	if strings.TrimSpace(audience) != "" {
		q := u.Query()
		q.Set("audience", strings.TrimSpace(audience))
		u.RawQuery = q.Encode()
	}

	req, err := http.NewRequestWithContext(ctx, httpm.GET, u.String(), nil)
	if err != nil {
		return "", fmt.Errorf("build request: %w", err)
	}
	req.Header.Set("Authorization", "Bearer "+reqTok)
	req.Header.Set("Accept", "application/json")

	client := httpClient()
	resp, err := client.Do(req)
	if err != nil {
		return "", fmt.Errorf("request github oidc token: %w", err)
	}
	defer resp.Body.Close()

	if resp.StatusCode/100 != 2 {
		b, _ := io.ReadAll(io.LimitReader(resp.Body, 2048))
		return "", fmt.Errorf("github oidc token endpoint returned %s: %s", resp.Status, strings.TrimSpace(string(b)))
	}

	var tr githubOIDCResponse
	if err := json.NewDecoder(resp.Body).Decode(&tr); err != nil {
		return "", fmt.Errorf("decode github oidc response: %w", err)
	}
	if strings.TrimSpace(tr.Value) == "" {
		return "", errors.New("github oidc response contained empty token")
	}

	// GitHub response doesn't provide exp directly; caller can parse JWT if needed.
	return tr.Value, nil
}

func acquireAWSWebIdentityToken(ctx context.Context, audience string) (string, error) {
	// LoadDefaultConfig wires up the default credential chain (incl. IMDS).
	cfg, err := config.LoadDefaultConfig(ctx)
	if err != nil {
		return "", fmt.Errorf("load aws config: %w", err)
	}

	// Verify credentials are available before proceeding.
	if _, err := cfg.Credentials.Retrieve(ctx); err != nil {
		return "", fmt.Errorf("AWS credentials unavailable (instance profile/IMDS?): %w", err)
	}

	imdsClient := imds.NewFromConfig(cfg)
	region, err := imdsClient.GetRegion(ctx, &imds.GetRegionInput{})
	if err != nil {
		return "", fmt.Errorf("couldn't get AWS region: %w", err)
	}
	cfg.Region = region.Region

	stsClient := sts.NewFromConfig(cfg)
	in := &sts.GetWebIdentityTokenInput{
		Audience:         []string{strings.TrimSpace(audience)},
		SigningAlgorithm: aws.String("ES384"),
		DurationSeconds:  aws.Int32(300), // 5 minutes
	}

	out, err := stsClient.GetWebIdentityToken(ctx, in)
	if err != nil {
		var apiErr smithy.APIError
		if errors.As(err, &apiErr) {
			return "", fmt.Errorf("aws sts:GetWebIdentityToken failed (%s): %w", apiErr.ErrorCode(), err)
		}
		return "", fmt.Errorf("aws sts:GetWebIdentityToken failed: %w", err)
	}

	if out.WebIdentityToken == nil || strings.TrimSpace(*out.WebIdentityToken) == "" {
		return "", fmt.Errorf("aws sts:GetWebIdentityToken returned empty token")
	}

	return *out.WebIdentityToken, nil
}

func acquireGCPMetadataIDToken(ctx context.Context, audience string) (string, error) {
	u := "http://metadata.google.internal/computeMetadata/v1/instance/service-accounts/default/identity"
	v := url.Values{}
	v.Set("audience", strings.TrimSpace(audience))
	v.Set("format", "full")
	fullURL := u + "?" + v.Encode()

	req, err := http.NewRequestWithContext(ctx, httpm.GET, fullURL, nil)
	if err != nil {
		return "", fmt.Errorf("build request: %w", err)
	}
	req.Header.Set("Metadata-Flavor", "Google")

	client := httpClient()
	resp, err := client.Do(req)
	if err != nil {
		return "", fmt.Errorf("call gcp metadata identity endpoint: %w", err)
	}
	defer resp.Body.Close()

	if resp.StatusCode/100 != 2 {
		b, _ := io.ReadAll(io.LimitReader(resp.Body, 2048))
		return "", fmt.Errorf("gcp metadata identity endpoint returned %s: %s", resp.Status, strings.TrimSpace(string(b)))
	}

	b, err := io.ReadAll(io.LimitReader(resp.Body, 1024*1024))
	if err != nil {
		return "", fmt.Errorf("read gcp id token: %w", err)
	}
	jwt := strings.TrimSpace(string(b))
	if jwt == "" {
		return "", fmt.Errorf("gcp metadata returned empty token")
	}

	return jwt, nil
}