summaryrefslogtreecommitdiffhomepage
path: root/types/lazy/deferred.go
blob: 582090ab93112ae793bc18167c50a2d6fc2fa51e (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
// Copyright (c) Tailscale Inc & contributors
// SPDX-License-Identifier: BSD-3-Clause

package lazy

import (
	"sync"
	"sync/atomic"

	"tailscale.com/types/ptr"
)

// DeferredInit allows one or more funcs to be deferred
// until [DeferredInit.Do] is called for the first time.
//
// DeferredInit is safe for concurrent use.
type DeferredInit struct {
	DeferredFuncs
}

// DeferredFuncs allows one or more funcs to be deferred
// until the owner's [DeferredInit.Do] method is called
// for the first time.
//
// DeferredFuncs is safe for concurrent use. The execution
// order of functions deferred by different goroutines is
// unspecified and must not be relied upon.
// However, functions deferred by the same goroutine are
// executed in the same relative order they were deferred.
// Warning: this is the opposite of the behavior of Go's
// defer statement, which executes deferred functions in
// reverse order.
type DeferredFuncs struct {
	m     sync.Mutex
	funcs []func() error

	// err is either:
	//    * nil, if deferred init has not yet been completed
	//    * nilErrPtr, if initialization completed successfully
	//    * non-nil and not nilErrPtr, if there was an error
	//
	// It is an atomic.Pointer so it can be read without m held.
	err atomic.Pointer[error]
}

// Defer adds a function to be called when [DeferredInit.Do]
// is called for the first time. It returns true on success,
// or false if [DeferredInit.Do] has already been called.
func (d *DeferredFuncs) Defer(f func() error) bool {
	d.m.Lock()
	defer d.m.Unlock()
	if d.err.Load() != nil {
		return false
	}
	d.funcs = append(d.funcs, f)
	return true
}

// MustDefer is like [DeferredFuncs.Defer], but panics
// if [DeferredInit.Do] has already been called.
func (d *DeferredFuncs) MustDefer(f func() error) {
	if !d.Defer(f) {
		panic("deferred init already completed")
	}
}

// Do calls previously deferred init functions if it is being called
// for the first time on this instance of [DeferredInit].
// It stops and returns an error if any init function returns an error.
//
// It is safe for concurrent use, and the deferred init is guaranteed
// to have been completed, either successfully or with an error,
// when Do() returns.
func (d *DeferredInit) Do() error {
	err := d.err.Load()
	if err == nil {
		err = d.doSlow()
	}
	return *err
}

func (d *DeferredInit) doSlow() (err *error) {
	d.m.Lock()
	defer d.m.Unlock()
	if err := d.err.Load(); err != nil {
		return err
	}
	defer func() {
		d.err.Store(err)
		d.funcs = nil // do not keep funcs alive after invoking
	}()
	for _, f := range d.funcs {
		if err := f(); err != nil {
			return ptr.To(err)
		}
	}
	return nilErrPtr
}

// Funcs is a shorthand for &d.DeferredFuncs.
// The returned value can safely be passed to external code,
// allowing to defer init funcs without also exposing [DeferredInit.Do].
func (d *DeferredInit) Funcs() *DeferredFuncs {
	return &d.DeferredFuncs
}