summaryrefslogtreecommitdiffhomepage
path: root/util/execqueue/execqueue.go
blob: acf25f64569d6ccb3b32f50a8910ea5ec3f74324 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
// Copyright (c) Tailscale Inc & AUTHORS
// SPDX-License-Identifier: BSD-3-Clause

// Package execqueue implements an ordered asynchronous queue for executing functions.
package execqueue

import (
	"context"
	"errors"
	"sync"

	"tailscale.com/syncs"
)

type ExecQueue struct {
	regMutexOnce sync.Once

	mu         syncs.Mutex
	ctx        context.Context    // context.Background + closed on Shutdown
	cancel     context.CancelFunc // closes ctx
	closed     bool
	inFlight   bool          // whether a goroutine is running q.run
	doneWaiter chan struct{} // non-nil if waiter is waiting, then closed
	queue      []func()
}

func (q *ExecQueue) registerMutex() {
	syncs.RegisterMutex(&q.mu, "execqueue.ExecQueue.mu")
}

func (q *ExecQueue) Add(f func()) {
	q.regMutexOnce.Do(q.registerMutex)

	q.mu.Lock()
	defer q.mu.Unlock()
	if q.closed {
		return
	}
	q.initCtxLocked()
	if q.inFlight {
		q.queue = append(q.queue, f)
	} else {
		q.inFlight = true
		go q.run(f)
	}
}

// RunSync waits for the queue to be drained and then synchronously runs f.
// It returns an error if the queue is closed before f is run or ctx expires.
func (q *ExecQueue) RunSync(ctx context.Context, f func()) error {
	q.regMutexOnce.Do(q.registerMutex)

	q.mu.Lock()
	q.initCtxLocked()
	shutdownCtx := q.ctx
	q.mu.Unlock()

	ch := make(chan struct{})
	q.Add(f)
	q.Add(func() { close(ch) })
	select {
	case <-ch:
		return nil
	case <-ctx.Done():
		return ctx.Err()
	case <-shutdownCtx.Done():
		return errExecQueueShutdown
	}
}

func (q *ExecQueue) run(f func()) {
	f()

	q.mu.Lock()
	for len(q.queue) > 0 && !q.closed {
		f := q.queue[0]
		q.queue[0] = nil
		q.queue = q.queue[1:]
		q.mu.Unlock()
		f()
		q.mu.Lock()
	}
	q.inFlight = false
	q.queue = nil
	if q.doneWaiter != nil {
		close(q.doneWaiter)
		q.doneWaiter = nil
	}
	q.mu.Unlock()
}

// Shutdown asynchronously signals the queue to stop.
func (q *ExecQueue) Shutdown() {
	q.regMutexOnce.Do(q.registerMutex)

	q.mu.Lock()
	defer q.mu.Unlock()
	q.closed = true
	if q.cancel != nil {
		q.cancel()
	}
}

func (q *ExecQueue) initCtxLocked() {
	if q.ctx == nil {
		q.ctx, q.cancel = context.WithCancel(context.Background())
	}
}

var errExecQueueShutdown = errors.New("execqueue shut down")

// Wait waits for the queue to be empty or shut down.
func (q *ExecQueue) Wait(ctx context.Context) error {
	q.regMutexOnce.Do(q.registerMutex)

	q.mu.Lock()
	q.initCtxLocked()
	waitCh := q.doneWaiter
	if q.inFlight && waitCh == nil {
		waitCh = make(chan struct{})
		q.doneWaiter = waitCh
	}
	closed := q.closed
	shutdownCtx := q.ctx
	q.mu.Unlock()

	if closed {
		return errExecQueueShutdown
	}
	if waitCh == nil {
		return nil
	}

	select {
	case <-waitCh:
		return nil
	case <-shutdownCtx.Done():
		return errExecQueueShutdown
	case <-ctx.Done():
		return ctx.Err()
	}
}