summaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorNick Khyl <nickk@tailscale.com>2025-05-04 23:02:29 -0500
committerNick Khyl <nickk@tailscale.com>2025-05-04 23:15:41 -0500
commite744ea41c9d633c45187a9785e67bf66914a15fe (patch)
tree2f99a79bb6dce69b750ab90ec52d723e049e1698
parent64e5da8024661bb31cabc97902fe681e02c31318 (diff)
downloadtailscale-nickkhyl/context-with-lock.tar.xz
tailscale-nickkhyl/context-with-lock.zip
util/ctxlock: enforce mutex lock ordering defined by its ranknickkhyl/context-with-lock
Updates #12614 Signed-off-by: Nick Khyl <nickk@tailscale.com>
-rw-r--r--util/ctxlock/doc.go134
-rw-r--r--util/ctxlock/doc_test.go272
-rw-r--r--util/ctxlock/mutex.go45
-rw-r--r--util/ctxlock/mutex_test.go120
-rw-r--r--util/ctxlock/rank.go58
-rw-r--r--util/ctxlock/state.go228
-rw-r--r--util/ctxlock/state_checked.go244
-rw-r--r--util/ctxlock/state_test.go304
-rw-r--r--util/ctxlock/state_unchecked.go102
-rw-r--r--util/ctxlock/state_use_checked.go20
-rw-r--r--util/ctxlock/state_use_unchecked.go20
11 files changed, 1025 insertions, 522 deletions
diff --git a/util/ctxlock/doc.go b/util/ctxlock/doc.go
new file mode 100644
index 000000000..6cab258b8
--- /dev/null
+++ b/util/ctxlock/doc.go
@@ -0,0 +1,134 @@
+// Copyright (c) Tailscale Inc & AUTHORS
+// SPDX-License-Identifier: BSD-3-Clause
+
+// Package ctxlock provides a [Mutex] type and allows to define lock ordering
+// and reentrancy rules for mutexes using a [Rank]. It then enforces these
+// rules at runtime using a [State] hierarchy.
+//
+// The package has two implementations: checked and unchecked.
+//
+// Both implementations support reentrancy and lock ordering,
+// but the checked implementation performs additional runtime checks
+// and ensures that:
+// - a parent [LockHandle] is not unlocked before its child,
+// - a [LockHandle] is only unlocked once, and
+// - a [State] is not used after being unlocked.
+//
+// The unchecked implementation skips these checks for improved performance,
+// and is enabled in builds with the ts_omit_ctxlock_checks build tag.
+//
+// Example:
+//
+// type Resource struct {
+// mu Mutex[Reentrant]
+// value int
+// }
+//
+// func (r *Resource) GetValue(ctx State) int {
+// lock := Lock(ctx, &r.mu)
+// defer lock.Unlock()
+// return r.value
+// }
+//
+// func (r *Resource) SetValue(ctx State, v int) {
+// lock := Lock(ctx, &r.mu)
+// defer lock.Unlock()
+// r.value = v
+// }
+//
+// func (r *Resource) Foo(ctx State, cb func(State) int) int {
+// lock := Lock(ctx, &r.mu)
+// defer lock.Unlock()
+// return cb(lock.State())
+// }
+//
+// func main() {
+// r := Resource{}
+// r.SetValue(State{}, 42)
+// v := r.Foo(State{}, func(ctx State) int {
+// return r.GetValue(ctx)
+// })
+// fmt.Println(v) // prints 42
+// }
+package ctxlock
+
+import "context"
+
+// IsChecked indicates whether the checked implementation is used.
+const IsChecked = useCheckedImpl
+
+// A Mutex is a potentially reentrant mutual exclusion lock
+// with a lock hierarchy and reentrancy rules defined by its [Rank].
+// The zero value of a Mutex is valid and represents an unlocked mutex.
+//
+// The lock state of zero or more mutexes held by a given call chain
+// is carried by a [State].
+//
+// A mutex can be locked using [Lock]. The returned [LockHandle] becomes
+// the mutex's owner if the mutex wasn't already held by an ancestor [State].
+// It can be used to unlock the mutex or access the lock state hierarchy.
+//
+// It is a runtime error to lock a mutex if its rank's CheckLockAfter
+// reports a conflict with any mutex already held along the call chain.
+type Mutex[R Rank] struct {
+ mutex[R, lockState]
+}
+
+// ReentrantMutex is a reentrant [Mutex] with no defined lock hierarchy.
+type ReentrantMutex = Mutex[Reentrant]
+
+// State is a [context.Context] that carries the lock state of zero or more mutexes.
+//
+// Its zero value is valid and represents an unlocked state and an empty context.
+type State struct {
+ stateImpl
+}
+
+// None returns a zero [State].
+func None() State {
+ return State{}
+}
+
+// FromContext returns a [State] that carries the same lock state
+// as the given [context.Context].
+//
+// It's typically used when [context.Context] already handles
+// cancellation or deadlines and can be extended to locking as well.
+func FromContext(ctx context.Context) State {
+ return State{fromContext(ctx)}
+}
+
+// Lock locks the specified mutex and becomes its owner, unless it is
+// already held by the parent or its ancestor. It returns a [LockHandle]
+// that can be used to unlock the mutex or access the modified lock [State].
+//
+// The parent can be either a [State] or a [context.Context].
+// A zero State is a valid parent.
+//
+// It is a runtime error to pass a nil mutex or to unlock the parent's
+// [LockHandle] before the returned one.
+func Lock[T context.Context, R Rank](parent T, mu *Mutex[R]) LockHandle {
+ //return LockHandle{lock(parent, &mu.mutex)}
+ if parent, ok := any(parent).(State); ok {
+ return LockHandle{lock(parent.stateImpl, &mu.mutex)}
+ }
+ return LockHandle{lock(fromContext(parent), &mu.mutex)}
+}
+
+// LockHandle allows releasing a mutex acquired with [Lock]
+// and provides access to the lock state hierarchy.
+type LockHandle struct {
+ state stateImpl
+}
+
+// State returns the current lock state.
+func (h LockHandle) State() State {
+ return State{h.state}
+}
+
+// Unlock releases the mutex owned by the handle, if any.
+// It is a runtime error to call Unlock more than once on the same handle,
+// or to unlock a [LockHandle] while its associated [State] is still in use.
+func (h LockHandle) Unlock() {
+ h.state.unlock()
+}
diff --git a/util/ctxlock/doc_test.go b/util/ctxlock/doc_test.go
index a6b0de407..d1c32f9ba 100644
--- a/util/ctxlock/doc_test.go
+++ b/util/ctxlock/doc_test.go
@@ -6,75 +6,130 @@ package ctxlock_test
import (
"context"
"fmt"
- "sync"
+ "strings"
"testing"
"tailscale.com/util/ctxlock"
)
-type Resource struct {
- mu sync.Mutex
- foo, bar string
-}
+func ExampleMutex_reentrant() {
+ var mu ctxlock.ReentrantMutex // shorthand for ctxlock.Mutex[ctxlock.Reentrant]
-func (r *Resource) GetFoo(ctx ctxlock.State) string {
- // Lock the mutex if not already held.
- defer ctxlock.Lock(ctx, &r.mu).Unlock()
- return r.foo
-}
+ // The mutex is reentrant, so foo can be called with or without holding the mu.
+ // If mu is not already held, it will be locked on entry and unlocked on exit.
+ // The [ctxlock.State] parameter carries the current lock state.
+ foo := func(ctx ctxlock.State, msg string) {
+ lock := ctxlock.Lock(ctx, &mu)
+ defer lock.Unlock()
+ fmt.Println(msg)
+ }
-func (r *Resource) SetFoo(ctx ctxlock.State, foo string) {
- // You can do it this way, if you prefer
- // or if you need to pass the state to another function.
- ctx = ctxlock.Lock(ctx, &r.mu)
- defer ctx.Unlock()
- r.foo = foo
-}
+ // Calling foo without holding the lock.
+ foo(ctxlock.None(), "no lock")
-func (r *Resource) GetBar(ctx ctxlock.State) string {
- defer ctxlock.Lock(ctx, &r.mu).Unlock()
- return r.bar
-}
+ // Locking the mutex and calling foo again.
+ lock := ctxlock.Lock(ctxlock.None(), &mu)
+ foo(lock.State(), "with lock")
+ defer lock.Unlock()
-func (r *Resource) SetBar(ctx ctxlock.State, bar string) {
- defer ctxlock.Lock(ctx, &r.mu).Unlock()
- r.bar = bar
+ // Output:
+ // no lock
+ // with lock
}
-func (r *Resource) WithLock(ctx ctxlock.State, f func(ctx ctxlock.State)) {
- // Lock the mutex if not already held, and get a new state.
- ctx = ctxlock.Lock(ctx, &r.mu)
- defer ctx.Unlock()
- f(ctx) // Call the callback with the new lock state.
+func ExampleMutex_nonReentrant() {
+ var mu ctxlock.Mutex[ctxlock.NonReentrant]
+
+ // The mutex is non-reentrant, so foo must only be called without holding the mu.
+ // If mu is already held, it will panic attempting to lock it again.
+ foo := func(ctx ctxlock.State, msg string) {
+ defer func() {
+ if r := recover(); r != nil {
+ fmt.Println("panic:", trimPanicMessage(r))
+ }
+ }()
+
+ lock := ctxlock.Lock(ctx, &mu)
+ defer lock.Unlock()
+ fmt.Println(msg)
+ }
+
+ // Calling foo without holding the lock.
+ foo(ctxlock.None(), "no lock")
+
+ // Locking the mutex and calling foo again.
+ // This will panic because the mutex is non-reentrant.
+ lock := ctxlock.Lock(ctxlock.None(), &mu)
+ foo(lock.State(), "with lock")
+ defer lock.Unlock()
+
+ // Output:
+ // no lock
+ // panic: non-reentrant mutex already locked
}
-func (r *Resource) HandleRequest(ctx context.Context, foo, bar string, f func(ls ctxlock.State) string) string {
- // Same, but with a standard [context.Context] instead of [ctxlock.State].
- // [ctxlock.Lock] is generic and works with both without allocating.
- // The ctx can be used for cancellation, etc.
- mu := ctxlock.Lock(ctx, &r.mu)
- defer mu.Unlock()
- r.foo = foo
- r.bar = bar
- return f(mu)
+func ExampleRank() {
+ var mu1 ctxlock.Mutex[rank1] // cannot be locked after mu2 or mu3
+ var mu2 ctxlock.Mutex[rank2] // cannot be locked after mu3
+ var mu3 ctxlock.Mutex[rank3]
+
+ lock := ctxlock.Lock(ctxlock.None(), &mu1)
+ defer lock.Unlock()
+ fmt.Println("locked mu1")
+
+ lock = ctxlock.Lock(lock.State(), &mu2)
+ defer lock.Unlock()
+ fmt.Println("locked mu2")
+
+ lock = ctxlock.Lock(lock.State(), &mu3)
+ defer lock.Unlock()
+ fmt.Println("locked mu3")
+
+ // Output:
+ // locked mu1
+ // locked mu2
+ // locked mu3
}
-func (r *Resource) HandleIntRequest(ctx context.Context, foo, bar string, f func(ls ctxlock.State) int) int {
- // Same, but returns an int instead of a string,
- // and must not allocate with the unchecked implementation.
- mu := ctxlock.Lock(ctx, &r.mu)
- defer mu.Unlock()
- r.foo = foo
- r.bar = bar
- return f(mu)
+func ExampleRank_lockOrderViolation() {
+ var mu1 ctxlock.Mutex[rank1] // cannot be locked after mu2 or mu3
+ var mu2 ctxlock.Mutex[rank2] // cannot be locked after mu3
+ var mu3 ctxlock.Mutex[rank3]
+
+ defer func() {
+ if r := recover(); r != nil {
+ fmt.Println("panic:", trimPanicMessage(r))
+ }
+ }()
+
+ // While we can lock mu2 first...
+ lock := ctxlock.Lock(ctxlock.None(), &mu2)
+ defer lock.Unlock()
+ fmt.Println("locked mu2")
+
+ // ...and then mu3...
+ lock = ctxlock.Lock(lock.State(), &mu3)
+ defer lock.Unlock()
+ fmt.Println("locked mu3")
+
+ // It is a lock order violation to lock mu1
+ // after either mu2 or mu3.
+ lock = ctxlock.Lock(lock.State(), &mu1)
+ defer lock.Unlock()
+ fmt.Println("locked mu1")
+
+ // Output:
+ // locked mu2
+ // locked mu3
+ // panic: cannot lock ctxlock_test.rank1 after ctxlock_test.rank3
}
-func ExampleState() {
+func ExampleState_resource() {
var r Resource
r.SetFoo(ctxlock.None(), "foo")
r.SetBar(ctxlock.None(), "bar")
r.WithLock(ctxlock.None(), func(ctx ctxlock.State) {
- // This callback is invoked with r's lock held,
+ // This callback is invoked with r's mutex held,
// and ctx carries the lock state. This means we can safely call
// other methods on r using ctx without causing a deadlock.
r.SetFoo(ctx, r.GetFoo(ctx)+r.GetBar(ctx))
@@ -88,7 +143,7 @@ func ExampleState_twoResources() {
r1.SetFoo(ctxlock.None(), "foo")
r2.SetBar(ctxlock.None(), "bar")
r1.WithLock(ctxlock.None(), func(ctx ctxlock.State) {
- // Here, r1's lock is held, but r2's lock is not.
+ // Here, r1's mutex is held, but r2's mutex is not.
// So r2 will be locked when we call r2.GetBar(ctx).
r1.SetFoo(ctx, r1.GetFoo(ctx)+r2.GetBar(ctx))
})
@@ -96,29 +151,27 @@ func ExampleState_twoResources() {
// Output: foobar
}
-func ExampleState_stdContext() {
+func ExampleState_withStdContext() {
var r Resource
ctx := context.Background()
result := r.HandleRequest(ctx, "foo", "bar", func(ctx ctxlock.State) string {
- // The r's lock is held, and ctx carries the lock state.
+ // The r's mutex is held, and ctx carries the lock state.
return r.GetFoo(ctx) + r.GetBar(ctx)
})
fmt.Println(result)
// Output: foobar
}
-func TestAllocFree(t *testing.T) {
- if ctxlock.Checked {
- t.Skip("Exported implementation is not alloc-free (use --tags=ts_omit_ctxlock_checks)")
+func TestEndToEndAllocFree(t *testing.T) {
+ if ctxlock.IsChecked {
+ t.Skip("Exported implementation is not alloc-free (use --tags=ts_omit_ctxlock_checks).")
}
var r Resource
- ctx := context.Background()
-
- const runs = 1000
- if allocs := testing.AllocsPerRun(runs, func() {
- res := r.HandleIntRequest(ctx, "foo", "bar", func(ctx ctxlock.State) int {
- // The r's lock is held, and ctx carries the lock state.
+ const N = 1000
+ if allocs := testing.AllocsPerRun(N, func() {
+ res := r.HandleIntRequest(context.Background(), "foo", "bar", func(ctx ctxlock.State) int {
+ // The r's mutex is held, and ctx carries the lock state.
return len(r.GetFoo(ctx) + r.GetBar(ctx))
})
if res != 6 {
@@ -128,3 +181,102 @@ func TestAllocFree(t *testing.T) {
t.Errorf("expected 0 allocs, got %f", allocs)
}
}
+
+type (
+ rank1 struct{}
+ rank2 struct{}
+ rank3 struct{}
+)
+
+// CheckLockAfter implements [ctxlock.Rank].
+func (r rank1) CheckLockAfter(r2 ctxlock.Rank) error {
+ switch r2.(type) {
+ case rank2, rank3:
+ return fmt.Errorf("cannot lock %T after %T", r, r2)
+ default:
+ return nil
+ }
+}
+
+// CheckLockAfter implements [ctxlock.Rank].
+func (r rank2) CheckLockAfter(r2 ctxlock.Rank) error {
+ switch r2.(type) {
+ case rank2, rank3:
+ return fmt.Errorf("cannot lock %T after %T", r, r2)
+ default:
+ return nil
+ }
+}
+
+// CheckLockAfter implements [ctxlock.Rank].
+func (a rank3) CheckLockAfter(b ctxlock.Rank) error {
+ return nil
+}
+
+type Resource struct {
+ mu ctxlock.ReentrantMutex
+ foo, bar string
+}
+
+func (r *Resource) GetFoo(ctx ctxlock.State) string {
+ // Lock the mutex if not already held,
+ // and unlock it when the function returns.
+ defer ctxlock.Lock(ctx, &r.mu).Unlock()
+ return r.foo
+}
+
+func (r *Resource) SetFoo(ctx ctxlock.State, foo string) {
+ // You can do it this way, if you prefer.
+ mu := ctxlock.Lock(ctx, &r.mu)
+ defer mu.Unlock()
+ r.foo = foo
+}
+
+func (r *Resource) GetBar(ctx ctxlock.State) string {
+ mu := ctxlock.Lock(ctx, &r.mu)
+ defer mu.Unlock()
+ return r.bar
+}
+
+func (r *Resource) SetBar(ctx ctxlock.State, bar string) {
+ mu := ctxlock.Lock(ctx, &r.mu)
+ defer mu.Unlock()
+ r.bar = bar
+}
+
+func (r *Resource) WithLock(ctx ctxlock.State, f func(ctx ctxlock.State)) {
+ mu := ctxlock.Lock(ctx, &r.mu)
+ defer mu.Unlock()
+ // Call the callback with the new lock state.
+ f(mu.State())
+}
+
+func (r *Resource) HandleRequest(ctx context.Context, foo, bar string, f func(ls ctxlock.State) string) string {
+ // Same, but with a standard [context.Context] instead of [ctxlock.State].
+ // [ctxlock.Lock] is generic and works with both without allocating.
+ // The ctx can be used for cancellation, etc.
+ mu := ctxlock.Lock(ctx, &r.mu)
+ defer mu.Unlock()
+ r.foo = foo
+ r.bar = bar
+ return f(mu.State())
+}
+
+func (r *Resource) HandleIntRequest(ctx context.Context, foo, bar string, f func(ls ctxlock.State) int) int {
+ // Same, but returns an int instead of a string.
+ // It must not allocate with the checked implementation.
+ mu := ctxlock.Lock(ctx, &r.mu)
+ defer mu.Unlock()
+ r.foo = foo
+ r.bar = bar
+ return f(mu.State())
+}
+
+func trimPanicMessage(r any) string {
+ msg := fmt.Sprintf("%v", r)
+ msg = strings.TrimSpace(msg)
+ if i := strings.IndexByte(msg, '\n'); i >= 0 {
+ return msg[:i]
+ }
+ return msg
+}
diff --git a/util/ctxlock/mutex.go b/util/ctxlock/mutex.go
new file mode 100644
index 000000000..3d00ef7e6
--- /dev/null
+++ b/util/ctxlock/mutex.go
@@ -0,0 +1,45 @@
+// Copyright (c) Tailscale Inc & AUTHORS
+// SPDX-License-Identifier: BSD-3-Clause
+
+package ctxlock
+
+import (
+ "sync"
+)
+
+// mutex is a wrapper around [sync.Mutex] that associates a [Rank] with the mutex
+// and provides storage for an arbitrary value (of type S) to be used by the state
+// that owns the lock while it is held. It's exported as [Mutex] in the package API.
+type mutex[R Rank, S any] struct {
+ // r is the rank of the mutex, used to check lock order.
+ r R
+ // m is the underlying mutex that provides the locking mechanism.
+ m sync.Mutex
+ // lockState is a memory region used by the state that owns the lock while it is held.
+ // It serves as pre-allocated lockState to avoid (in the [unchecked] case)
+ // or reduce (in the [checked] case) memory allocations.
+ lockState S
+}
+
+func (m *mutex[R, S]) rank() Rank {
+ return m.r
+}
+
+func (m *mutex[R, S]) lock() {
+ m.m.Lock()
+}
+
+func (m *mutex[R, S]) state() any {
+ return &m.lockState
+}
+
+func (m *mutex[R, S]) unlock() {
+ m.m.Unlock()
+}
+
+// mutexHandle is a subset of the [mutex] methods that are used once the mutex is locked.
+type mutexHandle interface {
+ rank() Rank
+ state() any
+ unlock()
+}
diff --git a/util/ctxlock/mutex_test.go b/util/ctxlock/mutex_test.go
new file mode 100644
index 000000000..5713ee8fb
--- /dev/null
+++ b/util/ctxlock/mutex_test.go
@@ -0,0 +1,120 @@
+// Copyright (c) Tailscale Inc & AUTHORS
+// SPDX-License-Identifier: BSD-3-Clause
+
+package ctxlock
+
+import (
+ "context"
+ "fmt"
+ "testing"
+)
+
+func BenchmarkReentrantMutex(b *testing.B) {
+ b.ReportAllocs()
+ // Does not allocate with --tags=ts_omit_ctxlock_checks.
+ b.Run("ctxlock.State", func(b *testing.B) {
+ var mu ReentrantMutex
+ for b.Loop() {
+ reentrantMutexLockUnlock(&mu, None)
+ }
+ })
+ b.Run("context.Context", func(b *testing.B) {
+ var mu ReentrantMutex
+ for b.Loop() {
+ reentrantMutexLockUnlock(&mu, context.Background)
+ }
+ })
+}
+
+func TestReentrantMutexAllocFree(t *testing.T) {
+ if IsChecked {
+ t.Skip("Exported implementation is not alloc-free (use --tags=ts_omit_ctxlock_checks).")
+ }
+
+ const N = 1000
+ t.Run("ctxlock.State", func(t *testing.T) {
+ var mu ReentrantMutex
+ if allocs := testing.AllocsPerRun(N, func() {
+ reentrantMutexLockUnlock(&mu, None)
+ }); allocs != 0 {
+ t.Errorf("expected 0 allocs, got %f", allocs)
+ }
+ })
+ t.Run("context.Context", func(t *testing.T) {
+ var mu ReentrantMutex
+ if allocs := testing.AllocsPerRun(N, func() {
+ reentrantMutexLockUnlock(&mu, context.Background)
+ }); allocs != 0 {
+ t.Errorf("expected 0 allocs, got %f", allocs)
+ }
+ })
+}
+
+func reentrantMutexLockUnlock[T context.Context](mu *ReentrantMutex, rootState func() T) {
+ parent := Lock(rootState(), mu)
+ func(state State) {
+ child := Lock(state, mu)
+ child.Unlock()
+ }(parent.State())
+ parent.Unlock()
+}
+
+func TestMutexRank(t *testing.T) {
+ var m1 mutex1
+ var m2 mutex2
+ var m3 mutex3
+ // Locking m1, m2, and m3 in order is valid.
+ lock := Lock(None(), &m1)
+ defer lock.Unlock()
+ lock = Lock(lock.State(), &m2)
+ defer lock.Unlock()
+ lock = Lock(lock.State(), &m3)
+ defer lock.Unlock()
+}
+
+func TestMutexLockOrderViolation(t *testing.T) {
+ var m1 mutex1
+ var m2 mutex2
+ var m3 mutex3
+ // Locking m2 m3, and then m1 is invalid.
+ lock := Lock(None(), &m2)
+ defer lock.Unlock()
+ lock = Lock(lock.State(), &m3)
+ defer lock.Unlock()
+ wantPanic(t, "cannot lock ctxlock.testRank1 after ctxlock.testRank3", func() {
+ lock := Lock(lock.State(), &m1)
+ defer lock.Unlock()
+ })
+}
+
+type (
+ testRank1 struct{}
+ testRank2 struct{}
+ testRank3 struct{}
+
+ mutex1 = Mutex[testRank1]
+ mutex2 = Mutex[testRank2]
+ mutex3 = Mutex[testRank3]
+)
+
+func (r testRank1) CheckLockAfter(r2 Rank) error {
+ switch r2.(type) {
+ case testRank2, testRank3:
+ return fmt.Errorf("cannot lock %T after %T", r, r2)
+ default:
+ return nil
+ }
+}
+
+func (r testRank2) CheckLockAfter(r2 Rank) error {
+ switch r2.(type) {
+ case testRank2, testRank3:
+ return fmt.Errorf("cannot lock %T after %T", r, r2)
+ default:
+ return nil
+ }
+}
+
+func (a testRank3) CheckLockAfter(b Rank) error {
+ return nil
+}
diff --git a/util/ctxlock/rank.go b/util/ctxlock/rank.go
new file mode 100644
index 000000000..0e1661f43
--- /dev/null
+++ b/util/ctxlock/rank.go
@@ -0,0 +1,58 @@
+// Copyright (c) Tailscale Inc & AUTHORS
+// SPDX-License-Identifier: BSD-3-Clause
+
+package ctxlock
+
+// A Rank defines the locking rules for a [Mutex].
+//
+// Typically, a distinct [Rank] type is defined for each mutex
+// that requires specific locking order.
+//
+// Example:
+//
+// type (
+// fooRank struct{} // fooRank must not be locked after barRank
+// barRank struct{}
+// )
+//
+// func (r fooRank) CheckLockAfter(r2 Rank) error {
+// switch r2.(type) {
+// case barRank:
+// return fmt.Errorf("cannot lock %T after %T", r, r2)
+// default:
+// return nil
+// }
+// }
+//
+// func (r barRank) CheckLockAfter(r2 Rank) error {
+// return nil // barRank can be locked anytime
+// }
+//
+// type Foo struct {
+// mu Mutex[fooRank]
+// }
+//
+// type Bar struct {
+// mu Mutex[barRank]
+// }
+type Rank interface {
+ // CheckLockAfter returns an error if locking the receiver
+ // after the given rank would violate lock ordering or reentrancy rules.
+ CheckLockAfter(Rank) error
+}
+
+// Reentrant is a [Rank] that does not enforce any locking order and allows reentrancy.
+//
+// It is used by a pre-defined [ReentrantMutex] type.
+type Reentrant struct {
+ noRank
+}
+
+// NonReentrant is a [Rank] that does not enforce any locking order, but disallows reentrancy.
+type NonReentrant struct {
+ noRank
+}
+
+type noRank struct{}
+
+func (noRank) CheckLockAfter(Rank) error { return nil }
diff --git a/util/ctxlock/state.go b/util/ctxlock/state.go
deleted file mode 100644
index 4ec9857e6..000000000
--- a/util/ctxlock/state.go
+++ /dev/null
@@ -1,228 +0,0 @@
-// Copyright (c) Tailscale Inc & AUTHORS
-// SPDX-License-Identifier: BSD-3-Clause
-
-// Package ctxlock provides a [context.Context] implementation that carries mutex lock state
-// and enables reentrant locking. It offers two implementations: checked and unchecked.
-// The checked implementation performs runtime validation to ensure that:
-// - a parent context is not unlocked before its child,
-// - a context is only unlocked once, and
-// - a context is not used after being unlocked.
-// The unchecked implementation skips these checks for improved performance.
-// It defaults to the checked implementation unless the ts_omit_ctxlock_checks build tag is set.
-package ctxlock
-
-// This file contains both the [checked] and [unchecked] implementations of [State].
-
-import (
- "context"
- "fmt"
- "reflect"
- "sync"
- "time"
-)
-
-type ctxKey struct{ *sync.Mutex }
-
-func ctxKeyOf(mu *sync.Mutex) ctxKey {
- return ctxKey{mu}
-}
-
-// checked is an implementation of [State] that performs runtime checks
-// to ensure the correct order of locking and unlocking.
-//
-// Its zero value and a nil pointer are valid and carry no lock state
-// and an empty [context.Context].
-type checked struct {
- context.Context // nil means an empty context
-
- // mu is the mutex tracked by this state,
- // or nil if it wasn't created with [Lock].
- mu *sync.Mutex
-
- // parent is an ancestor State associated with the same mutex.
- // It may or may not own the lock (the lock could be held by a further ancestor).
- // The parent is nil if this State is the root of the hierarchy,
- // meaning it owns the lock.
- parent *checked
-
- // unlocked is whether [checked.Unlock] was called on this state.
- unlocked bool
-}
-
-func fromContextChecked(ctx context.Context) *checked {
- return &checked{ctx, nil, nil, false}
-}
-
-func lockChecked(parent *checked, mu *sync.Mutex) *checked {
- panicIfNil(mu)
- if parentState, ok := parent.Value(ctxKeyOf(mu)).(*checked); ok {
- if appearsUnlocked(mu) {
- // The parent is already unlocked, but the mutex is not.
- panic(fmt.Sprintf("%T is spuriously unlocked", mu))
- }
- return &checked{parent, mu, parentState, false}
- }
- mu.Lock()
- return &checked{parent, mu, nil, false}
-}
-
-func (c *checked) Deadline() (deadline time.Time, ok bool) {
- c.panicIfUnlocked()
- if c == nil || c.Context == nil {
- return time.Time{}, false
- }
- return c.Context.Deadline()
-}
-
-func (c *checked) Done() <-chan struct{} {
- c.panicIfUnlocked()
- if c == nil || c.Context == nil {
- return nil
- }
- return c.Context.Done()
-}
-
-func (c *checked) Err() error {
- c.panicIfUnlocked()
- if c == nil || c.Context == nil {
- return nil
- }
- return c.Context.Err()
-}
-
-func (c *checked) Value(key any) any {
- c.panicIfUnlocked()
- if c == nil {
- // No-op; zero state.
- return nil
- }
- if key, ok := key.(ctxKey); ok && key.Mutex == c.mu {
- // This is the mutex tracked by this state.
- return c
- }
- if c.Context != nil {
- // Forward the call to the parent context,
- // which may or may not be a [checked] state.
- return c.Context.Value(key)
- }
- return nil
-}
-
-func (c *checked) Unlock() {
- switch {
- case c == nil:
- // No-op; zero state.
- return
- case c.unlocked:
- panic("already unlocked")
- case c.mu == nil:
- // No-op; the state does not track a mutex lock state,
- // meaning it was not created with [Lock].
- case c.parent == nil:
- // The state own the mutex's lock; we must unlock it.
- // This triggers a fatal error if the mutex is already unlocked.
- c.mu.Unlock()
- case c.parent.unlocked:
- // The parent state is already unlocked.
- // The mutex may or may not be locked;
- // something else may have already locked it.
- panic("parent already unlocked")
- case appearsUnlocked(c.mu):
- // The mutex itself is unlocked,
- // even though the parent state is still locked.
- // It may be unlocked by an ancestor state
- // or by something else entirely.
- panic("mutex is not locked")
- default:
- // No-op; a parent or ancestor will handle unlocking.
- }
- c.unlocked = true // mark this state as unlocked
-}
-
-func (c *checked) panicIfUnlocked() {
- if c != nil && c.unlocked {
- panic("use after unlock")
- }
-}
-
-func panicIfNil[T comparable](v T) {
- if reflect.ValueOf(v).IsNil() {
- panic(fmt.Sprintf("nil %T", v))
- }
-}
-
-// unchecked is an implementation of [State] that trades runtime checks for performance.
-//
-// Its zero value carries no mutex lock state and an empty [context.Context].
-type unchecked struct {
- context.Context // nil means an empty context
- mu *sync.Mutex // non-nil if owned by this state
-}
-
-func fromContextUnchecked(ctx context.Context) unchecked {
- return unchecked{ctx, nil}
-}
-
-func lockUnchecked(parent unchecked, mu *sync.Mutex) unchecked {
- if parent.Value(ctxKeyOf(mu)) == nil {
- // There's no ancestor state associated with this mutex,
- // so we can lock it.
- mu.Lock()
- } else {
- // The mutex is already locked by a parent/ancestor state.
- mu = nil
- }
- return unchecked{parent.Context, mu}
-}
-
-func (c unchecked) Deadline() (deadline time.Time, ok bool) {
- if c.Context == nil {
- return time.Time{}, false
- }
- return c.Context.Deadline()
-}
-
-func (c unchecked) Done() <-chan struct{} {
- if c.Context == nil {
- return nil
- }
- return c.Context.Done()
-}
-
-func (c unchecked) Err() error {
- if c.Context == nil {
- return nil
- }
- return c.Context.Err()
-}
-
-func (c unchecked) Value(key any) any {
- if key, ok := key.(ctxKey); ok && key.Mutex == c.mu {
- return key
- }
- if c.Context == nil {
- return nil
- }
- return c.Context.Value(key)
-}
-
-func (c unchecked) Unlock() {
- if c.mu != nil {
- c.mu.Unlock()
- }
-}
-
-type tryLocker interface {
- TryLock() bool
- Unlock()
-}
-
-// appearsUnlocked reports whether m is unlocked.
-// It may return a false negative if m does not have a TryLock method.
-func appearsUnlocked[T sync.Locker](m T) bool {
- if m, ok := any(m).(tryLocker); ok && m.TryLock() {
- m.Unlock()
- return true
- }
- return false
-}
diff --git a/util/ctxlock/state_checked.go b/util/ctxlock/state_checked.go
index 4705b9c19..56cf193e7 100644
--- a/util/ctxlock/state_checked.go
+++ b/util/ctxlock/state_checked.go
@@ -1,53 +1,229 @@
// Copyright (c) Tailscale Inc & AUTHORS
// SPDX-License-Identifier: BSD-3-Clause
-// This file exports default, unoptimized implementation of the [State] that includes runtime checks.
-// It is used unless the build tag ts_omit_ctxlock_checks is set.
-
-//go:build !ts_omit_ctxlock_checks
-
package ctxlock
import (
"context"
- "sync"
+ "errors"
+ "fmt"
+ "runtime"
+ "time"
)
-// Checked indicates whether runtime checks are enabled for this package.
-const Checked = true
-
-// State carries the lock state of zero or more mutexes and an optional [context.Context].
-// Its zero value is valid and represents an unlocked state and an empty context.
+// checked is an implementation of [State] with additional runtime checks.
//
-// Calling [Lock] returns a derived State with the specified mutex locked. The State is considered
-// the owner of the lock if it wasn't already acquired by a parent State. Calling [State.Unlock]
-// releases the lock owned by the state. It is a runtime error to call Unlock more than once,
-// to use the State after it has been unlocked, or to unlock a parent State before its child.
-type State struct {
- *checked
+// Its zero value and a nil pointer are valid and carry no lock state
+// and an empty [context.Context].
+type checked struct {
+ context.Context // nil means an empty context
+
+ // mu is the mutex locked (or re-locked) by this state,
+ // or nil if it wasn't created with [Lock].
+ mu mutexHandle
+
+ // parent is the next state in the hierarchy associated with the same mutex.
+ // It may or may not own the lock (the lock could be held by a further ancestor).
+ //
+ // The parent is nil if this state owns the lock, or if it's a zero state.
+ parent *checked
+
+ // unlocked is whether [checked.Unlock] was called on this state.
+ unlocked bool
+
+ // lockedBy are the program counters of function invocations
+ // that locked the mutex, or nil if mu is not owned by this state.
+ lockedBy *lockCallers
}
-// None returns a [State] that carries no lock state and an empty [context.Context].
-func None() State {
- return State{}
+type (
+ lockCallers [5]uintptr
+ checkedMutex[R Rank] = mutex[R, lockCallers]
+)
+
+func fromContextChecked(ctx context.Context) *checked {
+ return &checked{Context: ctx}
}
-// FromContext returns a [State] that carries the same lock state as the provided [context.Context].
-//
-// It is typically used by methods that already accept a [context.Context] for cancellation or deadline
-// management, and would like to use it for locking as well.
-func FromContext(ctx context.Context) State {
- return State{fromContextChecked(ctx)}
+func lockChecked[R Rank](parent *checked, mu *checkedMutex[R]) *checked {
+ if mu == nil {
+ panic("nil mutex")
+ }
+ if parentState, ok := parent.isAlreadyLocked(mu); ok {
+ return &checked{parent, mu, parentState, false, nil}
+ }
+ mu.lock()
+ runtime.Callers(4, mu.lockState[:])
+ return &checked{parent, mu, nil, false, nil}
+}
+
+func (c *checked) isAlreadyLocked(m mutexHandle) (parent *checked, ok bool) {
+ switch val := c.Value(m).(type) {
+ case nil:
+ // No ancestor state associated with this mutex,
+ // and locking it does not violate the lock ordering.
+ return nil, false
+ case error:
+ // There's a lock ordering or reentrancy violation.
+ panic(val)
+ case *checked:
+ // The mutex is reentrant and is already held by a parent
+ // or ancestor state.
+ return val, true
+ default:
+ panic("unreachable")
+ }
+}
+
+func (c *checked) unlock() {
+ switch {
+ case c == nil:
+ // No-op; zero state.
+ return
+ case c.unlocked:
+ panic("already unlocked")
+ case c.mu == nil:
+ // No-op; the state does not track a mutex lock state,
+ // meaning it was not created with [Lock].
+ case c.parent == nil:
+ // The state own the mutex's lock; we must unlock it.
+ // This triggers a fatal error if the mutex is already unlocked.
+ c.mu.unlock()
+ case c.parent.unlocked:
+ // The parent state is already unlocked.
+ // The mutex may or may not be locked;
+ // something else may have already locked it.
+ panic("parent already unlocked")
+ default:
+ // No-op; a parent or ancestor will handle unlocking.
+ }
+ c.unlocked = true // mark this state as unlocked
+}
+
+func (c *checked) panicIfUnlocked() {
+ if c != nil && c.unlocked {
+ panic("use after unlock")
+ }
+}
+
+// Deadline implements [context.Context].
+func (c *checked) Deadline() (deadline time.Time, ok bool) {
+ c.panicIfUnlocked()
+ if c == nil || c.Context == nil {
+ return time.Time{}, false
+ }
+ return c.Context.Deadline()
+}
+
+// Done implements [context.Context].
+func (c *checked) Done() <-chan struct{} {
+ c.panicIfUnlocked()
+ if c == nil || c.Context == nil {
+ return nil
+ }
+ return c.Context.Done()
+}
+
+// Err implements [context.Context].
+func (c *checked) Err() error {
+ c.panicIfUnlocked()
+ if c == nil || c.Context == nil {
+ return nil
+ }
+ return c.Context.Err()
+}
+
+// Value implements [context.Context].
+func (c *checked) Value(key any) any {
+ c.panicIfUnlocked()
+ if c == nil {
+ // No-op; zero state.
+ return nil
+ }
+ if mu, ok := key.(mutexHandle); ok {
+ // Checks whether mu can be acquired after c.mu.
+ if res, done := checkLockOrder(mu, c.mu, c); done {
+ // We have a definite answer.
+ switch res := res.(type) {
+ case error:
+ // There's a lock ordering or reentrancy violation.
+ // Enrich the error with the call stack when the other mutex was locked.
+ if lockedBy, ok := c.mu.state().(*lockCallers); ok {
+ return LockOrderError{res, *lockedBy}
+ }
+ default:
+ // A reentrant mutex is already locked by a parent or ancestor state.
+ return res
+ }
+ }
+ }
+ if c.Context != nil {
+ // Forward the call to the parent context,
+ // which may or may not be a [checked] state.
+ return c.Context.Value(key)
+ }
+ return nil
}
-// Lock acquires the specified mutex and becomes its owner, unless it is already held by a parent.
-// The parent can be either a [State] or a [context.Context]. A zero [State] is a valid parent.
-// It returns a new [State] that augments the parent with the additional lock state.
+var errAlreadyLocked = errors.New("non-reentrant mutex already locked")
+
+// checkLockOrder determines whether m1 can be acquired after m2.
+// It returns an error and true if there's a lock ordering or reentrancy violation,
+// or the provided alreadyLocked value and true if m1 and m2 are the same and reentrancy is allowed,
+// or nil and false if the caller should continue checking against the next locked mutex.
+func checkLockOrder[T any](m1, m2 mutexHandle, alreadyLocked T) (res any, done bool) {
+ if m2 == nil {
+ // Nothing to check; continue search.
+ return nil, false
+ }
+ r1, r2 := m1.rank(), m2.rank()
+ if err := r1.CheckLockAfter(r2); err != nil {
+ // There's a lock ordering (or reentrancy) violation.
+ return err, true
+ }
+ if m1 != m2 {
+ // There's no lock ordering violation,
+ // but the mutex being locked is not the same as the one
+ // already locked. We need to continue checking.
+ return nil, false
+ }
+ if _, ok := r1.(NonReentrant); ok {
+ // Special handling for the [NonReentrant] rank.
+ //
+ // For user-defined ranks, reentrancy rules are enforced
+ // by the rank implementation itself, since each mutex
+ // is expected to have a distinct rank, and the rank
+ // can define its own rules. However, the predefined
+ // [NonReentrant] rank is shared by multiple mutexes.
+ return errAlreadyLocked, true
+ }
+ // The locking mutex is the same as the one already locked,
+ // and the rank allows reentrancy. We found a match.
+ return alreadyLocked, true
+}
+
+// LockOrderError represents a violation of mutex lock ordering.
//
-// It is a runtime error to pass a nil mutex or to unlock the parent state before the returned one.
-func Lock[T context.Context](parent T, mu *sync.Mutex) State {
- if parent, ok := any(parent).(State); ok {
- return State{lockChecked(parent.checked, mu)}
+// This error is not returned directly; it is used in panics to indicate a programming error
+// when lock acquisition violates the expected order.
+type LockOrderError struct {
+ error
+ violatedBy lockCallers // the call stack when the other mutex was locked
+}
+
+func (e LockOrderError) Error() string {
+ return fmt.Sprintf("%s\n\nConflicting lock held at:\n%s", e.error, e.violatedBy)
+}
+
+func (c lockCallers) String() string {
+ var output string
+ frames := runtime.CallersFrames(c[:])
+ for {
+ frame, more := frames.Next()
+ output += fmt.Sprintf("%s\n\t%s:%d\n", frame.Function, frame.File, frame.Line)
+ if !more {
+ break
+ }
}
- return State{lockChecked(fromContextChecked(parent), mu)}
+ return output
}
diff --git a/util/ctxlock/state_test.go b/util/ctxlock/state_test.go
index e9b747c97..a60862e47 100644
--- a/util/ctxlock/state_test.go
+++ b/util/ctxlock/state_test.go
@@ -5,62 +5,58 @@ package ctxlock
import (
"context"
+ "fmt"
+ "strings"
"sync"
"testing"
"tailscale.com/util/ctxkey"
)
-type state interface {
+type stateType interface {
+ *checked | unchecked
context.Context
- Unlock()
+ unlock()
}
-type impl[T state] struct {
+type lockStateType interface{ lockCallers | unchecked }
+
+type impl[T stateType, S lockStateType] struct {
None func() T
FromContext func(context.Context) T
- Lock func(T, *sync.Mutex) T
- LockCtx func(context.Context, *sync.Mutex) T
+ Lock func(T, *mutex[Reentrant, S]) T
+ LockCtx func(context.Context, *mutex[Reentrant, S]) T
}
var (
- exportedImpl = impl[State]{
- None: None,
- FromContext: FromContext,
- Lock: Lock[State],
- LockCtx: Lock[context.Context],
- }
- checkedImpl = impl[*checked]{
+ checkedImpl = impl[*checked, lockCallers]{
None: func() *checked { return nil },
FromContext: fromContextChecked,
- Lock: lockChecked,
- LockCtx: func(ctx context.Context, mu *sync.Mutex) *checked {
+ Lock: lockChecked[Reentrant],
+ LockCtx: func(ctx context.Context, mu *checkedMutex[Reentrant]) *checked {
return lockChecked(fromContextChecked(ctx), mu)
},
}
- uncheckedImpl = impl[unchecked]{
+ uncheckedImpl = impl[unchecked, unchecked]{
None: func() unchecked { return unchecked{} },
FromContext: fromContextUnchecked,
- Lock: lockUnchecked,
- LockCtx: func(ctx context.Context, mu *sync.Mutex) unchecked {
+ Lock: lockUnchecked[Reentrant],
+ LockCtx: func(ctx context.Context, mu *mutex[Reentrant, unchecked]) unchecked {
return lockUnchecked(fromContextUnchecked(ctx), mu)
},
}
)
-// BenchmarkLockUnlock benchmarks the performance of locking and unlocking a mutex.
-func BenchmarkLockUnlock(b *testing.B) {
- var mu sync.Mutex
- b.Run("Exported", func(b *testing.B) {
- benchmarkLockUnlock(b, exportedImpl)
- })
+// BenchmarkStateLockUnlock benchmarks the performance of locking and unlocking a mutex.
+func BenchmarkStateLockUnlock(b *testing.B) {
b.Run("Checked", func(b *testing.B) {
- benchmarkLockUnlock(b, checkedImpl)
+ benchmarkStateLockUnlock(b, checkedImpl)
})
b.Run("Unchecked", func(b *testing.B) {
- benchmarkLockUnlock(b, uncheckedImpl)
+ benchmarkStateLockUnlock(b, uncheckedImpl)
})
b.Run("Reference", func(b *testing.B) {
+ var mu sync.Mutex
for b.Loop() {
mu.Lock()
mu.Unlock()
@@ -68,21 +64,16 @@ func BenchmarkLockUnlock(b *testing.B) {
})
}
-func benchmarkLockUnlock[T state](b *testing.B, impl impl[T]) {
- var mu sync.Mutex
+func benchmarkStateLockUnlock[T stateType, S lockStateType](b *testing.B, impl impl[T, S]) {
+ var mu mutex[Reentrant, S]
for b.Loop() {
- ctx := impl.Lock(impl.None(), &mu)
- ctx.Unlock()
+ state := impl.Lock(impl.None(), &mu)
+ state.unlock()
}
}
// BenchmarkReentrance benchmarks the performance of reentrant locking and unlocking.
func BenchmarkReentrance(b *testing.B) {
- var mu sync.Mutex
-
- b.Run("Exported", func(b *testing.B) {
- benchmarkReentrance(b, exportedImpl)
- })
b.Run("Checked", func(b *testing.B) {
benchmarkReentrance(b, checkedImpl)
})
@@ -90,6 +81,7 @@ func BenchmarkReentrance(b *testing.B) {
benchmarkReentrance(b, uncheckedImpl)
})
b.Run("Reference", func(b *testing.B) {
+ var mu sync.Mutex
for b.Loop() {
mu.Lock()
func(mu *sync.Mutex) {
@@ -102,102 +94,68 @@ func BenchmarkReentrance(b *testing.B) {
})
}
-func benchmarkReentrance[T state](b *testing.B, impl impl[T]) {
- var mu sync.Mutex
+func benchmarkReentrance[T stateType, S lockStateType](b *testing.B, impl impl[T, S]) {
+ var mu mutex[Reentrant, S]
for b.Loop() {
parent := impl.Lock(impl.None(), &mu)
func(ctx T) {
child := impl.Lock(ctx, &mu)
- child.Unlock()
+ child.unlock()
}(parent)
- parent.Unlock()
+ parent.unlock()
}
}
-// BenchmarkGenericLock benchmarks the performance of the generic [Lock] function
-// that works with both [State] and [context.Context].
-func BenchmarkGenericLock(b *testing.B) {
- // Does not allocate with --tags=ts_omit_ctxlock_checks.
- b.Run("State", func(b *testing.B) {
- var mu sync.Mutex
- var ctx State
- for b.Loop() {
- parent := Lock(ctx, &mu)
- func(ctx State) {
- child := Lock(ctx, &mu)
- child.Unlock()
- }(parent)
- parent.Unlock()
- }
- })
- b.Run("StdContext", func(b *testing.B) {
- var mu sync.Mutex
- ctx := context.Background()
- for b.Loop() {
- parent := Lock(ctx, &mu)
- func(ctx State) {
- child := Lock(ctx, &mu)
- child.Unlock()
- }(parent)
- parent.Unlock()
- }
- })
-}
-
// TestUncheckedAllocFree tests that the exported implementation of [State] does not allocate memory
// when the ts_omit_ctxlock_checks build tag is set.
func TestUncheckedAllocFree(t *testing.T) {
- if Checked {
+ if IsChecked {
t.Skip("Exported implementation is not alloc-free (use --tags=ts_omit_ctxlock_checks)")
}
t.Run("Simple/WithState", func(t *testing.T) {
- var mu sync.Mutex
+ var mu ReentrantMutex
mustNotAllocate(t, func() {
- ctx := Lock(None(), &mu)
- ctx.Unlock()
+ mu := Lock(None(), &mu)
+ mu.Unlock()
})
})
t.Run("Simple/WithContext", func(t *testing.T) {
- var mu sync.Mutex
+ var mu ReentrantMutex
ctx := context.Background()
mustNotAllocate(t, func() {
- ctx := Lock(ctx, &mu)
- ctx.Unlock()
+ mu := Lock(ctx, &mu)
+ mu.Unlock()
})
})
t.Run("Reentrant/WithState", func(t *testing.T) {
- var mu sync.Mutex
+ var mu ReentrantMutex
mustNotAllocate(t, func() {
parent := Lock(None(), &mu)
- func(ctx State) {
- child := Lock(parent, &mu)
+ func(state State) {
+ child := Lock(state, &mu)
child.Unlock()
- }(parent)
+ }(parent.State())
parent.Unlock()
})
})
t.Run("Reentrant/WithContext", func(t *testing.T) {
- var mu sync.Mutex
+ var mu ReentrantMutex
ctx := context.Background()
mustNotAllocate(t, func() {
parent := Lock(ctx, &mu)
- func(ctx State) {
- child := Lock(ctx, &mu)
+ func(state State) {
+ child := Lock(state, &mu)
child.Unlock()
- }(parent)
+ }(parent.State())
parent.Unlock()
})
})
}
func TestHappyPath(t *testing.T) {
- t.Run("Exported", func(t *testing.T) {
- testHappyPath(t, exportedImpl)
- })
-
t.Run("Checked", func(t *testing.T) {
testHappyPath(t, checkedImpl)
})
@@ -207,32 +165,33 @@ func TestHappyPath(t *testing.T) {
})
}
-func testHappyPath[T state](t *testing.T, impl impl[T]) {
- var mu sync.Mutex
+func testHappyPath[T stateType, S lockStateType](t *testing.T, impl impl[T, S]) {
+ var mu mutex[Reentrant, S]
parent := impl.Lock(impl.None(), &mu)
wantLocked(t, &mu) // mu is locked by parent
child := impl.Lock(parent, &mu)
wantLocked(t, &mu) // mu is still locked by parent
- var mu2 sync.Mutex
+ var mu2 mutex[Reentrant, S]
ls2 := impl.Lock(child, &mu2)
- wantLocked(t, &mu2) // mu2 is locked by ls2
- ls2.Unlock() // unlocks mu2
+ wantLocked(t, &mu2) // mu2 is locked by ls2
+
+ grandchild := impl.Lock(ls2, &mu)
+ grandchild.unlock() // no-op; mu is owned by parent
+ wantLocked(t, &mu) // mu is still locked by parent
+
+ ls2.unlock() // unlocks mu2
wantUnlocked(t, &mu2) // mu2 is now unlocked
- child.Unlock() // noop
+ child.unlock() // noop
wantLocked(t, &mu) // mu is still locked by parent
- parent.Unlock() // unlocks mu
+ parent.unlock() // unlocks mu
wantUnlocked(t, &mu) // mu is now unlocked
}
func TestContextWrapping(t *testing.T) {
- t.Run("Exported", func(t *testing.T) {
- testContextWrapping(t, exportedImpl)
- })
-
t.Run("Checked", func(t *testing.T) {
testContextWrapping(t, checkedImpl)
})
@@ -242,13 +201,13 @@ func TestContextWrapping(t *testing.T) {
})
}
-func testContextWrapping[T state](t *testing.T, impl impl[T]) {
+func testContextWrapping[T stateType, S lockStateType](t *testing.T, impl impl[T, S]) {
// Create a [context.Context] with a value set in it.
wantValue := "value"
key := ctxkey.New("key", "")
ctxWithValue := key.WithValue(context.Background(), wantValue)
- var mu sync.Mutex
+ var mu mutex[Reentrant, S]
parent := impl.LockCtx(ctxWithValue, &mu)
wantLocked(t, &mu) // mu is locked by parent
@@ -268,103 +227,72 @@ func testContextWrapping[T state](t *testing.T, impl impl[T]) {
}
// ... and the lock state.
- child.Unlock() // no-op; mu is owned by parent
+ child.unlock() // no-op; mu is owned by parent
wantLocked(t, &mu) // mu is still locked by parent
- parentDup.Unlock() // no-op; mu is owned by parent
+ parentDup.unlock() // no-op; mu is owned by parent
wantLocked(t, &mu) // mu is still locked by parent
- parent.Unlock() // unlocks mu
+ parent.unlock() // unlocks mu
wantUnlocked(t, &mu) // mu is now unlocked
}
func TestNilMutex(t *testing.T) {
impl := checkedImpl
- wantPanic(t, "nil *sync.Mutex", func() { impl.Lock(impl.None(), nil) })
+ wantPanic(t, "nil mutex", func() { impl.Lock(impl.None(), nil) })
}
func TestUseUnlockedParent_Checked(t *testing.T) {
impl := checkedImpl
- var mu sync.Mutex
+ var mu checkedMutex[Reentrant]
parent := impl.Lock(impl.None(), &mu)
- parent.Unlock() // unlocks mu
+ parent.unlock() // unlocks mu
wantUnlocked(t, &mu) // mu is now unlocked
wantPanic(t, "use after unlock", func() { impl.Lock(parent, &mu) })
}
-func TestUseUnlockedMutex_Checked(t *testing.T) {
- impl := checkedImpl
-
- var mu sync.Mutex
- parent := impl.Lock(impl.None(), &mu)
- mu.Unlock() // unlock mu directly without unlocking parent
- wantPanic(t, "*sync.Mutex is spuriously unlocked", func() { impl.Lock(parent, &mu) })
-}
-
func TestUnlockParentFirst_Checked(t *testing.T) {
impl := checkedImpl
- var mu sync.Mutex
+ var mu checkedMutex[Reentrant]
parent := impl.Lock(impl.FromContext(context.Background()), &mu)
child := impl.Lock(parent, &mu)
- parent.Unlock() // unlocks mu
+ parent.unlock() // unlocks mu
wantUnlocked(t, &mu) // mu is now unlocked
- wantPanic(t, "parent already unlocked", child.Unlock)
+ wantPanic(t, "parent already unlocked", child.unlock)
}
func TestUnlockTwice_Checked(t *testing.T) {
impl := checkedImpl
unlockTwice := func(t *testing.T, ctx *checked) {
- ctx.Unlock() // unlocks mu
- wantPanic(t, "already unlocked", ctx.Unlock)
+ ctx.unlock() // unlocks mu
+ wantPanic(t, "already unlocked", ctx.unlock)
}
t.Run("Wrapped", func(t *testing.T) {
unlockTwice(t, impl.FromContext(context.Background()))
})
t.Run("Locked", func(t *testing.T) {
- var mu sync.Mutex
+ var mu checkedMutex[Reentrant]
ctx := impl.Lock(impl.None(), &mu)
unlockTwice(t, ctx)
})
- t.Run("Locked/WithReloc", func(t *testing.T) {
- var mu sync.Mutex
- ctx := impl.Lock(impl.None(), &mu)
- ctx.Unlock() // unlocks mu
- mu.Lock() // re-locks mu, but not by the state
- wantPanic(t, "already unlocked", ctx.Unlock)
- })
t.Run("Child", func(t *testing.T) {
- var mu sync.Mutex
+ var mu checkedMutex[Reentrant]
parent := impl.Lock(impl.None(), &mu)
- defer parent.Unlock()
+ defer parent.unlock()
child := impl.Lock(parent, &mu)
unlockTwice(t, child)
})
- t.Run("Child/WithReloc", func(t *testing.T) {
- var mu sync.Mutex
- parent := impl.Lock(impl.None(), &mu)
- child := impl.Lock(parent, &mu)
- parent.Unlock()
- mu.Lock() // re-locks mu, but not the parent state
- wantPanic(t, "parent already unlocked", child.Unlock)
- })
- t.Run("Child/WithManualUnlock", func(t *testing.T) {
- var mu sync.Mutex
- parent := impl.Lock(impl.None(), &mu)
- child := impl.Lock(parent, &mu)
- mu.Unlock() // unlocks mu, but not the parent state
- wantPanic(t, "mutex is not locked", child.Unlock)
- })
t.Run("Grandchild", func(t *testing.T) {
- var mu sync.Mutex
+ var mu checkedMutex[Reentrant]
parent := impl.Lock(impl.None(), &mu)
- defer parent.Unlock()
+ defer parent.unlock()
child := impl.Lock(parent, &mu)
- defer child.Unlock()
+ defer child.unlock()
grandchild := impl.Lock(child, &mu)
unlockTwice(t, grandchild)
})
@@ -373,76 +301,70 @@ func TestUnlockTwice_Checked(t *testing.T) {
func TestUseUnlocked_Checked(t *testing.T) {
impl := checkedImpl
- var mu sync.Mutex
+ var mu checkedMutex[Reentrant]
state := lockChecked(impl.None(), &mu)
- state.Unlock()
+ state.unlock()
// All of these should panic since the state is already unlocked.
- wantPanic(t, "", func() { state.Deadline() })
- wantPanic(t, "", func() { state.Done() })
- wantPanic(t, "", func() { state.Err() })
- wantPanic(t, "", func() { state.Unlock() })
- wantPanic(t, "", func() { state.Value("key") })
+ wantPanic(t, "*", func() { state.Deadline() })
+ wantPanic(t, "*", func() { state.Done() })
+ wantPanic(t, "*", func() { state.Err() })
+ wantPanic(t, "*", func() { state.unlock() })
+ wantPanic(t, "*", func() { state.Value("key") })
}
func TestUseZeroState(t *testing.T) {
- t.Run("Exported", func(t *testing.T) {
- testUseEmptyState(t, exportedImpl.None, exportedImpl)
- })
t.Run("Checked", func(t *testing.T) {
- testUseEmptyState(t, checkedImpl.None, checkedImpl)
+ testUseEmptyState(t, checkedImpl.None)
})
t.Run("Unchecked", func(t *testing.T) {
- testUseEmptyState(t, uncheckedImpl.None, uncheckedImpl)
+ testUseEmptyState(t, uncheckedImpl.None)
})
}
func TestUseWrappedBackground(t *testing.T) {
- t.Run("Exported", func(t *testing.T) {
- testUseEmptyState(t, getWrappedBackground(t, exportedImpl), exportedImpl)
- })
t.Run("Checked", func(t *testing.T) {
- testUseEmptyState(t, getWrappedBackground(t, checkedImpl), checkedImpl)
+ testUseEmptyState(t, getWrappedBackground(t, checkedImpl))
})
t.Run("Unchecked", func(t *testing.T) {
- testUseEmptyState(t, getWrappedBackground(t, uncheckedImpl), uncheckedImpl)
+ testUseEmptyState(t, getWrappedBackground(t, uncheckedImpl))
})
}
-func getWrappedBackground[T state](t *testing.T, impl impl[T]) func() T {
+func getWrappedBackground[T stateType, S lockStateType](t *testing.T, impl impl[T, S]) func() T {
t.Helper()
return func() T {
return impl.FromContext(context.Background())
}
}
-func testUseEmptyState[T state](t *testing.T, getCtx func() T, impl impl[T]) {
- // Using aan empty [State] must not panic or deadlock.
+func testUseEmptyState[T stateType](t *testing.T, getState func() T) {
+ // Using an empty [State] must not panic or deadlock.
// It should also behave like [context.Background].
for range 2 {
- ctx := getCtx()
- if gotDone := ctx.Done(); gotDone != nil {
+ state := getState()
+ if gotDone := state.Done(); gotDone != nil {
t.Errorf("ctx.Done() = %v; want nil", gotDone)
}
- if gotDeadline, ok := ctx.Deadline(); ok {
+ if gotDeadline, ok := state.Deadline(); ok {
t.Errorf("ctx.Deadline() = %v; want !ok", gotDeadline)
}
- if gotErr := ctx.Err(); gotErr != nil {
+ if gotErr := state.Err(); gotErr != nil {
t.Errorf("ctx.Err() = %v; want nil", gotErr)
}
- if gotValue := ctx.Value("test-key"); gotValue != nil {
+ if gotValue := state.Value("test-key"); gotValue != nil {
t.Errorf("ctx.Value(test-key) = %v; want nil", gotValue)
}
- ctx.Unlock()
+ state.unlock()
}
}
func wantPanic(t *testing.T, wantMsg string, fn func()) {
t.Helper()
defer func() {
- if r := recover(); wantMsg != "" {
- if gotMsg, ok := r.(string); !ok || gotMsg != wantMsg {
- t.Errorf("panic: %v; want %q", r, wantMsg)
+ if r := recover(); wantMsg != "*" {
+ if gotMsg := trimPanicMessage(r); gotMsg != wantMsg {
+ t.Errorf("panic: got %q; want %q", r, wantMsg)
}
}
}()
@@ -450,19 +372,26 @@ func wantPanic(t *testing.T, wantMsg string, fn func()) {
t.Fatal("failed to panic")
}
-func wantLocked(t *testing.T, m *sync.Mutex) {
- if m.TryLock() {
- m.Unlock()
+func (m *mutex[R, S]) isLockedForTest() bool {
+ if m.m.TryLock() {
+ m.m.Unlock()
+ return false
+ }
+ return true
+}
+
+func wantLocked[R Rank, S lockStateType](t *testing.T, m *mutex[R, S]) {
+ t.Helper()
+ if !m.isLockedForTest() {
t.Fatal("mutex is not locked")
}
}
-func wantUnlocked(t *testing.T, m *sync.Mutex) {
+func wantUnlocked[R Rank, S lockStateType](t *testing.T, m *mutex[R, S]) {
t.Helper()
- if !m.TryLock() {
+ if m.isLockedForTest() {
t.Fatal("mutex is locked")
}
- m.Unlock()
}
func mustNotAllocate(t *testing.T, steps func()) {
@@ -472,3 +401,12 @@ func mustNotAllocate(t *testing.T, steps func()) {
t.Errorf("expected 0 allocs, got %f", allocs)
}
}
+
+func trimPanicMessage(r any) string {
+ msg := fmt.Sprintf("%v", r)
+ msg = strings.TrimSpace(msg)
+ if i := strings.IndexByte(msg, '\n'); i >= 0 {
+ return msg[:i]
+ }
+ return msg
+}
diff --git a/util/ctxlock/state_unchecked.go b/util/ctxlock/state_unchecked.go
index c55150b75..522243ca0 100644
--- a/util/ctxlock/state_unchecked.go
+++ b/util/ctxlock/state_unchecked.go
@@ -1,35 +1,103 @@
// Copyright (c) Tailscale Inc & AUTHORS
// SPDX-License-Identifier: BSD-3-Clause
-// This file exports optimized implementation of the [State] that omits runtime checks.
-// It is used when the build tag ts_omit_ctxlock_checks is set.
-
-//go:build ts_omit_ctxlock_checks
-
package ctxlock
import (
"context"
- "sync"
+ "time"
)
-const Checked = false
+// unchecked is an implementation of [State] that trades additional runtime checks
+// for performance.
+//
+// Its zero value carries no mutex lock state and an empty [context.Context].
+type unchecked struct {
+ context.Context // nil means an empty context
+ mu mutexHandle // non-nil if owned by this state
+}
+
+type (
+ alreadyLocked struct{}
+ uncheckedMutex[R Rank] = mutex[R, unchecked]
+)
-type State struct {
- unchecked
+func fromContextUnchecked(ctx context.Context) unchecked {
+ return unchecked{ctx, nil}
}
-func None() State {
- return State{}
+func lockUnchecked[R Rank](parent unchecked, mu *uncheckedMutex[R]) unchecked {
+ if !parent.isAlreadyLocked(mu) {
+ mu.lock()
+ // Locking a mutex creates a new state that must be accessible from any derived state.
+ // Normally, this state would be heap-allocated, but we want to avoid allocating new memory
+ // on every lock. Instead, we use a storage region within the mutex itself.
+ mu.lockState = unchecked{parent.Context, mu}
+ return unchecked{&mu.lockState, mu}
+
+ }
+ // The mutex is already locked by a parent or ancestor state.
+ return unchecked{parent.Context, nil}
}
-func FromContext(parent context.Context) State {
- return State{fromContextUnchecked(parent)}
+func (c unchecked) isAlreadyLocked(m mutexHandle) bool {
+ switch val := c.Value(m).(type) {
+ case nil:
+ // No ancestor state associated with this mutex,
+ // and locking it does not violate the lock ordering.
+ return false
+ case error:
+ // There's a lock ordering or reentrancy violation.
+ panic(val)
+ case alreadyLocked:
+ // The mutex is reentrant and is already held by a parent
+ // or ancestor state.
+ return true
+ default:
+ panic("unreachable")
+ }
+}
+
+func (c unchecked) unlock() {
+ if c.mu != nil {
+ c.mu.unlock()
+ }
}
-func Lock[T context.Context](parent T, mu *sync.Mutex) State {
- if parent, ok := any(parent).(State); ok {
- return State{lockUnchecked(parent.unchecked, mu)}
+// Deadline implements [context.Context].
+func (c unchecked) Deadline() (deadline time.Time, ok bool) {
+ if c.Context == nil {
+ return time.Time{}, false
+ }
+ return c.Context.Deadline()
+}
+
+// Done implements [context.Context].
+func (c unchecked) Done() <-chan struct{} {
+ if c.Context == nil {
+ return nil
+ }
+ return c.Context.Done()
+}
+
+// Err implements [context.Context].
+func (c unchecked) Err() error {
+ if c.Context == nil {
+ return nil
+ }
+ return c.Context.Err()
+}
+
+// Err implements [context.Context].
+func (c unchecked) Value(key any) any {
+ if mu, ok := key.(mutexHandle); ok {
+ if res, done := checkLockOrder(mu, c.mu, alreadyLocked{}); done {
+ // We have a definite answer.
+ return res
+ }
+ }
+ if c.Context == nil {
+ return nil
}
- return State{lockUnchecked(fromContextUnchecked(parent), mu)}
+ return c.Context.Value(key)
}
diff --git a/util/ctxlock/state_use_checked.go b/util/ctxlock/state_use_checked.go
new file mode 100644
index 000000000..e5d775d14
--- /dev/null
+++ b/util/ctxlock/state_use_checked.go
@@ -0,0 +1,20 @@
+// Copyright (c) Tailscale Inc & AUTHORS
+// SPDX-License-Identifier: BSD-3-Clause
+
+//go:build !ts_omit_ctxlock_checks
+
+package ctxlock
+
+const useCheckedImpl = true
+
+type (
+ stateImpl = *checked
+ lockState = lockCallers
+ _ = lockState
+)
+
+var fromContext = fromContextChecked
+
+func lock[R Rank](parent stateImpl, mu *checkedMutex[R]) stateImpl {
+ return lockChecked(parent, mu)
+}
diff --git a/util/ctxlock/state_use_unchecked.go b/util/ctxlock/state_use_unchecked.go
new file mode 100644
index 000000000..06e087d3d
--- /dev/null
+++ b/util/ctxlock/state_use_unchecked.go
@@ -0,0 +1,20 @@
+// Copyright (c) Tailscale Inc & AUTHORS
+// SPDX-License-Identifier: BSD-3-Clause
+
+//go:build ts_omit_ctxlock_checks
+
+package ctxlock
+
+const useCheckedImpl = false
+
+type (
+ stateImpl = unchecked
+ lockState = unchecked
+ _ = lockState
+)
+
+var fromContext = fromContextUnchecked
+
+func lock[R Rank](parent stateImpl, mu *uncheckedMutex[R]) stateImpl {
+ return lockUnchecked(parent, mu)
+}