summaryrefslogtreecommitdiffhomepage
path: root/util/reload/reload.go
blob: edcb90c12a3f20f1c340a628e9dabb4cf7d6837a (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
// Copyright (c) Tailscale Inc & contributors
// SPDX-License-Identifier: BSD-3-Clause

// Package reload contains functions that allow periodically reloading a value
// (e.g. a config file) from various sources.
package reload

import (
	"context"
	"encoding/json"
	"fmt"
	"math/rand/v2"
	"os"
	"reflect"
	"time"

	"tailscale.com/syncs"
	"tailscale.com/types/logger"
)

// DefaultInterval is the default value for ReloadOpts.Interval if none is
// provided.
const DefaultInterval = 5 * time.Minute

// ReloadOpts specifies options for reloading a value. Various helper functions
// in this package can be used to create one of these specialized for a given
// use-case.
type ReloadOpts[T any] struct {
	// Read is called to obtain the data to be unmarshaled; e.g. by reading
	// from a file, or making a network request, etc.
	//
	// An error from this function is fatal when calling New, but only a
	// warning during reload.
	//
	// This value is required.
	Read func(context.Context) ([]byte, error)

	// Unmarshal is called with the data that the Read function returns and
	// should return a parsed form of the given value, or an error.
	//
	// An error from this function is fatal when calling New, but only a
	// warning during reload.
	//
	// This value is required.
	Unmarshal func([]byte) (T, error)

	// Logf is a logger used to print errors that occur on reload. If nil,
	// no messages are printed.
	Logf logger.Logf

	// Interval is the interval at which to reload the given data from the
	// source; if zero, DefaultInterval will be used.
	Interval time.Duration

	// IntervalJitter is the jitter to be added to the given Interval; if
	// provided, a duration between 0 and this value will be added to each
	// Interval when waiting.
	IntervalJitter time.Duration
}

func (r *ReloadOpts[T]) logf(format string, args ...any) {
	if r.Logf != nil {
		r.Logf(format, args...)
	}
}

func (r *ReloadOpts[T]) intervalWithJitter() time.Duration {
	tt := r.Interval
	if tt == 0 {
		tt = DefaultInterval
	}
	if r.IntervalJitter == 0 {
		return tt
	}

	jitter := rand.N(r.IntervalJitter)
	return tt + jitter
}

// New creates and starts reloading the provided value as per opts. It returns
// a function that, when called, returns the current stored value, or an error
// that indicates something went wrong.
//
// The value will be present immediately upon return.
func New[T any](ctx context.Context, opts ReloadOpts[T]) (func() T, error) {
	// Create our reloader, which hasn't started.
	reloader, err := newUnstarted(ctx, opts)
	if err != nil {
		return nil, err
	}

	// Start it
	go reloader.run()

	// Return the load function now that we're all set up.
	return reloader.store.Load, nil
}

type reloader[T any] struct {
	ctx   context.Context
	store syncs.AtomicValue[T]
	opts  ReloadOpts[T]
}

// newUnstarted creates a reloader that hasn't yet been started.
func newUnstarted[T any](ctx context.Context, opts ReloadOpts[T]) (*reloader[T], error) {
	if opts.Read == nil {
		return nil, fmt.Errorf("the Read function is required")
	}
	if opts.Unmarshal == nil {
		return nil, fmt.Errorf("the Unmarshal function is required")
	}

	// Start by reading and unmarshaling the value.
	data, err := opts.Read(ctx)
	if err != nil {
		return nil, fmt.Errorf("reading initial value: %w", err)
	}

	initial, err := opts.Unmarshal(data)
	if err != nil {
		return nil, fmt.Errorf("unmarshaling initial value: %v", err)
	}

	reloader := &reloader[T]{
		ctx:  ctx,
		opts: opts,
	}
	reloader.store.Store(initial)
	return reloader, nil
}

func (r *reloader[T]) run() {
	// Create a timer that we re-set each time we fire.
	timer := time.NewTimer(r.opts.intervalWithJitter())
	defer timer.Stop()

	for {
		select {
		case <-r.ctx.Done():
			r.opts.logf("run context is done")
			return
		case <-timer.C:
		}

		if err := r.updateOnce(); err != nil {
			r.opts.logf("error refreshing data: %v", err)
		}

		// Re-arm the timer after we're done; this is safe
		// since the only way this loop woke up was by reading
		// from timer.C
		timer.Reset(r.opts.intervalWithJitter())
	}
}

func (r *reloader[T]) updateOnce() error {
	data, err := r.opts.Read(r.ctx)
	if err != nil {
		return fmt.Errorf("reading data: %w", err)
	}
	next, err := r.opts.Unmarshal(data)
	if err != nil {
		return fmt.Errorf("unmarshaling data: %w", err)
	}

	oldValue := r.store.Swap(next)
	if !reflect.DeepEqual(oldValue, next) {
		r.opts.logf("stored new value: %+v", next)
	}
	return nil
}

// FromJSONFile creates a ReloadOpts describing reloading a value of type T
// from the given JSON file on-disk.
func FromJSONFile[T any](path string) ReloadOpts[T] {
	return ReloadOpts[T]{
		Read: func(_ context.Context) ([]byte, error) {
			return os.ReadFile(path)
		},
		Unmarshal: func(b []byte) (T, error) {
			var ret, zero T
			if err := json.Unmarshal(b, &ret); err != nil {
				return zero, err
			}
			return ret, nil
		},
	}
}