summaryrefslogtreecommitdiffhomepage
path: root/docs/webhooks/example.go
blob: 712028362c53e5cefa6d4f52280b40a643b4954a (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
// Copyright (c) Tailscale Inc & AUTHORS
// SPDX-License-Identifier: BSD-3-Clause

// Command webhooks provides example consumer code for Tailscale
// webhooks.
package main

import (
	"crypto/hmac"
	"crypto/sha256"
	"crypto/subtle"
	"encoding/hex"
	"encoding/json"
	"errors"
	"fmt"
	"io"
	"log"
	"net/http"
	"strconv"
	"strings"
	"time"
)

type event struct {
	Timestamp string            `json:"timestamp"`
	Version   int               `json:"version"`
	Type      string            `json:"type"`
	Tailnet   string            `json:"tailnet"`
	Message   string            `json:"message"`
	Data      map[string]string `json:"data"`
}

const (
	currentVersion = "v1"
	secret         = "tskey-webhook-xxxxx" // sensitive, here just as an example
)

var (
	errNotSigned     = errors.New("webhook has no signature")
	errInvalidHeader = errors.New("webhook has an invalid signature")
)

func main() {
	http.HandleFunc("/webhook", webhooksHandler)
	if err := http.ListenAndServe(":80", nil); err != nil {
		log.Fatal(err)
	}
}

func webhooksHandler(w http.ResponseWriter, req *http.Request) {
	defer req.Body.Close()
	events, err := verifyWebhookSignature(req, secret)
	if err != nil {
		log.Printf("error validating signature: %v\n", err)
	} else {
		log.Printf("events received %v\n", events)
		// Do something with your events. :)
	}

	// The handler should always report 2XX except in the case of
	// transient failures (e.g. database backend is down).
	// Otherwise your future events will be blocked by retries.
}

// verifyWebhookSignature checks the request's "Tailscale-Webhook-Signature"
// header to verify that the events were signed by your webhook secret.
// If verification fails, an error is reported.
// If verification succeeds, the list of contained events is reported.
func verifyWebhookSignature(req *http.Request, secret string) (events []event, err error) {
	defer req.Body.Close()

	// Grab the signature sent on the request header.
	timestamp, signatures, err := parseSignatureHeader(req.Header.Get("Tailscale-Webhook-Signature"))
	if err != nil {
		return nil, err
	}

	// Verify that the timestamp is recent.
	// Here, we use a threshold of 5 minutes.
	if timestamp.Before(time.Now().Add(-time.Minute * 5)) {
		return nil, fmt.Errorf("invalid header: timestamp older than 5 minutes")
	}

	// Form the expected signature.
	b, err := io.ReadAll(req.Body)
	if err != nil {
		return nil, err
	}
	mac := hmac.New(sha256.New, []byte(secret))
	mac.Write([]byte(fmt.Sprint(timestamp.Unix())))
	mac.Write([]byte("."))
	mac.Write(b)
	want := hex.EncodeToString(mac.Sum(nil))

	// Verify that the signatures match.
	var match bool
	for _, signature := range signatures[currentVersion] {
		if subtle.ConstantTimeCompare([]byte(signature), []byte(want)) == 1 {
			match = true
			break
		}
	}
	if !match {
		return nil, fmt.Errorf("signature does not match: want = %q, got = %q", want, signatures[currentVersion])
	}

	// If verified, return the events.
	if err := json.Unmarshal(b, &events); err != nil {
		return nil, err
	}
	return events, nil
}

// parseSignatureHeader splits header into its timestamp and included signatures.
// The signatures are reported as a map of version (e.g. "v1") to a list of signatures
// found with that version.
func parseSignatureHeader(header string) (timestamp time.Time, signatures map[string][]string, err error) {
	if header == "" {
		return time.Time{}, nil, fmt.Errorf("request has no signature")
	}

	signatures = make(map[string][]string)
	pairs := strings.Split(header, ",")
	for _, pair := range pairs {
		parts := strings.Split(pair, "=")
		if len(parts) != 2 {
			return time.Time{}, nil, errNotSigned
		}

		switch parts[0] {
		case "t":
			tsint, err := strconv.ParseInt(parts[1], 10, 64)
			if err != nil {
				return time.Time{}, nil, errInvalidHeader
			}
			timestamp = time.Unix(tsint, 0)
		case currentVersion:
			signatures[parts[0]] = append(signatures[parts[0]], parts[1])
		default:
			// Ignore unknown parts of the header.
			continue
		}
	}

	if len(signatures) == 0 {
		return time.Time{}, nil, errNotSigned
	}
	return
}