summaryrefslogtreecommitdiffhomepage
path: root/util/syspolicy/internal/metrics/test_handler.go
blob: 36c3f2cad876a29f1c364b0198238a10aecfc417 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
// Copyright (c) Tailscale Inc & AUTHORS
// SPDX-License-Identifier: BSD-3-Clause

package metrics

import (
	"strings"

	"tailscale.com/util/clientmetric"
	"tailscale.com/util/set"
	"tailscale.com/util/syspolicy/internal"
	"tailscale.com/util/testenv"
)

// TestState represents a metric name and its expected value.
type TestState struct {
	Name  string // `$os` in the name will be replaced by the actual operating system name.
	Value int64
}

// TestHandler facilitates testing of the code that uses metrics.
type TestHandler struct {
	t testenv.TB

	m map[string]int64
}

// NewTestHandler returns a new TestHandler.
func NewTestHandler(t testenv.TB) *TestHandler {
	return &TestHandler{t, make(map[string]int64)}
}

// AddMetric increments the metric with the specified name and type by delta d.
func (h *TestHandler) AddMetric(name string, typ clientmetric.Type, d int64) {
	h.t.Helper()
	if typ == clientmetric.TypeCounter && d < 0 {
		h.t.Fatalf("an attempt was made to decrement a counter metric %q", name)
	}
	if v, ok := h.m[name]; ok || d != 0 {
		h.m[name] = v + d
	}
}

// SetMetric sets the metric with the specified name and type to the value v.
func (h *TestHandler) SetMetric(name string, typ clientmetric.Type, v int64) {
	h.t.Helper()
	if typ == clientmetric.TypeCounter {
		h.t.Fatalf("an attempt was made to set a counter metric %q", name)
	}
	if _, ok := h.m[name]; ok || v != 0 {
		h.m[name] = v
	}
}

// MustEqual fails the test if the actual metric state differs from the specified state.
func (h *TestHandler) MustEqual(metrics ...TestState) {
	h.t.Helper()
	h.MustContain(metrics...)
	h.mustNoExtra(metrics...)
}

// MustContain fails the test if the specified metrics are not set or have
// different values than specified. It permits other metrics to be set in
// addition to the ones being tested.
func (h *TestHandler) MustContain(metrics ...TestState) {
	h.t.Helper()
	for _, m := range metrics {
		name := strings.ReplaceAll(m.Name, "$os", internal.OS())
		v, ok := h.m[name]
		if !ok {
			h.t.Errorf("%q: got (none), want %v", name, m.Value)
		} else if v != m.Value {
			h.t.Fatalf("%q: got %v, want %v", name, v, m.Value)
		}
	}
}

func (h *TestHandler) mustNoExtra(metrics ...TestState) {
	h.t.Helper()
	s := make(set.Set[string])
	for i := range metrics {
		s.Add(strings.ReplaceAll(metrics[i].Name, "$os", internal.OS()))
	}
	for n, v := range h.m {
		if !s.Contains(n) {
			h.t.Errorf("%q: got %v, want (none)", n, v)
		}
	}
}