summaryrefslogtreecommitdiffhomepage
path: root/util
diff options
context:
space:
mode:
Diffstat (limited to 'util')
-rw-r--r--util/syspolicy/caching_handler.go122
-rw-r--r--util/syspolicy/caching_handler_test.go262
-rw-r--r--util/syspolicy/handler.go107
-rw-r--r--util/syspolicy/handler_test.go19
-rw-r--r--util/syspolicy/handler_windows.go105
-rw-r--r--util/syspolicy/internal/internal.go63
-rw-r--r--util/syspolicy/internal/lazyinit/lazyinit.go84
-rw-r--r--util/syspolicy/internal/loggerx/logger.go46
-rw-r--r--util/syspolicy/internal/metrics/metrics.go315
-rw-r--r--util/syspolicy/internal/metrics/metrics_test.go423
-rw-r--r--util/syspolicy/internal/metrics/test_handler.go88
-rw-r--r--util/syspolicy/policy_keys.go96
-rw-r--r--util/syspolicy/policy_keys_test.go95
-rw-r--r--util/syspolicy/policy_keys_windows.go38
-rw-r--r--util/syspolicy/rsop/change_callbacks.go109
-rw-r--r--util/syspolicy/rsop/resultant_policy.go698
-rw-r--r--util/syspolicy/rsop/resultant_policy_test.go368
-rw-r--r--util/syspolicy/setting/errors.go60
-rw-r--r--util/syspolicy/setting/key.go13
-rw-r--r--util/syspolicy/setting/origin.go71
-rw-r--r--util/syspolicy/setting/policy_scope.go195
-rw-r--r--util/syspolicy/setting/policy_scope_test.go550
-rw-r--r--util/syspolicy/setting/raw_item.go47
-rw-r--r--util/syspolicy/setting/setting.go352
-rw-r--r--util/syspolicy/setting/setting_test.go344
-rw-r--r--util/syspolicy/setting/snapshot.go153
-rw-r--r--util/syspolicy/setting/snapshot_test.go372
-rw-r--r--util/syspolicy/setting/summary.go84
-rw-r--r--util/syspolicy/setting/types.go132
-rw-r--r--util/syspolicy/source/policy_reader.go393
-rw-r--r--util/syspolicy/source/policy_reader_test.go291
-rw-r--r--util/syspolicy/source/policy_store.go146
-rw-r--r--util/syspolicy/source/policy_store_windows.go438
-rw-r--r--util/syspolicy/source/policy_store_windows_test.go298
-rw-r--r--util/syspolicy/source/test_store.go446
-rw-r--r--util/syspolicy/syspolicy.go197
-rw-r--r--util/syspolicy/syspolicy_test.go270
-rw-r--r--util/syspolicy/syspolicy_windows.go93
-rw-r--r--util/winutil/gp/policylock_windows.go1
39 files changed, 7253 insertions, 731 deletions
diff --git a/util/syspolicy/caching_handler.go b/util/syspolicy/caching_handler.go
deleted file mode 100644
index 5192958bc..000000000
--- a/util/syspolicy/caching_handler.go
+++ /dev/null
@@ -1,122 +0,0 @@
-// Copyright (c) Tailscale Inc & AUTHORS
-// SPDX-License-Identifier: BSD-3-Clause
-
-package syspolicy
-
-import (
- "errors"
- "sync"
-)
-
-// CachingHandler is a handler that reads policies from an underlying handler the first time each key is requested
-// and permanently caches the result unless there is an error. If there is an ErrNoSuchKey error, that result is cached,
-// otherwise the actual error is returned and the next read for that key will retry using the handler.
-type CachingHandler struct {
- mu sync.Mutex
- strings map[string]string
- uint64s map[string]uint64
- bools map[string]bool
- strArrs map[string][]string
- notFound map[string]bool
- handler Handler
-}
-
-// NewCachingHandler creates a CachingHandler given a handler.
-func NewCachingHandler(handler Handler) *CachingHandler {
- return &CachingHandler{
- handler: handler,
- strings: make(map[string]string),
- uint64s: make(map[string]uint64),
- bools: make(map[string]bool),
- strArrs: make(map[string][]string),
- notFound: make(map[string]bool),
- }
-}
-
-// ReadString reads the policy settings value string given the key.
-// ReadString first reads from the handler's cache before resorting to using the handler.
-func (ch *CachingHandler) ReadString(key string) (string, error) {
- ch.mu.Lock()
- defer ch.mu.Unlock()
- if val, ok := ch.strings[key]; ok {
- return val, nil
- }
- if notFound := ch.notFound[key]; notFound {
- return "", ErrNoSuchKey
- }
- val, err := ch.handler.ReadString(key)
- if errors.Is(err, ErrNoSuchKey) {
- ch.notFound[key] = true
- return "", err
- } else if err != nil {
- return "", err
- }
- ch.strings[key] = val
- return val, nil
-}
-
-// ReadUInt64 reads the policy settings uint64 value given the key.
-// ReadUInt64 first reads from the handler's cache before resorting to using the handler.
-func (ch *CachingHandler) ReadUInt64(key string) (uint64, error) {
- ch.mu.Lock()
- defer ch.mu.Unlock()
- if val, ok := ch.uint64s[key]; ok {
- return val, nil
- }
- if notFound := ch.notFound[key]; notFound {
- return 0, ErrNoSuchKey
- }
- val, err := ch.handler.ReadUInt64(key)
- if errors.Is(err, ErrNoSuchKey) {
- ch.notFound[key] = true
- return 0, err
- } else if err != nil {
- return 0, err
- }
- ch.uint64s[key] = val
- return val, nil
-}
-
-// ReadBoolean reads the policy settings boolean value given the key.
-// ReadBoolean first reads from the handler's cache before resorting to using the handler.
-func (ch *CachingHandler) ReadBoolean(key string) (bool, error) {
- ch.mu.Lock()
- defer ch.mu.Unlock()
- if val, ok := ch.bools[key]; ok {
- return val, nil
- }
- if notFound := ch.notFound[key]; notFound {
- return false, ErrNoSuchKey
- }
- val, err := ch.handler.ReadBoolean(key)
- if errors.Is(err, ErrNoSuchKey) {
- ch.notFound[key] = true
- return false, err
- } else if err != nil {
- return false, err
- }
- ch.bools[key] = val
- return val, nil
-}
-
-// ReadBoolean reads the policy settings boolean value given the key.
-// ReadBoolean first reads from the handler's cache before resorting to using the handler.
-func (ch *CachingHandler) ReadStringArray(key string) ([]string, error) {
- ch.mu.Lock()
- defer ch.mu.Unlock()
- if val, ok := ch.strArrs[key]; ok {
- return val, nil
- }
- if notFound := ch.notFound[key]; notFound {
- return nil, ErrNoSuchKey
- }
- val, err := ch.handler.ReadStringArray(key)
- if errors.Is(err, ErrNoSuchKey) {
- ch.notFound[key] = true
- return nil, err
- } else if err != nil {
- return nil, err
- }
- ch.strArrs[key] = val
- return val, nil
-}
diff --git a/util/syspolicy/caching_handler_test.go b/util/syspolicy/caching_handler_test.go
deleted file mode 100644
index 881f6ff83..000000000
--- a/util/syspolicy/caching_handler_test.go
+++ /dev/null
@@ -1,262 +0,0 @@
-// Copyright (c) Tailscale Inc & AUTHORS
-// SPDX-License-Identifier: BSD-3-Clause
-
-package syspolicy
-
-import (
- "testing"
-)
-
-func TestHandlerReadString(t *testing.T) {
- tests := []struct {
- name string
- key string
- handlerKey Key
- handlerValue string
- handlerError error
- preserveHandler bool
- wantValue string
- wantErr error
- strings map[string]string
- expectedCalls int
- }{
- {
- name: "read existing cached values",
- key: "test",
- handlerKey: "do not read",
- strings: map[string]string{"test": "foo"},
- wantValue: "foo",
- expectedCalls: 0,
- },
- {
- name: "read existing values not cached",
- key: "test",
- handlerKey: "test",
- handlerValue: "foo",
- wantValue: "foo",
- expectedCalls: 1,
- },
- {
- name: "error no such key",
- key: "test",
- handlerKey: "test",
- handlerError: ErrNoSuchKey,
- wantErr: ErrNoSuchKey,
- expectedCalls: 1,
- },
- {
- name: "other error",
- key: "test",
- handlerKey: "test",
- handlerError: someOtherError,
- wantErr: someOtherError,
- preserveHandler: true,
- expectedCalls: 2,
- },
- }
-
- for _, tt := range tests {
- t.Run(tt.name, func(t *testing.T) {
- testHandler := &testHandler{
- t: t,
- key: tt.handlerKey,
- s: tt.handlerValue,
- err: tt.handlerError,
- }
- cache := NewCachingHandler(testHandler)
- if tt.strings != nil {
- cache.strings = tt.strings
- }
- got, err := cache.ReadString(tt.key)
- if err != tt.wantErr {
- t.Errorf("err=%v want %v", err, tt.wantErr)
- }
- if got != tt.wantValue {
- t.Errorf("got %v want %v", got, cache.strings[tt.key])
- }
- if !tt.preserveHandler {
- testHandler.key, testHandler.s, testHandler.err = "do not read", "", nil
- }
- got, err = cache.ReadString(tt.key)
- if err != tt.wantErr {
- t.Errorf("repeat err=%v want %v", err, tt.wantErr)
- }
- if got != tt.wantValue {
- t.Errorf("repeat got %v want %v", got, cache.strings[tt.key])
- }
- if testHandler.calls != tt.expectedCalls {
- t.Errorf("calls=%v want %v", testHandler.calls, tt.expectedCalls)
- }
- })
- }
-}
-
-func TestHandlerReadUint64(t *testing.T) {
- tests := []struct {
- name string
- key string
- handlerKey Key
- handlerValue uint64
- handlerError error
- preserveHandler bool
- wantValue uint64
- wantErr error
- uint64s map[string]uint64
- expectedCalls int
- }{
- {
- name: "read existing cached values",
- key: "test",
- handlerKey: "do not read",
- uint64s: map[string]uint64{"test": 1},
- wantValue: 1,
- expectedCalls: 0,
- },
- {
- name: "read existing values not cached",
- key: "test",
- handlerKey: "test",
- handlerValue: 1,
- wantValue: 1,
- expectedCalls: 1,
- },
- {
- name: "error no such key",
- key: "test",
- handlerKey: "test",
- handlerError: ErrNoSuchKey,
- wantErr: ErrNoSuchKey,
- expectedCalls: 1,
- },
- {
- name: "other error",
- key: "test",
- handlerKey: "test",
- handlerError: someOtherError,
- wantErr: someOtherError,
- preserveHandler: true,
- expectedCalls: 2,
- },
- }
-
- for _, tt := range tests {
- t.Run(tt.name, func(t *testing.T) {
- testHandler := &testHandler{
- t: t,
- key: tt.handlerKey,
- u64: tt.handlerValue,
- err: tt.handlerError,
- }
- cache := NewCachingHandler(testHandler)
- if tt.uint64s != nil {
- cache.uint64s = tt.uint64s
- }
- got, err := cache.ReadUInt64(tt.key)
- if err != tt.wantErr {
- t.Errorf("err=%v want %v", err, tt.wantErr)
- }
- if got != tt.wantValue {
- t.Errorf("got %v want %v", got, cache.strings[tt.key])
- }
- if !tt.preserveHandler {
- testHandler.key, testHandler.s, testHandler.err = "do not read", "", nil
- }
- got, err = cache.ReadUInt64(tt.key)
- if err != tt.wantErr {
- t.Errorf("repeat err=%v want %v", err, tt.wantErr)
- }
- if got != tt.wantValue {
- t.Errorf("repeat got %v want %v", got, cache.strings[tt.key])
- }
- if testHandler.calls != tt.expectedCalls {
- t.Errorf("calls=%v want %v", testHandler.calls, tt.expectedCalls)
- }
- })
- }
-
-}
-
-func TestHandlerReadBool(t *testing.T) {
- tests := []struct {
- name string
- key string
- handlerKey Key
- handlerValue bool
- handlerError error
- preserveHandler bool
- wantValue bool
- wantErr error
- bools map[string]bool
- expectedCalls int
- }{
- {
- name: "read existing cached values",
- key: "test",
- handlerKey: "do not read",
- bools: map[string]bool{"test": true},
- wantValue: true,
- expectedCalls: 0,
- },
- {
- name: "read existing values not cached",
- key: "test",
- handlerKey: "test",
- handlerValue: true,
- wantValue: true,
- expectedCalls: 1,
- },
- {
- name: "error no such key",
- key: "test",
- handlerKey: "test",
- handlerError: ErrNoSuchKey,
- wantErr: ErrNoSuchKey,
- expectedCalls: 1,
- },
- {
- name: "other error",
- key: "test",
- handlerKey: "test",
- handlerError: someOtherError,
- wantErr: someOtherError,
- preserveHandler: true,
- expectedCalls: 2,
- },
- }
-
- for _, tt := range tests {
- t.Run(tt.name, func(t *testing.T) {
- testHandler := &testHandler{
- t: t,
- key: tt.handlerKey,
- b: tt.handlerValue,
- err: tt.handlerError,
- }
- cache := NewCachingHandler(testHandler)
- if tt.bools != nil {
- cache.bools = tt.bools
- }
- got, err := cache.ReadBoolean(tt.key)
- if err != tt.wantErr {
- t.Errorf("err=%v want %v", err, tt.wantErr)
- }
- if got != tt.wantValue {
- t.Errorf("got %v want %v", got, cache.strings[tt.key])
- }
- if !tt.preserveHandler {
- testHandler.key, testHandler.s, testHandler.err = "do not read", "", nil
- }
- got, err = cache.ReadBoolean(tt.key)
- if err != tt.wantErr {
- t.Errorf("repeat err=%v want %v", err, tt.wantErr)
- }
- if got != tt.wantValue {
- t.Errorf("repeat got %v want %v", got, cache.strings[tt.key])
- }
- if testHandler.calls != tt.expectedCalls {
- t.Errorf("calls=%v want %v", testHandler.calls, tt.expectedCalls)
- }
- })
- }
-
-}
diff --git a/util/syspolicy/handler.go b/util/syspolicy/handler.go
index f1fad9770..0671dc058 100644
--- a/util/syspolicy/handler.go
+++ b/util/syspolicy/handler.go
@@ -4,16 +4,15 @@
package syspolicy
import (
- "errors"
- "sync/atomic"
-)
-
-var (
- handlerUsed atomic.Bool
- handler Handler = defaultHandler{}
+ "tailscale.com/util/syspolicy/internal"
+ "tailscale.com/util/syspolicy/rsop"
+ "tailscale.com/util/syspolicy/setting"
+ "tailscale.com/util/syspolicy/source"
)
// Handler reads system policies from OS-specific storage.
+//
+// Deprecated: implementing a [Store] should be preferred.
type Handler interface {
// ReadString reads the policy setting's string value for the given key.
// It should return ErrNoSuchKey if the key does not have a value set.
@@ -29,55 +28,81 @@ type Handler interface {
ReadStringArray(key string) ([]string, error)
}
-// ErrNoSuchKey is returned by a Handler when the specified key does not have a
-// value set.
-var ErrNoSuchKey = errors.New("no such key")
+// RegisterHandler wraps and registers the specified handler as the device's
+// policy [Store] for the program's lifetime.
+//
+// Deprecated: using [RegisterStore] should be preferred.
+func RegisterHandler(h Handler) {
+ rsop.RegisterStore("DeviceHandler", setting.DeviceScope, WrapHandler(h))
+}
-// defaultHandler is the catch all syspolicy type for anything that isn't windows or apple.
-type defaultHandler struct{}
+// TB is a subset of testing.TB that we use to set up test helpers.
+// It's defined here to avoid pulling in the testing package.
+type TB = internal.TB
-func (defaultHandler) ReadString(_ string) (string, error) {
- return "", ErrNoSuchKey
+// SetHandlerForTest wraps and sets the specified handler as the device's policy
+// [Store] for the duration of tb.
+//
+// Deprecated: using [resultant.RegisterStoreForTest] should be preferred.
+func SetHandlerForTest(tb TB, h Handler) {
+ if err := setWellKnownSettingsForTest(tb); err != nil {
+ tb.Fatalf("setWellKnownSettingsForTest failed: %v", err)
+ }
+ rsop.RegisterStoreForTest(tb, "DeviceHandler-TestOnly", setting.CurrentScope(), WrapHandler(h))
}
-func (defaultHandler) ReadUInt64(_ string) (uint64, error) {
- return 0, ErrNoSuchKey
+var _ source.Store = (*handlerStore)(nil)
+
+// handlerStore is a [source.Store] that calls the underlying [Handler].
+// TODO(nickkhyl): remove it when the corp and android repos are updated.
+type handlerStore struct {
+ h Handler
}
-func (defaultHandler) ReadBoolean(_ string) (bool, error) {
- return false, ErrNoSuchKey
+// WrapHandler returns a [source.Store] that wraps the specified [Handler].
+func WrapHandler(h Handler) source.Store {
+ return handlerStore{h}
}
-func (defaultHandler) ReadStringArray(_ string) ([]string, error) {
- return nil, ErrNoSuchKey
+func (s handlerStore) Lock() error {
+ if lockable, ok := s.h.(source.Lockable); ok {
+ return lockable.Lock()
+ }
+ return nil
}
-// markHandlerInUse is called before handler methods are called.
-func markHandlerInUse() {
- handlerUsed.Store(true)
+func (s handlerStore) Unlock() {
+ if lockable, ok := s.h.(source.Lockable); ok {
+ lockable.Unlock()
+ }
}
-// RegisterHandler initializes the policy handler and ensures registration will happen once.
-func RegisterHandler(h Handler) {
- // Technically this assignment is not concurrency safe, but in the
- // event that there was any risk of a data race, we will panic due to
- // the CompareAndSwap failing.
- handler = h
- if !handlerUsed.CompareAndSwap(false, true) {
- panic("handler was already used before registration")
+func (s handlerStore) RegisterChangeCallback(callback func()) (unregister func(), err error) {
+ if lockable, ok := s.h.(source.Changeable); ok {
+ return lockable.RegisterChangeCallback(callback)
}
+ return func() {}, nil
}
-// TB is a subset of testing.TB that we use to set up test helpers.
-// It's defined here to avoid pulling in the testing package.
-type TB interface {
- Helper()
- Cleanup(func())
+func (s handlerStore) ReadString(key setting.Key) (string, error) {
+ return s.h.ReadString(string(key))
}
-func SetHandlerForTest(tb TB, h Handler) {
- tb.Helper()
- oldHandler := handler
- handler = h
- tb.Cleanup(func() { handler = oldHandler })
+func (s handlerStore) ReadUInt64(key setting.Key) (uint64, error) {
+ return s.h.ReadUInt64(string(key))
+}
+
+func (s handlerStore) ReadBoolean(key setting.Key) (bool, error) {
+ return s.h.ReadBoolean(string(key))
+}
+
+func (s handlerStore) ReadStringArray(key setting.Key) ([]string, error) {
+ return s.h.ReadStringArray(string(key))
+}
+
+func (s handlerStore) Done() <-chan struct{} {
+ if expirable, ok := s.h.(source.Expirable); ok {
+ return expirable.Done()
+ }
+ return nil
}
diff --git a/util/syspolicy/handler_test.go b/util/syspolicy/handler_test.go
deleted file mode 100644
index 39b18936f..000000000
--- a/util/syspolicy/handler_test.go
+++ /dev/null
@@ -1,19 +0,0 @@
-// Copyright (c) Tailscale Inc & AUTHORS
-// SPDX-License-Identifier: BSD-3-Clause
-
-package syspolicy
-
-import "testing"
-
-func TestDefaultHandlerReadValues(t *testing.T) {
- var h defaultHandler
-
- got, err := h.ReadString(string(AdminConsoleVisibility))
- if got != "" || err != ErrNoSuchKey {
- t.Fatalf("got %v err %v", got, err)
- }
- result, err := h.ReadUInt64(string(LogSCMInteractions))
- if result != 0 || err != ErrNoSuchKey {
- t.Fatalf("got %v err %v", result, err)
- }
-}
diff --git a/util/syspolicy/handler_windows.go b/util/syspolicy/handler_windows.go
deleted file mode 100644
index 661853ead..000000000
--- a/util/syspolicy/handler_windows.go
+++ /dev/null
@@ -1,105 +0,0 @@
-// Copyright (c) Tailscale Inc & AUTHORS
-// SPDX-License-Identifier: BSD-3-Clause
-
-package syspolicy
-
-import (
- "errors"
- "fmt"
-
- "tailscale.com/util/clientmetric"
- "tailscale.com/util/winutil"
-)
-
-var (
- windowsErrors = clientmetric.NewCounter("windows_syspolicy_errors")
- windowsAny = clientmetric.NewGauge("windows_syspolicy_any")
-)
-
-type windowsHandler struct{}
-
-func init() {
- RegisterHandler(NewCachingHandler(windowsHandler{}))
-
- keyList := []struct {
- isSet func(Key) bool
- keys []Key
- }{
- {
- isSet: func(k Key) bool {
- _, err := handler.ReadString(string(k))
- return err == nil
- },
- keys: stringKeys,
- },
- {
- isSet: func(k Key) bool {
- _, err := handler.ReadBoolean(string(k))
- return err == nil
- },
- keys: boolKeys,
- },
- {
- isSet: func(k Key) bool {
- _, err := handler.ReadUInt64(string(k))
- return err == nil
- },
- keys: uint64Keys,
- },
- }
-
- var anySet bool
- for _, l := range keyList {
- for _, k := range l.keys {
- if !l.isSet(k) {
- continue
- }
- clientmetric.NewGauge(fmt.Sprintf("windows_syspolicy_%s", k)).Set(1)
- anySet = true
- }
- }
- if anySet {
- windowsAny.Set(1)
- }
-}
-
-func (windowsHandler) ReadString(key string) (string, error) {
- s, err := winutil.GetPolicyString(key)
- if errors.Is(err, winutil.ErrNoValue) {
- err = ErrNoSuchKey
- } else if err != nil {
- windowsErrors.Add(1)
- }
-
- return s, err
-}
-
-func (windowsHandler) ReadUInt64(key string) (uint64, error) {
- value, err := winutil.GetPolicyInteger(key)
- if errors.Is(err, winutil.ErrNoValue) {
- err = ErrNoSuchKey
- } else if err != nil {
- windowsErrors.Add(1)
- }
- return value, err
-}
-
-func (windowsHandler) ReadBoolean(key string) (bool, error) {
- value, err := winutil.GetPolicyInteger(key)
- if errors.Is(err, winutil.ErrNoValue) {
- err = ErrNoSuchKey
- } else if err != nil {
- windowsErrors.Add(1)
- }
- return value != 0, err
-}
-
-func (windowsHandler) ReadStringArray(key string) ([]string, error) {
- value, err := winutil.GetPolicyStringArray(key)
- if errors.Is(err, winutil.ErrNoValue) {
- err = ErrNoSuchKey
- } else if err != nil {
- windowsErrors.Add(1)
- }
- return value, err
-}
diff --git a/util/syspolicy/internal/internal.go b/util/syspolicy/internal/internal.go
new file mode 100644
index 000000000..4c3e28d39
--- /dev/null
+++ b/util/syspolicy/internal/internal.go
@@ -0,0 +1,63 @@
+// Copyright (c) Tailscale Inc & AUTHORS
+// SPDX-License-Identifier: BSD-3-Clause
+
+// Package internal contains miscellaneous functions and types
+// that are internal to the syspolicy packages.
+package internal
+
+import (
+ "bytes"
+
+ "github.com/go-json-experiment/json/jsontext"
+ "tailscale.com/types/lazy"
+ "tailscale.com/version"
+)
+
+// OSForTesting is the operating system override used for testing.
+// It follows the same naming convention as [version.OS].
+var OSForTesting lazy.SyncValue[string]
+
+// OS is like [version.OS], but supports a test hook.
+func OS() string {
+ return OSForTesting.Get(version.OS)
+}
+
+// TB is a subset of testing.TB that we use to set up test helpers.
+// It's defined here to avoid pulling in the testing package.
+type TB interface {
+ Helper()
+ Cleanup(func())
+ Logf(format string, args ...any)
+ Error(args ...any)
+ Errorf(format string, args ...any)
+ Fatal(args ...any)
+ Fatalf(format string, args ...any)
+}
+
+// EqualJSONForTest compares the JSON in j1 and j2 for semantic equality.
+// It returns "", "", true if j1 and j2 are equal. Otherwise, it returns
+// indented versions of j1 and j2 and false.
+func EqualJSONForTest(tb TB, j1, j2 jsontext.Value) (s1, s2 string, equal bool) {
+ tb.Helper()
+ j1 = j1.Clone()
+ j2 = j2.Clone()
+ // Canonicalize JSON values for comparison.
+ if err := j1.Canonicalize(); err != nil {
+ tb.Error(err)
+ }
+ if err := j2.Canonicalize(); err != nil {
+ tb.Error(err)
+ }
+ // Check and return true if the two values are structurally equal.
+ if bytes.Equal(j1, j2) {
+ return "", "", true
+ }
+ // Otherwise, format the values for display and return false.
+ if err := j1.Indent("", "\t"); err != nil {
+ tb.Fatal(err)
+ }
+ if err := j2.Indent("", "\t"); err != nil {
+ tb.Fatal(err)
+ }
+ return j1.String(), j2.String(), false
+}
diff --git a/util/syspolicy/internal/lazyinit/lazyinit.go b/util/syspolicy/internal/lazyinit/lazyinit.go
new file mode 100644
index 000000000..94c16c238
--- /dev/null
+++ b/util/syspolicy/internal/lazyinit/lazyinit.go
@@ -0,0 +1,84 @@
+// Copyright (c) Tailscale Inc & AUTHORS
+// SPDX-License-Identifier: BSD-3-Clause
+
+// The lazyinit package facilitates deferred package initialization.
+package lazyinit
+
+import (
+ "sync"
+ "sync/atomic"
+)
+
+var packageInit deferredOnce
+
+// Defer defers the specified action until [Do] is called.
+// It returns a boolean indicating whether [Do] has already been called.
+func Defer(action func() error) bool {
+ return packageInit.Defer(action)
+}
+
+// DeferWithCleanup is like [Defer], but the action function returns a cleanup
+// function to be called in case of an error.
+func DeferWithCleanup(action func() (cleanup func(), err error)) bool {
+ return packageInit.DeferWithCleanup(action)
+}
+
+// Do runs all deferred functions and returns an error if any of them fail.
+func Do() error {
+ return packageInit.Do()
+}
+
+type deferredOnce struct {
+ done atomic.Uint32
+ err error
+ m sync.Mutex
+ funcs []func() (cleanup func(), err error)
+}
+
+func (o *deferredOnce) Defer(action func() error) bool {
+ return o.DeferWithCleanup(func() (cleanup func(), err error) {
+ return nil, action()
+ })
+}
+
+func (o *deferredOnce) DeferWithCleanup(action func() (cleanup func(), err error)) bool {
+ o.m.Lock()
+ defer o.m.Unlock()
+ if o.done.Load() != 0 {
+ return false
+ }
+ o.funcs = append(o.funcs, action)
+ return true
+}
+
+func (o *deferredOnce) Do() error {
+ if o.done.Load() == 0 {
+ o.doSlow()
+ }
+ return o.err
+}
+
+func (o *deferredOnce) doSlow() (err error) {
+ o.m.Lock()
+ defer o.m.Unlock()
+ if o.done.Load() == 0 {
+ defer func() {
+ o.done.Store(1)
+ o.err = err
+ }()
+ for _, f := range o.funcs {
+ cleanup, err := f()
+ if err != nil {
+ return err
+ }
+ if cleanup != nil {
+ defer func() {
+ if err != nil {
+ cleanup()
+ }
+ }()
+ }
+ }
+ }
+ return o.err
+}
diff --git a/util/syspolicy/internal/loggerx/logger.go b/util/syspolicy/internal/loggerx/logger.go
new file mode 100644
index 000000000..b28610826
--- /dev/null
+++ b/util/syspolicy/internal/loggerx/logger.go
@@ -0,0 +1,46 @@
+// Copyright (c) Tailscale Inc & AUTHORS
+// SPDX-License-Identifier: BSD-3-Clause
+
+// Package loggerx provides logging functions to the rest of the syspolicy packages.
+package loggerx
+
+import (
+ "log"
+
+ "tailscale.com/types/lazy"
+ "tailscale.com/types/logger"
+ "tailscale.com/util/syspolicy/internal"
+)
+
+const (
+ errorPrefix = "syspolicy: "
+ verbosePrefix = "syspolicy: [v2] "
+)
+
+var (
+ lazyErrorf lazy.SyncValue[logger.Logf]
+ lazyVerbosef lazy.SyncValue[logger.Logf]
+)
+
+// Errorf formats and writes an error message to the log.
+func Errorf(format string, args ...any) {
+ errorf := lazyErrorf.Get(func() logger.Logf {
+ return logger.WithPrefix(log.Printf, errorPrefix)
+ })
+ errorf(format, args...)
+}
+
+// Verbosef formats and writes an optional, verbose message to the log.
+func Verbosef(format string, args ...any) {
+ verbosef := lazyVerbosef.Get(func() logger.Logf {
+ return logger.WithPrefix(log.Printf, verbosePrefix)
+ })
+ verbosef(format, args...)
+}
+
+// SetForTest sets the specified errorf and verbosef functions for the duration
+// of tb and its subtests.
+func SetForTest(tb internal.TB, errorf, verbosef logger.Logf) {
+ lazyErrorf.SetForTest(tb, errorf, nil)
+ lazyVerbosef.SetForTest(tb, verbosef, nil)
+}
diff --git a/util/syspolicy/internal/metrics/metrics.go b/util/syspolicy/internal/metrics/metrics.go
new file mode 100644
index 000000000..4f2bf5396
--- /dev/null
+++ b/util/syspolicy/internal/metrics/metrics.go
@@ -0,0 +1,315 @@
+// Copyright (c) Tailscale Inc & AUTHORS
+// SPDX-License-Identifier: BSD-3-Clause
+
+// Package metrics provides logging and reporting for policy settings and scopes.
+package metrics
+
+import (
+ "strings"
+ "sync"
+
+ xmaps "golang.org/x/exp/maps"
+
+ "tailscale.com/syncs"
+ "tailscale.com/types/lazy"
+ "tailscale.com/util/clientmetric"
+ "tailscale.com/util/mak"
+ "tailscale.com/util/syspolicy/internal"
+ "tailscale.com/util/syspolicy/internal/loggerx"
+ "tailscale.com/util/syspolicy/setting"
+ "tailscale.com/util/testenv"
+)
+
+var lazyReportMetrics lazy.SyncValue[bool] // used as a test hook
+
+// ShouldReport reports whether metrics should be reported on the current environment.
+func ShouldReport() bool {
+ return lazyReportMetrics.Get(func() bool {
+ // macOS, iOS and tvOS create their own metrics,
+ // and we don't have syspolicy on any other platforms.
+ return setting.PlatformList{"android", "windows"}.HasCurrent()
+ })
+}
+
+// Reset metrics for the specified policy origin.
+func Reset(origin *setting.Origin) {
+ scopeMetrics(origin).Reset()
+}
+
+// ReportConfigured updates metrics and logs that the specified setting is
+// configured with the given value in the origin.
+func ReportConfigured(origin *setting.Origin, setting *setting.Definition, value any) {
+ settingMetricsFor(setting).ReportValue(origin, value)
+}
+
+// ReportError updates metrics and logs that the specified setting has an error
+// in the origin.
+func ReportError(origin *setting.Origin, setting *setting.Definition, err error) {
+ settingMetricsFor(setting).ReportError(origin, err)
+}
+
+// ReportNotConfigured updates metrics and logs that the specified setting is
+// not configured in the origin.
+func ReportNotConfigured(origin *setting.Origin, setting *setting.Definition) {
+ settingMetricsFor(setting).Reset(origin)
+}
+
+// metric is an interface implemented by [clientmetric.Metric] and [funcMetric].
+type metric interface {
+ Add(v int64)
+ Set(v int64)
+}
+
+// policyScopeMetrics are metrics that apply to an entire policy scope rather
+// than a specific policy setting.
+type policyScopeMetrics struct {
+ hasAny metric
+ numErrored metric
+}
+
+func newScopeMetrics(scope setting.Scope) *policyScopeMetrics {
+ prefix := metricScopeName(scope)
+ if prefix != "" {
+ prefix += "_"
+ }
+ // {os}_syspolicy_{scope_unless_device}_any
+ // Example: windows_syspolicy_any or windows_syspolicy_user_any.
+ hasAny := newMetric(prefix+"any", clientmetric.TypeGauge)
+ // {os}_syspolicy_{scope_unless_device}_errors
+ // Example: windows_syspolicy_errors or windows_syspolicy_user_errors.
+ //
+ // TODO(nickkhyl): maybe make the `{os}_syspolicy_errors` metric a gauge rather than a counter?
+ // It was a counter prior to https://github.com/tailscale/tailscale/issues/12687, so I kept it as such.
+ // But I think a gauge makes more sense: syspolicy errors indicate a mismatch between the expected
+ // policy value type or format and the actual value read from the underlying store (like the Windows Registry).
+ // We'll encounter the same error every time we re-read the policy setting from the backing store
+ // until the policy value is corrected by the user, or until we fix the bug in the code or ADMX.
+ // There's probably no reason to count and accumulate them over time.
+ numErrored := newMetric(prefix+"errors", clientmetric.TypeCounter)
+ return &policyScopeMetrics{hasAny, numErrored}
+}
+
+// ReportHasSettings is called when there's any configured policy setting in the scope.
+func (m *policyScopeMetrics) ReportHasSettings() {
+ if m != nil {
+ m.hasAny.Set(1)
+ }
+}
+
+// ReportError is called when there's any errored policy setting in the scope.
+func (m *policyScopeMetrics) ReportError() {
+ if m != nil {
+ m.numErrored.Add(1)
+ }
+}
+
+// Reset is called to reset the policy scope metrics, such as when the policy scope
+// is about to be reloaded.
+func (m *policyScopeMetrics) Reset() {
+ if m != nil {
+ m.hasAny.Set(0)
+ // numErrored is a counter and cannot be (re-)set.
+ }
+}
+
+// settingMetrics are metrics for a single policy setting in one or more scopes.
+type settingMetrics struct {
+ definition *setting.Definition
+ isSet []metric // by scope
+ hasErrors []metric // by scope
+}
+
+// ReportValue is called when the policy setting is found to be configured in the specified source.
+func (m *settingMetrics) ReportValue(origin *setting.Origin, v any) {
+ if m == nil {
+ return
+ }
+ if scope := origin.Scope().Kind(); int(scope) < len(m.isSet) {
+ m.isSet[scope].Set(1)
+ m.hasErrors[scope].Set(0)
+ }
+ scopeMetrics(origin).ReportHasSettings()
+ loggerx.Verbosef("%v(%q) = %v\n", origin, m.definition.Key(), v)
+}
+
+// ReportError is called when there's an error with the policy setting in the specified source.
+func (m *settingMetrics) ReportError(origin *setting.Origin, err error) {
+ if m == nil {
+ return
+ }
+ if scope := origin.Scope().Kind(); int(scope) < len(m.hasErrors) {
+ m.isSet[scope].Set(0)
+ m.hasErrors[scope].Set(1)
+ }
+ scopeMetrics(origin).ReportError()
+ loggerx.Errorf("%v(%q): %v\n", origin, m.definition.Key(), err)
+}
+
+// Reset is called to reset the policy setting's metrics, such as when
+// the policy setting does not exist or the source containing the policy
+// is about to be reloaded.
+func (m *settingMetrics) Reset(origin *setting.Origin) {
+ if m == nil {
+ return
+ }
+ if scope := origin.Scope().Kind(); int(scope) < len(m.isSet) {
+ m.isSet[scope].Set(0)
+ m.hasErrors[scope].Set(0)
+ }
+}
+
+// metricFn is a function that adds or sets a metric value.
+type metricFn = func(name string, typ clientmetric.Type, v int64)
+
+// funcMetric implements [metric] by calling the specified add and set functions.
+// Used for testing, and with nil functions on platforms that do not support
+// syspolicy, and on platforms that report policy metrics from the GUI.
+type funcMetric struct {
+ name string
+ typ clientmetric.Type
+ add, set metricFn
+}
+
+func (m funcMetric) Add(v int64) {
+ if m.add != nil {
+ m.add(m.name, m.typ, v)
+ }
+}
+
+func (m funcMetric) Set(v int64) {
+ if m.set != nil {
+ m.set(m.name, m.typ, v)
+ }
+}
+
+var (
+ lazyDeviceMetrics lazy.SyncValue[*policyScopeMetrics]
+ lazyProfileMetrics lazy.SyncValue[*policyScopeMetrics]
+ lazyUserMetrics lazy.SyncValue[*policyScopeMetrics]
+)
+
+func scopeMetrics(origin *setting.Origin) *policyScopeMetrics {
+ switch origin.Scope().Kind() {
+ case setting.DeviceSetting:
+ return lazyDeviceMetrics.Get(func() *policyScopeMetrics {
+ return newScopeMetrics(setting.DeviceSetting)
+ })
+ case setting.ProfileSetting:
+ return lazyProfileMetrics.Get(func() *policyScopeMetrics {
+ return newScopeMetrics(setting.ProfileSetting)
+ })
+ case setting.UserSetting:
+ return lazyUserMetrics.Get(func() *policyScopeMetrics {
+ return newScopeMetrics(setting.UserSetting)
+ })
+ default:
+ panic("unreachable")
+ }
+}
+
+var (
+ settingMetricsMu sync.RWMutex
+ settingMetricsMap map[setting.Key]*settingMetrics
+)
+
+func settingMetricsFor(setting *setting.Definition) *settingMetrics {
+ settingMetricsMu.RLock()
+ if metrics, ok := settingMetricsMap[setting.Key()]; ok {
+ settingMetricsMu.RUnlock()
+ return metrics
+ }
+ settingMetricsMu.RUnlock()
+ return settingMetricsForSlow(setting)
+}
+
+func settingMetricsForSlow(d *setting.Definition) *settingMetrics {
+ settingMetricsMu.Lock()
+ defer settingMetricsMu.Unlock()
+ if metrics, ok := settingMetricsMap[d.Key()]; ok {
+ return metrics
+ }
+
+ isSet := make([]metric, d.Scope()+1)
+ hasErrors := make([]metric, d.Scope()+1)
+ for i := range isSet {
+ scope := setting.Scope(i)
+ // {os}_syspolicy_{key}_{scope_unless_device}
+ // Example: windows_syspolicy_AdminConsole or windows_syspolicy_AdminConsole_user.
+ isSet[i] = newSettingMetric(d.Key(), scope, "", clientmetric.TypeGauge)
+ // {os}_syspolicy_{key}_{scope_unless_device}_error
+ // Example: windows_syspolicy_AdminConsole_error or windows_syspolicy_TestSetting01_user_error.
+ hasErrors[i] = newSettingMetric(d.Key(), scope, "error", clientmetric.TypeGauge)
+ }
+ metrics := &settingMetrics{d, isSet, hasErrors}
+ mak.Set(&settingMetricsMap, d.Key(), metrics)
+ return metrics
+}
+
+// hooks for testing
+var addMetricTestHook, setMetricTestHook syncs.AtomicValue[metricFn]
+
+// SetHooksForTest sets the specified addMetric and setMetric functions
+// as the metric functions for the duration of tb and all its subtests.
+func SetHooksForTest(tb internal.TB, addMetric, setMetric metricFn) {
+ oldAddMetric := addMetricTestHook.Swap(addMetric)
+ oldSetMetric := setMetricTestHook.Swap(setMetric)
+ tb.Cleanup(func() {
+ addMetricTestHook.Store(oldAddMetric)
+ setMetricTestHook.Store(oldSetMetric)
+ })
+
+ settingMetricsMu.Lock()
+ oldSettingMetricsMap := xmaps.Clone(settingMetricsMap)
+ clear(settingMetricsMap)
+ settingMetricsMu.Unlock()
+ tb.Cleanup(func() {
+ settingMetricsMu.Lock()
+ settingMetricsMap = oldSettingMetricsMap
+ settingMetricsMu.Unlock()
+ })
+
+ // (re-)set the scope metrics to use the test hooks for the duration of tb.
+ lazyDeviceMetrics.SetForTest(tb, newScopeMetrics(setting.DeviceSetting), nil)
+ lazyProfileMetrics.SetForTest(tb, newScopeMetrics(setting.ProfileSetting), nil)
+ lazyUserMetrics.SetForTest(tb, newScopeMetrics(setting.UserSetting), nil)
+}
+
+func newSettingMetric(key setting.Key, scope setting.Scope, suffix string, typ clientmetric.Type) metric {
+ name := strings.ReplaceAll(string(key), setting.KeyPathSeparator, "_")
+ if tag := metricScopeName(scope); tag != "" {
+ name += "_" + tag
+ }
+ if suffix != "" {
+ name += "_" + suffix
+ }
+ return newMetric(name, typ)
+}
+
+func newMetric(name string, typ clientmetric.Type) metric {
+ name = internal.OS() + "_syspolicy_" + name
+ switch {
+ case !ShouldReport():
+ return &funcMetric{name: name, typ: typ}
+ case testenv.InTest():
+ return &funcMetric{name, typ, addMetricTestHook.Load(), setMetricTestHook.Load()}
+ case typ == clientmetric.TypeCounter:
+ return clientmetric.NewCounter(name)
+ case typ == clientmetric.TypeGauge:
+ return clientmetric.NewGauge(name)
+ default:
+ panic("unreachable")
+ }
+}
+
+func metricScopeName(scope setting.Scope) string {
+ switch scope {
+ case setting.DeviceSetting:
+ return ""
+ case setting.ProfileSetting:
+ return "profile"
+ case setting.UserSetting:
+ return "user"
+ default:
+ panic("unreachable")
+ }
+}
diff --git a/util/syspolicy/internal/metrics/metrics_test.go b/util/syspolicy/internal/metrics/metrics_test.go
new file mode 100644
index 000000000..07be4773c
--- /dev/null
+++ b/util/syspolicy/internal/metrics/metrics_test.go
@@ -0,0 +1,423 @@
+// Copyright (c) Tailscale Inc & AUTHORS
+// SPDX-License-Identifier: BSD-3-Clause
+
+package metrics
+
+import (
+ "errors"
+ "testing"
+
+ "tailscale.com/types/lazy"
+ "tailscale.com/util/clientmetric"
+ "tailscale.com/util/syspolicy/internal"
+ "tailscale.com/util/syspolicy/setting"
+)
+
+func TestSettingMetricNames(t *testing.T) {
+ tests := []struct {
+ name string
+ key setting.Key
+ scope setting.Scope
+ suffix string
+ typ clientmetric.Type
+ osOverride string
+ wantMetricName string
+ }{
+ {
+ name: "windows-device-no-suffix",
+ key: "AdminConsole",
+ scope: setting.DeviceSetting,
+ suffix: "",
+ typ: clientmetric.TypeCounter,
+ osOverride: "windows",
+ wantMetricName: "windows_syspolicy_AdminConsole",
+ },
+ {
+ name: "windows-user-no-suffix",
+ key: "AdminConsole",
+ scope: setting.UserSetting,
+ suffix: "",
+ typ: clientmetric.TypeCounter,
+ osOverride: "windows",
+ wantMetricName: "windows_syspolicy_AdminConsole_user",
+ },
+ {
+ name: "windows-profile-no-suffix",
+ key: "AdminConsole",
+ scope: setting.ProfileSetting,
+ suffix: "",
+ typ: clientmetric.TypeCounter,
+ osOverride: "windows",
+ wantMetricName: "windows_syspolicy_AdminConsole_profile",
+ },
+ {
+ name: "windows-profile-err",
+ key: "AdminConsole",
+ scope: setting.ProfileSetting,
+ suffix: "error",
+ typ: clientmetric.TypeCounter,
+ osOverride: "windows",
+ wantMetricName: "windows_syspolicy_AdminConsole_profile_error",
+ },
+ {
+ name: "android-device-no-suffix",
+ key: "AdminConsole",
+ scope: setting.DeviceSetting,
+ suffix: "",
+ typ: clientmetric.TypeCounter,
+ osOverride: "android",
+ wantMetricName: "android_syspolicy_AdminConsole",
+ },
+ {
+ name: "key-path",
+ key: "category/subcategory/setting",
+ scope: setting.DeviceSetting,
+ suffix: "",
+ typ: clientmetric.TypeCounter,
+ osOverride: "fakeos",
+ wantMetricName: "fakeos_syspolicy_category_subcategory_setting",
+ },
+ }
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ internal.OSForTesting.SetForTest(t, tt.osOverride, nil)
+ metric, ok := newSettingMetric(tt.key, tt.scope, tt.suffix, tt.typ).(*funcMetric)
+ if !ok {
+ t.Fatal("metric is not a funcMetric")
+ }
+ if metric.name != tt.wantMetricName {
+ t.Errorf("got %q, want %q", metric.name, tt.wantMetricName)
+ }
+ })
+ }
+}
+
+func TestScopeMetrics(t *testing.T) {
+ tests := []struct {
+ name string
+ scope setting.Scope
+ osOverride string
+ wantHasAnyName string
+ wantNumErroredName string
+ wantHasAnyType clientmetric.Type
+ wantNumErroredType clientmetric.Type
+ }{
+ {
+ name: "windows-device",
+ scope: setting.DeviceSetting,
+ osOverride: "windows",
+ wantHasAnyName: "windows_syspolicy_any",
+ wantHasAnyType: clientmetric.TypeGauge,
+ wantNumErroredName: "windows_syspolicy_errors",
+ wantNumErroredType: clientmetric.TypeCounter,
+ },
+ {
+ name: "windows-profile",
+ scope: setting.ProfileSetting,
+ osOverride: "windows",
+ wantHasAnyName: "windows_syspolicy_profile_any",
+ wantHasAnyType: clientmetric.TypeGauge,
+ wantNumErroredName: "windows_syspolicy_profile_errors",
+ wantNumErroredType: clientmetric.TypeCounter,
+ },
+ {
+ name: "windows-user",
+ scope: setting.UserSetting,
+ osOverride: "windows",
+ wantHasAnyName: "windows_syspolicy_user_any",
+ wantHasAnyType: clientmetric.TypeGauge,
+ wantNumErroredName: "windows_syspolicy_user_errors",
+ wantNumErroredType: clientmetric.TypeCounter,
+ },
+ {
+ name: "android-device",
+ scope: setting.DeviceSetting,
+ osOverride: "android",
+ wantHasAnyName: "android_syspolicy_any",
+ wantHasAnyType: clientmetric.TypeGauge,
+ wantNumErroredName: "android_syspolicy_errors",
+ wantNumErroredType: clientmetric.TypeCounter,
+ },
+ }
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ internal.OSForTesting.SetForTest(t, tt.osOverride, nil)
+ metrics := newScopeMetrics(tt.scope)
+ hasAny, ok := metrics.hasAny.(*funcMetric)
+ if !ok {
+ t.Fatal("hasAny is not a funcMetric")
+ }
+ numErrored, ok := metrics.numErrored.(*funcMetric)
+ if !ok {
+ t.Fatal("numErrored is not a funcMetric")
+ }
+ if hasAny.name != tt.wantHasAnyName {
+ t.Errorf("hasAny.Name: got %q, want %q", hasAny.name, tt.wantHasAnyName)
+ }
+ if hasAny.typ != tt.wantHasAnyType {
+ t.Errorf("hasAny.Type: got %q, want %q", hasAny.typ, tt.wantHasAnyType)
+ }
+ if numErrored.name != tt.wantNumErroredName {
+ t.Errorf("numErrored.Name: got %q, want %q", numErrored.name, tt.wantNumErroredName)
+ }
+ if numErrored.typ != tt.wantNumErroredType {
+ t.Errorf("hasAny.Type: got %q, want %q", numErrored.typ, tt.wantNumErroredType)
+ }
+ })
+ }
+}
+
+type testSettingDetails struct {
+ definition *setting.Definition
+ origin *setting.Origin
+ value any
+ err error
+}
+
+func TestReportMetrics(t *testing.T) {
+ tests := []struct {
+ name string
+ osOverride string
+ useMetrics bool
+ settings []testSettingDetails
+ wantMetrics []TestState
+ wantResetMetrics []TestState
+ }{
+ {
+ name: "none",
+ osOverride: "windows",
+ settings: []testSettingDetails{},
+ wantMetrics: []TestState{},
+ },
+ {
+ name: "single-value",
+ osOverride: "windows",
+ settings: []testSettingDetails{
+ {
+ definition: setting.NewDefinition("TestSetting01", setting.DeviceSetting, setting.IntegerValue),
+ origin: setting.NewOrigin(setting.DeviceScope),
+ value: 42,
+ },
+ },
+ wantMetrics: []TestState{
+ {"windows_syspolicy_any", 1},
+ {"windows_syspolicy_TestSetting01", 1},
+ },
+ wantResetMetrics: []TestState{
+ {"windows_syspolicy_any", 0},
+ {"windows_syspolicy_TestSetting01", 0},
+ },
+ },
+ {
+ name: "single-error",
+ osOverride: "windows",
+ settings: []testSettingDetails{
+ {
+ definition: setting.NewDefinition("TestSetting02", setting.DeviceSetting, setting.IntegerValue),
+ origin: setting.NewOrigin(setting.DeviceScope),
+ err: errors.New("bang!"),
+ },
+ },
+ wantMetrics: []TestState{
+ {"windows_syspolicy_errors", 1},
+ {"windows_syspolicy_TestSetting02_error", 1},
+ },
+ wantResetMetrics: []TestState{
+ {"windows_syspolicy_errors", 1},
+ {"windows_syspolicy_TestSetting02_error", 0},
+ },
+ },
+ {
+ name: "value-and-error",
+ osOverride: "windows",
+ settings: []testSettingDetails{
+ {
+ definition: setting.NewDefinition("TestSetting01", setting.DeviceSetting, setting.IntegerValue),
+ origin: setting.NewOrigin(setting.DeviceScope),
+ value: 42,
+ },
+ {
+ definition: setting.NewDefinition("TestSetting02", setting.DeviceSetting, setting.IntegerValue),
+ origin: setting.NewOrigin(setting.DeviceScope),
+ err: errors.New("bang!"),
+ },
+ },
+
+ wantMetrics: []TestState{
+ {"windows_syspolicy_any", 1},
+ {"windows_syspolicy_errors", 1},
+ {"windows_syspolicy_TestSetting01", 1},
+ {"windows_syspolicy_TestSetting02_error", 1},
+ },
+ wantResetMetrics: []TestState{
+ {"windows_syspolicy_any", 0},
+ {"windows_syspolicy_errors", 1},
+ {"windows_syspolicy_TestSetting01", 0},
+ {"windows_syspolicy_TestSetting02_error", 0},
+ },
+ },
+ {
+ name: "two-values",
+ osOverride: "windows",
+ settings: []testSettingDetails{
+ {
+ definition: setting.NewDefinition("TestSetting01", setting.DeviceSetting, setting.IntegerValue),
+ origin: setting.NewOrigin(setting.DeviceScope),
+ value: 42,
+ },
+ {
+ definition: setting.NewDefinition("TestSetting02", setting.DeviceSetting, setting.IntegerValue),
+ origin: setting.NewOrigin(setting.DeviceScope),
+ value: 17,
+ },
+ },
+ wantMetrics: []TestState{
+ {"windows_syspolicy_any", 1},
+ {"windows_syspolicy_TestSetting01", 1},
+ {"windows_syspolicy_TestSetting02", 1},
+ },
+ wantResetMetrics: []TestState{
+ {"windows_syspolicy_any", 0},
+ {"windows_syspolicy_TestSetting01", 0},
+ {"windows_syspolicy_TestSetting02", 0},
+ },
+ },
+ {
+ name: "two-errors",
+ osOverride: "windows",
+ settings: []testSettingDetails{
+ {
+ definition: setting.NewDefinition("TestSetting01", setting.DeviceSetting, setting.IntegerValue),
+ origin: setting.NewOrigin(setting.DeviceScope),
+ err: errors.New("bang!"),
+ },
+ {
+ definition: setting.NewDefinition("TestSetting02", setting.DeviceSetting, setting.IntegerValue),
+ origin: setting.NewOrigin(setting.DeviceScope),
+ err: errors.New("bang!"),
+ },
+ },
+ wantMetrics: []TestState{
+ {"windows_syspolicy_errors", 2},
+ {"windows_syspolicy_TestSetting01_error", 1},
+ {"windows_syspolicy_TestSetting02_error", 1},
+ },
+ wantResetMetrics: []TestState{
+ {"windows_syspolicy_errors", 2},
+ {"windows_syspolicy_TestSetting01_error", 0},
+ {"windows_syspolicy_TestSetting02_error", 0},
+ },
+ },
+ {
+ name: "multi-scope",
+ osOverride: "windows",
+ settings: []testSettingDetails{
+ {
+ definition: setting.NewDefinition("TestSetting01", setting.ProfileSetting, setting.IntegerValue),
+ origin: setting.NewOrigin(setting.DeviceScope),
+ value: 42,
+ },
+ {
+ definition: setting.NewDefinition("TestSetting02", setting.ProfileSetting, setting.IntegerValue),
+ origin: setting.NewOrigin(setting.CurrentProfileScope),
+ err: errors.New("bang!"),
+ },
+ {
+ definition: setting.NewDefinition("TestSetting03", setting.UserSetting, setting.IntegerValue),
+ origin: setting.NewOrigin(setting.CurrentUserScope),
+ value: 17,
+ },
+ },
+ wantMetrics: []TestState{
+ {"windows_syspolicy_any", 1},
+ {"windows_syspolicy_profile_errors", 1},
+ {"windows_syspolicy_user_any", 1},
+ {"windows_syspolicy_TestSetting01", 1},
+ {"windows_syspolicy_TestSetting02_profile_error", 1},
+ {"windows_syspolicy_TestSetting03_user", 1},
+ },
+ wantResetMetrics: []TestState{
+ {"windows_syspolicy_any", 0},
+ {"windows_syspolicy_profile_errors", 1},
+ {"windows_syspolicy_user_any", 0},
+ {"windows_syspolicy_TestSetting01", 0},
+ {"windows_syspolicy_TestSetting02_profile_error", 0},
+ {"windows_syspolicy_TestSetting03_user", 0},
+ },
+ },
+ {
+ name: "report-metrics-on-android",
+ osOverride: "android",
+ settings: []testSettingDetails{
+ {
+ definition: setting.NewDefinition("TestSetting01", setting.DeviceSetting, setting.IntegerValue),
+ origin: setting.NewOrigin(setting.DeviceScope),
+ value: 42,
+ },
+ },
+ wantMetrics: []TestState{
+ {"android_syspolicy_any", 1},
+ {"android_syspolicy_TestSetting01", 1},
+ },
+ wantResetMetrics: []TestState{
+ {"android_syspolicy_any", 0},
+ {"android_syspolicy_TestSetting01", 0},
+ },
+ },
+ {
+ name: "do-not-report-metrics-on-macos",
+ osOverride: "macos",
+ settings: []testSettingDetails{
+ {
+ definition: setting.NewDefinition("TestSetting01", setting.DeviceSetting, setting.IntegerValue),
+ origin: setting.NewOrigin(setting.DeviceScope),
+ value: 42,
+ },
+ },
+
+ wantMetrics: []TestState{}, // none reported
+ },
+ {
+ name: "do-not-report-metrics-on-ios",
+ osOverride: "ios",
+ settings: []testSettingDetails{
+ {
+ definition: setting.NewDefinition("TestSetting01", setting.DeviceSetting, setting.IntegerValue),
+ origin: setting.NewOrigin(setting.DeviceScope),
+ value: 42,
+ },
+ },
+
+ wantMetrics: []TestState{}, // none reported
+ },
+ }
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ // Reset the lazy value so it'll be re-evaluated with the osOverride.
+ lazyReportMetrics = lazy.SyncValue[bool]{}
+ t.Cleanup(func() {
+ // Also reset it during the cleanup.
+ lazyReportMetrics = lazy.SyncValue[bool]{}
+ })
+ internal.OSForTesting.SetForTest(t, tt.osOverride, nil)
+
+ h := NewTestHandler(t)
+ SetHooksForTest(t, h.AddMetric, h.SetMetric)
+
+ for _, s := range tt.settings {
+ if s.err != nil {
+ ReportError(s.origin, s.definition, s.err)
+ } else {
+ ReportConfigured(s.origin, s.definition, s.value)
+ }
+ }
+ h.MustEqual(tt.wantMetrics...)
+
+ for _, s := range tt.settings {
+ Reset(s.origin)
+ ReportNotConfigured(s.origin, s.definition)
+ }
+ h.MustEqual(tt.wantResetMetrics...)
+ })
+ }
+}
diff --git a/util/syspolicy/internal/metrics/test_handler.go b/util/syspolicy/internal/metrics/test_handler.go
new file mode 100644
index 000000000..50ee42bbe
--- /dev/null
+++ b/util/syspolicy/internal/metrics/test_handler.go
@@ -0,0 +1,88 @@
+// Copyright (c) Tailscale Inc & AUTHORS
+// SPDX-License-Identifier: BSD-3-Clause
+
+package metrics
+
+import (
+ "strings"
+
+ "tailscale.com/util/clientmetric"
+ "tailscale.com/util/set"
+ "tailscale.com/util/syspolicy/internal"
+)
+
+// TestState represents a metric name and its expected value.
+type TestState struct {
+ Name string // `$os` in the name will be replaced by the actual operating system name.`
+ Value int64
+}
+
+// TestHandler facilitates testing of the code that uses metrics.
+type TestHandler struct {
+ t internal.TB
+
+ m map[string]int64
+}
+
+// NewTestHandler returns a new TestHandler.
+func NewTestHandler(t internal.TB) *TestHandler {
+ return &TestHandler{t, make(map[string]int64)}
+}
+
+// AddMetric increments the metric with the specified name and type by delta d.
+func (h *TestHandler) AddMetric(name string, typ clientmetric.Type, d int64) {
+ h.t.Helper()
+ if typ == clientmetric.TypeCounter && d < 0 {
+ h.t.Fatalf("an attempt was made to decrement a counter metric %q", name)
+ }
+ if v, ok := h.m[name]; ok || d != 0 {
+ h.m[name] = v + d
+ }
+}
+
+// SetMetric sets the metric with the specified name and type to the value v.
+func (h *TestHandler) SetMetric(name string, typ clientmetric.Type, v int64) {
+ h.t.Helper()
+ if typ == clientmetric.TypeCounter {
+ h.t.Fatalf("an attempt was made to set a counter metric %q", name)
+ }
+ if _, ok := h.m[name]; ok || v != 0 {
+ h.m[name] = v
+ }
+}
+
+// MustEqual fails the test if the actual metric state differs from the specified state.
+func (h *TestHandler) MustEqual(metrics ...TestState) {
+ h.t.Helper()
+ h.MustContain(metrics...)
+ h.mustNoExtra(metrics...)
+}
+
+// MustContain fails the test if the specified metrics are not set or have
+// different values than specified. It permits other metrics to be set in
+// addition to the ones being tested.
+func (h *TestHandler) MustContain(metrics ...TestState) {
+ h.t.Helper()
+ for _, m := range metrics {
+ name := strings.ReplaceAll(m.Name, "$os", internal.OS())
+ v, ok := h.m[name]
+ if !ok {
+ h.t.Errorf("%q: got (none), want %v", name, m.Value)
+ } else if v != m.Value {
+ h.t.Fatalf("%q: got %v, want %v", name, v, m.Value)
+ }
+ }
+}
+
+func (h *TestHandler) mustNoExtra(metrics ...TestState) {
+ h.t.Helper()
+ s := make(set.Set[string])
+ for i := range metrics {
+ s.Add(strings.ReplaceAll(metrics[i].Name, "$os", internal.OS()))
+ }
+ for n, v := range h.m {
+ if !s.Contains(n) {
+ h.t.Errorf("%q: got %v, want (none)", n, v)
+ }
+ }
+}
diff --git a/util/syspolicy/policy_keys.go b/util/syspolicy/policy_keys.go
index ef0cfed8f..cf5685c01 100644
--- a/util/syspolicy/policy_keys.go
+++ b/util/syspolicy/policy_keys.go
@@ -3,7 +3,21 @@
package syspolicy
-type Key string
+import (
+ "tailscale.com/types/lazy"
+ "tailscale.com/util/syspolicy/internal/lazyinit"
+ "tailscale.com/util/syspolicy/setting"
+ "tailscale.com/util/testenv"
+)
+
+type Key = setting.Key
+
+// The const block below lists known policy keys.
+// When adding a key to this list, remember to add a corresponding
+// [setting.Definition] to [implicitDefinitions] below.
+// Otherwise, the [TestKnownKeysRegistered] test will fail as a reminder.
+// Preferably, use a strongly typed policy hierarchy, such as [Policy],
+// instead of adding each key to the list below.
const (
// Keys with a string value
@@ -96,3 +110,83 @@ const (
// AllowedSuggestedExitNodes's string array value is a list of exit node IDs that restricts which exit nodes are considered when generating suggestions for exit nodes.
AllowedSuggestedExitNodes Key = "AllowedSuggestedExitNodes"
)
+
+// implicitDefinitions is a list of [setting.Definition] that will be registered
+// automatically by [settingDefinitions] as soon as the package needs to ready a policy.
+var implicitDefinitions = []*setting.Definition{
+ // Device policy settings
+ setting.NewDefinition(AllowedSuggestedExitNodes, setting.DeviceSetting, setting.StringListValue),
+ setting.NewDefinition(ApplyUpdates, setting.DeviceSetting, setting.PreferenceOptionValue),
+ setting.NewDefinition(CheckUpdates, setting.DeviceSetting, setting.PreferenceOptionValue),
+ setting.NewDefinition(ControlURL, setting.DeviceSetting, setting.StringValue),
+ setting.NewDefinition(DeviceSerialNumber, setting.DeviceSetting, setting.StringValue),
+ setting.NewDefinition(EnableIncomingConnections, setting.DeviceSetting, setting.PreferenceOptionValue),
+ setting.NewDefinition(EnableRunExitNode, setting.DeviceSetting, setting.PreferenceOptionValue),
+ setting.NewDefinition(EnableServerMode, setting.DeviceSetting, setting.PreferenceOptionValue),
+ setting.NewDefinition(EnableTailscaleDNS, setting.DeviceSetting, setting.PreferenceOptionValue),
+ setting.NewDefinition(EnableTailscaleSubnets, setting.DeviceSetting, setting.PreferenceOptionValue),
+ setting.NewDefinition(ExitNodeAllowLANAccess, setting.DeviceSetting, setting.PreferenceOptionValue),
+ setting.NewDefinition(ExitNodeID, setting.DeviceSetting, setting.StringValue),
+ setting.NewDefinition(ExitNodeIP, setting.DeviceSetting, setting.StringValue),
+ setting.NewDefinition(FlushDNSOnSessionUnlock, setting.DeviceSetting, setting.BooleanValue),
+ setting.NewDefinition(LogSCMInteractions, setting.DeviceSetting, setting.BooleanValue),
+ setting.NewDefinition(LogTarget, setting.DeviceSetting, setting.StringValue),
+ setting.NewDefinition(PostureChecking, setting.DeviceSetting, setting.PreferenceOptionValue),
+ setting.NewDefinition(Tailnet, setting.DeviceSetting, setting.StringValue),
+
+ // User policy settings
+ setting.NewDefinition(AdminConsoleVisibility, setting.UserSetting, setting.VisibilityValue),
+ setting.NewDefinition(AutoUpdateVisibility, setting.UserSetting, setting.VisibilityValue),
+ setting.NewDefinition(ExitNodeMenuVisibility, setting.UserSetting, setting.VisibilityValue),
+ setting.NewDefinition(KeyExpirationNoticeTime, setting.UserSetting, setting.DurationValue),
+ setting.NewDefinition(ManagedByCaption, setting.UserSetting, setting.StringValue),
+ setting.NewDefinition(ManagedByOrganizationName, setting.UserSetting, setting.StringValue),
+ setting.NewDefinition(ManagedByURL, setting.UserSetting, setting.StringValue),
+ setting.NewDefinition(NetworkDevicesVisibility, setting.UserSetting, setting.VisibilityValue),
+ setting.NewDefinition(PreferencesMenuVisibility, setting.UserSetting, setting.VisibilityValue),
+ setting.NewDefinition(ResetToDefaultsVisibility, setting.UserSetting, setting.VisibilityValue),
+ setting.NewDefinition(RunExitNodeVisibility, setting.UserSetting, setting.VisibilityValue),
+ setting.NewDefinition(SuggestedExitNodeVisibility, setting.UserSetting, setting.VisibilityValue),
+ setting.NewDefinition(TestMenuVisibility, setting.UserSetting, setting.VisibilityValue),
+ setting.NewDefinition(UpdateMenuVisibility, setting.UserSetting, setting.VisibilityValue),
+}
+
+func init() {
+ lazyinit.Defer(func() error {
+ // Avoid implicit [SettingDefinition] registration during tests.
+ // Each test should control which policy settings to register.
+ // Use [setting.SetDefinitionsForTest] to specify necessary definitions,
+ // or [setWellKnownSettingsForTest] to set implicit definitions for the test duration.
+ if testenv.InTest() {
+ return nil
+ }
+ for _, d := range implicitDefinitions {
+ setting.RegisterDefinition(d)
+ }
+ return nil
+ })
+}
+
+var implicitDefinitionMap lazy.SyncValue[setting.DefinitionMap]
+
+// WellKnownSettingDefinition returns a well-known, implicit setting definition by its key,
+// or an [ErrNoSuchKey] if a policy setting with the specified key does not exist
+// among implicit policy definitions.
+func WellKnownSettingDefinition(k Key) (*setting.Definition, error) {
+ m, err := implicitDefinitionMap.GetErr(func() (setting.DefinitionMap, error) {
+ return setting.DefinitionMapOf(implicitDefinitions)
+ })
+ if err != nil {
+ return nil, err
+ }
+ if d, ok := m[k]; ok {
+ return d, nil
+ }
+ return nil, ErrNoSuchKey
+}
+
+// setWellKnownSettingsForTest registers all implicit setting definitions
+// for the duration of the test.
+func setWellKnownSettingsForTest(tb lazy.TB) error {
+ return setting.SetDefinitionsForTest(tb, implicitDefinitions...)
+}
diff --git a/util/syspolicy/policy_keys_test.go b/util/syspolicy/policy_keys_test.go
new file mode 100644
index 000000000..4d3260f3e
--- /dev/null
+++ b/util/syspolicy/policy_keys_test.go
@@ -0,0 +1,95 @@
+// Copyright (c) Tailscale Inc & AUTHORS
+// SPDX-License-Identifier: BSD-3-Clause
+
+package syspolicy
+
+import (
+ "fmt"
+ "go/ast"
+ "go/parser"
+ "go/token"
+ "go/types"
+ "os"
+ "reflect"
+ "strconv"
+ "testing"
+
+ "tailscale.com/util/syspolicy/setting"
+)
+
+func TestKnownKeysRegistered(t *testing.T) {
+ keyConsts, err := listStringConsts[Key]("policy_keys.go")
+ if err != nil {
+ t.Fatalf("listStringConsts failed: %v", err)
+ }
+
+ m, err := setting.DefinitionMapOf(implicitDefinitions)
+ if err != nil {
+ t.Fatalf("definitionMapOf failed: %v", err)
+ }
+
+ for _, key := range keyConsts {
+ t.Run(string(key), func(t *testing.T) {
+ d := m[key]
+ if d == nil {
+ t.Fatalf("%q was not registered", key)
+ }
+ if d.Key() != key {
+ t.Fatalf("d.Key got: %s, want %s", d.Key(), key)
+ }
+ })
+ }
+}
+
+func TestNotAWellKnownSetting(t *testing.T) {
+ d, err := WellKnownSettingDefinition("TestSettingDoesNotExist")
+ if d != nil || err == nil {
+ t.Fatalf("got %v, %v; want nil, %v", d, err, ErrNoSuchKey)
+ }
+}
+
+func listStringConsts[T ~string](filename string) (map[string]T, error) {
+ fset := token.NewFileSet()
+ src, err := os.ReadFile(filename)
+ if err != nil {
+ return nil, err
+ }
+
+ f, err := parser.ParseFile(fset, filename, src, 0)
+ if err != nil {
+ return nil, err
+ }
+
+ consts := make(map[string]T)
+ typeName := reflect.TypeFor[T]().Name()
+ for _, d := range f.Decls {
+ g, ok := d.(*ast.GenDecl)
+ if !ok || g.Tok != token.CONST {
+ continue
+ }
+
+ for _, s := range g.Specs {
+ vs, ok := s.(*ast.ValueSpec)
+ if !ok || len(vs.Names) != len(vs.Values) {
+ continue
+ }
+ if typ, ok := vs.Type.(*ast.Ident); !ok || typ.Name != typeName {
+ continue
+ }
+
+ for i, n := range vs.Names {
+ lit, ok := vs.Values[i].(*ast.BasicLit)
+ if !ok {
+ return nil, fmt.Errorf("unexpected string literal: %v = %v", n.Name, types.ExprString(vs.Values[i]))
+ }
+ val, err := strconv.Unquote(lit.Value)
+ if err != nil {
+ return nil, fmt.Errorf("unexpected string literal: %v = %v", n.Name, lit.Value)
+ }
+ consts[n.Name] = T(val)
+ }
+ }
+ }
+
+ return consts, nil
+}
diff --git a/util/syspolicy/policy_keys_windows.go b/util/syspolicy/policy_keys_windows.go
deleted file mode 100644
index 5e9a71695..000000000
--- a/util/syspolicy/policy_keys_windows.go
+++ /dev/null
@@ -1,38 +0,0 @@
-// Copyright (c) Tailscale Inc & AUTHORS
-// SPDX-License-Identifier: BSD-3-Clause
-
-package syspolicy
-
-var stringKeys = []Key{
- ControlURL,
- LogTarget,
- Tailnet,
- ExitNodeID,
- ExitNodeIP,
- EnableIncomingConnections,
- EnableServerMode,
- ExitNodeAllowLANAccess,
- EnableTailscaleDNS,
- EnableTailscaleSubnets,
- AdminConsoleVisibility,
- NetworkDevicesVisibility,
- TestMenuVisibility,
- UpdateMenuVisibility,
- RunExitNodeVisibility,
- PreferencesMenuVisibility,
- ExitNodeMenuVisibility,
- AutoUpdateVisibility,
- ResetToDefaultsVisibility,
- KeyExpirationNoticeTime,
- PostureChecking,
- ManagedByOrganizationName,
- ManagedByCaption,
- ManagedByURL,
-}
-
-var boolKeys = []Key{
- LogSCMInteractions,
- FlushDNSOnSessionUnlock,
-}
-
-var uint64Keys = []Key{}
diff --git a/util/syspolicy/rsop/change_callbacks.go b/util/syspolicy/rsop/change_callbacks.go
new file mode 100644
index 000000000..e46ee38f6
--- /dev/null
+++ b/util/syspolicy/rsop/change_callbacks.go
@@ -0,0 +1,109 @@
+// Copyright (c) Tailscale Inc & AUTHORS
+// SPDX-License-Identifier: BSD-3-Clause
+
+package rsop
+
+import (
+ "reflect"
+ "slices"
+ "sync"
+ "time"
+
+ "tailscale.com/util/set"
+ "tailscale.com/util/syspolicy/internal/loggerx"
+ "tailscale.com/util/syspolicy/setting"
+)
+
+// Change represents a change from the Old to the New value of type T.
+type Change[T any] struct {
+ New, Old T
+}
+
+// PolicyChangeCallback is a function called whenever a policy changes.
+type PolicyChangeCallback func(*PolicyChange)
+
+// PolicyChange describes a policy change.
+type PolicyChange struct {
+ snapshots Change[*setting.Snapshot]
+}
+
+// New returns the [setting.Snapshot] after the change.
+func (c PolicyChange) New() *setting.Snapshot {
+ return c.snapshots.New
+}
+
+// Old returns the [setting.Snapshot] before the change.
+func (c PolicyChange) Old() *setting.Snapshot {
+ return c.snapshots.Old
+}
+
+// HasChanged reports whether a policy setting with the specified [setting.Key], has changed.
+func (c PolicyChange) HasChanged(key setting.Key) bool {
+ new, newErr := c.snapshots.New.GetErr(key)
+ old, oldErr := c.snapshots.Old.GetErr(key)
+ if newErr != nil && oldErr != nil {
+ return false
+ }
+ if newErr != nil || oldErr != nil {
+ return true
+ }
+ switch newVal := new.(type) {
+ case bool, uint64, string, setting.Visibility, setting.PreferenceOption, time.Duration:
+ return newVal != old
+ case []string:
+ if oldVal, ok := old.([]string); ok {
+ return slices.Equal(newVal, oldVal)
+ }
+ return false
+ default:
+ loggerx.Errorf("%q has an unsupported value type: %T", newVal)
+ return reflect.DeepEqual(new, old)
+ }
+}
+
+// policyChangeCallbacks are the callbacks to invoke when the resultant policy changes.
+// It is safe for concurrent use.
+type policyChangeCallbacks struct {
+ mu sync.RWMutex
+ cbs set.HandleSet[PolicyChangeCallback]
+}
+
+// Register adds the specified callback to be invoked whenever the policy changes.
+func (c *policyChangeCallbacks) Register(callback PolicyChangeCallback) (unregister func()) {
+ c.mu.Lock()
+ handle := c.cbs.Add(callback)
+ c.mu.Unlock()
+ return func() {
+ c.mu.Lock()
+ delete(c.cbs, handle)
+ c.mu.Unlock()
+ }
+}
+
+// Invoke calls the registered callback functions with the specified policy change info.
+func (c *policyChangeCallbacks) Invoke(snapshots Change[*setting.Snapshot]) {
+ var wg sync.WaitGroup
+ defer wg.Wait()
+
+ c.mu.RLock()
+ defer c.mu.RUnlock()
+
+ wg.Add(len(c.cbs))
+ change := &PolicyChange{snapshots: snapshots}
+ for _, cb := range c.cbs {
+ go func() {
+ defer wg.Done()
+ cb(change)
+ }()
+ }
+}
+
+// Close awaits the completion of active callbacks and prevents any further invocations.
+func (c *policyChangeCallbacks) Close() {
+ c.mu.Lock()
+ defer c.mu.Unlock()
+ if c.cbs != nil {
+ clear(c.cbs)
+ c.cbs = nil
+ }
+}
diff --git a/util/syspolicy/rsop/resultant_policy.go b/util/syspolicy/rsop/resultant_policy.go
new file mode 100644
index 000000000..9191f80cb
--- /dev/null
+++ b/util/syspolicy/rsop/resultant_policy.go
@@ -0,0 +1,698 @@
+// Copyright (c) Tailscale Inc & AUTHORS
+// SPDX-License-Identifier: BSD-3-Clause
+
+// Package rsop facilitates [source.Store] registration via [RegisterStore]
+// and provides access to the resultant policy merged from all registered sources
+// via [PolicyFor].
+package rsop
+
+import (
+ "errors"
+ "fmt"
+ "reflect"
+ "slices"
+ "sync"
+ "sync/atomic"
+ "time"
+
+ "tailscale.com/syncs"
+ "tailscale.com/types/lazy"
+ "tailscale.com/util/slicesx"
+ "tailscale.com/util/syspolicy/internal"
+ "tailscale.com/util/syspolicy/internal/lazyinit"
+ "tailscale.com/util/syspolicy/internal/loggerx"
+ "tailscale.com/util/syspolicy/setting"
+
+ "tailscale.com/util/syspolicy/source"
+)
+
+var errResultantPolicyClosed = errors.New("resultant policy closed")
+
+// The minimum and maximum wait times after detecting a policy change
+// before reloading the policy.
+// Policy changes occurring within [policyReloadMinDelay] of each other
+// will be batched together, resulting in a single policy reload
+// no later than [policyReloadMaxDelay] after the first detected change.
+// In other words, the resultant policy will be reloaded no more often than once
+// every 5 seconds, but at most 15 seconds after an underlying [source.Store]
+// has issued a policy change callback.
+// See [Policy.watchReload].
+const (
+ defaultPolicyReloadMinDelay = 5 * time.Second
+ defaultPolicyReloadMaxDelay = 15 * time.Second
+)
+
+// policyReloadMinDelay and policyReloadMaxDelay are test hooks.
+// Their values default to [defaultPolicyReloadMinDelay] and [defaultPolicyReloadMaxDelay].
+var (
+ policyReloadMinDelay, policyReloadMaxDelay lazy.SyncValue[time.Duration]
+)
+
+// Policy provides access to the current resultant [setting.Snapshot] for a given
+// scope and allows to reload it from the underlying [source.Store]s. It also allows to
+// subscribe and receive a callback whenever the resultant [setting.Snapshot] is
+// changed. It is safe for concurrent use.
+type Policy struct {
+ scope setting.PolicyScope
+
+ reloadCh chan reloadRequest // 1-buffered; written to when a policy reload is required
+ changeSourceCh chan sourceChangeRequest // written to to add a new or remove an existing source
+ closeCh chan struct{} // closed to signal that the Policy is being closed
+ doneCh chan struct{} // closed by closeInternal when watchReload exits
+
+ // resultant is the most recent version of the [setting.Snapshot] containing policy settings
+ // merged from all applicable sources.
+ resultant atomic.Pointer[setting.Snapshot]
+
+ changeCallbacks policyChangeCallbacks
+
+ mu sync.RWMutex
+ sources source.ReadableSources
+ closing bool // Close was called (even if we're still closing)
+}
+
+// newPolicy returns a new [Policy] for the specified [setting.PolicyScope]
+// that tracks changes and merges policy settings read from the specified sources.
+func newPolicy(scope setting.PolicyScope, sources ...*source.Source) (p *Policy, err error) {
+ readableSources := source.ReadableSources(make([]source.ReadableSource, len(sources)))
+ for i, s := range sources {
+ reader, err := s.Reader()
+ if err != nil {
+ return nil, fmt.Errorf("failed to get a store reader: %v", err)
+ }
+ session, err := reader.OpenSession()
+ if err != nil {
+ return nil, fmt.Errorf("failed to open a reading session: %v", err)
+ }
+
+ readableSource := source.ReadableSource{
+ Source: s,
+ ReadingSession: session,
+ }
+ readableSources[i] = readableSource
+ defer func() {
+ if err != nil {
+ readableSource.Close()
+ }
+ }()
+ }
+
+ // Sort policy sources by their precedence from lower to higher.
+ // For example, {UserPolicy},{ProfilePolicy},{DevicePolicy}.
+ readableSources.StableSort()
+
+ p = &Policy{
+ scope: scope,
+ sources: readableSources,
+ reloadCh: make(chan reloadRequest, 1),
+ changeSourceCh: make(chan sourceChangeRequest),
+ closeCh: make(chan struct{}),
+ doneCh: make(chan struct{}),
+ }
+ if err := p.start(); err != nil {
+ return nil, err
+ }
+ return p, nil
+}
+
+// IsValid reports whether p is in a valid state and has not been closed.
+func (p *Policy) IsValid() bool {
+ select {
+ case <-p.closeCh:
+ return false
+ default:
+ return true
+ }
+}
+
+// Scope returns the [setting.PolicyScope] that this resultant policy applies to.
+func (p *Policy) Scope() setting.PolicyScope {
+ return p.scope
+}
+
+// Get returns the most recent resultant [setting.Snapshot].
+func (p *Policy) Get() *setting.Snapshot {
+ return p.resultant.Load()
+}
+
+// RegisterChangeCallback adds a function to be called whenever the resultant
+// policy changes. The returned function can be used to unregister the callback.
+func (p *Policy) RegisterChangeCallback(callback PolicyChangeCallback) (unregister func()) {
+ return p.changeCallbacks.Register(callback)
+}
+
+// Reload synchronously re-reads policy settings from the underlying policy
+// [source.Store], constructing a new merged [setting.Snapshot] even if the policy remains
+// unchanged. In most scenarios, there's no need to re-read the policy manually.
+// Instead, it is recommended to register a policy change callback, or to use
+// the most recent [setting.Snapshot] returned by the [Policy.Get] method.
+func (p *Policy) Reload() (*setting.Snapshot, error) {
+ return p.reload(true)
+}
+
+// reload is like Reload, but allows to specify whether to re-read policy settings
+// from unchanged policy sources.
+func (p *Policy) reload(force bool) (*setting.Snapshot, error) {
+ respCh := make(chan reloadResponse, 1)
+ select {
+ case p.reloadCh <- reloadRequest{force: force, respCh: respCh}:
+ // continue
+ case <-p.closeCh:
+ return nil, errResultantPolicyClosed
+ }
+ select {
+ case resp := <-respCh:
+ return resp.policy, resp.err
+ case <-p.closeCh:
+ return nil, errResultantPolicyClosed
+ }
+}
+
+// Done returns a channel that is closed when the [Policy] is closed.
+func (p *Policy) Done() <-chan struct{} {
+ return p.doneCh
+}
+
+func (p *Policy) start() error {
+ if _, err := p.reloadNow(false); err != nil {
+ return err
+ }
+ go p.watchPolicyChanges()
+ go p.watchReload()
+ return nil
+}
+
+// readAndMerge reads and merges policy settings from the underlying sources,
+// returning a [setting.Snapshot] with the merged result.
+// If the force parameter is true, it re-reads policy settings from each store
+// even if no policy change was observed, and returns an error if the read
+// operation fails.
+func (p *Policy) readAndMerge(force bool) (*setting.Snapshot, error) {
+ p.mu.RLock()
+ defer p.mu.RUnlock()
+ // Start with an empty policy in the target scope.
+ resultant := setting.NewSnapshot(nil, setting.SummaryWith(p.scope))
+ // Then merge policy settings from all sources.
+ // Policy sources with the highest precedence (e.g., the device policy) are merged last,
+ // overriding any conflicting policy settings with lower precedence.
+ for _, s := range p.sources {
+ var policy *setting.Snapshot
+ if force {
+ var err error
+ if policy, err = s.ReadSettings(); err != nil {
+ return nil, err
+ }
+ } else {
+ policy = s.GetSettings()
+ }
+ resultant = setting.MergeSnapshots(resultant, policy)
+ }
+ return resultant, nil
+}
+
+// reloadAsync requests an asynchronous background policy reload.
+// The policy will be reloaded no later than in [policyReloadMaxDelay].
+func (p *Policy) reloadAsync() {
+ select {
+ case p.reloadCh <- reloadRequest{}:
+ // Sent.
+ default:
+ // A reload request is already en route.
+ }
+}
+
+// reloadNow loads and merges policies from all sources, updating the resultant policy.
+// If the force parameter is true, it forcibly reloads policies
+// from the underlying policy store, even if no policy changes were detected.
+//
+// Except for the initial policy reload during the [Policy] creation,
+// this method should only be called from the [Policy.watchReload] goroutine.
+func (p *Policy) reloadNow(force bool) (*setting.Snapshot, error) {
+ new, err := p.readAndMerge(force)
+ if err != nil {
+ return nil, err
+ }
+ old := p.resultant.Swap(new)
+ // A nil old value indicates the initial policy load rather than a policy change.
+ // Additionally, we should not invoke the policy change callbacks unless the
+ // policy items have actually changed.
+ if old != nil && !old.EqualItems(new) {
+ snapshots := Change[*setting.Snapshot]{New: new, Old: old}
+ p.changeCallbacks.Invoke(snapshots)
+ }
+ return new, nil
+}
+
+// AddSource adds the specified source to the list of sources used by p,
+// and triggers a synchronous policy refresh. It returns an error
+// if the source is not a valid source for this resultant policy,
+// or if the resultant policy is being closed,
+// or if policy refresh fails with an error.
+func (p *Policy) AddSource(source *source.Source) error {
+ return p.changeSource(source, nil)
+}
+
+// RemoveSource removes the specified source from the list of sources used by p,
+// and triggers a synchronous policy refresh. It returns an error if the
+// resultant policy is being closed, or if policy refresh fails with an error.
+func (p *Policy) RemoveSource(source *source.Source) error {
+ return p.changeSource(nil, source)
+}
+
+// ReplaceSource replaces the old source with the new source atomically,
+// and triggers a synchronous policy refresh. It returns an error
+// if the source is not a valid source for this resultant policy,
+// or if the resultant policy is being closed,
+// or if policy refresh fails with an error.
+func (p *Policy) ReplaceSource(old, new *source.Source) error {
+ return p.changeSource(new, old)
+}
+
+func (p *Policy) changeSource(toAdd, toRemove *source.Source) error {
+ if toAdd == toRemove {
+ return nil
+ }
+ if toAdd != nil && !p.scope.IsWithinOf(toAdd.Scope()) {
+ return errors.New("scope mismatch")
+ }
+ respCh := make(chan error, 1)
+ req := sourceChangeRequest{toAdd, toRemove, respCh}
+ select {
+ case p.changeSourceCh <- req:
+ return <-respCh
+ case <-p.closeCh:
+ return errResultantPolicyClosed
+ }
+}
+
+// watchPolicyChanges awaits a policy change notification from any of the sources
+// and calls reloadAsync whenever a notification is received.
+func (p *Policy) watchPolicyChanges() {
+ const (
+ closeIdx = iota
+ changeSourceIdx
+ policyChangedOffset
+ )
+
+ // The cases are Close, ChangeSource, PolicyChanged[0],...,PolicyChanged[N-1].
+ p.mu.RLock()
+ cases := make([]reflect.SelectCase, len(p.sources)+policyChangedOffset)
+ // Add the PolicyChanged[N] cases.
+ for i, source := range p.sources {
+ cases[i+policyChangedOffset] = reflect.SelectCase{Dir: reflect.SelectRecv, Chan: reflect.ValueOf(source.PolicyChanged())}
+ }
+ // Add the Close case.
+ cases[closeIdx] = reflect.SelectCase{Dir: reflect.SelectRecv, Chan: reflect.ValueOf(p.closeCh)}
+ // Add the ChangeSource case.
+ cases[changeSourceIdx] = reflect.SelectCase{Dir: reflect.SelectRecv, Chan: reflect.ValueOf(p.changeSourceCh)}
+ p.mu.RUnlock()
+
+ for {
+ switch chosen, recv, ok := reflect.Select(cases); chosen {
+ case closeIdx: // Close
+ // Exit the watch as the closeCh was closed, indicating that
+ // the [Policy] is being closed.
+ return
+ case changeSourceIdx: // ChangeSource
+ // We've received a source change request from one of the AddSource,
+ // RemoveSource, or ReplaceSource methods, meaning that we need to:
+ // - Open a reader session if a new source is being added;
+ // - Update the p.sources slice;
+ // - Update the cases slice;
+ // - Trigger a synchronous policy reload;
+ // - Report an error, if any, back to the caller.
+ req := recv.Interface().(sourceChangeRequest)
+ needClose, err := func() (close bool, err error) {
+ p.mu.Lock()
+ defer p.mu.Unlock()
+ if req.toAdd != nil {
+ if !p.sources.Contains(req.toAdd) {
+ reader, err := req.toAdd.Reader()
+ if err != nil {
+ return false, fmt.Errorf("failed to get a store reader: %v", err)
+ }
+ session, err := reader.OpenSession()
+ if err != nil {
+ return false, fmt.Errorf("failed to open a reading session: %v", err)
+ }
+
+ addAt := p.sources.InsertionIndexOf(req.toAdd)
+ toAdd := source.ReadableSource{
+ Source: req.toAdd,
+ ReadingSession: session,
+ }
+ p.sources = slices.Insert(p.sources, addAt, toAdd)
+ newCase := reflect.SelectCase{Dir: reflect.SelectRecv, Chan: reflect.ValueOf(toAdd.PolicyChanged())}
+ caseIndex := addAt + policyChangedOffset
+ cases = slices.Insert(cases, caseIndex, newCase)
+ }
+ }
+ if req.toDelete != nil {
+ if deleteAt := p.sources.IndexOf(req.toDelete); deleteAt != -1 {
+ p.sources.DeleteAt(deleteAt)
+ caseIndex := deleteAt + policyChangedOffset
+ cases = slices.Delete(cases, caseIndex, caseIndex+1)
+ }
+ }
+ return len(p.sources) == 0, nil
+ }()
+ if err == nil {
+ if needClose {
+ // Close the resultant policy if the last policy source was deleted.
+ p.Close()
+ } else {
+ // Otherwise, reload the policy synchronously.
+ _, err = p.reload(false)
+ }
+ }
+ req.respCh <- err
+ default: // PolicyChanged[N]
+ if !ok {
+ // One of the PolicyChanged channels was closed, indicating that
+ // the corresponding [source.Source] is no longer valid.
+ // We can no longer keep this [Policy] up to date
+ // and should close it.
+ p.Close()
+ return
+ }
+
+ // One of the PolicyChanged channels was signaled.
+ // We should request an asynchronous policy reload.
+ p.reloadAsync()
+ }
+ }
+}
+
+// watchReload processes incoming synchronous and asynchronous policy reload requests.
+// Synchronous requests (with a non-nil respCh) are served immediately.
+// Asynchronous requests are debounced and throttled: they are executed at least
+// [policyReloadMinDelay] after the last request, but no later than [policyReloadMaxDelay]
+// after the first request in a batch.
+func (p *Policy) watchReload() {
+ force := false // whether a forced refresh was requested
+ var delayCh, timeoutCh <-chan time.Time
+ reload := func(respCh chan<- reloadResponse) {
+ delayCh, timeoutCh = nil, nil
+ policy, err := p.reloadNow(force)
+ if err != nil {
+ loggerx.Errorf("%v policy reload failed: %v\n", p.scope, err)
+ }
+ if respCh != nil {
+ respCh <- reloadResponse{policy: policy, err: err}
+ }
+ force = false
+ }
+
+loop:
+ for {
+ select {
+ case req := <-p.reloadCh:
+ if req.force {
+ force = true
+ }
+ if req.respCh != nil {
+ reload(req.respCh)
+ continue
+ }
+ if delayCh == nil {
+ timeoutCh = time.After(policyReloadMaxDelay.Get(func() time.Duration { return defaultPolicyReloadMaxDelay }))
+ }
+ delayCh = time.After(policyReloadMinDelay.Get(func() time.Duration { return defaultPolicyReloadMinDelay }))
+ case <-delayCh:
+ reload(nil)
+ case <-timeoutCh:
+ reload(nil)
+ case <-p.closeCh:
+ break loop
+ }
+ }
+
+ p.closeInternal()
+}
+
+func (p *Policy) closeInternal() {
+ p.mu.Lock()
+ defer p.mu.Unlock()
+ p.sources.Close()
+ p.changeCallbacks.Close()
+ close(p.doneCh)
+}
+
+// Close initiates the closing of the resultant policy.
+// The actual closing is performed by closeInternal when watchReload exits,
+// and the Done() channel is closed when closeInternal finishes.
+func (p *Policy) Close() {
+ p.mu.Lock()
+ defer p.mu.Unlock()
+ if p.closing {
+ return
+ }
+ p.closing = true
+ close(p.closeCh)
+}
+
+// sourceChangeRequest is a request to add and/or remove source from a [Policy].
+type sourceChangeRequest struct {
+ toAdd, toDelete *source.Source
+ respCh chan<- error
+}
+
+// reloadRequest describes a policy reload request.
+type reloadRequest struct {
+ // force triggers an immediate synchronous policy reload,
+ // reloading the policy regardless of whether a policy change was detected.
+ force bool
+ // respCh is an optional channel. If non-nil, it makes the reload request
+ // synchronous and receives the result.
+ respCh chan<- reloadResponse
+}
+
+type reloadResponse struct {
+ policy *setting.Snapshot
+ err error
+}
+
+var (
+ policyMu sync.RWMutex
+ policySources []*source.Source
+ resultantPolicies []*Policy
+
+ resultantPolicyLRU [setting.MaxSettingScope + 1]syncs.AtomicValue[*Policy] // by [Scope.Kind]
+)
+
+// registerSource registers the specified [source.Source] to be used by the package.
+// It updates existing [Policy]s returned by [PolicyFor] to use this source if
+// they are within the source's [setting.PolicyScope].
+func registerSource(source *source.Source) error {
+ policyMu.Lock()
+ defer policyMu.Unlock()
+ if slices.Contains(policySources, source) {
+ return nil
+ }
+ policySources = append(policySources, source)
+ return forEachResultantPolicyLocked(func(policy *Policy) error {
+ if !policy.Scope().IsWithinOf(source.Scope()) {
+ return nil
+ }
+ return policy.AddSource(source)
+ })
+}
+
+// replaceSource is like [unregisterSource](old) followed by [registerSource](new),
+// but is atomic from the perspective of each [Policy].
+func replaceSource(old, new *source.Source) error {
+ policyMu.Lock()
+ defer policyMu.Unlock()
+ oldIndex := slices.Index(policySources, old)
+ if oldIndex == -1 {
+ return fmt.Errorf("the source is not registered: %v", old)
+ }
+ policySources[oldIndex] = new
+ return forEachResultantPolicyLocked(func(policy *Policy) error {
+ if policy.Scope().IsWithinOf(old.Scope()) || policy.Scope().IsWithinOf(new.Scope()) {
+ return nil
+ }
+ return policy.ReplaceSource(old, new)
+ })
+}
+
+// unregisterSource unregisters the specified [source.Source],
+// so that it won't be used by any new or existing [Policy].
+func unregisterSource(source *source.Source) error {
+ policyMu.Lock()
+ defer policyMu.Unlock()
+ index := slices.Index(policySources, source)
+ if index == -1 {
+ return nil
+ }
+ policySources = slices.Delete(policySources, index, index+1)
+ return forEachResultantPolicyLocked(func(policy *Policy) error {
+ if !policy.Scope().IsWithinOf(source.Scope()) {
+ return nil
+ }
+ return policy.RemoveSource(source)
+ })
+}
+
+// forEachResultantPolicyLocked calls fn for every [Policy] in [resultantPolicies].
+// It accumulates the returned errors, except for [errResultantPolicyClosed],
+// and returns an error that wraps all errors returned by fn.
+// The [policyMu] mutex must be held while this function is executed.
+func forEachResultantPolicyLocked(fn func(p *Policy) error) error {
+ var errs []error
+ for _, policy := range resultantPolicies {
+ err := fn(policy)
+ if err != nil && !errors.Is(err, errResultantPolicyClosed) {
+ errs = append(errs, err)
+ }
+ }
+ return errors.Join(errs...)
+}
+
+// PolicyFor returns the [Policy] for the specified scope,
+// creating one from the registered [source.Store]s if it does not exist.
+func PolicyFor(scope setting.PolicyScope) (*Policy, error) {
+ if err := lazyinit.Do(); err != nil {
+ return nil, err
+ }
+ if policy := resultantPolicyLRU[scope.Kind()].Load(); policy != nil && policy.Scope() == scope && policy.IsValid() {
+ return policy, nil
+ }
+ return policyForSlow(scope)
+}
+
+func policyForSlow(scope setting.PolicyScope) (policy *Policy, err error) {
+ defer func() {
+ if policy != nil {
+ resultantPolicyLRU[scope.Kind()].Store(policy)
+ }
+ }()
+
+ policyMu.RLock()
+ if policy, ok := findPolicyByScopeLocked(scope); ok {
+ policyMu.RUnlock()
+ return policy, nil
+ }
+ policyMu.RUnlock()
+
+ policyMu.Lock()
+ defer policyMu.Unlock()
+ if policy, ok := findPolicyByScopeLocked(scope); ok {
+ return policy, nil
+ }
+ sources := slicesx.Filter(nil, policySources, func(source *source.Source) bool {
+ return scope.IsWithinOf(source.Scope())
+ })
+ policy, err = newPolicy(scope, sources...)
+ if err != nil {
+ return nil, err
+ }
+ resultantPolicies = append(resultantPolicies, policy)
+ go func() {
+ <-policy.Done()
+ deletePolicy(policy)
+ }()
+ return policy, nil
+}
+
+// findPolicyByScopeLocked returns a policy with the specified scope and true if
+// one exists, otherwise it returns nil, false.
+// [policyMu] must be held.
+func findPolicyByScopeLocked(target setting.PolicyScope) (policy *Policy, ok bool) {
+ for _, policy := range resultantPolicies {
+ if policy.Scope() == target && policy.IsValid() {
+ return policy, true
+ }
+ }
+ return nil, false
+}
+
+// deletePolicy deletes the specified resultant policy from the [resultantPolicies] list.
+func deletePolicy(policy *Policy) {
+ policyMu.Lock()
+ if i := slices.Index(resultantPolicies, policy); i != -1 {
+ resultantPolicies = slices.Delete(resultantPolicies, i, i+1)
+ }
+ resultantPolicyLRU[policy.Scope().Kind()].CompareAndSwap(policy, nil)
+ policyMu.Unlock()
+}
+
+// ErrAlreadyConsumed is the error returned when [StoreRegistration.ReplaceStore]
+// or [StoreRegistration.Unregister] is called more than once.
+var ErrAlreadyConsumed = errors.New("the store registration is no longer valid")
+
+// StoreRegistration is a [source.Store] registered for use in the specified scope.
+// It can be used to unregister the store, or replace it with another one.
+type StoreRegistration struct {
+ source *source.Source
+ consumed atomic.Uint32
+ m sync.Mutex
+}
+
+// RegisterStore registers a new policy [source.Store] with the specified name and [setting.PolicyScope].
+func RegisterStore(name string, scope setting.PolicyScope, store source.Store) (*StoreRegistration, error) {
+ return newStoreRegistration(name, scope, store)
+}
+
+// RegisterStoreForTest is like [RegisterStore], but unregisters the store when
+// tb and all its subtests complete.
+func RegisterStoreForTest(tb internal.TB, name string, scope setting.PolicyScope, store source.Store) (*StoreRegistration, error) {
+ reg, err := RegisterStore(name, scope, store)
+ if err == nil {
+ tb.Cleanup(func() {
+ if err := reg.Unregister(); err != nil && !errors.Is(err, ErrAlreadyConsumed) {
+ tb.Fatalf("Unregister failed: %v", err)
+ }
+ })
+ }
+ return reg, err // may be nil or non-nil
+}
+
+func newStoreRegistration(name string, scope setting.PolicyScope, store source.Store) (*StoreRegistration, error) {
+ source := source.NewSource(name, scope, store)
+ if err := registerSource(source); err != nil {
+ return nil, err
+ }
+ return &StoreRegistration{source: source}, nil
+}
+
+// ReplaceStore replaces the registered store with the new one,
+// returning a new [StoreRegistration] or an error.
+func (r *StoreRegistration) ReplaceStore(new source.Store) (*StoreRegistration, error) {
+ var res *StoreRegistration
+ err := r.consume(func() error {
+ newSource := source.NewSource(r.source.Name(), r.source.Scope(), new)
+ if err := replaceSource(r.source, newSource); err != nil {
+ return err
+ }
+ res = &StoreRegistration{source: newSource}
+ return nil
+ })
+ return res, err
+}
+
+// Unregister reverts the registration.
+func (r *StoreRegistration) Unregister() error {
+ return r.consume(func() error { return unregisterSource(r.source) })
+}
+
+// consume invokes fn, consuming r if no error is returned.
+// It returns [ErrAlreadyConsumed] on subsequent calls after the first successful call.
+func (r *StoreRegistration) consume(fn func() error) (err error) {
+ if r.consumed.Load() != 0 {
+ return ErrAlreadyConsumed
+ }
+ return r.consumeSlow(fn)
+}
+
+func (r *StoreRegistration) consumeSlow(fn func() error) (err error) {
+ r.m.Lock()
+ defer r.m.Unlock()
+ if r.consumed.Load() != 0 {
+ return ErrAlreadyConsumed
+ }
+ if err = fn(); err == nil {
+ r.consumed.Store(1)
+ }
+ return err // may be nil or non-nil
+}
diff --git a/util/syspolicy/rsop/resultant_policy_test.go b/util/syspolicy/rsop/resultant_policy_test.go
new file mode 100644
index 000000000..744d4bfe9
--- /dev/null
+++ b/util/syspolicy/rsop/resultant_policy_test.go
@@ -0,0 +1,368 @@
+// Copyright (c) Tailscale Inc & AUTHORS
+// SPDX-License-Identifier: BSD-3-Clause
+
+package rsop
+
+import (
+ "slices"
+ "sort"
+ "testing"
+
+ "tailscale.com/util/syspolicy/setting"
+
+ "tailscale.com/util/syspolicy/source"
+)
+
+func TestRegisterSourceAndGetResultantPolicy(t *testing.T) {
+ type sourceConfig struct {
+ name string
+ scope setting.PolicyScope
+ settingKey setting.Key
+ settingValue string
+ wantEffective bool
+ }
+ tests := []struct {
+ name string
+ scope setting.PolicyScope
+ initialSources []sourceConfig
+ additionalSources []sourceConfig
+ wantSnapshot *setting.Snapshot
+ }{
+ {
+ name: "DevicePolicy/NoSources",
+ scope: setting.DeviceScope,
+ wantSnapshot: setting.NewSnapshot(nil, setting.DeviceScope),
+ },
+ {
+ name: "UserScope/NoSources",
+ scope: setting.CurrentUserScope,
+ wantSnapshot: setting.NewSnapshot(nil, setting.CurrentUserScope),
+ },
+ {
+ name: "DevicePolicy/OneInitialSource",
+ scope: setting.DeviceScope,
+ initialSources: []sourceConfig{
+ {
+ name: "TestSourceA",
+ scope: setting.DeviceScope,
+ settingKey: "TestKeyA",
+ settingValue: "TestValueA",
+ wantEffective: true,
+ },
+ },
+ wantSnapshot: setting.NewSnapshot(map[setting.Key]setting.RawItem{
+ "TestKeyA": setting.RawItemWith("TestValueA", nil, setting.NewNamedOrigin("TestSourceA", setting.DeviceScope)),
+ }, setting.NewNamedOrigin("TestSourceA", setting.DeviceScope)),
+ },
+ {
+ name: "DevicePolicy/OneAdditionalSource",
+ scope: setting.DeviceScope,
+ additionalSources: []sourceConfig{
+ {
+ name: "TestSourceA",
+ scope: setting.DeviceScope,
+ settingKey: "TestKeyA",
+ settingValue: "TestValueA",
+ wantEffective: true,
+ },
+ },
+ wantSnapshot: setting.NewSnapshot(map[setting.Key]setting.RawItem{
+ "TestKeyA": setting.RawItemWith("TestValueA", nil, setting.NewNamedOrigin("TestSourceA", setting.DeviceScope)),
+ }, setting.NewNamedOrigin("TestSourceA", setting.DeviceScope)),
+ },
+ {
+ name: "DevicePolicy/ManyInitialSources/NoConflicts",
+ scope: setting.DeviceScope,
+ initialSources: []sourceConfig{
+ {
+ name: "TestSourceA",
+ scope: setting.DeviceScope,
+ settingKey: "TestKeyA",
+ settingValue: "TestValueA",
+ wantEffective: true,
+ },
+ {
+ name: "TestSourceB",
+ scope: setting.DeviceScope,
+ settingKey: "TestKeyB",
+ settingValue: "TestValueB",
+ wantEffective: true,
+ },
+ {
+ name: "TestSourceC",
+ scope: setting.DeviceScope,
+ settingKey: "TestKeyC",
+ settingValue: "TestValueC",
+ wantEffective: true,
+ },
+ },
+ wantSnapshot: setting.NewSnapshot(map[setting.Key]setting.RawItem{
+ "TestKeyA": setting.RawItemWith("TestValueA", nil, setting.NewNamedOrigin("TestSourceA", setting.DeviceScope)),
+ "TestKeyB": setting.RawItemWith("TestValueB", nil, setting.NewNamedOrigin("TestSourceB", setting.DeviceScope)),
+ "TestKeyC": setting.RawItemWith("TestValueC", nil, setting.NewNamedOrigin("TestSourceC", setting.DeviceScope)),
+ }, setting.DeviceScope),
+ },
+ {
+ name: "DevicePolicy/ManyInitialSources/Conflicts",
+ scope: setting.DeviceScope,
+ initialSources: []sourceConfig{
+ {
+ name: "TestSourceA",
+ scope: setting.DeviceScope,
+ settingKey: "TestKeyA",
+ settingValue: "TestValueA",
+ wantEffective: true,
+ },
+ {
+ name: "TestSourceB",
+ scope: setting.DeviceScope,
+ settingKey: "TestKeyB",
+ settingValue: "TestValueB",
+ wantEffective: true,
+ },
+ {
+ name: "TestSourceC",
+ scope: setting.DeviceScope,
+ settingKey: "TestKeyA",
+ settingValue: "TestValueC",
+ wantEffective: true,
+ },
+ },
+ wantSnapshot: setting.NewSnapshot(map[setting.Key]setting.RawItem{
+ "TestKeyA": setting.RawItemWith("TestValueC", nil, setting.NewNamedOrigin("TestSourceC", setting.DeviceScope)),
+ "TestKeyB": setting.RawItemWith("TestValueB", nil, setting.NewNamedOrigin("TestSourceB", setting.DeviceScope)),
+ }, setting.DeviceScope),
+ },
+ {
+ name: "DevicePolicy/MixedSources/Conflicts",
+ scope: setting.DeviceScope,
+ initialSources: []sourceConfig{
+ {
+ name: "TestSourceA",
+ scope: setting.DeviceScope,
+ settingKey: "TestKeyA",
+ settingValue: "TestValueA",
+ wantEffective: true,
+ },
+ {
+ name: "TestSourceB",
+ scope: setting.DeviceScope,
+ settingKey: "TestKeyB",
+ settingValue: "TestValueB",
+ wantEffective: true,
+ },
+ {
+ name: "TestSourceC",
+ scope: setting.DeviceScope,
+ settingKey: "TestKeyA",
+ settingValue: "TestValueC",
+ wantEffective: true,
+ },
+ },
+ additionalSources: []sourceConfig{
+ {
+ name: "TestSourceD",
+ scope: setting.DeviceScope,
+ settingKey: "TestKeyA",
+ settingValue: "TestValueD",
+ wantEffective: true,
+ },
+ {
+ name: "TestSourceE",
+ scope: setting.DeviceScope,
+ settingKey: "TestKeyC",
+ settingValue: "TestValueE",
+ wantEffective: true,
+ },
+ {
+ name: "TestSourceF",
+ scope: setting.DeviceScope,
+ settingKey: "TestKeyA",
+ settingValue: "TestValueF",
+ wantEffective: true,
+ },
+ },
+ wantSnapshot: setting.NewSnapshot(map[setting.Key]setting.RawItem{
+ "TestKeyA": setting.RawItemWith("TestValueF", nil, setting.NewNamedOrigin("TestSourceF", setting.DeviceScope)),
+ "TestKeyB": setting.RawItemWith("TestValueB", nil, setting.NewNamedOrigin("TestSourceB", setting.DeviceScope)),
+ "TestKeyC": setting.RawItemWith("TestValueE", nil, setting.NewNamedOrigin("TestSourceE", setting.DeviceScope)),
+ }, setting.DeviceScope),
+ },
+ {
+ name: "UserScope/Init-DeviceSource",
+ scope: setting.CurrentUserScope,
+ initialSources: []sourceConfig{
+ {
+ name: "TestSourceDevice",
+ scope: setting.DeviceScope,
+ settingKey: "TestKeyA",
+ settingValue: "DeviceValue",
+ wantEffective: true,
+ },
+ },
+ wantSnapshot: setting.NewSnapshot(map[setting.Key]setting.RawItem{
+ "TestKeyA": setting.RawItemWith("DeviceValue", nil, setting.NewNamedOrigin("TestSourceDevice", setting.DeviceScope)),
+ }, setting.CurrentUserScope, setting.NewNamedOrigin("TestSourceDevice", setting.DeviceScope)),
+ },
+ {
+ name: "UserScope/Init-DeviceSource/Add-UserSource",
+ scope: setting.CurrentUserScope,
+ initialSources: []sourceConfig{
+ {
+ name: "TestSourceDevice",
+ scope: setting.DeviceScope,
+ settingKey: "TestKeyA",
+ settingValue: "DeviceValue",
+ wantEffective: true,
+ },
+ },
+ additionalSources: []sourceConfig{
+ {
+ name: "TestSourceUser",
+ scope: setting.CurrentUserScope,
+ settingKey: "TestKeyB",
+ settingValue: "UserValue",
+ wantEffective: true,
+ },
+ },
+ wantSnapshot: setting.NewSnapshot(map[setting.Key]setting.RawItem{
+ "TestKeyA": setting.RawItemWith("DeviceValue", nil, setting.NewNamedOrigin("TestSourceDevice", setting.DeviceScope)),
+ "TestKeyB": setting.RawItemWith("UserValue", nil, setting.NewNamedOrigin("TestSourceUser", setting.CurrentUserScope)),
+ }, setting.CurrentUserScope),
+ },
+ {
+ name: "UserScope/Init-DeviceSource/Add-UserSource-and-ProfileSource",
+ scope: setting.CurrentUserScope,
+ initialSources: []sourceConfig{
+ {
+ name: "TestSourceDevice",
+ scope: setting.DeviceScope,
+ settingKey: "TestKeyA",
+ settingValue: "DeviceValue",
+ wantEffective: true,
+ },
+ },
+ additionalSources: []sourceConfig{
+ {
+ name: "TestSourceProfile",
+ scope: setting.CurrentProfileScope,
+ settingKey: "TestKeyB",
+ settingValue: "ProfileValue",
+ wantEffective: true,
+ },
+ {
+ name: "TestSourceUser",
+ scope: setting.CurrentUserScope,
+ settingKey: "TestKeyB",
+ settingValue: "UserValue",
+ wantEffective: true,
+ },
+ },
+ wantSnapshot: setting.NewSnapshot(map[setting.Key]setting.RawItem{
+ "TestKeyA": setting.RawItemWith("DeviceValue", nil, setting.NewNamedOrigin("TestSourceDevice", setting.DeviceScope)),
+ "TestKeyB": setting.RawItemWith("ProfileValue", nil, setting.NewNamedOrigin("TestSourceProfile", setting.CurrentProfileScope)),
+ }, setting.CurrentUserScope),
+ },
+ {
+ name: "DevicePolicy/User-Source-does-not-apply",
+ scope: setting.DeviceScope,
+ initialSources: []sourceConfig{
+ {
+ name: "TestSourceDevice",
+ scope: setting.DeviceScope,
+ settingKey: "TestKeyA",
+ settingValue: "DeviceValue",
+ wantEffective: true,
+ },
+ },
+ additionalSources: []sourceConfig{
+ {
+ name: "TestSourceUser",
+ scope: setting.CurrentUserScope,
+ settingKey: "TestKeyA",
+ settingValue: "UserValue",
+ wantEffective: false, // Registering a user source should have no impact on the device policy.
+ },
+ },
+ wantSnapshot: setting.NewSnapshot(map[setting.Key]setting.RawItem{
+ "TestKeyA": setting.RawItemWith("DeviceValue", nil, setting.NewNamedOrigin("TestSourceDevice", setting.DeviceScope)),
+ }, setting.NewNamedOrigin("TestSourceDevice", setting.DeviceScope)),
+ },
+ }
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ // Register all settings that we use in this test.
+ var definitions []*setting.Definition
+ for _, source := range slices.Concat(tt.initialSources, tt.additionalSources) {
+ definitions = append(definitions, setting.NewDefinition(source.settingKey, tt.scope.Kind(), setting.StringValue))
+ }
+ if err := setting.SetDefinitionsForTest(t, definitions...); err != nil {
+ t.Fatalf("SetDefinitionsForTest failed: %v", err)
+ }
+
+ // Add the initial policy sources.
+ var wantSources []*source.Source
+ for _, s := range tt.initialSources {
+ store := source.NewTestStoreOf(t, source.TestSettingOf(s.settingKey, s.settingValue))
+ source := source.NewSource(s.name, s.scope, store)
+ if err := registerSource(source); err != nil {
+ t.Fatalf("failed to register policy source: %v", source)
+ }
+ if s.wantEffective {
+ wantSources = append(wantSources, source)
+ }
+ t.Cleanup(func() { unregisterSource(source) })
+ }
+
+ // Retrieve the resultant policy.
+ policy, err := resultantPolicyForTest(t, tt.scope)
+ if err != nil {
+ t.Fatalf("failed to get resultant policy for %v", tt.scope)
+ }
+
+ // Add additional setting sources one by one, and check the policy settings at each step.
+ for _, s := range tt.additionalSources {
+ store := source.NewTestStoreOf(t, source.TestSettingOf(s.settingKey, s.settingValue))
+ source := source.NewSource(s.name, s.scope, store)
+ if err := registerSource(source); err != nil {
+ t.Fatalf("failed to register additional policy source: %v", source)
+ }
+ if s.wantEffective {
+ wantSources = append(wantSources, source)
+ }
+ t.Cleanup(func() { unregisterSource(source) })
+ }
+
+ sort.SliceStable(wantSources, func(i, j int) bool {
+ return wantSources[i].Compare(wantSources[j]) < 0
+ })
+ gotSources := make([]*source.Source, len(policy.sources))
+ for i, s := range policy.sources {
+ gotSources[i] = s.Source
+ }
+ if !slices.Equal(gotSources, wantSources) {
+ t.Errorf("Sources: got %v; want %v", gotSources, wantSources)
+ }
+
+ // Verify the final resultant settings snapshots.
+ if got := policy.Get(); !got.Equal(tt.wantSnapshot) {
+ t.Errorf("Snapshot: got %v; want %v", got, tt.wantSnapshot)
+ }
+ })
+ }
+}
+
+// resultantPolicyForTest is like [resultantPolicyFor], but it deletes the policy
+// when tb and all its subtests complete.
+func resultantPolicyForTest(tb testing.TB, target setting.PolicyScope) (*Policy, error) {
+ policy, err := PolicyFor(target)
+ if err != nil {
+ return nil, err
+ }
+ tb.Cleanup(func() {
+ policy.Close()
+ <-policy.Done()
+ deletePolicy(policy)
+ })
+ return policy, nil
+}
diff --git a/util/syspolicy/setting/errors.go b/util/syspolicy/setting/errors.go
new file mode 100644
index 000000000..8d5e73754
--- /dev/null
+++ b/util/syspolicy/setting/errors.go
@@ -0,0 +1,60 @@
+// Copyright (c) Tailscale Inc & AUTHORS
+// SPDX-License-Identifier: BSD-3-Clause
+
+package setting
+
+import "errors"
+
+var (
+ // ErrNotConfigured is returned when the requested policy setting is not configured.
+ ErrNotConfigured = errors.New("not configured")
+ // ErrTypeMismatch is returned when there's a type mismatch between the actual type
+ // of the setting value and the expected type.
+ ErrTypeMismatch = errors.New("type mismatch")
+ // ErrNoSuchKey is returned by [DefinitionOf] when no policy setting
+ // has been registered with the specified key.
+ //
+ // Until 2024-08-02, this error was also returned by a [Handler] when the specified
+ // key did not have a value set. While the package maintains compatibility with this
+ // usage of ErrNoSuchKey, it is recommended to return [ErrNotConfigured] from newer
+ // [source.Store] implementations.
+ ErrNoSuchKey = errors.New("no such key")
+)
+
+// Error is an error when reading or parsing a policy setting.
+type Error struct {
+ text string
+}
+
+// NewError returns a [Error] with the specified error message.
+func NewError(text string) *Error {
+ return &Error{text}
+}
+
+// WrapError returns an [Error] with the text of the specified error,
+// or nil if err is nil, [ErrNotConfigured], or [ErrNoSuchKey].
+func WrapError(err error) *Error {
+ if err == nil || errors.Is(err, ErrNotConfigured) || errors.Is(err, ErrNoSuchKey) {
+ return nil
+ }
+ if err, ok := err.(*Error); ok {
+ return err
+ }
+ return &Error{err.Error()}
+}
+
+// Error implements error.
+func (e Error) Error() string {
+ return e.text
+}
+
+// MarshalText implements [encoding.TextMarshaler].
+func (e Error) MarshalText() (text []byte, err error) {
+ return []byte(e.Error()), nil
+}
+
+// UnmarshalText implements [encoding.TextUnmarshaler].
+func (e *Error) UnmarshalText(text []byte) error {
+ e.text = string(text)
+ return nil
+}
diff --git a/util/syspolicy/setting/key.go b/util/syspolicy/setting/key.go
new file mode 100644
index 000000000..406fde132
--- /dev/null
+++ b/util/syspolicy/setting/key.go
@@ -0,0 +1,13 @@
+// Copyright (c) Tailscale Inc & AUTHORS
+// SPDX-License-Identifier: BSD-3-Clause
+
+package setting
+
+// Key is a string that uniquely identifies a policy and must remain unchanged
+// once established and documented for a given policy setting. It may contain
+// alphanumeric characters and zero or more [KeyPathSeparator]s to group
+// individual policy settings into categories.
+type Key string
+
+// KeyPathSeparator allows logical grouping of policy settings into categories.
+const KeyPathSeparator = "/"
diff --git a/util/syspolicy/setting/origin.go b/util/syspolicy/setting/origin.go
new file mode 100644
index 000000000..3e61cd946
--- /dev/null
+++ b/util/syspolicy/setting/origin.go
@@ -0,0 +1,71 @@
+// Copyright (c) Tailscale Inc & AUTHORS
+// SPDX-License-Identifier: BSD-3-Clause
+
+package setting
+
+import (
+ "fmt"
+
+ jsonv2 "github.com/go-json-experiment/json"
+ "github.com/go-json-experiment/json/jsontext"
+)
+
+// Origin describes where a policy or a policy setting is configured.
+type Origin struct {
+ data settingOrigin
+}
+
+// settingOrigin is the marshallable data of a [Origin].
+type settingOrigin struct {
+ Name string `json:",omitzero"`
+ Scope PolicyScope
+}
+
+// NewOrigin returns a new [Origin] with the specified scope.
+func NewOrigin(scope PolicyScope) *Origin {
+ return NewNamedOrigin("", scope)
+}
+
+// NewNamedOrigin returns a new [Origin] with the specified scope and name.
+func NewNamedOrigin(name string, scope PolicyScope) *Origin {
+ return &Origin{settingOrigin{name, scope}}
+}
+
+// Scope reports the policy [PolicyScope] where the setting is configured.
+func (s Origin) Scope() PolicyScope {
+ return s.data.Scope
+}
+
+// Name returns the name of the policy source where the setting is configured,
+// or "" if not available.
+func (s Origin) Name() string {
+ return s.data.Name
+}
+
+// String implements [fmt.Stringer].
+func (s Origin) String() string {
+ if s.Name() != "" {
+ return fmt.Sprintf("%s (%v)", s.Name(), s.Scope())
+ }
+ return s.Scope().String()
+}
+
+// MarshalJSONV2 implements [jsonv2.MarshalerV2].
+func (s Origin) MarshalJSONV2(out *jsontext.Encoder, opts jsonv2.Options) error {
+ return jsonv2.MarshalEncode(out, &s.data, opts)
+}
+
+// UnmarshalJSONV2 implements [jsonv2.UnmarshalerV2].
+func (s *Origin) UnmarshalJSONV2(in *jsontext.Decoder, opts jsonv2.Options) error {
+ return jsonv2.UnmarshalDecode(in, &s.data, opts)
+}
+
+// MarshalJSON implements [json.Marshaler].
+func (s Origin) MarshalJSON() ([]byte, error) {
+ return jsonv2.Marshal(s) // uses MarshalJSONV2
+}
+
+// UnmarshalJSON implements [json.Unmarshaler].
+func (s *Origin) UnmarshalJSON(b []byte) error {
+ return jsonv2.Unmarshal(b, s) // uses UnmarshalJSONV2
+}
diff --git a/util/syspolicy/setting/policy_scope.go b/util/syspolicy/setting/policy_scope.go
new file mode 100644
index 000000000..636c815b2
--- /dev/null
+++ b/util/syspolicy/setting/policy_scope.go
@@ -0,0 +1,195 @@
+// Copyright (c) Tailscale Inc & AUTHORS
+// SPDX-License-Identifier: BSD-3-Clause
+
+package setting
+
+import (
+ "fmt"
+ "strings"
+
+ "tailscale.com/types/lazy"
+ "tailscale.com/util/syspolicy/internal/lazyinit"
+)
+
+var (
+ lazyCurrentScope lazy.SyncValue[PolicyScope]
+
+ // DeviceScope indicates a scope containing device-global policies.
+ DeviceScope = PolicyScope{kind: DeviceSetting}
+ // CurrentProfileScope indicates a scope containing policies that apply to the
+ // currently active Tailscale profile.
+ CurrentProfileScope = PolicyScope{kind: ProfileSetting}
+ // CurrentUserScope indicates a scope containing policies that apply to the
+ // current user, for whatever that means on the current platform and
+ // in the current application context.
+ CurrentUserScope = PolicyScope{kind: UserSetting}
+)
+
+// PolicyScope is a management scope.
+type PolicyScope struct {
+ kind Scope
+ userID string
+ profileID string
+}
+
+// CurrentScope returns the default [PolicyScope] that the package will use to return
+// the policy settings for unless a different scope is explicitly requested.
+// This defaults to [DeviceScope], unless the process runs as a user (rather than LocalSystem)
+// on Windows, in which case it returns the [CurrentUserScope].
+func CurrentScope() PolicyScope {
+ // Allow deferred package init functions to override the default scope.
+ lazyinit.Do()
+ return lazyCurrentScope.Get(func() PolicyScope { return DeviceScope })
+}
+
+// SetCurrentScope attempts to set the specified scope as the current scope,
+// and reports whether it succeeds.
+// It can be called only once and must be during lazy package initialization.
+func SetCurrentScope(scope PolicyScope) bool {
+ return lazyCurrentScope.Set(scope)
+}
+
+// UserScopeOf returns a policy [PolicyScope] of the specified user.
+func UserScopeOf(uid string) PolicyScope {
+ return PolicyScope{kind: UserSetting, userID: uid}
+}
+
+// Kind reports the base [Scope] of s.
+func (s PolicyScope) Kind() Scope {
+ return s.kind
+}
+
+// IsApplicableSetting reports whether the specified setting applies to
+// and can be retrieved for this scope. Policy settings are applicable
+// to their own scopes as well as more specific scopes. For example,
+// device settings are applicable to device, profile and user scopes,
+// but user settings are only applicable to user scopes.
+// For instance, a menu visibility setting is inherently a user setting
+// and only makes sense in the context of a specific user.
+func (s PolicyScope) IsApplicableSetting(setting *Definition) bool {
+ return setting != nil && setting.Scope() <= s.Kind()
+}
+
+// IsConfigurableSetting reports whether the specified setting can be configured
+// by a policy at this scope. Policy settings are configurable at their own scopes
+// as well as broader scopes. For example, [UserSetting]s are configurable in
+// user, profile, and device scopes, but [DeviceSetting]s are only configurable
+// in the [DeviceScope]. For instance, the InstallUpdates policy setting
+// can only be configured in the device scope, as it controls whether updates
+// will be installed automatically on the device, rather than for specific users.
+func (s PolicyScope) IsConfigurableSetting(setting *Definition) bool {
+ return setting != nil && setting.Scope() >= s.Kind()
+}
+
+// IsWithinOf reports whether policy settings that apply to s2 also apply to s.
+// For example, policy settings that apply to the [DeviceScope] also apply to
+// the [CurrentUserScope].
+func (s PolicyScope) IsWithinOf(s2 PolicyScope) bool {
+ if s2.Kind() > s.Kind() {
+ return false
+ }
+ switch s2.Kind() {
+ case DeviceSetting:
+ return true
+ case ProfileSetting:
+ return s.profileID == s2.profileID
+ case UserSetting:
+ return s.userID == s2.userID
+ default:
+ panic("unreachable")
+ }
+}
+
+// IsStrictlyWithinOf is like [IsWithinOf], except it returns false
+// when s and s2 is the same scope.
+func (s PolicyScope) IsStrictlyWithinOf(s2 PolicyScope) bool {
+ return s != s2 && s.IsWithinOf(s2)
+}
+
+// String implements [fmt.Stringer].
+func (s PolicyScope) String() string {
+ if s.profileID == "" && s.userID == "" {
+ return s.kind.String()
+ }
+ return s.stringSlow()
+}
+
+// MarshalText implements [encoding.TextMarshaler].
+func (s PolicyScope) MarshalText() ([]byte, error) {
+ return []byte(s.String()), nil
+}
+
+// MarshalText implements [encoding.TextUnmarshaler].
+func (s *PolicyScope) UnmarshalText(b []byte) error {
+ *s = PolicyScope{}
+ parts := strings.SplitN(string(b), "/", 2)
+ if len(parts) == 0 {
+ return fmt.Errorf("%s is not a valid scope", b)
+ }
+ for i, part := range parts {
+ kind, id, err := parseScopeAndID(part)
+ if err != nil {
+ return err
+ }
+ if i > 0 && kind <= s.kind {
+ return fmt.Errorf("invalid scope hierarchy: %s", b)
+ }
+ s.kind = kind
+ switch kind {
+ case DeviceSetting:
+ if id != "" {
+ return fmt.Errorf("the device scope must not have an ID: %s", b)
+ }
+ case ProfileSetting:
+ s.profileID = id
+ case UserSetting:
+ s.userID = id
+ }
+ }
+ return nil
+}
+
+func (s PolicyScope) stringSlow() string {
+ var sb strings.Builder
+ writeScopeWithID := func(s Scope, id string) {
+ sb.WriteString(s.String())
+ if id != "" {
+ sb.WriteRune('(')
+ sb.WriteString(id)
+ sb.WriteRune(')')
+ }
+ }
+ if s.kind == ProfileSetting || s.profileID != "" {
+ writeScopeWithID(ProfileSetting, s.profileID)
+ if s.kind != ProfileSetting {
+ sb.WriteRune('/')
+ }
+ }
+ if s.kind == UserSetting {
+ writeScopeWithID(UserSetting, s.userID)
+ }
+ return sb.String()
+}
+
+func parseScopeAndID(s string) (scope Scope, id string, err error) {
+ name, params, ok := extractScopeAndParams(s)
+ if !ok {
+ return 0, "", fmt.Errorf("%q is not a valid scope string", s)
+ }
+ if err := scope.UnmarshalText([]byte(name)); err != nil {
+ return 0, "", err
+ }
+ return scope, params, nil
+}
+
+func extractScopeAndParams(s string) (name, params string, ok bool) {
+ paramsStart := strings.Index(s, "(")
+ if paramsStart == -1 {
+ return s, "", true
+ }
+ paramsEnd := strings.LastIndex(s, ")")
+ if paramsEnd < paramsStart {
+ return "", "", false
+ }
+ return s[0:paramsStart], s[paramsStart+1 : paramsEnd], true
+}
diff --git a/util/syspolicy/setting/policy_scope_test.go b/util/syspolicy/setting/policy_scope_test.go
new file mode 100644
index 000000000..8140fc5a0
--- /dev/null
+++ b/util/syspolicy/setting/policy_scope_test.go
@@ -0,0 +1,550 @@
+// Copyright (c) Tailscale Inc & AUTHORS
+// SPDX-License-Identifier: BSD-3-Clause
+
+package setting
+
+import (
+ "reflect"
+ "testing"
+
+ jsonv2 "github.com/go-json-experiment/json"
+)
+
+func TestPolicyScopeIsApplicableSetting(t *testing.T) {
+ tests := []struct {
+ name string
+ scope PolicyScope
+ setting *Definition
+ wantApplicable bool
+ }{
+ {
+ name: "DeviceScope/DeviceSetting",
+ scope: DeviceScope,
+ setting: NewDefinition("TestSetting", DeviceSetting, IntegerValue),
+ wantApplicable: true,
+ },
+ {
+ name: "DeviceScope/ProfileSetting",
+ scope: DeviceScope,
+ setting: NewDefinition("TestSetting", ProfileSetting, IntegerValue),
+ wantApplicable: false,
+ },
+ {
+ name: "DeviceScope/UserSetting",
+ scope: DeviceScope,
+ setting: NewDefinition("TestSetting", UserSetting, IntegerValue),
+ wantApplicable: false,
+ },
+ {
+ name: "ProfileScope/DeviceSetting",
+ scope: CurrentProfileScope,
+ setting: NewDefinition("TestSetting", DeviceSetting, IntegerValue),
+ wantApplicable: true,
+ },
+ {
+ name: "ProfileScope/ProfileSetting",
+ scope: CurrentProfileScope,
+ setting: NewDefinition("TestSetting", ProfileSetting, IntegerValue),
+ wantApplicable: true,
+ },
+ {
+ name: "ProfileScope/UserSetting",
+ scope: CurrentProfileScope,
+ setting: NewDefinition("TestSetting", UserSetting, IntegerValue),
+ wantApplicable: false,
+ },
+ {
+ name: "UserScope/DeviceSetting",
+ scope: CurrentUserScope,
+ setting: NewDefinition("TestSetting", DeviceSetting, IntegerValue),
+ wantApplicable: true,
+ },
+ {
+ name: "UserScope/ProfileSetting",
+ scope: CurrentUserScope,
+ setting: NewDefinition("TestSetting", ProfileSetting, IntegerValue),
+ wantApplicable: true,
+ },
+ {
+ name: "UserScope/UserSetting",
+ scope: CurrentUserScope,
+ setting: NewDefinition("TestSetting", UserSetting, IntegerValue),
+ wantApplicable: true,
+ },
+ }
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ gotApplicable := tt.scope.IsApplicableSetting(tt.setting)
+ if gotApplicable != tt.wantApplicable {
+ t.Fatalf("got %v, want %v", gotApplicable, tt.wantApplicable)
+ }
+ })
+ }
+}
+
+func TestPolicyScopeIsConfigurableSetting(t *testing.T) {
+ tests := []struct {
+ name string
+ scope PolicyScope
+ setting *Definition
+ wantConfigurable bool
+ }{
+ {
+ name: "DeviceScope/DeviceSetting",
+ scope: DeviceScope,
+ setting: NewDefinition("TestSetting", DeviceSetting, IntegerValue),
+ wantConfigurable: true,
+ },
+ {
+ name: "DeviceScope/ProfileSetting",
+ scope: DeviceScope,
+ setting: NewDefinition("TestSetting", ProfileSetting, IntegerValue),
+ wantConfigurable: true,
+ },
+ {
+ name: "DeviceScope/UserSetting",
+ scope: DeviceScope,
+ setting: NewDefinition("TestSetting", UserSetting, IntegerValue),
+ wantConfigurable: true,
+ },
+ {
+ name: "ProfileScope/DeviceSetting",
+ scope: CurrentProfileScope,
+ setting: NewDefinition("TestSetting", DeviceSetting, IntegerValue),
+ wantConfigurable: false,
+ },
+ {
+ name: "ProfileScope/ProfileSetting",
+ scope: CurrentProfileScope,
+ setting: NewDefinition("TestSetting", ProfileSetting, IntegerValue),
+ wantConfigurable: true,
+ },
+ {
+ name: "ProfileScope/UserSetting",
+ scope: CurrentProfileScope,
+ setting: NewDefinition("TestSetting", UserSetting, IntegerValue),
+ wantConfigurable: true,
+ },
+ {
+ name: "UserScope/DeviceSetting",
+ scope: CurrentUserScope,
+ setting: NewDefinition("TestSetting", DeviceSetting, IntegerValue),
+ wantConfigurable: false,
+ },
+ {
+ name: "UserScope/ProfileSetting",
+ scope: CurrentUserScope,
+ setting: NewDefinition("TestSetting", ProfileSetting, IntegerValue),
+ wantConfigurable: false,
+ },
+ {
+ name: "UserScope/UserSetting",
+ scope: CurrentUserScope,
+ setting: NewDefinition("TestSetting", UserSetting, IntegerValue),
+ wantConfigurable: true,
+ },
+ }
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ gotConfigurable := tt.scope.IsConfigurableSetting(tt.setting)
+ if gotConfigurable != tt.wantConfigurable {
+ t.Fatalf("got %v, want %v", gotConfigurable, tt.wantConfigurable)
+ }
+ })
+ }
+}
+
+func TestPolicyScopeIsWithinOf(t *testing.T) {
+ tests := []struct {
+ name string
+ scopeA PolicyScope
+ scopeB PolicyScope
+ wantBWithinOfA bool
+ wantBStrictlyWithinOfA bool
+ }{
+ {
+ name: "DeviceScope/DeviceScope",
+ scopeA: DeviceScope,
+ scopeB: DeviceScope,
+ wantBWithinOfA: true,
+ wantBStrictlyWithinOfA: false,
+ },
+ {
+ name: "DeviceScope/CurrentProfileScope",
+ scopeA: DeviceScope,
+ scopeB: CurrentProfileScope,
+ wantBWithinOfA: true,
+ wantBStrictlyWithinOfA: true,
+ },
+ {
+ name: "DeviceScope/UserScope",
+ scopeA: DeviceScope,
+ scopeB: CurrentUserScope,
+ wantBWithinOfA: true,
+ wantBStrictlyWithinOfA: true,
+ },
+ {
+ name: "ProfileScope/DeviceScope",
+ scopeA: CurrentProfileScope,
+ scopeB: DeviceScope,
+ wantBWithinOfA: false,
+ wantBStrictlyWithinOfA: false,
+ },
+ {
+ name: "ProfileScope/ProfileScope",
+ scopeA: CurrentProfileScope,
+ scopeB: CurrentProfileScope,
+ wantBWithinOfA: true,
+ wantBStrictlyWithinOfA: false,
+ },
+ {
+ name: "ProfileScope/UserScope",
+ scopeA: CurrentProfileScope,
+ scopeB: CurrentUserScope,
+ wantBWithinOfA: true,
+ wantBStrictlyWithinOfA: true,
+ },
+ {
+ name: "UserScope/DeviceScope",
+ scopeA: CurrentUserScope,
+ scopeB: DeviceScope,
+ wantBWithinOfA: false,
+ wantBStrictlyWithinOfA: false,
+ },
+ {
+ name: "UserScope/ProfileScope",
+ scopeA: CurrentUserScope,
+ scopeB: CurrentProfileScope,
+ wantBWithinOfA: false,
+ wantBStrictlyWithinOfA: false,
+ },
+ {
+ name: "UserScope/UserScope",
+ scopeA: CurrentUserScope,
+ scopeB: CurrentUserScope,
+ wantBWithinOfA: true,
+ wantBStrictlyWithinOfA: false,
+ },
+ {
+ name: "UserScope(1234)/UserScope(1234)",
+ scopeA: UserScopeOf("1234"),
+ scopeB: UserScopeOf("1234"),
+ wantBWithinOfA: true,
+ wantBStrictlyWithinOfA: false,
+ },
+ {
+ name: "UserScope(1234)/UserScope(5678)",
+ scopeA: UserScopeOf("1234"),
+ scopeB: UserScopeOf("5678"),
+ wantBWithinOfA: false,
+ wantBStrictlyWithinOfA: false,
+ },
+ {
+ name: "ProfileScope(A)/UserScope(A/1234)",
+ scopeA: PolicyScope{kind: ProfileSetting, profileID: "A"},
+ scopeB: PolicyScope{kind: UserSetting, userID: "1234", profileID: "A"},
+ wantBWithinOfA: true,
+ wantBStrictlyWithinOfA: true,
+ },
+ {
+ name: "ProfileScope(A)/UserScope(B/1234)",
+ scopeA: PolicyScope{kind: ProfileSetting, profileID: "A"},
+ scopeB: PolicyScope{kind: UserSetting, userID: "1234", profileID: "B"},
+ wantBWithinOfA: false,
+ wantBStrictlyWithinOfA: false,
+ },
+ {
+ name: "UserScope(1234)/UserScope(A/1234)",
+ scopeA: PolicyScope{kind: UserSetting, userID: "1234"},
+ scopeB: PolicyScope{kind: UserSetting, userID: "1234", profileID: "A"},
+ wantBWithinOfA: true,
+ wantBStrictlyWithinOfA: true,
+ },
+ {
+ name: "UserScope(1234)/UserScope(A/5678)",
+ scopeA: PolicyScope{kind: UserSetting, userID: "1234"},
+ scopeB: PolicyScope{kind: UserSetting, userID: "5678", profileID: "A"},
+ wantBWithinOfA: false,
+ wantBStrictlyWithinOfA: false,
+ },
+ }
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ gotWithinOf := tt.scopeB.IsWithinOf(tt.scopeA)
+ if gotWithinOf != tt.wantBWithinOfA {
+ t.Fatalf("WithinOf: got %v, want %v", gotWithinOf, tt.wantBWithinOfA)
+ }
+
+ gotStrictlyWithinOf := tt.scopeB.IsStrictlyWithinOf(tt.scopeA)
+ if gotStrictlyWithinOf != tt.wantBStrictlyWithinOfA {
+ t.Fatalf("StrictlyWithinOf: got %v, want %v", gotStrictlyWithinOf, tt.wantBStrictlyWithinOfA)
+ }
+ })
+ }
+}
+
+func TestPolicyScopeMarshalUnmarshal(t *testing.T) {
+ tests := []struct {
+ name string
+ in any
+ wantJSON string
+ wantError bool
+ }{
+ {
+ name: "null-scope",
+ in: &struct {
+ Scope PolicyScope
+ }{},
+ wantJSON: `{"Scope":"Device"}`,
+ },
+ {
+ name: "null-scope-omit-zero",
+ in: &struct {
+ Scope PolicyScope `json:",omitzero"`
+ }{},
+ wantJSON: `{}`,
+ },
+ {
+ name: "device-scope",
+ in: &struct {
+ Scope PolicyScope
+ }{DeviceScope},
+ wantJSON: `{"Scope":"Device"}`,
+ },
+ {
+ name: "current-profile-scope",
+ in: &struct {
+ Scope PolicyScope
+ }{CurrentProfileScope},
+ wantJSON: `{"Scope":"Profile"}`,
+ },
+ {
+ name: "current-user-scope",
+ in: &struct {
+ Scope PolicyScope
+ }{CurrentUserScope},
+ wantJSON: `{"Scope":"User"}`,
+ },
+ {
+ name: "specific-user-scope",
+ in: &struct {
+ Scope PolicyScope
+ }{UserScopeOf("_")},
+ wantJSON: `{"Scope":"User(_)"}`,
+ },
+ {
+ name: "specific-user-scope",
+ in: &struct {
+ Scope PolicyScope
+ }{UserScopeOf("S-1-5-21-3698941153-1525015703-2649197413-1001")},
+ wantJSON: `{"Scope":"User(S-1-5-21-3698941153-1525015703-2649197413-1001)"}`,
+ },
+ {
+ name: "specific-profile-scope",
+ in: &struct {
+ Scope PolicyScope
+ }{PolicyScope{kind: ProfileSetting, profileID: "1234"}},
+ wantJSON: `{"Scope":"Profile(1234)"}`,
+ },
+ {
+ name: "specific-profile-and-user-scope",
+ in: &struct {
+ Scope PolicyScope
+ }{PolicyScope{
+ kind: UserSetting,
+ profileID: "1234",
+ userID: "S-1-5-21-3698941153-1525015703-2649197413-1001",
+ }},
+ wantJSON: `{"Scope":"Profile(1234)/User(S-1-5-21-3698941153-1525015703-2649197413-1001)"}`,
+ },
+ }
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ gotJSON, err := jsonv2.Marshal(tt.in)
+ if err != nil {
+ t.Fatalf("Marshal failed: %v", err)
+ }
+ if string(gotJSON) != tt.wantJSON {
+ t.Fatalf("Marshal got %s, want %s", gotJSON, tt.wantJSON)
+ }
+ wantBack := tt.in
+ gotBack := reflect.New(reflect.TypeOf(tt.in).Elem()).Interface()
+ err = jsonv2.Unmarshal(gotJSON, gotBack)
+ if err != nil {
+ t.Fatalf("Unmarshal failed: %v", err)
+ }
+ if !reflect.DeepEqual(gotBack, wantBack) {
+ t.Fatalf("Unmarshal got %+v, want %+v", gotBack, wantBack)
+ }
+ })
+ }
+}
+
+func TestPolicyScopeUnmarshalSpecial(t *testing.T) {
+ tests := []struct {
+ name string
+ json string
+ want any
+ wantError bool
+ }{
+ {
+ name: "empty",
+ json: "{}",
+ want: &struct {
+ Scope PolicyScope
+ }{},
+ },
+ {
+ name: "too-many-scopes",
+ json: `{"Scope":"Device/Profile/User"}`,
+ wantError: true,
+ },
+ {
+ name: "user/profile", // incorrect order
+ json: `{"Scope":"User/Profile"}`,
+ wantError: true,
+ },
+ {
+ name: "profile-user-no-params",
+ json: `{"Scope":"Profile/User"}`,
+ want: &struct {
+ Scope PolicyScope
+ }{CurrentUserScope},
+ },
+ {
+ name: "unknown-scope",
+ json: `{"Scope":"Unknown"}`,
+ wantError: true,
+ },
+ {
+ name: "unknown-scope/unknown-scope",
+ json: `{"Scope":"Unknown/Unknown"}`,
+ wantError: true,
+ },
+ {
+ name: "device-scope/unknown-scope",
+ json: `{"Scope":"Device/Unknown"}`,
+ wantError: true,
+ },
+ {
+ name: "unknown-scope/device-scope",
+ json: `{"Scope":"Unknown/Device"}`,
+ wantError: true,
+ },
+ {
+ name: "slash",
+ json: `{"Scope":"/"}`,
+ wantError: true,
+ },
+ }
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ got := &struct {
+ Scope PolicyScope
+ }{}
+ err := jsonv2.Unmarshal([]byte(tt.json), got)
+ if (err != nil) != tt.wantError {
+ t.Errorf("Marshal error: got %v, want %v", err, tt.wantError)
+ }
+ if err != nil {
+ return
+ }
+ if !reflect.DeepEqual(got, tt.want) {
+ t.Fatalf("Unmarshal got %+v, want %+v", got, tt.want)
+ }
+ })
+ }
+
+}
+
+func TestExtractScopeAndParams(t *testing.T) {
+ tests := []struct {
+ name string
+ s string
+ scope string
+ params string
+ wantOk bool
+ }{
+ {
+ name: "empty",
+ s: "",
+ wantOk: true,
+ },
+ {
+ name: "scope-only",
+ s: "device",
+ scope: "device",
+ wantOk: true,
+ },
+ {
+ name: "scope-with-params",
+ s: "user(1234)",
+ scope: "user",
+ params: "1234",
+ wantOk: true,
+ },
+ {
+ name: "params-empty-scope",
+ s: "(1234)",
+ scope: "",
+ params: "1234",
+ wantOk: true,
+ },
+ {
+ name: "params-with-brackets",
+ s: "test()())))())",
+ scope: "test",
+ params: ")())))()",
+ wantOk: true,
+ },
+ {
+ name: "no-closing-bracket",
+ s: "user(1234",
+ scope: "",
+ params: "",
+ wantOk: false,
+ },
+ {
+ name: "open-before-close",
+ s: ")user(1234",
+ scope: "",
+ params: "",
+ wantOk: false,
+ },
+ {
+ name: "brackets-only",
+ s: ")(",
+ scope: "",
+ params: "",
+ wantOk: false,
+ },
+ {
+ name: "closing-bracket",
+ s: ")",
+ scope: "",
+ params: "",
+ wantOk: false,
+ },
+ {
+ name: "opening-bracket",
+ s: ")",
+ scope: "",
+ params: "",
+ wantOk: false,
+ },
+ }
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ scope, params, ok := extractScopeAndParams(tt.s)
+ if ok != tt.wantOk {
+ t.Logf("OK: got %v; want %v", ok, tt.wantOk)
+ }
+ if scope != tt.scope {
+ t.Logf("Scope: got %q; want %q", scope, tt.scope)
+ }
+ if params != tt.params {
+ t.Logf("Params: got %v; want %v", params, tt.params)
+ }
+ })
+ }
+}
diff --git a/util/syspolicy/setting/raw_item.go b/util/syspolicy/setting/raw_item.go
new file mode 100644
index 000000000..a901b505a
--- /dev/null
+++ b/util/syspolicy/setting/raw_item.go
@@ -0,0 +1,47 @@
+// Copyright (c) Tailscale Inc & AUTHORS
+// SPDX-License-Identifier: BSD-3-Clause
+
+package setting
+
+// RawItem contains a raw policy setting as read from a policy store, or an
+// error if the requested setting could not be read from the store. As a special
+// case, it may also hold a value of the [Visibility], [PreferenceOption],
+// or [time.Duration] types. While the policy store interface does not support
+// these types natively, and the values of these types have to be unmarshalled
+// or converted from strings, these setting types predate the typed policy
+// hierarchies, and must be supported at this layer.
+type RawItem struct {
+ value any
+ err *Error
+ origin *Origin // or nil
+}
+
+// RawItemOf returns [RawItem] with the specified value.
+func RawItemOf(value any) RawItem {
+ return RawItemWith(value, nil, nil)
+}
+
+// RawItemWith returns an [RawItem] with the specified value, error and origin.
+func RawItemWith(value any, err *Error, origin *Origin) RawItem {
+ return RawItem{value, err, origin}
+}
+
+// Value returns the value of an untyped policy setting,
+// or nil if the policy setting is not configured.
+func (i RawItem) Value() any {
+ return i.value
+}
+
+// Error returns the error that occurred when reading the policy setting,
+// or nil if no error occurred.
+func (i RawItem) Error() error {
+ if i.err != nil {
+ return i.err
+ }
+ return nil
+}
+
+// Origin returns an optional [Origin] indicating the policy settings is configured.
+func (i RawItem) Origin() *Origin {
+ return i.origin
+}
diff --git a/util/syspolicy/setting/setting.go b/util/syspolicy/setting/setting.go
new file mode 100644
index 000000000..e60aab12c
--- /dev/null
+++ b/util/syspolicy/setting/setting.go
@@ -0,0 +1,352 @@
+// Copyright (c) Tailscale Inc & AUTHORS
+// SPDX-License-Identifier: BSD-3-Clause
+
+// Package setting contain types for policy settings.
+package setting
+
+import (
+ "fmt"
+ "slices"
+ "strings"
+ "sync"
+ "time"
+
+ "tailscale.com/types/lazy"
+ "tailscale.com/util/syspolicy/internal"
+ "tailscale.com/util/syspolicy/internal/lazyinit"
+)
+
+// Scope indicates the broadest scope at which a policy setting may apply,
+// and the narrowest scope at which it may be configured.
+type Scope int8
+
+const (
+ // DeviceSetting indicates a policy setting that applies to a device, regardless of
+ // which OS user or Tailscale profile is currently active, if any.
+ // It can only be configured at a [DeviceScope].
+ DeviceSetting Scope = iota
+ // ProfileSetting indicates a policy setting that applies to a Tailscale profile.
+ // It can only be configured for a specific profile or at a [DeviceScope],
+ // in which case it applies to all profiles on the device.
+ ProfileSetting
+ // UserSetting indicates a policy setting that applies to users.
+ // It can be configured for a user, profile, or the entire device.
+ UserSetting
+
+ // MaxSettingScope is the maximum possible [Scope] value.
+ MaxSettingScope = UserSetting
+)
+
+// String implements [fmt.Stringer].
+func (s Scope) String() string {
+ switch s {
+ case DeviceSetting:
+ return "Device"
+ case ProfileSetting:
+ return "Profile"
+ case UserSetting:
+ return "User"
+ default:
+ panic("unreachable")
+ }
+}
+
+// MarshalText implements [encoding.TextMarshaler].
+func (s Scope) MarshalText() (text []byte, err error) {
+ return []byte(s.String()), nil
+}
+
+// UnmarshalText implements [encoding.TextUnmarshaler].
+func (s *Scope) UnmarshalText(text []byte) error {
+ switch strings.ToLower(string(text)) {
+ case "device":
+ *s = DeviceSetting
+ case "profile":
+ *s = ProfileSetting
+ case "user":
+ *s = UserSetting
+ default:
+ return fmt.Errorf("%q is not a valid scope", string(text))
+ }
+ return nil
+}
+
+// Type is a policy setting value type.
+// Except for [InvalidValue], which represents an invalid policy setting type,
+// and [PreferenceOptionValue], [VisibilityValue], and [DurationValue],
+// which have special handling due to their legacy status in the package,
+// SettingTypes represent the raw value types readable from policy stores.
+type Type int
+
+const (
+ // InvalidValue indicates an invalid policy setting value type.
+ InvalidValue Type = iota
+ // BooleanValue indicates a policy setting whose underlying type in the
+ // [source.Store] is a bool.
+ BooleanValue
+ // IntegerValue indicates a policy setting whose underlying type in the
+ // [source.Store] is a uint64.
+ IntegerValue
+ // StringValue indicates a policy setting whose underlying type in the
+ // [source.Store] is a string.
+ StringValue
+ // StringListValue indicates a policy setting whose underlying type in the
+ // [source.Store] is a []string.
+ StringListValue
+ // PreferenceOptionValue indicates a three-state policy setting whose
+ // underlying type in the [source.Store] is a string, but the actual value
+ // is a [PreferenceOption].
+ PreferenceOptionValue
+ // VisibilityValue indicates a two-state boolean-like policy setting whose
+ // underlying type in the [source.Store] is a string, but the actual value
+ // is a [Visibility].
+ VisibilityValue
+ // DurationValue indicates an interval/period/duration policy setting whose
+ // underlying type in the [source.Store] is a string, but the actual value
+ // is a [time.Duration].
+ DurationValue
+)
+
+// String returns a string representation of t.
+func (t Type) String() string {
+ switch t {
+ case InvalidValue:
+ return "Invalid"
+ case BooleanValue:
+ return "Boolean"
+ case IntegerValue:
+ return "Integer"
+ case StringValue:
+ return "String"
+ case StringListValue:
+ return "StringList"
+ case PreferenceOptionValue:
+ return "PreferenceOption"
+ case VisibilityValue:
+ return "Visibility"
+ case DurationValue:
+ return "Duration"
+ default:
+ panic("unreachable")
+ }
+}
+
+// ValueType is a constraint that allows Go types corresponding to [Type].
+type ValueType interface {
+ bool | uint64 | string | []string | Visibility | PreferenceOption | time.Duration
+}
+
+// Definition defines policy key, scope and value type.
+type Definition struct {
+ key Key
+ scope Scope
+ typ Type
+ platforms PlatformList
+}
+
+// NewDefinition returns a new [Definition] with the specified
+// key, scope, type and supported platforms (see [PlatformList]).
+func NewDefinition(k Key, s Scope, t Type, platforms ...string) *Definition {
+ return &Definition{key: k, scope: s, typ: t, platforms: platforms}
+}
+
+// Key returns a policy setting's identifier.
+func (d *Definition) Key() Key {
+ if d == nil {
+ return ""
+ }
+ return d.key
+}
+
+// Scope reports the broadest [Scope] the policy setting may apply to.
+func (d *Definition) Scope() Scope {
+ if d == nil {
+ return 0
+ }
+ return d.scope
+}
+
+// Type reports the underlying value type of the policy setting.
+func (d *Definition) Type() Type {
+ if d == nil {
+ return InvalidValue
+ }
+ return d.typ
+}
+
+// IsSupported reports whether the policy setting is supported on the current OS.
+func (d *Definition) IsSupported() bool {
+ if d == nil {
+ return false
+ }
+ return d.platforms.HasCurrent()
+}
+
+// SupportedPlatforms reports platforms on which the policy setting is supported.
+// An empty [PlatformList] indicates that s is available on all platforms.
+func (d *Definition) SupportedPlatforms() PlatformList {
+ if d == nil {
+ return nil
+ }
+ return d.platforms
+}
+
+// String implements [fmt.Stringer].
+func (d *Definition) String() string {
+ if d == nil {
+ return "(nil)"
+ }
+ return fmt.Sprintf("%v(%q, %v)", d.scope, d.key, d.typ)
+}
+
+// Equal reports whether d and d2 have the same key, type and scope.
+// It does not check whether both s and s2 are supported on the same platforms.
+func (d *Definition) Equal(d2 *Definition) bool {
+ if d == d2 {
+ return true
+ }
+ if d == nil || d2 == nil {
+ return false
+ }
+ return d.key == d2.key && d.typ == d2.typ && d.scope == d2.scope
+}
+
+// DefinitionMap is a map of setting [Definition] by [Key].
+type DefinitionMap map[Key]*Definition
+
+var (
+ definitions lazy.SyncValue[DefinitionMap]
+
+ definitionsMu sync.Mutex
+ definitionsList []*Definition
+ definitionsUsed bool
+)
+
+// Register registers a policy setting with the specified key, scope, and value type.
+// All policy settings must be registered before any of them can be used.
+// Register panics if called after invoking any syspolicy functions that use the
+// registered policy definitions, such as functions that read the policy.
+func Register(k Key, s Scope, t Type, platforms ...string) {
+ RegisterDefinition(NewDefinition(k, s, t, platforms...))
+}
+
+// RegisterDefinition is like [Register], but accepts a [Definition].
+func RegisterDefinition(d *Definition) {
+ definitionsMu.Lock()
+ defer definitionsMu.Unlock()
+ registerLocked(d)
+}
+
+func registerLocked(d *Definition) {
+ if definitionsUsed {
+ panic("policy definitions are already in use")
+ }
+ definitionsList = append(definitionsList, d)
+}
+
+func settingDefinitions() (DefinitionMap, error) {
+ return definitions.GetErr(func() (DefinitionMap, error) {
+ lazyinit.Do()
+ definitionsMu.Lock()
+ defer definitionsMu.Unlock()
+ definitionsUsed = true
+ return DefinitionMapOf(definitionsList)
+ })
+}
+
+// DefinitionMapOf returns a [DefinitionMap] with the specified settings,
+// or an error if any settings have the same key but different type or scope.
+func DefinitionMapOf(settings []*Definition) (DefinitionMap, error) {
+ m := make(DefinitionMap, len(settings))
+ for _, s := range settings {
+ if existing, exists := m[s.key]; exists {
+ if existing.Equal(s) {
+ // Ignore duplicate setting definitions if they match. It is acceptable
+ // if the same policy setting was registered more than once
+ // (e.g. by the syspolicy package itself and by iOS/Android code).
+ existing.platforms.mergeFrom(s.platforms)
+ continue
+ }
+ return nil, fmt.Errorf("duplicate policy definition: %q", s.key)
+ }
+ m[s.key] = s
+ }
+ return m, nil
+}
+
+// SetDefinitionsForTest allows to register the specified setting definitions
+// for the test duration. It is not concurrency-safe, but unlike [Register],
+// it does not panic and can be called anytime.
+// It returns an error if ds contains two different settings with the same [Key].
+func SetDefinitionsForTest(tb lazy.TB, ds ...*Definition) error {
+ m, err := DefinitionMapOf(ds)
+ if err != nil {
+ return err
+ }
+ definitions.SetForTest(tb, m, err)
+ return nil
+}
+
+// DefinitionOf returns a setting definition by key,
+// or [ErrNoSuchKey] if the specified key does not exist,
+// or an error if there are conflicting policy definitions.
+func DefinitionOf(k Key) (*Definition, error) {
+ ds, err := settingDefinitions()
+ if err != nil {
+ return nil, err
+ }
+ if d, ok := ds[k]; ok {
+ return d, nil
+ }
+ return nil, ErrNoSuchKey
+}
+
+// Definitions returns all registered setting definitions,
+// or an error if different policies were registered under the same name.
+func Definitions() ([]*Definition, error) {
+ ds, err := settingDefinitions()
+ if err != nil {
+ return nil, err
+ }
+ res := make([]*Definition, 0, len(ds))
+ for _, d := range ds {
+ res = append(res, d)
+ }
+ return res, nil
+}
+
+// PlatformList is a list of OSes.
+// An empty list indicates that all possible platforms are supported.
+type PlatformList []string
+
+// Has reports whether the list contains the target platform.
+func (l PlatformList) Has(target string) bool {
+ if len(l) == 0 {
+ return true
+ }
+ return slices.ContainsFunc(l, func(os string) bool {
+ return strings.EqualFold(os, target)
+ })
+}
+
+// HasCurrent is like Has, but for the current platform.
+func (l PlatformList) HasCurrent() bool {
+ return l.Has(internal.OS())
+}
+
+// mergeFrom merges l2 into l. Since an empty list indicates no platform restrictions,
+// if either l or l2 is empty, the merged result in l will also be empty.
+func (l *PlatformList) mergeFrom(l2 PlatformList) {
+ switch {
+ case len(*l) == 0:
+ // No-op. An empty list indicates no platform restrictions.
+ case len(l2) == 0:
+ // Merging with an empty list results in an empty list.
+ *l = l2
+ default:
+ // Append, sort and dedup.
+ *l = append(*l, l2...)
+ slices.Sort(*l)
+ *l = slices.Compact(*l)
+ }
+}
diff --git a/util/syspolicy/setting/setting_test.go b/util/syspolicy/setting/setting_test.go
new file mode 100644
index 000000000..3cc08e7da
--- /dev/null
+++ b/util/syspolicy/setting/setting_test.go
@@ -0,0 +1,344 @@
+// Copyright (c) Tailscale Inc & AUTHORS
+// SPDX-License-Identifier: BSD-3-Clause
+
+package setting
+
+import (
+ "slices"
+ "strings"
+ "testing"
+
+ "tailscale.com/types/lazy"
+ "tailscale.com/types/ptr"
+ "tailscale.com/util/syspolicy/internal"
+)
+
+func TestSettingDefinition(t *testing.T) {
+ tests := []struct {
+ name string
+ setting *Definition
+ osOverride string
+ wantKey Key
+ wantScope Scope
+ wantType Type
+ wantIsSupported bool
+ wantSupportedPlatforms PlatformList
+ wantString string
+ }{
+ {
+ name: "Nil",
+ setting: nil,
+ wantKey: "",
+ wantScope: 0,
+ wantType: InvalidValue,
+ wantIsSupported: false,
+ wantString: "(nil)",
+ },
+ {
+ name: "Device/Invalid",
+ setting: NewDefinition("TestDevicePolicySetting", DeviceSetting, InvalidValue),
+ wantKey: "TestDevicePolicySetting",
+ wantScope: DeviceSetting,
+ wantType: InvalidValue,
+ wantIsSupported: true,
+ wantString: `Device("TestDevicePolicySetting", Invalid)`,
+ },
+ {
+ name: "Device/Integer",
+ setting: NewDefinition("TestDevicePolicySetting", DeviceSetting, IntegerValue),
+ wantKey: "TestDevicePolicySetting",
+ wantScope: DeviceSetting,
+ wantType: IntegerValue,
+ wantIsSupported: true,
+ wantString: `Device("TestDevicePolicySetting", Integer)`,
+ },
+ {
+ name: "Profile/String",
+ setting: NewDefinition("TestProfilePolicySetting", ProfileSetting, StringValue),
+ wantKey: "TestProfilePolicySetting",
+ wantScope: ProfileSetting,
+ wantType: StringValue,
+ wantIsSupported: true,
+ wantString: `Profile("TestProfilePolicySetting", String)`,
+ },
+ {
+ name: "Device/StringList",
+ setting: NewDefinition("AllowedSuggestedExitNodes", DeviceSetting, StringListValue),
+ wantKey: "AllowedSuggestedExitNodes",
+ wantScope: DeviceSetting,
+ wantType: StringListValue,
+ wantIsSupported: true,
+ wantString: `Device("AllowedSuggestedExitNodes", StringList)`,
+ },
+ {
+ name: "Device/PreferenceOption",
+ setting: NewDefinition("AdvertiseExitNode", DeviceSetting, PreferenceOptionValue),
+ wantKey: "AdvertiseExitNode",
+ wantScope: DeviceSetting,
+ wantType: PreferenceOptionValue,
+ wantIsSupported: true,
+ wantString: `Device("AdvertiseExitNode", PreferenceOption)`,
+ },
+ {
+ name: "User/Boolean",
+ setting: NewDefinition("TestUserPolicySetting", UserSetting, BooleanValue),
+ wantKey: "TestUserPolicySetting",
+ wantScope: UserSetting,
+ wantType: BooleanValue,
+ wantIsSupported: true,
+ wantString: `User("TestUserPolicySetting", Boolean)`,
+ },
+ {
+ name: "User/Visibility",
+ setting: NewDefinition("AdminConsole", UserSetting, VisibilityValue),
+ wantKey: "AdminConsole",
+ wantScope: UserSetting,
+ wantType: VisibilityValue,
+ wantIsSupported: true,
+ wantString: `User("AdminConsole", Visibility)`,
+ },
+ {
+ name: "User/Duration",
+ setting: NewDefinition("KeyExpirationNotice", UserSetting, DurationValue),
+ wantKey: "KeyExpirationNotice",
+ wantScope: UserSetting,
+ wantType: DurationValue,
+ wantIsSupported: true,
+ wantString: `User("KeyExpirationNotice", Duration)`,
+ },
+ {
+ name: "SupportedSetting",
+ setting: NewDefinition("DesktopPolicySetting", DeviceSetting, StringValue, "macos", "windows"),
+ osOverride: "windows",
+ wantKey: "DesktopPolicySetting",
+ wantScope: DeviceSetting,
+ wantType: StringValue,
+ wantIsSupported: true,
+ wantSupportedPlatforms: PlatformList{"macos", "windows"},
+ wantString: `Device("DesktopPolicySetting", String)`,
+ },
+ {
+ name: "UnsupportedSetting",
+ setting: NewDefinition("AndroidPolicySetting", DeviceSetting, StringValue, "android"),
+ osOverride: "macos",
+ wantKey: "AndroidPolicySetting",
+ wantScope: DeviceSetting,
+ wantType: StringValue,
+ wantIsSupported: false,
+ wantSupportedPlatforms: PlatformList{"android"},
+ wantString: `Device("AndroidPolicySetting", String)`,
+ },
+ }
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ if tt.osOverride != "" {
+ internal.OSForTesting.SetForTest(t, tt.osOverride, nil)
+ }
+ if !tt.setting.Equal(tt.setting) {
+ t.Errorf("the setting should be equal to itself")
+ }
+ if tt.setting != nil && !tt.setting.Equal(ptr.To(*tt.setting)) {
+ t.Errorf("the setting should be equal to its shallow copy")
+ }
+ if gotKey := tt.setting.Key(); gotKey != tt.wantKey {
+ t.Errorf("Key: got %q, want %q", gotKey, tt.wantKey)
+ }
+ if gotScope := tt.setting.Scope(); gotScope != tt.wantScope {
+ t.Errorf("Scope: got %v, want %v", gotScope, tt.wantScope)
+ }
+ if gotType := tt.setting.Type(); gotType != tt.wantType {
+ t.Errorf("Type: got %v, want %v", gotType, tt.wantType)
+ }
+ if gotIsSupported := tt.setting.IsSupported(); gotIsSupported != tt.wantIsSupported {
+ t.Errorf("IsSupported: got %v, want %v", gotIsSupported, tt.wantIsSupported)
+ }
+ if gotSupportedPlatforms := tt.setting.SupportedPlatforms(); !slices.Equal(gotSupportedPlatforms, tt.wantSupportedPlatforms) {
+ t.Errorf("SupportedPlatforms: got %v, want %v", gotSupportedPlatforms, tt.wantSupportedPlatforms)
+ }
+ if gotString := tt.setting.String(); gotString != tt.wantString {
+ t.Errorf("String: got %v, want %v", gotString, tt.wantString)
+ }
+ })
+ }
+}
+
+func TestRegisterSettingDefinition(t *testing.T) {
+ const testPolicySettingKey Key = "TestPolicySetting"
+ tests := []struct {
+ name string
+ key Key
+ wantEq *Definition
+ wantErr error
+ }{
+ {
+ name: "GetRegistered",
+ key: "TestPolicySetting",
+ wantEq: NewDefinition(testPolicySettingKey, DeviceSetting, StringValue),
+ },
+ {
+ name: "GetNonRegistered",
+ key: "OtherPolicySetting",
+ wantEq: nil,
+ wantErr: ErrNoSuchKey,
+ },
+ }
+
+ resetSettingDefinitions(t)
+ Register(testPolicySettingKey, DeviceSetting, StringValue)
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ got, gotErr := DefinitionOf(tt.key)
+ if gotErr != tt.wantErr {
+ t.Errorf("gotErr %v, wantErr %v", gotErr, tt.wantErr)
+ }
+ if !got.Equal(tt.wantEq) {
+ t.Errorf("got %v, want %v", got, tt.wantEq)
+ }
+ })
+ }
+}
+
+func TestRegisterAfterUsePanics(t *testing.T) {
+ resetSettingDefinitions(t)
+
+ Register("TestPolicySetting", DeviceSetting, StringValue)
+ DefinitionOf("TestPolicySetting")
+
+ func() {
+ defer func() {
+ if gotPanic, wantPanic := recover(), "policy definitions are already in use"; gotPanic != wantPanic {
+ t.Errorf("gotPanic: %q, wantPanic: %q", gotPanic, wantPanic)
+ }
+ }()
+
+ Register("TestPolicySetting", DeviceSetting, StringValue)
+ }()
+}
+
+func TestRegisterDuplicateSettings(t *testing.T) {
+
+ tests := []struct {
+ name string
+ settings []*Definition
+ wantEq *Definition
+ wantErrStr string
+ }{
+ {
+ name: "NoConflict/Exact",
+ settings: []*Definition{
+ NewDefinition("TestPolicySetting", DeviceSetting, StringValue),
+ NewDefinition("TestPolicySetting", DeviceSetting, StringValue),
+ },
+ wantEq: NewDefinition("TestPolicySetting", DeviceSetting, StringValue),
+ },
+ {
+ name: "NoConflict/MergeOS-First",
+ settings: []*Definition{
+ NewDefinition("TestPolicySetting", DeviceSetting, StringValue, "android", "macos"),
+ NewDefinition("TestPolicySetting", DeviceSetting, StringValue), // all platforms
+ },
+ wantEq: NewDefinition("TestPolicySetting", DeviceSetting, StringValue), // all platforms
+ },
+ {
+ name: "NoConflict/MergeOS-Second",
+ settings: []*Definition{
+ NewDefinition("TestPolicySetting", DeviceSetting, StringValue), // all platforms
+ NewDefinition("TestPolicySetting", DeviceSetting, StringValue, "android", "macos"),
+ },
+ wantEq: NewDefinition("TestPolicySetting", DeviceSetting, StringValue), // all platforms
+ },
+ {
+ name: "NoConflict/MergeOS-Both",
+ settings: []*Definition{
+ NewDefinition("TestPolicySetting", DeviceSetting, StringValue, "macos"),
+ NewDefinition("TestPolicySetting", DeviceSetting, StringValue, "windows"),
+ },
+ wantEq: NewDefinition("TestPolicySetting", DeviceSetting, StringValue, "macos", "windows"),
+ },
+ {
+ name: "Conflict/Scope",
+ settings: []*Definition{
+ NewDefinition("TestPolicySetting", DeviceSetting, StringValue),
+ NewDefinition("TestPolicySetting", UserSetting, StringValue),
+ },
+ wantEq: nil,
+ wantErrStr: `duplicate policy definition: "TestPolicySetting"`,
+ },
+ {
+ name: "Conflict/Type",
+ settings: []*Definition{
+ NewDefinition("TestPolicySetting", UserSetting, StringValue),
+ NewDefinition("TestPolicySetting", UserSetting, IntegerValue),
+ },
+ wantEq: nil,
+ wantErrStr: `duplicate policy definition: "TestPolicySetting"`,
+ },
+ }
+
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ resetSettingDefinitions(t)
+ for _, s := range tt.settings {
+ Register(s.Key(), s.Scope(), s.Type(), s.SupportedPlatforms()...)
+ }
+ got, err := DefinitionOf("TestPolicySetting")
+ var gotErrStr string
+ if err != nil {
+ gotErrStr = err.Error()
+ }
+ if gotErrStr != tt.wantErrStr {
+ t.Fatalf("ErrStr: got %q, want %q", gotErrStr, tt.wantErrStr)
+ }
+ if !got.Equal(tt.wantEq) {
+ t.Errorf("Definition got %v, want %v", got, tt.wantEq)
+ }
+ if !slices.Equal(got.SupportedPlatforms(), tt.wantEq.SupportedPlatforms()) {
+ t.Errorf("SupportedPlatforms got %v, want %v", got.SupportedPlatforms(), tt.wantEq.SupportedPlatforms())
+ }
+ })
+ }
+}
+
+func TestListSettingDefinitions(t *testing.T) {
+ definitions := []*Definition{
+ NewDefinition("TestDevicePolicySetting", DeviceSetting, IntegerValue),
+ NewDefinition("TestProfilePolicySetting", ProfileSetting, StringValue),
+ NewDefinition("TestUserPolicySetting", UserSetting, BooleanValue),
+ NewDefinition("TestStringListPolicySetting", DeviceSetting, StringListValue),
+ }
+ if err := SetDefinitionsForTest(t, definitions...); err != nil {
+ t.Fatalf("SetDefinitionsForTest failed: %v", err)
+ }
+
+ cmp := func(l, r *Definition) int {
+ return strings.Compare(string(l.Key()), string(r.Key()))
+ }
+ want := append([]*Definition{}, definitions...)
+ slices.SortFunc(want, cmp)
+
+ got, err := Definitions()
+ if err != nil {
+ t.Fatalf("Definitions failed: %v", err)
+ }
+ slices.SortFunc(got, cmp)
+
+ if !slices.Equal(got, want) {
+ t.Errorf("got %v, want %v", got, want)
+ }
+}
+
+func resetSettingDefinitions(t *testing.T) {
+ t.Cleanup(func() {
+ definitionsMu.Lock()
+ definitionsList = nil
+ definitions = lazy.SyncValue[DefinitionMap]{}
+ definitionsUsed = false
+ definitionsMu.Unlock()
+ })
+
+ definitionsMu.Lock()
+ definitionsList = nil
+ definitions = lazy.SyncValue[DefinitionMap]{}
+ definitionsUsed = false
+ definitionsMu.Unlock()
+}
diff --git a/util/syspolicy/setting/snapshot.go b/util/syspolicy/setting/snapshot.go
new file mode 100644
index 000000000..4f4934a72
--- /dev/null
+++ b/util/syspolicy/setting/snapshot.go
@@ -0,0 +1,153 @@
+// Copyright (c) Tailscale Inc & AUTHORS
+// SPDX-License-Identifier: BSD-3-Clause
+
+package setting
+
+import (
+ xmaps "golang.org/x/exp/maps"
+ "tailscale.com/util/deephash"
+)
+
+// Snapshot is an immutable collection of [RawItem]s, representing
+// a set of policy settings applied at a specific moment in time.
+// A nil pointer to [Snapshot] is valid.
+type Snapshot struct {
+ m map[Key]RawItem
+ sig deephash.Sum // of m
+ summary Summary
+}
+
+// NewSnapshot returns a new [Snapshot] with the specified items and options.
+func NewSnapshot(items map[Key]RawItem, opts ...SummaryOption) *Snapshot {
+ return &Snapshot{m: items, sig: deephash.Hash(&items), summary: SummaryWith(opts...)}
+}
+
+type keyItemPair struct {
+ Key Key
+ Item RawItem
+}
+
+// All returns an iterator over [[Key], [RawItem]] key-value pairs in b. The
+// iteration order is not specified and is not guaranteed to be the same from
+// one call to the next.
+func (s *Snapshot) All() []keyItemPair {
+ if s == nil {
+ return nil
+ }
+ // TODO(nickkhyl): return iter.Seq2[[Key], [RawItem]] in Go 1.23,
+ // and remove [keyItemPair].
+ items := make([]keyItemPair, 0, len(s.m))
+ for k, i := range s.m {
+ items = append(items, keyItemPair{k, i})
+ }
+ return items
+}
+
+// Get returns the value of the policy setting with the specified key
+// or nil if it does not exist or could not be read.
+func (s *Snapshot) Get(k Key) any {
+ v, _ := s.GetErr(k)
+ return v
+}
+
+// GetErr returns the value of the policy setting with the specified key,
+// [ErrNotConfigured] if it does not exist, or an error returned by
+// the policy Store if the policy setting could not be read.
+func (s *Snapshot) GetErr(k Key) (any, error) {
+ if s != nil {
+ if s, ok := s.m[k]; ok {
+ return s.Value(), s.Error()
+ }
+ }
+ return nil, ErrNotConfigured
+}
+
+// GetSetting returns the untyped policy setting with the specified key and true
+// if a policy setting with such key has been configured;
+// otherwise, it returns zero, false.
+func (s *Snapshot) GetSetting(k Key) (setting RawItem, ok bool) {
+ setting, ok = s.m[k]
+ return setting, ok
+}
+
+// Equal reports whether s and s2 are equal.
+func (s *Snapshot) Equal(s2 *Snapshot) bool {
+ if !s.EqualItems(s2) {
+ return false
+ }
+ return s.Summary() == s2.Summary()
+}
+
+// EqualItems reports whether items in s and s2 are equal.
+func (s *Snapshot) EqualItems(s2 *Snapshot) bool {
+ if s == s2 {
+ return true
+ }
+ if s.Len() != s2.Len() {
+ return false
+ }
+ if s.Len() == 0 {
+ return true
+ }
+ return s.sig == s2.sig
+}
+
+// Keys return an iterator over keys in s. The iteration order is not specified
+// and is not guaranteed to be the same from one call to the next.
+func (s *Snapshot) Keys() []Key {
+ if s.m == nil {
+ return nil
+ }
+ // TODO(nickkhyl): return iter.Seq[Key] in Go 1.23.
+ return xmaps.Keys(s.m)
+}
+
+// Len reports the number of [RawItem]s in s.
+func (s *Snapshot) Len() int {
+ if s == nil {
+ return 0
+ }
+ return len(s.m)
+}
+
+// Summary returns information about s as a whole rather than about specific [RawItem]s in it.
+func (s *Snapshot) Summary() Summary {
+ if s == nil {
+ return Summary{}
+ }
+ return s.summary
+}
+
+// MergeSnapshots returns a [Snapshot] that contains all [RawItem]s
+// from snapshot1 and snapshot2 and the [Summary] with the narrower [PolicyScope].
+// If there's a conflict between policy settings in the two snapshots,
+// the policy settings from the snapshot with the broader scope take precedence.
+// In other words, policy settings configured for the [DeviceScope] win
+// over policy settings configured for a user scope.
+func MergeSnapshots(snapshot1, snapshot2 *Snapshot) *Snapshot {
+ scope1, ok1 := snapshot1.Summary().Scope().GetOk()
+ scope2, ok2 := snapshot2.Summary().Scope().GetOk()
+ if ok1 && ok2 && scope2.IsStrictlyWithinOf(scope1) {
+ // Swap snapshots if snapshot1 has higher precedence than snapshot2.
+ snapshot1, snapshot2 = snapshot2, snapshot1
+ }
+ if snapshot2.Len() == 0 {
+ return snapshot1
+ }
+ summaryOpts := make([]SummaryOption, 0, 2)
+ if scope, ok := snapshot1.Summary().Scope().GetOk(); ok {
+ // Use the scope from snapshot1, if present, which is the more specific snapshot.
+ summaryOpts = append(summaryOpts, scope)
+ }
+ if snapshot1.Len() == 0 {
+ if origin, ok := snapshot2.Summary().Origin().GetOk(); ok {
+ // Use the origin from snapshot2 if snapshot1 is empty.
+ summaryOpts = append(summaryOpts, origin)
+ }
+ return &Snapshot{snapshot2.m, snapshot2.sig, SummaryWith(summaryOpts...)}
+ }
+ m := make(map[Key]RawItem, snapshot1.Len()+snapshot2.Len())
+ xmaps.Copy(m, snapshot1.m)
+ xmaps.Copy(m, snapshot2.m) // snapshot2 has higher precedence
+ return &Snapshot{m, deephash.Hash(&m), SummaryWith(summaryOpts...)}
+}
diff --git a/util/syspolicy/setting/snapshot_test.go b/util/syspolicy/setting/snapshot_test.go
new file mode 100644
index 000000000..378fa6033
--- /dev/null
+++ b/util/syspolicy/setting/snapshot_test.go
@@ -0,0 +1,372 @@
+// Copyright (c) Tailscale Inc & AUTHORS
+// SPDX-License-Identifier: BSD-3-Clause
+
+package setting
+
+import (
+ "testing"
+ "time"
+)
+
+func TestMergeSnapshots(t *testing.T) {
+ tests := []struct {
+ name string
+ s1, s2 *Snapshot
+ want *Snapshot
+ }{
+ {
+ name: "both-nil",
+ s1: nil,
+ s2: nil,
+ want: NewSnapshot(map[Key]RawItem{}),
+ },
+ {
+ name: "both-empty",
+ s1: NewSnapshot(map[Key]RawItem{}),
+ s2: NewSnapshot(map[Key]RawItem{}),
+ want: NewSnapshot(map[Key]RawItem{}),
+ },
+ {
+ name: "first-nil",
+ s1: nil,
+ s2: NewSnapshot(map[Key]RawItem{
+ "Setting1": {value: 123},
+ "Setting2": {value: "String"},
+ "Setting3": {value: true},
+ }),
+ want: NewSnapshot(map[Key]RawItem{
+ "Setting1": {value: 123},
+ "Setting2": {value: "String"},
+ "Setting3": {value: true},
+ }),
+ },
+ {
+ name: "first-empty",
+ s1: NewSnapshot(map[Key]RawItem{}),
+ s2: NewSnapshot(map[Key]RawItem{
+ "Setting1": {value: 123},
+ "Setting2": {value: "String"},
+ "Setting3": {value: false},
+ }),
+ want: NewSnapshot(map[Key]RawItem{
+ "Setting1": {value: 123},
+ "Setting2": {value: "String"},
+ "Setting3": {value: false},
+ }),
+ },
+ {
+ name: "second-nil",
+ s1: NewSnapshot(map[Key]RawItem{
+ "Setting1": {value: 123},
+ "Setting2": {value: "String"},
+ "Setting3": {value: true},
+ }),
+ s2: nil,
+ want: NewSnapshot(map[Key]RawItem{
+ "Setting1": {value: 123},
+ "Setting2": {value: "String"},
+ "Setting3": {value: true},
+ }),
+ },
+ {
+ name: "second-empty",
+ s1: NewSnapshot(map[Key]RawItem{
+ "Setting1": {value: 123},
+ "Setting2": {value: "String"},
+ "Setting3": {value: false},
+ }),
+ s2: NewSnapshot(map[Key]RawItem{}),
+ want: NewSnapshot(map[Key]RawItem{
+ "Setting1": {value: 123},
+ "Setting2": {value: "String"},
+ "Setting3": {value: false},
+ }),
+ },
+ {
+ name: "no-conflicts",
+ s1: NewSnapshot(map[Key]RawItem{
+ "Setting1": {value: 123},
+ "Setting2": {value: "String"},
+ "Setting3": {value: false},
+ }),
+ s2: NewSnapshot(map[Key]RawItem{
+ "Setting4": {value: 2 * time.Hour},
+ "Setting5": {value: VisibleByPolicy},
+ "Setting6": {value: ShowChoiceByPolicy},
+ }),
+ want: NewSnapshot(map[Key]RawItem{
+ "Setting1": {value: 123},
+ "Setting2": {value: "String"},
+ "Setting3": {value: false},
+ "Setting4": {value: 2 * time.Hour},
+ "Setting5": {value: VisibleByPolicy},
+ "Setting6": {value: ShowChoiceByPolicy},
+ }),
+ },
+ {
+ name: "with-conflicts",
+ s1: NewSnapshot(map[Key]RawItem{
+ "Setting1": {value: 123},
+ "Setting2": {value: "String"},
+ "Setting3": {value: true},
+ }),
+ s2: NewSnapshot(map[Key]RawItem{
+ "Setting1": {value: 456},
+ "Setting3": {value: false},
+ "Setting4": {value: 2 * time.Hour},
+ }),
+ want: NewSnapshot(map[Key]RawItem{
+ "Setting1": {value: 456},
+ "Setting2": {value: "String"},
+ "Setting3": {value: false},
+ "Setting4": {value: 2 * time.Hour},
+ }),
+ },
+ {
+ name: "with-scope-first-wins",
+ s1: NewSnapshot(map[Key]RawItem{
+ "Setting1": {value: 123},
+ "Setting2": {value: "String"},
+ "Setting3": {value: true},
+ }, DeviceScope),
+ s2: NewSnapshot(map[Key]RawItem{
+ "Setting1": {value: 456},
+ "Setting3": {value: false},
+ "Setting4": {value: 2 * time.Hour},
+ }, CurrentUserScope),
+ want: NewSnapshot(map[Key]RawItem{
+ "Setting1": {value: 123},
+ "Setting2": {value: "String"},
+ "Setting3": {value: true},
+ "Setting4": {value: 2 * time.Hour},
+ }, CurrentUserScope),
+ },
+ {
+ name: "with-scope-second-wins",
+ s1: NewSnapshot(map[Key]RawItem{
+ "Setting1": {value: 123},
+ "Setting2": {value: "String"},
+ "Setting3": {value: true},
+ }, CurrentUserScope),
+ s2: NewSnapshot(map[Key]RawItem{
+ "Setting1": {value: 456},
+ "Setting3": {value: false},
+ "Setting4": {value: 2 * time.Hour},
+ }, DeviceScope),
+ want: NewSnapshot(map[Key]RawItem{
+ "Setting1": {value: 456},
+ "Setting2": {value: "String"},
+ "Setting3": {value: false},
+ "Setting4": {value: 2 * time.Hour},
+ }, CurrentUserScope),
+ },
+ {
+ name: "with-scope-both-empty",
+ s1: NewSnapshot(map[Key]RawItem{}, CurrentUserScope),
+ s2: NewSnapshot(map[Key]RawItem{}, DeviceScope),
+ want: NewSnapshot(map[Key]RawItem{}, CurrentUserScope),
+ },
+ {
+ name: "with-scope-first-empty",
+ s1: NewSnapshot(map[Key]RawItem{}, CurrentUserScope),
+ s2: NewSnapshot(map[Key]RawItem{
+ "Setting1": {value: 123},
+ "Setting2": {value: "String"},
+ "Setting3": {value: true}}, DeviceScope),
+ want: NewSnapshot(map[Key]RawItem{
+ "Setting1": {value: 123},
+ "Setting2": {value: "String"},
+ "Setting3": {value: true},
+ }, CurrentUserScope),
+ },
+ {
+ name: "with-scope-second-empty",
+ s1: NewSnapshot(map[Key]RawItem{
+ "Setting1": {value: 123},
+ "Setting2": {value: "String"},
+ "Setting3": {value: true},
+ }, CurrentUserScope),
+ s2: NewSnapshot(map[Key]RawItem{}, DeviceScope),
+ want: NewSnapshot(map[Key]RawItem{
+ "Setting1": {value: 123},
+ "Setting2": {value: "String"},
+ "Setting3": {value: true},
+ }, CurrentUserScope),
+ },
+ }
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ got := MergeSnapshots(tt.s1, tt.s2)
+ if !got.Equal(tt.want) {
+ t.Errorf("got %v, want %v", got, tt.want)
+ }
+ })
+ }
+}
+
+func TestSnapshotEqual(t *testing.T) {
+ tests := []struct {
+ name string
+ b1, b2 *Snapshot
+ wantEqual bool
+ wantEqualItems bool
+ }{
+ {
+ name: "nil-nil",
+ b1: nil,
+ b2: nil,
+ wantEqual: true,
+ wantEqualItems: true,
+ },
+ {
+ name: "nil-empty",
+ b1: nil,
+ b2: NewSnapshot(map[Key]RawItem{}),
+ wantEqual: true,
+ wantEqualItems: true,
+ },
+ {
+ name: "empty-nil",
+ b1: NewSnapshot(map[Key]RawItem{}),
+ b2: nil,
+ wantEqual: true,
+ wantEqualItems: true,
+ },
+ {
+ name: "empty-empty",
+ b1: NewSnapshot(map[Key]RawItem{}),
+ b2: NewSnapshot(map[Key]RawItem{}),
+ wantEqual: true,
+ wantEqualItems: true,
+ },
+ {
+ name: "first-nil",
+ b1: nil,
+ b2: NewSnapshot(map[Key]RawItem{
+ "Setting1": {value: 123},
+ "Setting2": {value: "String"},
+ "Setting3": {value: false},
+ }),
+ wantEqual: false,
+ wantEqualItems: false,
+ },
+ {
+ name: "first-empty",
+ b1: NewSnapshot(map[Key]RawItem{}),
+ b2: NewSnapshot(map[Key]RawItem{
+ "Setting1": {value: 123},
+ "Setting2": {value: "String"},
+ "Setting3": {value: false},
+ }),
+ wantEqual: false,
+ wantEqualItems: false,
+ },
+ {
+ name: "second-nil",
+ b1: NewSnapshot(map[Key]RawItem{
+ "Setting1": {value: 123},
+ "Setting2": {value: "String"},
+ "Setting3": {value: true},
+ }),
+ b2: nil,
+ wantEqual: false,
+ wantEqualItems: false,
+ },
+ {
+ name: "second-empty",
+ b1: NewSnapshot(map[Key]RawItem{
+ "Setting1": {value: 123},
+ "Setting2": {value: "String"},
+ "Setting3": {value: false},
+ }),
+ b2: NewSnapshot(map[Key]RawItem{}),
+ wantEqual: false,
+ wantEqualItems: false,
+ },
+ {
+ name: "same-items-same-order-no-scope",
+ b1: NewSnapshot(map[Key]RawItem{
+ "Setting1": {value: 123},
+ "Setting2": {value: "String"},
+ "Setting3": {value: false},
+ }),
+ b2: NewSnapshot(map[Key]RawItem{
+ "Setting1": {value: 123},
+ "Setting2": {value: "String"},
+ "Setting3": {value: false},
+ }),
+ wantEqual: true,
+ wantEqualItems: true,
+ },
+ {
+ name: "same-items-same-order-same-scope",
+ b1: NewSnapshot(map[Key]RawItem{
+ "Setting1": {value: 123},
+ "Setting2": {value: "String"},
+ "Setting3": {value: false},
+ }, DeviceScope),
+ b2: NewSnapshot(map[Key]RawItem{
+ "Setting1": {value: 123},
+ "Setting2": {value: "String"},
+ "Setting3": {value: false},
+ }, DeviceScope),
+ wantEqual: true,
+ wantEqualItems: true,
+ },
+ {
+ name: "same-items-different-order-same-scope",
+ b1: NewSnapshot(map[Key]RawItem{
+ "Setting1": {value: 123},
+ "Setting2": {value: "String"},
+ "Setting3": {value: false},
+ }, DeviceScope),
+ b2: NewSnapshot(map[Key]RawItem{
+ "Setting3": {value: false},
+ "Setting1": {value: 123},
+ "Setting2": {value: "String"},
+ }, DeviceScope),
+ wantEqual: true,
+ wantEqualItems: true,
+ },
+ {
+ name: "same-items-same-order-different-scope",
+ b1: NewSnapshot(map[Key]RawItem{
+ "Setting1": {value: 123},
+ "Setting2": {value: "String"},
+ "Setting3": {value: false},
+ }, DeviceScope),
+ b2: NewSnapshot(map[Key]RawItem{
+ "Setting1": {value: 123},
+ "Setting2": {value: "String"},
+ "Setting3": {value: false},
+ }, CurrentUserScope),
+ wantEqual: false,
+ wantEqualItems: true,
+ },
+ {
+ name: "different-items-same-scope",
+ b1: NewSnapshot(map[Key]RawItem{
+ "Setting1": {value: 123},
+ "Setting2": {value: "String"},
+ "Setting3": {value: false},
+ }, DeviceScope),
+ b2: NewSnapshot(map[Key]RawItem{
+ "Setting4": {value: 2 * time.Hour},
+ "Setting5": {value: VisibleByPolicy},
+ "Setting6": {value: ShowChoiceByPolicy},
+ }, DeviceScope),
+ wantEqual: false,
+ wantEqualItems: false,
+ },
+ }
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ if gotEqual := tt.b1.Equal(tt.b2); gotEqual != tt.wantEqual {
+ t.Errorf("WantEqual: got %v, want %v", gotEqual, tt.wantEqual)
+ }
+ if gotEqualItems := tt.b1.EqualItems(tt.b2); gotEqualItems != tt.wantEqualItems {
+ t.Errorf("WantEqualItems: got %v, want %v", gotEqualItems, tt.wantEqualItems)
+ }
+ })
+ }
+}
diff --git a/util/syspolicy/setting/summary.go b/util/syspolicy/setting/summary.go
new file mode 100644
index 000000000..5855b22e3
--- /dev/null
+++ b/util/syspolicy/setting/summary.go
@@ -0,0 +1,84 @@
+// Copyright (c) Tailscale Inc & AUTHORS
+// SPDX-License-Identifier: BSD-3-Clause
+
+package setting
+
+import (
+ jsonv2 "github.com/go-json-experiment/json"
+ "github.com/go-json-experiment/json/jsontext"
+ "tailscale.com/types/opt"
+)
+
+// Summary is an immutable [PolicyScope] and [Origin].
+type Summary struct {
+ data summary
+}
+
+type summary struct {
+ Scope opt.Value[PolicyScope] `json:",omitzero"`
+ Origin opt.Value[Origin] `json:",omitzero"`
+}
+
+// SummaryWith returns a [Summary] with the specified options.
+func SummaryWith(opts ...SummaryOption) Summary {
+ var summary Summary
+ for _, o := range opts {
+ o.applySummaryOption(&summary)
+ }
+ return summary
+}
+
+// Scope reports the [PolicyScope] in s.
+func (s Summary) Scope() opt.Value[PolicyScope] {
+ return s.data.Scope
+}
+
+// Origin reports the [Origin] in s.
+func (s Summary) Origin() opt.Value[Origin] {
+ return s.data.Origin
+}
+
+// MarshalJSONV2 implements [jsonv2.MarshalerV2].
+func (s Summary) MarshalJSONV2(out *jsontext.Encoder, opts jsonv2.Options) error {
+ return jsonv2.MarshalEncode(out, &s.data, opts)
+}
+
+// UnmarshalJSONV2 implements [jsonv2.UnmarshalerV2].
+func (s *Summary) UnmarshalJSONV2(in *jsontext.Decoder, opts jsonv2.Options) error {
+ return jsonv2.UnmarshalDecode(in, &s.data, opts)
+}
+
+// MarshalJSON implements [json.Marshaler].
+func (s Summary) MarshalJSON() ([]byte, error) {
+ return jsonv2.Marshal(s) // uses MarshalJSONV2
+}
+
+// UnmarshalJSON implements [json.Unmarshaler].
+func (s *Summary) UnmarshalJSON(b []byte) error {
+ return jsonv2.Unmarshal(b, s) // uses UnmarshalJSONV2
+}
+
+// SummaryOption is an option that configures [Summary]
+// The following are allowed options:
+//
+// - [Summary]
+// - [PolicyScope]
+// - [Origin]
+type SummaryOption interface {
+ applySummaryOption(summary *Summary)
+}
+
+func (s PolicyScope) applySummaryOption(summary *Summary) {
+ summary.data.Scope.Set(s)
+}
+
+func (o Origin) applySummaryOption(summary *Summary) {
+ summary.data.Origin.Set(o)
+ if !summary.data.Scope.IsSet() {
+ summary.data.Scope.Set(o.Scope())
+ }
+}
+
+func (s Summary) applySummaryOption(summary *Summary) {
+ *summary = s
+}
diff --git a/util/syspolicy/setting/types.go b/util/syspolicy/setting/types.go
new file mode 100644
index 000000000..16f9e7445
--- /dev/null
+++ b/util/syspolicy/setting/types.go
@@ -0,0 +1,132 @@
+// Copyright (c) Tailscale Inc & AUTHORS
+// SPDX-License-Identifier: BSD-3-Clause
+
+package setting
+
+import (
+ "encoding"
+)
+
+// PreferenceOption is a policy that governs whether a boolean variable
+// is forcibly assigned an administrator-defined value, or allowed to receive
+// a user-defined value.
+type PreferenceOption int
+
+const (
+ ShowChoiceByPolicy PreferenceOption = iota
+ NeverByPolicy
+ AlwaysByPolicy
+)
+
+// Show returns if the UI option that controls the choice administered by this
+// policy should be shown. Currently this is true if and only if the policy is
+// [ShowChoiceByPolicy].
+func (p PreferenceOption) Show() bool {
+ return p == ShowChoiceByPolicy
+}
+
+// ShouldEnable checks if the choice administered by this policy should be
+// enabled. If the administrator has chosen a setting, the administrator's
+// setting is returned, otherwise userChoice is returned.
+func (p PreferenceOption) ShouldEnable(userChoice bool) bool {
+ switch p {
+ case NeverByPolicy:
+ return false
+ case AlwaysByPolicy:
+ return true
+ default:
+ return userChoice
+ }
+}
+
+// IsAlways reports whether the preference should always be enabled.
+func (p PreferenceOption) IsAlways() bool {
+ return p == AlwaysByPolicy
+}
+
+// IsNever reports whether the preference should always be disabled.
+func (p PreferenceOption) IsNever() bool {
+ return p == NeverByPolicy
+}
+
+// WillOverride checks if the choice administered by the policy is different
+// from the user's choice.
+func (p PreferenceOption) WillOverride(userChoice bool) bool {
+ return p.ShouldEnable(userChoice) != userChoice
+}
+
+// String returns a string representation of p.
+func (p PreferenceOption) String() string {
+ switch p {
+ case AlwaysByPolicy:
+ return "always"
+ case NeverByPolicy:
+ return "never"
+ default:
+ return "user-decides"
+ }
+}
+
+// MarshalText implements [encoding.TextMarshaler].
+func (p *PreferenceOption) MarshalText() (text []byte, err error) {
+ return []byte(p.String()), nil
+}
+
+// UnmarshalText implements [encoding.TextUnmarshaler].
+func (p *PreferenceOption) UnmarshalText(text []byte) error {
+ switch string(text) {
+ case "always":
+ *p = AlwaysByPolicy
+ case "never":
+ *p = NeverByPolicy
+ default:
+ *p = ShowChoiceByPolicy
+ }
+ return nil
+}
+
+// Visibility is a policy that controls whether or not a particular
+// component of a user interface is to be shown.
+type Visibility byte
+
+var (
+ _ encoding.TextMarshaler = (*Visibility)(nil)
+ _ encoding.TextUnmarshaler = (*Visibility)(nil)
+)
+
+const (
+ VisibleByPolicy Visibility = 'v'
+ HiddenByPolicy Visibility = 'h'
+)
+
+// Show reports whether the UI option administered by this policy should be shown.
+// Currently this is true if the policy is not [hiddenByPolicy].
+func (p Visibility) Show() bool {
+ return p != HiddenByPolicy
+}
+
+// String returns a string representation of p.
+func (p Visibility) String() string {
+ switch p {
+ case 'h':
+ return "hide"
+ default:
+ return "show"
+ }
+}
+
+// MarshalText implements [encoding.TextMarshaler].
+func (p Visibility) MarshalText() (text []byte, err error) {
+ return []byte(p.String()), nil
+}
+
+// UnmarshalText implements [encoding.TextUnmarshaler].
+func (p *Visibility) UnmarshalText(text []byte) error {
+ switch string(text) {
+ case "hide":
+ *p = HiddenByPolicy
+ default:
+ *p = VisibleByPolicy
+ }
+ return nil
+}
diff --git a/util/syspolicy/source/policy_reader.go b/util/syspolicy/source/policy_reader.go
new file mode 100644
index 000000000..e608fd0da
--- /dev/null
+++ b/util/syspolicy/source/policy_reader.go
@@ -0,0 +1,393 @@
+// Copyright (c) Tailscale Inc & AUTHORS
+// SPDX-License-Identifier: BSD-3-Clause
+
+package source
+
+import (
+ "errors"
+ "fmt"
+ "io"
+ "slices"
+ "sort"
+ "sync"
+ "time"
+
+ "tailscale.com/util/mak"
+ "tailscale.com/util/set"
+ "tailscale.com/util/syspolicy/internal/loggerx"
+ "tailscale.com/util/syspolicy/internal/metrics"
+ "tailscale.com/util/syspolicy/setting"
+)
+
+// Reader reads all configured policy settings from a given [Store].
+// It registers a change callback with the [Store] and maintains the current version
+// of the [setting.Snapshot] by lazily re-reading policy settings from the [Store]
+// whenever a new snapshot is requested
+// It is safe for concurrent use.
+type Reader struct {
+ store Store
+ origin *setting.Origin
+ settings []*setting.Definition
+ unregisterChangeNotifier func()
+ doneCh chan struct{} // closed when policyCache is closed.
+
+ mu sync.RWMutex
+ closing bool
+ upToDate bool
+ lastPolicy *setting.Snapshot
+ sessions set.HandleSet[*ReadingSession]
+}
+
+// newReader returns a new [Reader] that reads policy settings from a given [Store].
+// The returned reader takes ownership of the store. If the store implements [io.Closer],
+// the returned reader will close the store when it is closed.
+func newReader(store Store, origin *setting.Origin) (*Reader, error) {
+ settings, err := setting.Definitions()
+ if err != nil {
+ return nil, err
+ }
+
+ if expirable, ok := store.(Expirable); ok {
+ select {
+ case <-expirable.Done():
+ return nil, ErrStoreClosed
+ default:
+ }
+ }
+
+ reader := &Reader{store: store, origin: origin, settings: settings, doneCh: make(chan struct{})}
+ if changeable, ok := store.(Changeable); ok {
+ // We should subscribe to policy change notifications first before reading
+ // the policy settings from the store. This way we won't miss any notifications.
+ if reader.unregisterChangeNotifier, err = changeable.RegisterChangeCallback(reader.onPolicyChange); err != nil {
+ // Errors registering policy change callbacks are non-fatal.
+ // TODO(nickkhyl): implement a background policy refresh every X minutes?
+ loggerx.Errorf("failed to register %v policy change callback: %v\n", origin, err)
+ }
+ }
+
+ if _, err := reader.reload(true); err != nil {
+ if reader.unregisterChangeNotifier != nil {
+ reader.unregisterChangeNotifier()
+ }
+ return nil, err
+ }
+
+ if expirable, ok := store.(Expirable); ok {
+ if waitCh := expirable.Done(); waitCh != nil {
+ go func() {
+ select {
+ case <-waitCh:
+ reader.Close()
+ case <-reader.doneCh:
+ }
+ }()
+ }
+ }
+
+ return reader, nil
+}
+
+// GetSettings returns the current [*setting.Snapshot],
+// re-reading it from from the underlying [Store] only if the policy
+// has changed since it was read last. It never fails and returns
+// the previous version of the policy settings if a read attempt fails.
+func (r *Reader) GetSettings() *setting.Snapshot {
+ r.mu.RLock()
+ if r.upToDate {
+ r.mu.RUnlock()
+ return r.lastPolicy
+ }
+ r.mu.RUnlock()
+
+ policy, err := r.reload(false)
+ if err != nil {
+ // If the policy could not be reloaded at all, we'll return the last cached version of it.
+ // On the contrary, errors specific to individual policy items are always propagated to the callers.
+ loggerx.Errorf("failed to reload %v policy: %v\n", r.origin, err)
+ }
+ return policy
+}
+
+// ReadSettings reads policy settings from the underlying [Store] even if no
+// changes were detected. It returns the new [*setting.Snapshot], nil on
+// success, or nil, error in case of failure.
+func (r *Reader) ReadSettings() (*setting.Snapshot, error) {
+ b, err := r.reload(true)
+ if err != nil {
+ return nil, err
+ }
+ return b, nil
+}
+
+// reload is like [Reader.ReadSettings], but allows specifying whether to re-read
+// an unchanged policy, and returns the last [*setting.Snapshot] if the read fails.
+func (r *Reader) reload(force bool) (*setting.Snapshot, error) {
+ r.mu.Lock()
+ defer r.mu.Unlock()
+ if r.upToDate && !force {
+ return r.lastPolicy, nil
+ }
+
+ if lockable, ok := r.store.(Lockable); ok {
+ if err := lockable.Lock(); err != nil {
+ return r.lastPolicy, err
+ }
+ defer lockable.Unlock()
+ }
+
+ r.upToDate = true
+
+ metrics.Reset(r.origin)
+
+ var m map[setting.Key]setting.RawItem
+ if lastPolicyCount := r.lastPolicy.Len(); lastPolicyCount > 0 {
+ m = make(map[setting.Key]setting.RawItem, lastPolicyCount)
+ }
+ for _, s := range r.settings {
+ if !r.origin.Scope().IsConfigurableSetting(s) {
+ // Skip settings that cannot be configured in the current scope.
+ continue
+ }
+
+ val, err := readPolicySettingValue(r.store, s)
+ if err != nil && (errors.Is(err, setting.ErrNoSuchKey) || errors.Is(err, setting.ErrNotConfigured)) {
+ metrics.ReportNotConfigured(r.origin, s)
+ continue
+ }
+
+ if err == nil {
+ metrics.ReportConfigured(r.origin, s, val)
+ } else {
+ metrics.ReportError(r.origin, s, err)
+ }
+
+ // If there's an error reading a single policy, such as a value type mismatch,
+ // we'll wrap the error to preserve its text and return it
+ // whenever someone attempts to fetch the value.
+ mak.Set(&m, s.Key(), setting.RawItemWith(val, setting.WrapError(err), r.origin))
+ }
+
+ newPolicy := setting.NewSnapshot(m, setting.SummaryWith(r.origin))
+ if r.lastPolicy == nil || !newPolicy.EqualItems(r.lastPolicy) {
+ r.lastPolicy = newPolicy
+ }
+ return r.lastPolicy, nil
+}
+
+// ReadingSession is like [Reader], but with a channel that's written
+// to when there's a policy change, and closed when the session is terminated.
+type ReadingSession struct {
+ reader *Reader
+ policyChangedCh chan struct{} // 1-buffered channel
+ handle set.Handle // in the reader.sessions
+ closeInternal func()
+}
+
+// OpenSession opens and returns a new session to r, allowing the caller
+// to get notified whenever a policy change is reported by the [source.Store],
+// or an [ErrStoreClosed] if the reader has already been closed.
+func (r *Reader) OpenSession() (*ReadingSession, error) {
+ session := &ReadingSession{
+ reader: r,
+ policyChangedCh: make(chan struct{}, 1),
+ }
+ session.closeInternal = sync.OnceFunc(func() { close(session.policyChangedCh) })
+ r.mu.Lock()
+ if !r.closing {
+ session.handle = r.sessions.Add(session)
+ r.mu.Unlock()
+ return session, nil
+ }
+ r.mu.Unlock()
+ return nil, ErrStoreClosed
+}
+
+// GetSettings is like [Reader.GetSettings].
+func (s *ReadingSession) GetSettings() *setting.Snapshot {
+ return s.reader.GetSettings()
+}
+
+// ReadSettings is like [Reader.ReadSettings].
+func (s *ReadingSession) ReadSettings() (*setting.Snapshot, error) {
+ return s.reader.ReadSettings()
+}
+
+// PolicyChanged returns a channel that's written to when
+// there's a policy change, closed when the session is terminated.
+func (s *ReadingSession) PolicyChanged() <-chan struct{} {
+ return s.policyChangedCh
+}
+
+// Close unregisters this session with the [Reader].
+func (s *ReadingSession) Close() {
+ s.reader.mu.Lock()
+ delete(s.reader.sessions, s.handle)
+ s.closeInternal()
+ s.reader.mu.Unlock()
+}
+
+// onPolicyChange handles a policy change notification from the [Store],
+// invalidating the current [setting.Snapshot] in r,
+// and notifying the active [ReadingSession]s.
+func (r *Reader) onPolicyChange() {
+ r.mu.Lock()
+ defer r.mu.Unlock()
+ r.upToDate = false
+ for _, s := range r.sessions {
+ select {
+ case s.policyChangedCh <- struct{}{}:
+ // Notified.
+ default:
+ // 1-buffered channel is full, meaning that another policy change
+ // notification is already en route.
+ }
+ }
+}
+
+// Close closes the store reader and the underlying store.
+func (r *Reader) Close() error {
+ r.mu.Lock()
+ if r.closing {
+ r.mu.Unlock()
+ return nil
+ }
+ r.closing = true
+ r.mu.Unlock()
+
+ if r.unregisterChangeNotifier != nil {
+ r.unregisterChangeNotifier()
+ r.unregisterChangeNotifier = nil
+ }
+
+ if closer, ok := r.store.(io.Closer); ok {
+ if err := closer.Close(); err != nil {
+ return err
+ }
+ }
+ r.store = nil
+
+ close(r.doneCh)
+
+ r.mu.Lock()
+ defer r.mu.Unlock()
+ for _, c := range r.sessions {
+ c.closeInternal()
+ }
+ r.sessions = nil
+ return nil
+}
+
+// Done returns a channel that is closed when the reader is closed.
+func (r *Reader) Done() <-chan struct{} {
+ return r.doneCh
+}
+
+// ReadableSource is a [Source] open for reading.
+type ReadableSource struct {
+ *Source
+ *ReadingSession
+}
+
+// Close closes the underlying [ReadingSession].
+func (s ReadableSource) Close() {
+ s.ReadingSession.Close()
+}
+
+// ReadableSources is a slice of [ReadableSource].
+type ReadableSources []ReadableSource
+
+// Contains reports whether s contains the specified source.
+func (s ReadableSources) Contains(source *Source) bool {
+ return s.IndexOf(source) != -1
+}
+
+// IndexOf returns position of the specified source in s, or -1
+// if the source does not exist.
+func (s ReadableSources) IndexOf(source *Source) int {
+ return slices.IndexFunc(s, func(rs ReadableSource) bool {
+ return rs.Source == source
+ })
+}
+
+// InsertionIndexOf returns the position at which source can be inserted
+// to maintain the sorted order of the readableSources.
+// The return value is unspecified if s is not sorted on entry to InsertionIndexOf.
+func (s ReadableSources) InsertionIndexOf(source *Source) int {
+ low, high := 0, len(s)
+ for low < high {
+ mid := (low + high) / 2
+ if s[mid].Compare(source) <= 0 {
+ low = mid + 1
+ } else {
+ high = mid
+ }
+ }
+ return low
+}
+
+// StableSort sorts the readableSources by the precedence, so that policy settings
+// from sources with higher precedence (e.g., [DeviceScope]) will be merged last,
+// overriding any policy settings with the same keys configured in sources with
+// lower precedence (e.g., [CurrentUserScope]).
+func (s *ReadableSources) StableSort() {
+ sort.SliceStable(*s, func(i, j int) bool {
+ return (*s)[i].Source.Compare((*s)[j].Source) < 0
+ })
+}
+
+// DeleteAt closes and deletes the i-th source from s.
+func (s *ReadableSources) DeleteAt(i int) {
+ (*s)[i].Close()
+ *s = slices.Delete(*s, i, i+1)
+}
+
+// Close closes and deletes all sources in s.
+func (s *ReadableSources) Close() {
+ for _, s := range *s {
+ s.Close()
+ }
+ *s = nil
+}
+
+func readPolicySettingValue(store Store, s *setting.Definition) (value any, err error) {
+ switch key := s.Key(); s.Type() {
+ case setting.BooleanValue:
+ return store.ReadBoolean(key)
+ case setting.IntegerValue:
+ return store.ReadUInt64(key)
+ case setting.StringValue:
+ return store.ReadString(key)
+ case setting.StringListValue:
+ return store.ReadStringArray(key)
+ case setting.PreferenceOptionValue:
+ s, err := store.ReadString(key)
+ if err == nil {
+ var value setting.PreferenceOption
+ if err = value.UnmarshalText([]byte(s)); err == nil {
+ return value, nil
+ }
+ }
+ return setting.ShowChoiceByPolicy, err
+ case setting.VisibilityValue:
+ s, err := store.ReadString(key)
+ if err == nil {
+ var value setting.Visibility
+ if err = value.UnmarshalText([]byte(s)); err == nil {
+ return value, nil
+ }
+ }
+ return setting.VisibleByPolicy, err
+ case setting.DurationValue:
+ s, err := store.ReadString(key)
+ if err == nil {
+ var value time.Duration
+ if value, err = time.ParseDuration(s); err == nil {
+ return value, nil
+ }
+ }
+ return nil, err
+ default:
+ return nil, fmt.Errorf("%w: unsupported setting type: %v", setting.ErrTypeMismatch, s.Type())
+ }
+}
diff --git a/util/syspolicy/source/policy_reader_test.go b/util/syspolicy/source/policy_reader_test.go
new file mode 100644
index 000000000..f2d411d12
--- /dev/null
+++ b/util/syspolicy/source/policy_reader_test.go
@@ -0,0 +1,291 @@
+// Copyright (c) Tailscale Inc & AUTHORS
+// SPDX-License-Identifier: BSD-3-Clause
+
+package source
+
+import (
+ "cmp"
+ "testing"
+ "time"
+
+ "tailscale.com/util/must"
+ "tailscale.com/util/syspolicy/setting"
+)
+
+func TestReaderLifecycle(t *testing.T) {
+ tests := []struct {
+ name string
+ origin *setting.Origin
+ definitions []*setting.Definition
+ wantReads []TestExpectedReads
+ initStrings []TestSetting[string]
+ initUInt64s []TestSetting[uint64]
+ initWant *setting.Snapshot
+ addStrings []TestSetting[string]
+ addStringLists []TestSetting[[]string]
+ newWant *setting.Snapshot
+ }{
+ {
+ name: "read-all-settings-once",
+ origin: setting.NewNamedOrigin("Test", setting.DeviceScope),
+ definitions: []*setting.Definition{
+ setting.NewDefinition("StringValue", setting.DeviceSetting, setting.StringValue),
+ setting.NewDefinition("IntegerValue", setting.DeviceSetting, setting.IntegerValue),
+ setting.NewDefinition("BooleanValue", setting.DeviceSetting, setting.BooleanValue),
+ setting.NewDefinition("StringListValue", setting.DeviceSetting, setting.StringListValue),
+ setting.NewDefinition("DurationValue", setting.DeviceSetting, setting.DurationValue),
+ setting.NewDefinition("PreferenceOptionValue", setting.DeviceSetting, setting.PreferenceOptionValue),
+ setting.NewDefinition("VisibilityValue", setting.DeviceSetting, setting.VisibilityValue),
+ },
+ wantReads: []TestExpectedReads{
+ {Key: "StringValue", Type: setting.StringValue, NumTimes: 1},
+ {Key: "IntegerValue", Type: setting.IntegerValue, NumTimes: 1},
+ {Key: "BooleanValue", Type: setting.BooleanValue, NumTimes: 1},
+ {Key: "StringListValue", Type: setting.StringListValue, NumTimes: 1},
+ {Key: "DurationValue", Type: setting.StringValue, NumTimes: 1}, // duration is string from the [Store]'s perspective
+ {Key: "PreferenceOptionValue", Type: setting.StringValue, NumTimes: 1}, // and so are [setting.PreferenceOption]s
+ {Key: "VisibilityValue", Type: setting.StringValue, NumTimes: 1}, // and [setting.Visibility]
+ },
+ initWant: setting.NewSnapshot(nil, setting.NewNamedOrigin("Test", setting.DeviceScope)),
+ },
+ {
+ name: "re-read-all-settings-when-the-policy-changes",
+ origin: setting.NewNamedOrigin("Test", setting.DeviceScope),
+ definitions: []*setting.Definition{
+ setting.NewDefinition("StringValue", setting.DeviceSetting, setting.StringValue),
+ setting.NewDefinition("IntegerValue", setting.DeviceSetting, setting.IntegerValue),
+ setting.NewDefinition("BooleanValue", setting.DeviceSetting, setting.BooleanValue),
+ setting.NewDefinition("StringListValue", setting.DeviceSetting, setting.StringListValue),
+ setting.NewDefinition("DurationValue", setting.DeviceSetting, setting.DurationValue),
+ setting.NewDefinition("PreferenceOptionValue", setting.DeviceSetting, setting.PreferenceOptionValue),
+ setting.NewDefinition("VisibilityValue", setting.DeviceSetting, setting.VisibilityValue),
+ },
+ wantReads: []TestExpectedReads{
+ {Key: "StringValue", Type: setting.StringValue, NumTimes: 1},
+ {Key: "IntegerValue", Type: setting.IntegerValue, NumTimes: 1},
+ {Key: "BooleanValue", Type: setting.BooleanValue, NumTimes: 1},
+ {Key: "StringListValue", Type: setting.StringListValue, NumTimes: 1},
+ {Key: "DurationValue", Type: setting.StringValue, NumTimes: 1}, // duration is string from the [Store]'s perspective
+ {Key: "PreferenceOptionValue", Type: setting.StringValue, NumTimes: 1}, // and so are [setting.PreferenceOption]s
+ {Key: "VisibilityValue", Type: setting.StringValue, NumTimes: 1}, // and [setting.Visibility]
+ },
+ initWant: setting.NewSnapshot(nil, setting.NewNamedOrigin("Test", setting.DeviceScope)),
+ addStrings: []TestSetting[string]{TestSettingOf("StringValue", "S1")},
+ addStringLists: []TestSetting[[]string]{TestSettingOf("StringListValue", []string{"S1", "S2", "S3"})},
+ newWant: setting.NewSnapshot(map[setting.Key]setting.RawItem{
+ "StringValue": setting.RawItemWith("S1", nil, setting.NewNamedOrigin("Test", setting.DeviceScope)),
+ "StringListValue": setting.RawItemWith([]string{"S1", "S2", "S3"}, nil, setting.NewNamedOrigin("Test", setting.DeviceScope)),
+ }, setting.NewNamedOrigin("Test", setting.DeviceScope)),
+ },
+ {
+ name: "read-settings-if-in-scope/device",
+ origin: setting.NewNamedOrigin("Test", setting.DeviceScope),
+ definitions: []*setting.Definition{
+ setting.NewDefinition("DeviceSetting", setting.DeviceSetting, setting.StringValue),
+ setting.NewDefinition("ProfileSetting", setting.ProfileSetting, setting.IntegerValue),
+ setting.NewDefinition("UserSetting", setting.UserSetting, setting.BooleanValue),
+ },
+ wantReads: []TestExpectedReads{
+ {Key: "DeviceSetting", Type: setting.StringValue, NumTimes: 1},
+ {Key: "ProfileSetting", Type: setting.IntegerValue, NumTimes: 1},
+ {Key: "UserSetting", Type: setting.BooleanValue, NumTimes: 1},
+ },
+ },
+ {
+ name: "read-settings-if-in-scope/profile",
+ origin: setting.NewNamedOrigin("Test", setting.CurrentProfileScope),
+ definitions: []*setting.Definition{
+ setting.NewDefinition("DeviceSetting", setting.DeviceSetting, setting.StringValue),
+ setting.NewDefinition("ProfileSetting", setting.ProfileSetting, setting.IntegerValue),
+ setting.NewDefinition("UserSetting", setting.UserSetting, setting.BooleanValue),
+ },
+ wantReads: []TestExpectedReads{
+ // Device settings cannot be configured at the profile scope and should not be read.
+ {Key: "ProfileSetting", Type: setting.IntegerValue, NumTimes: 1},
+ {Key: "UserSetting", Type: setting.BooleanValue, NumTimes: 1},
+ },
+ },
+ {
+ name: "read-settings-if-in-scope/user",
+ origin: setting.NewNamedOrigin("Test", setting.CurrentUserScope),
+ definitions: []*setting.Definition{
+ setting.NewDefinition("DeviceSetting", setting.DeviceSetting, setting.StringValue),
+ setting.NewDefinition("ProfileSetting", setting.ProfileSetting, setting.IntegerValue),
+ setting.NewDefinition("UserSetting", setting.UserSetting, setting.BooleanValue),
+ },
+ wantReads: []TestExpectedReads{
+ // Device and profile settings cannot be configured at the profile scope and should not be read.
+ {Key: "UserSetting", Type: setting.BooleanValue, NumTimes: 1},
+ },
+ },
+ {
+ name: "read-stringy-settings",
+ origin: setting.NewNamedOrigin("Test", setting.DeviceScope),
+ definitions: []*setting.Definition{
+ setting.NewDefinition("DurationValue", setting.DeviceSetting, setting.DurationValue),
+ setting.NewDefinition("PreferenceOptionValue", setting.DeviceSetting, setting.PreferenceOptionValue),
+ setting.NewDefinition("VisibilityValue", setting.DeviceSetting, setting.VisibilityValue),
+ },
+ wantReads: []TestExpectedReads{
+ {Key: "DurationValue", Type: setting.StringValue, NumTimes: 1}, // duration is string from the [Store]'s perspective
+ {Key: "PreferenceOptionValue", Type: setting.StringValue, NumTimes: 1}, // and so are [setting.PreferenceOption]s
+ {Key: "VisibilityValue", Type: setting.StringValue, NumTimes: 1}, // and [setting.Visibility]
+ },
+ initStrings: []TestSetting[string]{
+ TestSettingOf("DurationValue", "2h30m"),
+ TestSettingOf("PreferenceOptionValue", "always"),
+ TestSettingOf("VisibilityValue", "show"),
+ },
+ initWant: setting.NewSnapshot(map[setting.Key]setting.RawItem{
+ "DurationValue": setting.RawItemWith(must.Get(time.ParseDuration("2h30m")), nil, setting.NewNamedOrigin("Test", setting.DeviceScope)),
+ "PreferenceOptionValue": setting.RawItemWith(setting.AlwaysByPolicy, nil, setting.NewNamedOrigin("Test", setting.DeviceScope)),
+ "VisibilityValue": setting.RawItemWith(setting.VisibleByPolicy, nil, setting.NewNamedOrigin("Test", setting.DeviceScope)),
+ }, setting.NewNamedOrigin("Test", setting.DeviceScope)),
+ },
+ {
+ name: "read-erroneous-stringy-settings",
+ origin: setting.NewNamedOrigin("Test", setting.CurrentUserScope),
+ definitions: []*setting.Definition{
+ setting.NewDefinition("DurationValue1", setting.UserSetting, setting.DurationValue),
+ setting.NewDefinition("DurationValue2", setting.UserSetting, setting.DurationValue),
+ setting.NewDefinition("PreferenceOptionValue", setting.UserSetting, setting.PreferenceOptionValue),
+ setting.NewDefinition("VisibilityValue", setting.UserSetting, setting.VisibilityValue),
+ },
+ wantReads: []TestExpectedReads{
+ {Key: "DurationValue1", Type: setting.StringValue, NumTimes: 1}, // duration is string from the [Store]'s perspective
+ {Key: "DurationValue2", Type: setting.StringValue, NumTimes: 1}, // duration is string from the [Store]'s perspective
+ {Key: "PreferenceOptionValue", Type: setting.StringValue, NumTimes: 1}, // and so are [setting.PreferenceOption]s
+ {Key: "VisibilityValue", Type: setting.StringValue, NumTimes: 1}, // and [setting.Visibility]
+ },
+ initStrings: []TestSetting[string]{
+ TestSettingOf("DurationValue1", "soon"),
+ TestSettingWithError[string]("DurationValue2", setting.NewError("bang!")),
+ TestSettingOf("PreferenceOptionValue", "sometimes"),
+ },
+ initUInt64s: []TestSetting[uint64]{
+ TestSettingOf[uint64]("VisibilityValue", 42), // type mismatch
+ },
+ initWant: setting.NewSnapshot(map[setting.Key]setting.RawItem{
+ "DurationValue1": setting.RawItemWith(nil, setting.NewError("time: invalid duration \"soon\""), setting.NewNamedOrigin("Test", setting.CurrentUserScope)),
+ "DurationValue2": setting.RawItemWith(nil, setting.NewError("bang!"), setting.NewNamedOrigin("Test", setting.CurrentUserScope)),
+ "PreferenceOptionValue": setting.RawItemWith(setting.ShowChoiceByPolicy, nil, setting.NewNamedOrigin("Test", setting.CurrentUserScope)),
+ "VisibilityValue": setting.RawItemWith(setting.VisibleByPolicy, setting.NewError("type mismatch in ReadString: got uint64"), setting.NewNamedOrigin("Test", setting.CurrentUserScope)),
+ }, setting.NewNamedOrigin("Test", setting.CurrentUserScope)),
+ },
+ }
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ setting.SetDefinitionsForTest(t, tt.definitions...)
+ store := NewTestStore(t)
+ store.SetStrings(tt.initStrings...)
+ store.SetUInt64s(tt.initUInt64s...)
+
+ reader, err := newReader(store, tt.origin)
+ if err != nil {
+ t.Fatalf("newReader failed: %v", err)
+ }
+
+ if got := reader.GetSettings(); tt.initWant != nil && !got.Equal(tt.initWant) {
+ t.Errorf("Settings do not match: got %v, want %v", got, tt.initWant)
+ }
+ if tt.wantReads != nil {
+ store.ReadsMustEqual(tt.wantReads...)
+ }
+
+ // Should not result in new reads as there were no changes.
+ N := 100
+ for range N {
+ reader.GetSettings()
+ }
+ if tt.wantReads != nil {
+ store.ReadsMustEqual(tt.wantReads...)
+ }
+ store.ResetCounters()
+
+ got, err := reader.ReadSettings()
+ if err != nil {
+ t.Fatalf("ReadSettings failed: %v", err)
+ }
+
+ if tt.initWant != nil && !got.Equal(tt.initWant) {
+ t.Errorf("Settings do not match: got %v, want %v", got, tt.initWant)
+ }
+
+ if tt.wantReads != nil {
+ store.ReadsMustEqual(tt.wantReads...)
+ }
+ store.ResetCounters()
+
+ if len(tt.addStrings) != 0 || len(tt.addStringLists) != 0 {
+ store.SetStrings(tt.addStrings...)
+ store.SetStringLists(tt.addStringLists...)
+
+ // As the settings have changed, GetSettings needs to re-read them.
+ if got, want := reader.GetSettings(), cmp.Or(tt.newWant, tt.initWant); !got.Equal(want) {
+ t.Errorf("New Settings do not match: got %v, want %v", got, want)
+ }
+ if tt.wantReads != nil {
+ store.ReadsMustEqual(tt.wantReads...)
+ }
+ }
+
+ select {
+ case <-reader.Done():
+ t.Fatalf("the reader is closed")
+ default:
+ }
+
+ store.Close()
+
+ <-reader.Done()
+ })
+ }
+}
+
+func TestReadingSession(t *testing.T) {
+ setting.SetDefinitionsForTest(t, setting.NewDefinition("StringValue", setting.DeviceSetting, setting.StringValue))
+ store := NewTestStore(t)
+
+ origin := setting.NewOrigin(setting.DeviceScope)
+ reader, err := newReader(store, origin)
+ if err != nil {
+ t.Fatalf("newReader failed: %v", err)
+ }
+ session, err := reader.OpenSession()
+ if err != nil {
+ t.Fatalf("failed to open a reading session: %v", err)
+ }
+ t.Cleanup(session.Close)
+
+ if got, want := session.GetSettings(), setting.NewSnapshot(nil, origin); !got.Equal(want) {
+ t.Errorf("Settings do not match: got %v, want %v", got, want)
+ }
+
+ select {
+ case _, ok := <-session.PolicyChanged():
+ if ok {
+ t.Fatalf("the policy changed notification was sent prematurely")
+ } else {
+ t.Fatalf("the session was closed prematurely")
+ }
+ default:
+ }
+
+ store.SetStrings(TestSettingOf("StringValue", "S1"))
+ _, ok := <-session.PolicyChanged()
+ if !ok {
+ t.Fatalf("the session was closed prematurely")
+ }
+
+ want := setting.NewSnapshot(map[setting.Key]setting.RawItem{
+ "StringValue": setting.RawItemWith("S1", nil, origin),
+ }, origin)
+ if got := session.GetSettings(); !got.Equal(want) {
+ t.Errorf("Settings do not match: got %v, want %v", got, want)
+ }
+
+ store.Close()
+ if _, ok = <-session.PolicyChanged(); ok {
+ t.Fatalf("the session must be closed")
+ }
+}
diff --git a/util/syspolicy/source/policy_store.go b/util/syspolicy/source/policy_store.go
new file mode 100644
index 000000000..9b150825e
--- /dev/null
+++ b/util/syspolicy/source/policy_store.go
@@ -0,0 +1,146 @@
+// Copyright (c) Tailscale Inc & AUTHORS
+// SPDX-License-Identifier: BSD-3-Clause
+
+package source
+
+import (
+ "cmp"
+ "errors"
+ "fmt"
+ "io"
+
+ "tailscale.com/types/lazy"
+ "tailscale.com/util/syspolicy/setting"
+)
+
+// ErrStoreClosed is an error returned when attempting to use a [Store] after it
+// has been closed.
+var ErrStoreClosed = errors.New("the policy store has been closed")
+
+// Store provides methods to read system policy settings from OS-specific storage.
+// Implementations must be concurrency-safe, and may also implement
+// [Lockable], [Changeable], [Expirable] and [io.Closer].
+//
+// If a [Store] implementation also implements [io.Closer],
+// it will be called by the package to release the resources
+// when the store is no longer needed.
+type Store interface {
+ // ReadString returns the value of a [setting.StringValue] with the specified key,
+ // an [setting.ErrNotConfigured] if the policy setting is not configured, or
+ // an [setting.ErrTypeMismatch] if the policy setting is not of a string type.
+ ReadString(key setting.Key) (string, error)
+ // ReadUInt64 returns the value of a [setting.IntegerValue] with the specified key,
+ // an [setting.ErrNotConfigured] if the policy setting is not configured, or
+ // an [setting.ErrTypeMismatch] if the policy setting is not of a string type.
+ ReadUInt64(key setting.Key) (uint64, error)
+ // ReadBoolean returns the value of a [setting.BooleanValue] with the specified key,
+ // an [setting.ErrNotConfigured] if the policy setting is not configured, or
+ // an [setting.ErrTypeMismatch] if the policy setting is not of a string type.
+ ReadBoolean(key setting.Key) (bool, error)
+ // ReadStringArray returns the value of a [setting.StringListValue] with the specified key,
+ // an [setting.ErrNotConfigured] if the policy setting is not configured, or
+ // an [setting.ErrTypeMismatch] if the policy setting is not of a string list type.
+ ReadStringArray(key setting.Key) ([]string, error)
+}
+
+// Lockable is an optional interface that [Store] implementations may support.
+// Locking a [Store] is not mandatory as [Store] must be concurrency-safe,
+// but is recommended to avoid issues where consecutive read calls for related
+// policies might return inconsistent results if a policy change occurs between
+// the calls.
+type Lockable interface {
+
+ // Lock acquires a read lock on the policy store,
+ // ensuring the store's state remains unchanged while locked.
+ // Multiple readers can hold the lock simultaneously.
+ // It should return nil if the store does not support locking,
+ // or an error if the store cannot be locked.
+ Lock() error
+ // Unlock unlocks the policy store.
+ // It is a runtime error if the store is not locked on entry to Unlock.
+ Unlock()
+}
+
+// Changeable is an optional interface that [Store] implementations may support.
+type Changeable interface {
+ // RegisterChangeCallback adds a function that will be called
+ // whenever there's a policy change in the [Store].
+ // The returned function can be used to unregister the callback.
+ RegisterChangeCallback(callback func()) (unregister func(), err error)
+}
+
+// Expirable is an optional interface that [Store] implementations may support.
+type Expirable interface {
+ // Done returns a channel that is closed when the policy [Store] should no longer be used.
+ // It should return nil if the store never expires.
+ Done() <-chan struct{}
+}
+
+// Source represents a named source of policy settings for a given scope.
+type Source struct {
+ name string
+ scope setting.PolicyScope
+ store Store
+ origin *setting.Origin
+
+ lazyReader lazy.SyncValue[*Reader]
+}
+
+// NewSource returns a new [Source] with the specified name, scope, and store.
+func NewSource(name string, scope setting.PolicyScope, store Store) *Source {
+ return &Source{name: name, scope: scope, store: store, origin: setting.NewNamedOrigin(name, scope)}
+}
+
+// Name reports the name of the policy source.
+func (s *Source) Name() string {
+ return s.name
+}
+
+// Scope reports the management scope of the policy source.
+func (s *Source) Scope() setting.PolicyScope {
+ return s.scope
+}
+
+// Store returns the [Store] that can be used to read policy settings from this source.
+func (s *Source) Store() Store {
+ return s.store
+}
+
+// Reader returns a [Reader] that reads from this source's [Store].
+func (s *Source) Reader() (*Reader, error) {
+ return s.lazyReader.GetErr(func() (*Reader, error) {
+ return newReader(s.store, s.origin)
+ })
+}
+
+// String implements [fmt.Stringer].
+func (s *Source) String() string {
+ if s.Name() != "" {
+ return fmt.Sprintf("%s (%v)", s.Name(), s.Scope())
+ }
+ return s.Scope().String()
+}
+
+// Compare returns an integer comparing [Source] s and s2
+// by their precedence, following the "last-wins" model.
+// The result will be:
+//
+// -1 if policy settings from s should be processed before policy settings from s2;
+// +1 if policy settings from s should be processed after policy settings from s2, overriding s2;
+// 0 if the relative processing order of policy settings in s and s2 is unspecified.
+func (s *Source) Compare(s2 *Source) int {
+ return cmp.Compare(s2.Scope().Kind(), s.Scope().Kind())
+}
+
+// Close closes the [Source] and the underlying [Store].
+func (s *Source) Close() error {
+ // The [Reader], if any, owns the [Store].
+ if reader, _ := s.lazyReader.GetErr(func() (*Reader, error) { return nil, ErrStoreClosed }); reader != nil {
+ return reader.Close()
+ }
+ // Otherwise, it is our responsibility to close it.
+ if closer, ok := s.store.(io.Closer); ok {
+ return closer.Close()
+ }
+ return nil
+}
diff --git a/util/syspolicy/source/policy_store_windows.go b/util/syspolicy/source/policy_store_windows.go
new file mode 100644
index 000000000..5d6503981
--- /dev/null
+++ b/util/syspolicy/source/policy_store_windows.go
@@ -0,0 +1,438 @@
+// Copyright (c) Tailscale Inc & AUTHORS
+// SPDX-License-Identifier: BSD-3-Clause
+
+package source
+
+import (
+ "errors"
+ "fmt"
+ "strings"
+ "sync"
+
+ "golang.org/x/sys/windows"
+ "golang.org/x/sys/windows/registry"
+ "tailscale.com/util/set"
+ "tailscale.com/util/syspolicy/setting"
+ "tailscale.com/util/winutil/gp"
+)
+
+const (
+ softwareKeyName = "Software"
+ tsPoliciesSubkey = `Policies\Tailscale`
+ tsIPNSubkey = "Tailscale IPN" // the legacy key we need to fallback to
+)
+
+var (
+ // [PlatformPolicyStore] implements [Store].
+ _ Store = (*PlatformPolicyStore)(nil)
+)
+
+// PlatformPolicyStore implements [Store] by providing read access to the Registry-based
+// Tailscale policies, such as those configured via Group Policy or MDM. It is
+// recommended to lock it when reading multiple policy values in a row. It also
+// allows subscribing to notifications when there's a policy change.
+type PlatformPolicyStore struct {
+ scope gp.Scope // [gp.MachinePolicy] or [gp.UserPolicy]
+
+ // The softwareKey can be HKLM\Software, HKCU\Software, or
+ // HKU\{SID}\Software. Anything below the Software subkey, including
+ // Software\Policies, may not yet exist or could be deleted throughout the
+ // [PlatformPolicyStore]'s lifespan, invalidating the handle. We also prefer
+ // to always use a real registry key (rather than a predefined HKLM or HKCU)
+ // to simplify bookkeeping (predefined keys should never be closed).
+ // Finally, this will allow us to watch for any registry changes directly
+ // should we need this in the future in addition to gp.ChangeWatcher.
+ softwareKey registry.Key
+ watcher *gp.ChangeWatcher
+
+ done chan struct{} // done is closed when Close call completes
+
+ // The policyLock can be locked by the caller when reading multiple policy settings
+ // to prevent the Group Policy Client service from modifying policies while
+ // they are being read.
+ //
+ // When both policyLock and mu need to be taken, mu must be taken before policyLock.
+ policyLock *gp.PolicyLock
+
+ mu sync.RWMutex
+ tsKeys []registry.Key // or nil if the [PlatformPolicyStore] hasn't been locked.
+ cbs set.HandleSet[func()] // policy change callbacks
+ lockCnt int
+ locked sync.WaitGroup
+ closing bool
+ readable bool
+}
+
+type registryValueGetter[T any] func(key registry.Key, name setting.Key) (T, error)
+
+// NewMachinePlatformPolicyStore returns a new [PlatformPolicyStore] for the machine.
+func NewMachinePlatformPolicyStore() (*PlatformPolicyStore, error) {
+ softwareKey, err := registry.OpenKey(registry.LOCAL_MACHINE, softwareKeyName, windows.KEY_READ)
+ if err != nil {
+ return nil, fmt.Errorf("failed to open the %s key: %w", softwareKeyName, err)
+ }
+ return newPlatformPolicyStore(gp.MachinePolicy, softwareKey, 0)
+}
+
+// NewUserPlatformPolicyStore returns a new [PlatformPolicyStore] for the user specified by its token.
+// User's profile must be loaded, and the token handle must have [windows.TOKEN_QUERY]
+// access. The caller retains ownership of the token.
+func NewUserPlatformPolicyStore(token windows.Token) (*PlatformPolicyStore, error) {
+ var err error
+ var softwareKey registry.Key
+ if token != 0 {
+ var user *windows.Tokenuser
+ if user, err = token.GetTokenUser(); err != nil {
+ return nil, fmt.Errorf("failed to get token user: %w", err)
+ }
+ userSid := user.User.Sid
+ softwareKey, err = registry.OpenKey(registry.USERS, userSid.String()+`\`+softwareKeyName, windows.KEY_READ)
+ } else {
+ softwareKey, err = registry.OpenKey(registry.CURRENT_USER, softwareKeyName, windows.KEY_READ)
+ }
+ if err != nil {
+ return nil, fmt.Errorf("failed to open the %s key: %w", softwareKeyName, err)
+ }
+ return newPlatformPolicyStore(gp.UserPolicy, softwareKey, token)
+}
+
+func newPlatformPolicyStore(scope gp.Scope, softwareKey registry.Key, token windows.Token) (_ *PlatformPolicyStore, err error) {
+ store := &PlatformPolicyStore{
+ scope: scope,
+ softwareKey: softwareKey,
+ done: make(chan struct{}),
+ readable: true,
+ }
+ defer func() {
+ if err != nil {
+ store.Close()
+ }
+ }()
+
+ switch scope {
+ case gp.MachinePolicy:
+ store.policyLock = gp.NewMachinePolicyLock()
+ case gp.UserPolicy:
+ if store.policyLock, err = gp.NewUserPolicyLock(token); err != nil {
+ return nil, fmt.Errorf("failed to create a user policy lock: %w", err)
+ }
+ default:
+ panic("unreachable")
+ }
+
+ return store, nil
+}
+
+// Lock locks the policy store, preventing the system from modifying the policies
+// while they are being read. It is a read lock that may be acquired by multiple goroutines.
+// Each Lock call must be balanced by exactly one Unlock call.
+func (ps *PlatformPolicyStore) Lock() (err error) {
+ ps.mu.Lock()
+ defer ps.mu.Unlock()
+
+ if ps.closing {
+ return ErrStoreClosed
+ }
+
+ ps.lockCnt += 1
+ if ps.lockCnt != 1 {
+ return nil
+ }
+ defer func() {
+ if err != nil {
+ ps.lockCnt -= 1
+ }
+ }()
+
+ // Ensure ps remains open while the lock is held.
+ ps.locked.Add(1)
+ defer func() {
+ if err != nil {
+ ps.locked.Done()
+ }
+ }()
+
+ // Acquire the GP lock to prevent the system from modifying policy settings
+ // while they are being read.
+ if err := ps.policyLock.Lock(); err != nil {
+ if errors.Is(err, gp.ErrInvalidLockState) {
+ return ErrStoreClosed
+ }
+ return err
+ }
+ defer func() {
+ if err != nil {
+ ps.policyLock.Unlock()
+ }
+ }()
+
+ // Keep the Tailscale's registry keys open for the duration of the lock.
+ keyNames := tailscaleKeyNamesFor(ps.scope)
+ ps.tsKeys = make([]registry.Key, 0, len(keyNames))
+ for _, keyName := range keyNames {
+ var tsKey registry.Key
+ tsKey, err = registry.OpenKey(ps.softwareKey, keyName, windows.KEY_READ)
+ if err != nil {
+ if err == registry.ErrNotExist {
+ continue
+ }
+ return err
+ }
+ ps.tsKeys = append(ps.tsKeys, tsKey)
+ }
+
+ return nil
+}
+
+// Unlock decrements the lock counter and unlocks the policy store once the counter reaches 0.
+// It panics if ps is not locked on entry to Unlock.
+func (ps *PlatformPolicyStore) Unlock() {
+ ps.mu.Lock()
+ defer ps.mu.Unlock()
+
+ ps.lockCnt -= 1
+ if ps.lockCnt < 0 {
+ panic("negative lockCnt")
+ } else if ps.lockCnt != 0 {
+ return
+ }
+
+ for _, key := range ps.tsKeys {
+ key.Close()
+ }
+ ps.tsKeys = nil
+ ps.policyLock.Unlock()
+ ps.locked.Done()
+}
+
+// RegisterChangeCallback adds a function that will be called whenever there's a policy change.
+// It returns a function that needs to be called to unregister the specified callback or an error.
+// The error is [ErrStoreClosed] if ps has already been closed.
+func (ps *PlatformPolicyStore) RegisterChangeCallback(cb func()) (unregister func(), err error) {
+ ps.mu.Lock()
+ defer ps.mu.Unlock()
+ if ps.closing {
+ return nil, ErrStoreClosed
+ }
+
+ handle := ps.cbs.Add(cb)
+ if len(ps.cbs) == 1 {
+ if ps.watcher, err = gp.NewChangeWatcher(ps.scope, ps.onChange); err != nil {
+ return nil, err
+ }
+ }
+
+ return func() {
+ ps.mu.Lock()
+ defer ps.mu.Unlock()
+ delete(ps.cbs, handle)
+ if len(ps.cbs) == 0 {
+ if ps.watcher != nil {
+ ps.watcher.Close()
+ ps.watcher = nil
+ }
+ }
+ }, nil
+}
+
+func (ps *PlatformPolicyStore) onChange() {
+ ps.mu.RLock()
+ defer ps.mu.RUnlock()
+ if ps.closing {
+ return
+ }
+ for _, callback := range ps.cbs {
+ go callback()
+ }
+}
+
+// ReadString retrieves a string policy with the specified name.
+// It returns [ErrNotConfigured] if the policy setting does not exist.
+func (ps *PlatformPolicyStore) ReadString(name setting.Key) (val string, err error) {
+ return getPolicyValue(ps, canonicalizeValueName(name),
+ func(key registry.Key, name setting.Key) (string, error) {
+ val, _, err := key.GetStringValue(string(name))
+ return val, err
+ })
+}
+
+// ReadUInt64 retrieves an integer policy with the specified name.
+// It returns [ErrNotConfigured] if the policy setting does not exist.
+func (ps *PlatformPolicyStore) ReadUInt64(name setting.Key) (uint64, error) {
+ return getPolicyValue(ps, canonicalizeValueName(name),
+ func(key registry.Key, name setting.Key) (uint64, error) {
+ val, _, err := key.GetIntegerValue(string(name))
+ return val, err
+ })
+}
+
+// ReadBoolean retrieves a boolean policy with the specified name.
+// It returns [ErrNotConfigured] if the policy setting does not exist.
+func (ps *PlatformPolicyStore) ReadBoolean(name setting.Key) (bool, error) {
+ return getPolicyValue(ps, canonicalizeValueName(name),
+ func(key registry.Key, name setting.Key) (bool, error) {
+ val, _, err := key.GetIntegerValue(string(name))
+ if err != nil {
+ return false, err
+ }
+ return val != 0, nil
+ })
+}
+
+// ReadString retrieves a multi-string policy with the specified name.
+// It returns [ErrNotConfigured] if the policy setting does not exist.
+func (ps *PlatformPolicyStore) ReadStringArray(name setting.Key) ([]string, error) {
+ return getPolicyValue(ps, name,
+ func(key registry.Key, name setting.Key) ([]string, error) {
+ val, _, err := key.GetStringsValue(string(canonicalizeValueName(name)))
+ if err != registry.ErrNotExist {
+ return val, err
+ }
+
+ // The idiomatic way to store multiple string values in Group Policy
+ // and MDM for Windows is to have multiple REG_SZ (or REG_EXPAND_SZ)
+ // values under a subkey rather than in a single REG_MULTI_SZ value.
+ //
+ // See the Group Policy: Registry Extension Encoding specification,
+ // and specifically the ListElement and ListBox types.
+ // https://web.archive.org/web/20240721033657/https://winprotocoldoc.blob.core.windows.net/productionwindowsarchives/MS-GPREG/%5BMS-GPREG%5D.pdf
+ valKey, err := registry.OpenKey(key, string(canonicalizeKeyName(name)), windows.KEY_READ)
+ if err != nil {
+ return nil, err
+ }
+ valNames, err := valKey.ReadValueNames(0)
+ if err != nil {
+ return nil, err
+ }
+ val = make([]string, 0, len(valNames))
+ for _, name := range valNames {
+ switch item, _, err := valKey.GetStringValue(name); {
+ case err == registry.ErrNotExist:
+ continue
+ case err != nil:
+ return nil, err
+ default:
+ val = append(val, item)
+ }
+ }
+ return val, nil
+ })
+}
+
+func canonicalizeKeyName(name setting.Key) setting.Key {
+ return setting.Key(strings.ReplaceAll(string(name), setting.KeyPathSeparator, `\`))
+}
+
+func canonicalizeValueName(name setting.Key) setting.Key {
+ return setting.Key(strings.ReplaceAll(string(name), setting.KeyPathSeparator, `_`))
+}
+
+func getPolicyValue[T any](ps *PlatformPolicyStore, name setting.Key, getter registryValueGetter[T]) (T, error) {
+ var zero T
+
+ ps.mu.RLock()
+ defer ps.mu.RUnlock()
+ if !ps.readable {
+ return zero, setting.ErrNotConfigured
+ }
+
+ if ps.tsKeys != nil {
+ // A non-nil tsKeys indicates that ps has been locked.
+ // It may be empty if Tailscale policy keys do not exist.
+ for _, tsKey := range ps.tsKeys {
+ val, err := getter(tsKey, name)
+ if err == nil || err != registry.ErrNotExist {
+ return val, err
+ }
+ }
+ return zero, setting.ErrNotConfigured
+ }
+
+ // The ps has not been locked, so we don't have any pre-opened keys.
+ for _, tsKeyName := range tailscaleKeyNamesFor(ps.scope) {
+ var tsKey registry.Key
+ tsKey, err := registry.OpenKey(ps.softwareKey, tsKeyName, windows.KEY_READ)
+ if err != nil {
+ if err == registry.ErrNotExist {
+ continue
+ }
+ return zero, err
+ }
+ defer tsKey.Close()
+
+ val, err := getter(tsKey, name)
+ if err == nil || err != registry.ErrNotExist {
+ return val, err
+ }
+ }
+
+ return zero, setting.ErrNotConfigured
+}
+
+// Close closes the policy store and releases any associated resources.
+// It cancels pending locks and prevents any new lock attempts,
+// but waits for existing locks to be released.
+func (ps *PlatformPolicyStore) Close() error {
+ // Request to close the Group Policy read lock.
+ // Existing held locks will remain valid, but any new or pending locks
+ // will fail. In certain scenarios, the corresponding write lock may be held
+ // by the Group Policy service for extended periods (minutes rather than
+ // seconds or milliseconds). In such cases, we prefer not to wait that long
+ // if the ps is being closed anyway.
+ if ps.policyLock != nil {
+ ps.policyLock.Close()
+ }
+
+ // Signal to the external code that ps should no longer be used.
+ close(ps.done)
+
+ // Mark ps as closing to fast-fail any new lock attempts.
+ // Callers that have already locked it can finish their reading.
+ ps.mu.Lock()
+ if ps.closing {
+ ps.mu.Unlock()
+ return nil
+ }
+ ps.closing = true
+ if ps.watcher != nil {
+ ps.watcher.Close()
+ ps.watcher = nil
+ }
+ ps.mu.Unlock()
+
+ // Wait for any outstanding locks to be released.
+ ps.locked.Wait()
+
+ // Deny any further read attempts and release remaining resources.
+ ps.mu.Lock()
+ defer ps.mu.Unlock()
+ ps.cbs = nil
+ ps.policyLock = nil
+ ps.readable = false
+ if ps.softwareKey != 0 {
+ ps.softwareKey.Close()
+ ps.softwareKey = 0
+ }
+ return nil
+}
+
+// Done returns a channel that is closed when the Close method is called.
+func (ps *PlatformPolicyStore) Done() <-chan struct{} {
+ return ps.done
+}
+
+func tailscaleKeyNamesFor(scope gp.Scope) []string {
+ switch scope {
+ case gp.MachinePolicy:
+ // If a computer-side policy value does not exist under Software\Policies\Tailscale,
+ // we need to fallback and use the legacy Software\Tailscale IPN key.
+ return []string{tsPoliciesSubkey, tsIPNSubkey}
+ case gp.UserPolicy:
+ // However, we've never used the legacy key with user-side policies,
+ // and we should never do so. Unlike HKLM\Software\Tailscale IPN,
+ // its HKCU counterpart is user-writable.
+ return []string{tsPoliciesSubkey}
+ default:
+ panic("unreachable")
+ }
+}
diff --git a/util/syspolicy/source/policy_store_windows_test.go b/util/syspolicy/source/policy_store_windows_test.go
new file mode 100644
index 000000000..60c76837f
--- /dev/null
+++ b/util/syspolicy/source/policy_store_windows_test.go
@@ -0,0 +1,298 @@
+// Copyright (c) Tailscale Inc & AUTHORS
+// SPDX-License-Identifier: BSD-3-Clause
+
+package source
+
+import (
+ "errors"
+ "fmt"
+ "reflect"
+ "sync"
+ "sync/atomic"
+ "testing"
+ "time"
+
+ "golang.org/x/sys/windows"
+ "golang.org/x/sys/windows/registry"
+ "tailscale.com/util/cibuild"
+ "tailscale.com/util/syspolicy/setting"
+ "tailscale.com/util/winutil"
+ "tailscale.com/util/winutil/gp"
+)
+
+type testPolicyValue struct {
+ name setting.Key
+ value any
+}
+
+func TestLockUnlockPolicyStore(t *testing.T) {
+ store, err := NewMachinePlatformPolicyStore()
+ if err != nil {
+ t.Fatalf("NewMachinePolicyStore failed: %v", err)
+ }
+
+ t.Run("One-Goroutine", func(t *testing.T) {
+ if err := store.Lock(); err != nil {
+ t.Errorf("store.Lock(): got %v; want nil", err)
+ return
+ }
+ if v, err := store.ReadString("NonExistingPolicySetting"); err == nil || !errors.Is(err, setting.ErrNotConfigured) {
+ t.Errorf(`ReadString: got %v, %v; want "", %v`, v, err, setting.ErrNotConfigured)
+ }
+ store.Unlock()
+ })
+
+ // Lock the store N times from different goroutines.
+ const N = 100
+ var unlocked atomic.Int32
+ t.Run("N-Goroutines", func(t *testing.T) {
+ var wg sync.WaitGroup
+ wg.Add(N)
+ for range N {
+ go func() {
+ if err := store.Lock(); err != nil {
+ t.Errorf("store.Lock(): got %v; want nil", err)
+ return
+ }
+ if v, err := store.ReadString("NonExistingPolicySetting"); err == nil || !errors.Is(err, setting.ErrNotConfigured) {
+ t.Errorf(`ReadString: got %v, %v; want "", %v`, v, err, setting.ErrNotConfigured)
+ }
+ wg.Done()
+ time.Sleep(10 * time.Millisecond)
+ unlocked.Add(1)
+ store.Unlock()
+ }()
+ }
+
+ // Wait until the store is locked N times.
+ wg.Wait()
+ })
+
+ // Close the store. The call should wait for all held locks to be released.
+ if err := store.Close(); err != nil {
+ t.Fatalf("(*PolicyStore).Close failed: %v", err)
+ }
+ if locked := unlocked.Load(); locked != N {
+ t.Errorf("locked.Load(): got %v; want %v", locked, N)
+ }
+
+ // Any further attempts to lock it should fail.
+ if err = store.Lock(); err == nil || !errors.Is(err, ErrStoreClosed) {
+ t.Errorf("store.Lock(): got %v; want %v", err, ErrStoreClosed)
+ }
+}
+
+func TestReadPolicyStore(t *testing.T) {
+ if !winutil.IsCurrentProcessElevated() {
+ t.Skipf("test requires running as elevated user")
+ }
+ tests := []struct {
+ name setting.Key
+ newValue any
+ legacyValue any
+ want any
+ }{
+ {name: "LegacyPolicy", legacyValue: "LegacyValue", want: "LegacyValue"},
+ {name: "StringPolicy", legacyValue: "LegacyValue", newValue: "Value", want: "Value"},
+ {name: "StringPolicy_Empty", legacyValue: "LegacyValue", newValue: "", want: ""},
+ {name: "BoolPolicy_True", newValue: true, want: true},
+ {name: "BoolPolicy_False", newValue: false, want: false},
+ {name: "UIntPolicy_1", newValue: uint32(10), want: uint64(10)}, // uint32 values should be returned as uint64
+ {name: "UIntPolicy_2", newValue: uint64(1 << 37), want: uint64(1 << 37)},
+ {name: "StringListPolicy", newValue: []string{"Value1", "Value2"}, want: []string{"Value1", "Value2"}},
+ {name: "StringListPolicy_Empty", newValue: []string{}, want: []string{}},
+ }
+
+ runTests := func(t *testing.T, userStore bool, token windows.Token) {
+ var hive registry.Key
+ if userStore {
+ hive = registry.CURRENT_USER
+ } else {
+ hive = registry.LOCAL_MACHINE
+ }
+
+ // Write policy values to the registry.
+ newValues := make([]testPolicyValue, 0, len(tests))
+ for _, tt := range tests {
+ if tt.newValue != nil {
+ newValues = append(newValues, testPolicyValue{name: tt.name, value: tt.newValue})
+ }
+ }
+ policiesKeyName := softwareKeyName + `\` + tsPoliciesSubkey
+ cleanup, err := createTestPolicyValues(hive, policiesKeyName, newValues)
+ if err != nil {
+ t.Fatalf("createTestPolicyValues failed: %v", err)
+ }
+ t.Cleanup(cleanup)
+
+ // Write legacy policy values to the registry.
+ legacyValues := make([]testPolicyValue, 0, len(tests))
+ for _, tt := range tests {
+ if tt.legacyValue != nil {
+ legacyValues = append(legacyValues, testPolicyValue{name: tt.name, value: tt.legacyValue})
+ }
+ }
+ legacyKeyName := softwareKeyName + `\` + tsIPNSubkey
+ cleanup, err = createTestPolicyValues(hive, legacyKeyName, legacyValues)
+ if err != nil {
+ t.Fatalf("createTestPolicyValues failed: %v", err)
+ }
+ t.Cleanup(cleanup)
+
+ var store *PlatformPolicyStore
+ if userStore {
+ store, err = NewUserPlatformPolicyStore(token)
+ } else {
+ store, err = NewMachinePlatformPolicyStore()
+ }
+ if err != nil {
+ t.Fatalf("NewXPolicyStore failed: %v", err)
+ }
+ t.Cleanup(func() {
+ if err := store.Close(); err != nil {
+ t.Errorf("(*PolicyStore).Close failed: %v", err)
+ }
+ })
+
+ // testReadValues checks that [PolicyStore] returns the same values we wrote directly to the registry.
+ testReadValues := func(t *testing.T, withLocks bool) {
+ for _, tt := range tests {
+ t.Run(string(tt.name), func(t *testing.T) {
+ if userStore && tt.newValue == nil {
+ t.Skip("there is no legacy policies for users")
+ }
+
+ t.Parallel()
+
+ if withLocks {
+ if err := store.Lock(); err != nil {
+ t.Errorf("failed to acquire the lock: %v", err)
+ }
+ defer store.Unlock()
+ }
+
+ var got any
+ var err error
+ switch tt.want.(type) {
+ case string:
+ got, err = store.ReadString(tt.name)
+ case uint64:
+ got, err = store.ReadUInt64(tt.name)
+ case bool:
+ got, err = store.ReadBoolean(tt.name)
+ case []string:
+ got, err = store.ReadStringArray(tt.name)
+ }
+ if err != nil {
+ t.Fatal(err)
+ }
+ if !reflect.DeepEqual(got, tt.want) {
+ t.Errorf("got %v; want %v", got, tt.want)
+ }
+ })
+ }
+ }
+ t.Run("NoLock", func(t *testing.T) {
+ testReadValues(t, false)
+ })
+
+ t.Run("WithLock", func(t *testing.T) {
+ testReadValues(t, true)
+ })
+ }
+
+ t.Run("MachineStore", func(t *testing.T) {
+ runTests(t, false, 0)
+ })
+
+ t.Run("CurrentUserStore", func(t *testing.T) {
+ runTests(t, true, 0)
+ })
+
+ t.Run("UserStoreWithToken", func(t *testing.T) {
+ var token windows.Token
+ if err := windows.OpenProcessToken(windows.CurrentProcess(), windows.TOKEN_QUERY, &token); err != nil {
+ t.Fatalf("OpenProcessToken: %v", err)
+ }
+ defer token.Close()
+ runTests(t, true, token)
+ })
+}
+
+func TestPolicyStoreChangeNotifications(t *testing.T) {
+ if cibuild.On() {
+ t.Skipf("test requires running on a real Windows environment")
+ }
+ store, err := NewMachinePlatformPolicyStore()
+ if err != nil {
+ t.Fatalf("NewMachinePolicyStore failed: %v", err)
+ }
+ t.Cleanup(func() {
+ if err := store.Close(); err != nil {
+ t.Errorf("(*PolicyStore).Close failed: %v", err)
+ }
+ })
+
+ done := make(chan struct{})
+ unregister, err := store.RegisterChangeCallback(func() { close(done) })
+ if err != nil {
+ t.Fatalf("RegisterChangeCallback failed: %v", err)
+ }
+ t.Cleanup(unregister)
+
+ // RefreshMachinePolicy is a non-blocking call.
+ if err := gp.RefreshMachinePolicy(true); err != nil {
+ t.Fatalf("RefreshMachinePolicy failed: %v", err)
+ }
+
+ // We should receive a policy change notification when
+ // the Group Policy service completes policy processing.
+ // Otherwise, the test will eventually time out.
+ <-done
+}
+
+func createTestPolicyValues(hive registry.Key, keyName string, values []testPolicyValue) (cleanup func(), err error) {
+ key, existing, err := registry.CreateKey(hive, keyName, registry.ALL_ACCESS)
+ if err != nil {
+ return nil, err
+ }
+ doCleanup := func() {
+ for _, v := range values {
+ key.DeleteValue(string(v.name))
+ }
+ key.Close()
+ if !existing {
+ registry.DeleteKey(hive, keyName)
+ }
+ }
+ defer func() {
+ if err != nil {
+ doCleanup()
+ }
+ }()
+
+ for _, v := range values {
+ switch value := v.value.(type) {
+ case string:
+ err = key.SetStringValue(string(v.name), value)
+ case uint32:
+ err = key.SetDWordValue(string(v.name), value)
+ case uint64:
+ err = key.SetQWordValue(string(v.name), value)
+ case bool:
+ if value {
+ err = key.SetDWordValue(string(v.name), 1)
+ } else {
+ err = key.SetDWordValue(string(v.name), 0)
+ }
+ case []string:
+ err = key.SetStringsValue(string(v.name), value)
+ default:
+ err = fmt.Errorf("unsupported value: %v (%T), name: %q", value, value, v.name)
+ }
+ if err != nil {
+ return nil, err
+ }
+ }
+ return doCleanup, nil
+}
diff --git a/util/syspolicy/source/test_store.go b/util/syspolicy/source/test_store.go
new file mode 100644
index 000000000..fd422d852
--- /dev/null
+++ b/util/syspolicy/source/test_store.go
@@ -0,0 +1,446 @@
+// Copyright (c) Tailscale Inc & AUTHORS
+// SPDX-License-Identifier: BSD-3-Clause
+
+package source
+
+import (
+ "fmt"
+ "sync"
+ "sync/atomic"
+
+ xmaps "golang.org/x/exp/maps"
+ "tailscale.com/util/mak"
+ "tailscale.com/util/set"
+ "tailscale.com/util/syspolicy/internal"
+ "tailscale.com/util/syspolicy/setting"
+)
+
+var _ Store = (*TestStore)(nil)
+
+// TestValueType is a constraint that allows types supported by [TestStore].
+type TestValueType interface {
+ bool | uint64 | string | []string
+}
+
+// TestSetting is a policy setting in a [TestStore].
+type TestSetting[T TestValueType] struct {
+ // Key is the setting's unique identifier.
+ Key setting.Key
+ // Error is the error to be returned by the [TestStore] when reading
+ // a policy setting with the specified key.
+ Error error
+ // Value is the value to be returned by the [TestStore] when reading
+ // a policy setting with the specified key.
+ // It is only used if the Error is nil.
+ Value T
+}
+
+// TestSettingOf returns a [TestSetting] representing a policy setting
+// configured with the specified key and value.
+func TestSettingOf[T TestValueType](key setting.Key, value T) TestSetting[T] {
+ return TestSetting[T]{Key: key, Value: value}
+}
+
+// TestSettingWithError returns a [TestSetting] representing a policy setting
+// with the specified key and error.
+func TestSettingWithError[T TestValueType](key setting.Key, err error) TestSetting[T] {
+ return TestSetting[T]{Key: key, Error: err}
+}
+
+// testReadOperation describes a single policy setting read operation.
+type testReadOperation struct {
+ // Key is the setting's unique identifier.
+ Key setting.Key
+ // Type is a value type of a read operation.
+ // [setting.BooleanValue], [setting.IntegerValue], [setting.StringValue] or [setting.StringListValue]
+ Type setting.Type
+}
+
+// TestExpectedReads is the number of read operations with the specified details.
+type TestExpectedReads struct {
+ // Key is the setting's unique identifier.
+ Key setting.Key
+ // Type is a value type of a read operation.
+ // [setting.BooleanValue], [setting.IntegerValue], [setting.StringValue] or [setting.StringListValue]
+ Type setting.Type
+ // NumTimes is how many times a setting with the specified key and type should have been read.
+ NumTimes int
+}
+
+func (r TestExpectedReads) operation() testReadOperation {
+ return testReadOperation{r.Key, r.Type}
+}
+
+// TestStore is a [Store] that can be used in tests.
+type TestStore struct {
+ tb internal.TB
+
+ done chan struct{}
+
+ storeLock sync.RWMutex // its RLock is exposed via [Store.Lock]/[Store.Unlock].
+ storeLockCount atomic.Int32
+
+ mu sync.RWMutex
+ suspendCount int // change callback are suspended if > 0
+ mr, mw map[setting.Key]any // maps for reading and writing; they're the same unless the store is suspended.
+ cbs set.HandleSet[func()]
+
+ readsMu sync.Mutex
+ reads map[testReadOperation]int // how many times a policy setting was read
+}
+
+// NewTestStore returns a new [TestStore].
+// The tb will be used to report coding errors detected by the [TestStore].
+func NewTestStore(tb internal.TB) *TestStore {
+ m := make(map[setting.Key]any)
+ return &TestStore{
+ tb: tb,
+ done: make(chan struct{}),
+ mr: m,
+ mw: m,
+ }
+}
+
+// NewTestStoreOf is a shorthand for [NewTestStore] followed by [TestStore.SetBooleans],
+// [TestStore.SetUInt64s], [TestStore.SetStrings] or [TestStore.SetStringLists].
+func NewTestStoreOf[T TestValueType](tb internal.TB, settings ...TestSetting[T]) *TestStore {
+ m := make(map[setting.Key]any)
+ store := &TestStore{
+ tb: tb,
+ done: make(chan struct{}),
+ mr: m,
+ mw: m,
+ }
+ switch settings := any(settings).(type) {
+ case []TestSetting[bool]:
+ store.SetBooleans(settings...)
+ case []TestSetting[uint64]:
+ store.SetUInt64s(settings...)
+ case []TestSetting[string]:
+ store.SetStrings(settings...)
+ case []TestSetting[[]string]:
+ store.SetStringLists(settings...)
+ }
+ return store
+}
+
+// Lock implements [Store].
+func (s *TestStore) Lock() error {
+ s.storeLock.RLock()
+ s.storeLockCount.Add(1)
+ return nil
+}
+
+// Unlock implements [Store].
+func (s *TestStore) Unlock() {
+ if s.storeLockCount.Add(-1) < 0 {
+ s.tb.Fatal("negative storeLockCount")
+ }
+ s.storeLock.RUnlock()
+}
+
+// RegisterChangeCallback implements [Store].
+func (s *TestStore) RegisterChangeCallback(callback func()) (unregister func(), err error) {
+ s.mu.Lock()
+ defer s.mu.Unlock()
+ handle := s.cbs.Add(callback)
+ return func() {
+ s.mu.Lock()
+ defer s.mu.Unlock()
+ delete(s.cbs, handle)
+ }, nil
+}
+
+// ReadString implements [Store].
+func (s *TestStore) ReadString(key setting.Key) (string, error) {
+ defer s.recordRead(key, setting.StringValue)
+ s.mu.RLock()
+ defer s.mu.RUnlock()
+ v, ok := s.mr[key]
+ if !ok {
+ return "", setting.ErrNotConfigured
+ }
+ if err, ok := v.(error); ok {
+ return "", err
+ }
+ str, ok := v.(string)
+ if !ok {
+ return "", fmt.Errorf("%w in ReadString: got %T", setting.ErrTypeMismatch, v)
+ }
+ return str, nil
+}
+
+// ReadUInt64 implements [Store].
+func (s *TestStore) ReadUInt64(key setting.Key) (uint64, error) {
+ defer s.recordRead(key, setting.IntegerValue)
+ s.mu.RLock()
+ defer s.mu.RUnlock()
+ v, ok := s.mr[key]
+ if !ok {
+ return 0, setting.ErrNotConfigured
+ }
+ if err, ok := v.(error); ok {
+ return 0, err
+ }
+ u64, ok := v.(uint64)
+ if !ok {
+ return 0, fmt.Errorf("%w in ReadUInt64: got %T", setting.ErrTypeMismatch, v)
+ }
+ return u64, nil
+}
+
+// ReadBoolean implements [Store].
+func (s *TestStore) ReadBoolean(key setting.Key) (bool, error) {
+ defer s.recordRead(key, setting.BooleanValue)
+ s.mu.RLock()
+ defer s.mu.RUnlock()
+ v, ok := s.mr[key]
+ if !ok {
+ return false, setting.ErrNotConfigured
+ }
+ if err, ok := v.(error); ok {
+ return false, err
+ }
+ b, ok := v.(bool)
+ if !ok {
+ return false, fmt.Errorf("%w in ReadBoolean: got %T", setting.ErrTypeMismatch, v)
+ }
+ return b, nil
+}
+
+// ReadStringArray implements [Store].
+func (s *TestStore) ReadStringArray(key setting.Key) ([]string, error) {
+ defer s.recordRead(key, setting.StringListValue)
+ s.mu.RLock()
+ defer s.mu.RUnlock()
+ v, ok := s.mr[key]
+ if !ok {
+ return nil, setting.ErrNotConfigured
+ }
+ if err, ok := v.(error); ok {
+ return nil, err
+ }
+ slice, ok := v.([]string)
+ if !ok {
+ return nil, fmt.Errorf("%w in ReadStringArray: got %T", setting.ErrTypeMismatch, v)
+ }
+ return slice, nil
+}
+
+func (s *TestStore) recordRead(key setting.Key, typ setting.Type) {
+ s.readsMu.Lock()
+ op := testReadOperation{key, typ}
+ num := s.reads[op]
+ num++
+ mak.Set(&s.reads, op, num)
+ s.readsMu.Unlock()
+}
+
+func (s *TestStore) ResetCounters() {
+ s.readsMu.Lock()
+ clear(s.reads)
+ s.readsMu.Unlock()
+}
+
+// ReadsMustEqual fails the test if the actual reads differs from the specified reads.
+func (s *TestStore) ReadsMustEqual(reads ...TestExpectedReads) {
+ s.tb.Helper()
+ s.readsMu.Lock()
+ defer s.readsMu.Unlock()
+ s.readsMustContainLocked(reads...)
+ s.readMustNoExtraLocked(reads...)
+}
+
+// ReadsMustContain fails the test if the specified reads have not been made,
+// or have been made a different number of times. It permits other values to be
+// read in addition to the ones being tested.
+func (s *TestStore) ReadsMustContain(reads ...TestExpectedReads) {
+ s.tb.Helper()
+ s.readsMu.Lock()
+ defer s.readsMu.Unlock()
+ s.readsMustContainLocked(reads...)
+}
+
+func (s *TestStore) readsMustContainLocked(reads ...TestExpectedReads) {
+ s.tb.Helper()
+ for _, r := range reads {
+ if numTimes := s.reads[r.operation()]; numTimes != r.NumTimes {
+ s.tb.Errorf("%q (%v) reads: got %v, want %v", r.Key, r.Type, numTimes, r.NumTimes)
+ }
+ }
+}
+
+func (s *TestStore) readMustNoExtraLocked(reads ...TestExpectedReads) {
+ s.tb.Helper()
+ rs := make(set.Set[testReadOperation])
+ for i := range reads {
+ rs.Add(reads[i].operation())
+ }
+ for ro, num := range s.reads {
+ if !rs.Contains(ro) {
+ s.tb.Errorf("%q (%v) reads: got %v, want 0", ro.Key, ro.Type, num)
+ }
+ }
+}
+
+// Suspend suspends the store, batching changes and notifications
+// until [TestStore.Resume] is called the same number of times as Suspend.
+func (s *TestStore) Suspend() {
+ s.mu.Lock()
+ defer s.mu.Unlock()
+ if s.suspendCount++; s.suspendCount == 1 {
+ s.mw = xmaps.Clone(s.mr)
+ }
+}
+
+// Resume resumes the store, applying the changes and invoking
+// the change callbacks.
+func (s *TestStore) Resume() {
+ s.storeLock.Lock()
+ s.mu.Lock()
+ switch s.suspendCount--; {
+ case s.suspendCount == 0:
+ s.mr = s.mw
+ s.mu.Unlock()
+ s.storeLock.Unlock()
+ s.notifyPolicyChanged()
+ case s.suspendCount < 0:
+ s.tb.Fatal("negative suspendCount")
+ default:
+ s.mu.Unlock()
+ s.storeLock.Unlock()
+ }
+}
+
+// SetBooleans sets the specified boolean settings in s.
+func (s *TestStore) SetBooleans(settings ...TestSetting[bool]) {
+ s.storeLock.Lock()
+ for _, setting := range settings {
+ if setting.Key == "" {
+ s.tb.Fatal("empty keys disallowed")
+ }
+ s.mu.Lock()
+ if setting.Error != nil {
+ mak.Set(&s.mw, setting.Key, any(setting.Error))
+ } else {
+ mak.Set(&s.mw, setting.Key, any(setting.Value))
+ }
+ s.mu.Unlock()
+ }
+ s.storeLock.Unlock()
+ s.notifyPolicyChanged()
+}
+
+// SetUInt64s sets the specified integer settings in s.
+func (s *TestStore) SetUInt64s(settings ...TestSetting[uint64]) {
+ s.storeLock.Lock()
+ for _, setting := range settings {
+ if setting.Key == "" {
+ s.tb.Fatal("empty keys disallowed")
+ }
+ s.mu.Lock()
+ if setting.Error != nil {
+ mak.Set(&s.mw, setting.Key, any(setting.Error))
+ } else {
+ mak.Set(&s.mw, setting.Key, any(setting.Value))
+ }
+ s.mu.Unlock()
+ }
+ s.storeLock.Unlock()
+ s.notifyPolicyChanged()
+}
+
+// SetStrings sets the specified string settings in s.
+func (s *TestStore) SetStrings(settings ...TestSetting[string]) {
+ s.storeLock.Lock()
+ for _, setting := range settings {
+ if setting.Key == "" {
+ s.tb.Fatal("empty keys disallowed")
+ }
+ s.mu.Lock()
+ if setting.Error != nil {
+ mak.Set(&s.mw, setting.Key, any(setting.Error))
+ } else {
+ mak.Set(&s.mw, setting.Key, any(setting.Value))
+ }
+ s.mu.Unlock()
+ }
+ s.storeLock.Unlock()
+ s.notifyPolicyChanged()
+}
+
+// SetStrings sets the specified string list settings in s.
+func (s *TestStore) SetStringLists(settings ...TestSetting[[]string]) {
+ s.storeLock.Lock()
+ for _, setting := range settings {
+ if setting.Key == "" {
+ s.tb.Fatal("empty keys disallowed")
+ }
+ s.mu.Lock()
+ if setting.Error != nil {
+ mak.Set(&s.mw, setting.Key, any(setting.Error))
+ } else {
+ mak.Set(&s.mw, setting.Key, any(setting.Value))
+ }
+ s.mu.Unlock()
+ }
+ s.storeLock.Unlock()
+ s.notifyPolicyChanged()
+}
+
+// Delete deletes the specified settings from s.
+func (s *TestStore) Delete(keys ...setting.Key) {
+ s.storeLock.Lock()
+ for _, key := range keys {
+ s.mu.Lock()
+ delete(s.mw, key)
+ s.mu.Unlock()
+ }
+ s.storeLock.Unlock()
+ s.notifyPolicyChanged()
+}
+
+// Clear deletes all settings from s.
+func (s *TestStore) Clear() {
+ s.storeLock.Lock()
+ s.mu.Lock()
+ clear(s.mw)
+ s.mu.Unlock()
+ s.storeLock.Unlock()
+ s.notifyPolicyChanged()
+}
+
+func (s *TestStore) notifyPolicyChanged() {
+ s.mu.RLock()
+ if s.suspendCount != 0 {
+ s.mu.RUnlock()
+ return
+ }
+ cbs := xmaps.Values(s.cbs)
+ s.mu.RUnlock()
+
+ var wg sync.WaitGroup
+ wg.Add(len(cbs))
+ for _, cb := range cbs {
+ go func() {
+ defer wg.Done()
+ cb()
+ }()
+ }
+ wg.Wait()
+}
+
+// Close closes s, notifying its users that it has expired.
+func (s *TestStore) Close() {
+ s.mu.Lock()
+ defer s.mu.Unlock()
+ if s.done != nil {
+ close(s.done)
+ s.done = nil
+ }
+}
+
+// Done implements [Store].
+func (s *TestStore) Done() <-chan struct{} {
+ return s.done
+}
diff --git a/util/syspolicy/syspolicy.go b/util/syspolicy/syspolicy.go
index 76e11e2b6..1ff9ff97a 100644
--- a/util/syspolicy/syspolicy.go
+++ b/util/syspolicy/syspolicy.go
@@ -1,122 +1,83 @@
// Copyright (c) Tailscale Inc & AUTHORS
// SPDX-License-Identifier: BSD-3-Clause
-// Package syspolicy provides functions to retrieve system settings of a device.
+// Package syspolicy facilitates retrieval of the current policy settings
+// applied to the device or user and receiving notifications when the policy
+// changes.
+//
+// It provides functions that return specific policy settings by their unique
+// [setting.Key]s, such as [GetBoolean], [GetUint64], [GetString],
+// [GetStringArray], [GetPreferenceOption], [GetVisibility] and [GetDuration].
package syspolicy
import (
"errors"
+ "fmt"
+ "reflect"
"time"
+
+ "tailscale.com/util/syspolicy/rsop"
+ "tailscale.com/util/syspolicy/setting"
+)
+
+var (
+ // ErrNotConfigured is returned when the requested policy setting is not configured.
+ ErrNotConfigured = setting.ErrNotConfigured
+ // ErrTypeMismatch is returned when there's a type mismatch between the actual type
+ // of the setting value and the expected type.
+ ErrTypeMismatch = setting.ErrTypeMismatch
+ // ErrNoSuchKey is returned by [setting.DefinitionOf] when no policy setting
+ // has been registered with the specified key.
+ //
+ // Until 2024-08-02, this error was also returned by a [Handler] when the specified
+ // key did not have a value set. While the package maintains compatibility with this
+ // usage of ErrNoSuchKey, it is recommended to return [ErrNotConfigured] from newer
+ // [source.Store] implementations.
+ ErrNoSuchKey = setting.ErrNoSuchKey
)
+// GetString returns a string policy setting with the specified key,
+// or defaultValue if it does not exist.
func GetString(key Key, defaultValue string) (string, error) {
- markHandlerInUse()
- v, err := handler.ReadString(string(key))
- if errors.Is(err, ErrNoSuchKey) {
- return defaultValue, nil
- }
- return v, err
+ return getCurrentPolicySettingValue(key, defaultValue)
}
+// GetUint64 returns a numeric policy setting with the specified key,
+// or defaultValue if it does not exist.
func GetUint64(key Key, defaultValue uint64) (uint64, error) {
- markHandlerInUse()
- v, err := handler.ReadUInt64(string(key))
- if errors.Is(err, ErrNoSuchKey) {
- return defaultValue, nil
- }
- return v, err
+ return getCurrentPolicySettingValue(key, defaultValue)
}
+// GetBoolean returns a boolean policy setting with the specified key,
+// or defaultValue if it does not exist.
func GetBoolean(key Key, defaultValue bool) (bool, error) {
- markHandlerInUse()
- v, err := handler.ReadBoolean(string(key))
- if errors.Is(err, ErrNoSuchKey) {
- return defaultValue, nil
- }
- return v, err
+ return getCurrentPolicySettingValue(key, defaultValue)
}
+// GetStringArray returns a multi-string policy setting with the specified key,
+// or defaultValue if it does not exist.
func GetStringArray(key Key, defaultValue []string) ([]string, error) {
- markHandlerInUse()
- v, err := handler.ReadStringArray(string(key))
- if errors.Is(err, ErrNoSuchKey) {
- return defaultValue, nil
- }
- return v, err
+ return getCurrentPolicySettingValue(key, defaultValue)
}
-// PreferenceOption is a policy that governs whether a boolean variable
-// is forcibly assigned an administrator-defined value, or allowed to receive
-// a user-defined value.
-type PreferenceOption int
-
-const (
- showChoiceByPolicy PreferenceOption = iota
- neverByPolicy
- alwaysByPolicy
+type (
+ // PreferenceOption is a policy that governs whether a boolean variable
+ // is forcibly assigned an administrator-defined value, or allowed to receive
+ // a user-defined value.
+ PreferenceOption = setting.PreferenceOption
+ // Visibility is a policy that controls whether or not a particular
+ // component of a user interface is to be shown.
+ Visibility = setting.Visibility
)
-// Show returns if the UI option that controls the choice administered by this
-// policy should be shown. Currently this is true if and only if the policy is
-// showChoiceByPolicy.
-func (p PreferenceOption) Show() bool {
- return p == showChoiceByPolicy
-}
-
-// ShouldEnable checks if the choice administered by this policy should be
-// enabled. If the administrator has chosen a setting, the administrator's
-// setting is returned, otherwise userChoice is returned.
-func (p PreferenceOption) ShouldEnable(userChoice bool) bool {
- switch p {
- case neverByPolicy:
- return false
- case alwaysByPolicy:
- return true
- default:
- return userChoice
- }
-}
-
-// WillOverride checks if the choice administered by the policy is different
-// from the user's choice.
-func (p PreferenceOption) WillOverride(userChoice bool) bool {
- return p.ShouldEnable(userChoice) != userChoice
-}
-
// GetPreferenceOption loads a policy from the registry that can be
// managed by an enterprise policy management system and allows administrative
// overrides of users' choices in a way that we do not want tailcontrol to have
// the authority to set. It describes user-decides/always/never options, where
// "always" and "never" remove the user's ability to make a selection. If not
// present or set to a different value, "user-decides" is the default.
-func GetPreferenceOption(name Key) (PreferenceOption, error) {
- opt, err := GetString(name, "user-decides")
- if err != nil {
- return showChoiceByPolicy, err
- }
- switch opt {
- case "always":
- return alwaysByPolicy, nil
- case "never":
- return neverByPolicy, nil
- default:
- return showChoiceByPolicy, nil
- }
-}
-
-// Visibility is a policy that controls whether or not a particular
-// component of a user interface is to be shown.
-type Visibility byte
-
-const (
- visibleByPolicy Visibility = 'v'
- hiddenByPolicy Visibility = 'h'
-)
-
-// Show reports whether the UI option administered by this policy should be shown.
-// Currently this is true if and only if the policy is visibleByPolicy.
-func (p Visibility) Show() bool {
- return p == visibleByPolicy
+func GetPreferenceOption(name Key) (setting.PreferenceOption, error) {
+ return getCurrentPolicySettingValue(name, setting.ShowChoiceByPolicy)
}
// GetVisibility loads a policy from the registry that can be managed
@@ -124,17 +85,8 @@ func (p Visibility) Show() bool {
// for UI elements. The registry value should be a string set to "show" (return
// true) or "hide" (return true). If not present or set to a different value,
// "show" (return false) is the default.
-func GetVisibility(name Key) (Visibility, error) {
- opt, err := GetString(name, "show")
- if err != nil {
- return visibleByPolicy, err
- }
- switch opt {
- case "hide":
- return hiddenByPolicy, nil
- default:
- return visibleByPolicy, nil
- }
+func GetVisibility(name Key) (setting.Visibility, error) {
+ return getCurrentPolicySettingValue(name, setting.VisibleByPolicy)
}
// GetDuration loads a policy from the registry that can be managed
@@ -143,15 +95,48 @@ func GetVisibility(name Key) (Visibility, error) {
// understands. If the registry value is "" or can not be processed,
// defaultValue is returned instead.
func GetDuration(name Key, defaultValue time.Duration) (time.Duration, error) {
- opt, err := GetString(name, "")
- if opt == "" || err != nil {
- return defaultValue, err
+ d, err := getCurrentPolicySettingValue(name, defaultValue)
+ if err != nil {
+ return d, err
}
- v, err := time.ParseDuration(opt)
- if err != nil || v < 0 {
+ if d < 0 {
return defaultValue, nil
}
- return v, nil
+ return d, nil
+}
+
+// getCurrentPolicySettingValue returns the value of the policy setting
+// specified by its key from the [rsop.Policy] of the [CurrentScope]. It
+// returns def if the policy setting is not configured, or an error if it has
+// an error or could not be converted to the specified type T.
+func getCurrentPolicySettingValue[T setting.ValueType](key Key, def T) (T, error) {
+ resultant, err := rsop.PolicyFor(setting.CurrentScope())
+ if err != nil {
+ return def, err
+ }
+ value, err := resultant.Get().GetErr(key)
+ if err != nil {
+ if errors.Is(err, setting.ErrNotConfigured) || errors.Is(err, setting.ErrNoSuchKey) {
+ return def, nil
+ }
+ return def, err
+ }
+ if res, ok := value.(T); ok {
+ return res, nil
+ }
+ return convertPolicySettingValueTo(value, def)
+}
+
+func convertPolicySettingValueTo[T setting.ValueType](value any, def T) (T, error) {
+ // Convert [PreferenceOption], [Visibility], or [time.Duration] back to a string
+ // if someone requests a string instead of the actual setting's value.
+ // TODO(nickkhyl): check if this behavior is relied upon anywhere besides the old tests.
+ if reflect.TypeFor[T]().Kind() == reflect.String {
+ if str, ok := value.(fmt.Stringer); ok {
+ return any(str.String()).(T), nil
+ }
+ }
+ return def, fmt.Errorf("%w: got %T, want %T", setting.ErrTypeMismatch, value, def)
}
// SelectControlURL returns the ControlURL to use based on a value in
diff --git a/util/syspolicy/syspolicy_test.go b/util/syspolicy/syspolicy_test.go
index c2810ebbb..2adbe9d25 100644
--- a/util/syspolicy/syspolicy_test.go
+++ b/util/syspolicy/syspolicy_test.go
@@ -5,16 +5,24 @@ package syspolicy
import (
"errors"
+ "fmt"
"slices"
"testing"
"time"
+
+ "tailscale.com/types/logger"
+ "tailscale.com/util/syspolicy/internal/loggerx"
+ "tailscale.com/util/syspolicy/internal/metrics"
+ "tailscale.com/util/syspolicy/rsop"
+ "tailscale.com/util/syspolicy/setting"
+ "tailscale.com/util/syspolicy/source"
)
// testHandler encompasses all data types returned when testing any of the syspolicy
// methods that involve getting a policy value.
// For keys and the corresponding values, check policy_keys.go.
type testHandler struct {
- t *testing.T
+ t testing.TB
key Key
s string
u64 uint64
@@ -28,7 +36,10 @@ var someOtherError = errors.New("error other than not found")
func (th *testHandler) ReadString(key string) (string, error) {
if key != string(th.key) {
- th.t.Errorf("ReadString(%q) want %q", key, th.key)
+ // The syspolicy package now reads and caches all registered policy settings.
+ // Therefore, it is expected to call the handler requesting all policies
+ // rather than just the specific ones we asked for.
+ return "", ErrNotConfigured
}
th.calls++
return th.s, th.err
@@ -36,7 +47,10 @@ func (th *testHandler) ReadString(key string) (string, error) {
func (th *testHandler) ReadUInt64(key string) (uint64, error) {
if key != string(th.key) {
- th.t.Errorf("ReadUint64(%q) want %q", key, th.key)
+ // The syspolicy package now reads and caches all registered policy settings.
+ // Therefore, it is expected to call the handler requesting all policies
+ // rather than just the specific ones we asked for.
+ return 0, ErrNotConfigured
}
th.calls++
return th.u64, th.err
@@ -44,7 +58,10 @@ func (th *testHandler) ReadUInt64(key string) (uint64, error) {
func (th *testHandler) ReadBoolean(key string) (bool, error) {
if key != string(th.key) {
- th.t.Errorf("ReadBool(%q) want %q", key, th.key)
+ // The syspolicy package now reads and caches all registered policy settings.
+ // Therefore, it is expected to call the handler requesting all policies
+ // rather than just the specific ones we asked for.
+ return false, ErrNotConfigured
}
th.calls++
return th.b, th.err
@@ -52,7 +69,10 @@ func (th *testHandler) ReadBoolean(key string) (bool, error) {
func (th *testHandler) ReadStringArray(key string) ([]string, error) {
if key != string(th.key) {
- th.t.Errorf("ReadStringArray(%q) want %q", key, th.key)
+ // The syspolicy package now reads and caches all registered policy settings.
+ // Therefore, it is expected to call the handler requesting all policies
+ // rather than just the specific ones we asked for.
+ return nil, ErrNotConfigured
}
th.calls++
return th.sArr, th.err
@@ -67,23 +87,28 @@ func TestGetString(t *testing.T) {
defaultValue string
wantValue string
wantError error
+ wantMetrics []metrics.TestState
}{
{
name: "read existing value",
key: AdminConsoleVisibility,
handlerValue: "hide",
wantValue: "hide",
+ wantMetrics: []metrics.TestState{
+ {Name: "$os_syspolicy_any", Value: 1},
+ {Name: "$os_syspolicy_AdminConsole", Value: 1},
+ },
},
{
name: "read non-existing value",
key: EnableServerMode,
- handlerError: ErrNoSuchKey,
+ handlerError: ErrNotConfigured,
wantError: nil,
},
{
name: "read non-existing value, non-blank default",
key: EnableServerMode,
- handlerError: ErrNoSuchKey,
+ handlerError: ErrNotConfigured,
defaultValue: "test",
wantValue: "test",
wantError: nil,
@@ -93,11 +118,17 @@ func TestGetString(t *testing.T) {
key: NetworkDevicesVisibility,
handlerError: someOtherError,
wantError: someOtherError,
+ wantMetrics: []metrics.TestState{
+ {Name: "$os_syspolicy_errors", Value: 1},
+ {Name: "$os_syspolicy_NetworkDevices_error", Value: 1},
+ },
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
+ h := metrics.NewTestHandler(t)
+ metrics.SetHooksForTest(t, h.AddMetric, h.SetMetric)
SetHandlerForTest(t, &testHandler{
t: t,
key: tt.key,
@@ -105,12 +136,21 @@ func TestGetString(t *testing.T) {
err: tt.handlerError,
})
value, err := GetString(tt.key, tt.defaultValue)
- if err != tt.wantError {
+ if !errorsMatchForTest(err, tt.wantError) {
t.Errorf("err=%q, want %q", err, tt.wantError)
}
if value != tt.wantValue {
t.Errorf("value=%v, want %v", value, tt.wantValue)
}
+ wantMetrics := tt.wantMetrics
+ if !metrics.ShouldReport() {
+ // Check that metrics are not reported on platforms
+ // where they shouldn't be reported.
+ // As of 2024-08-02, syspolicy only reports metrics
+ // on Windows and Android.
+ wantMetrics = nil
+ }
+ h.MustEqual(wantMetrics...)
})
}
}
@@ -127,7 +167,7 @@ func TestGetUint64(t *testing.T) {
}{
{
name: "read existing value",
- key: KeyExpirationNoticeTime,
+ key: LogSCMInteractions,
handlerValue: 1,
wantValue: 1,
},
@@ -135,14 +175,14 @@ func TestGetUint64(t *testing.T) {
name: "read non-existing value",
key: LogSCMInteractions,
handlerValue: 0,
- handlerError: ErrNoSuchKey,
+ handlerError: ErrNotConfigured,
wantValue: 0,
},
{
name: "read non-existing value, non-zero default",
key: LogSCMInteractions,
defaultValue: 2,
- handlerError: ErrNoSuchKey,
+ handlerError: ErrNotConfigured,
wantValue: 2,
},
{
@@ -155,14 +195,21 @@ func TestGetUint64(t *testing.T) {
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
- SetHandlerForTest(t, &testHandler{
+ // None of the policy settings tested here are integers.
+ // In fact, we don't have any integer policies as of 2024-07-29.
+ // However, we can register each of them as an integer policy setting
+ // for the duration of the test, providing us with something to test against.
+ if err := setting.SetDefinitionsForTest(t, setting.NewDefinition(tt.key, setting.DeviceSetting, setting.IntegerValue)); err != nil {
+ t.Fatalf("SetDefinitionsForTest failed: %v", err)
+ }
+ rsop.RegisterStoreForTest(t, tt.name, setting.DeviceScope, WrapHandler(&testHandler{
t: t,
key: tt.key,
u64: tt.handlerValue,
err: tt.handlerError,
- })
+ }))
value, err := GetUint64(tt.key, tt.defaultValue)
- if err != tt.wantError {
+ if !errorsMatchForTest(err, tt.wantError) {
t.Errorf("err=%q, want %q", err, tt.wantError)
}
if value != tt.wantValue {
@@ -181,32 +228,43 @@ func TestGetBoolean(t *testing.T) {
defaultValue bool
wantValue bool
wantError error
+ wantMetrics []metrics.TestState
}{
{
name: "read existing value",
key: FlushDNSOnSessionUnlock,
handlerValue: true,
wantValue: true,
+ wantMetrics: []metrics.TestState{
+ {Name: "$os_syspolicy_any", Value: 1},
+ {Name: "$os_syspolicy_FlushDNSOnSessionUnlock", Value: 1},
+ },
},
{
name: "read non-existing value",
key: LogSCMInteractions,
handlerValue: false,
- handlerError: ErrNoSuchKey,
+ handlerError: ErrNotConfigured,
wantValue: false,
},
{
name: "reading value returns other error",
key: FlushDNSOnSessionUnlock,
handlerError: someOtherError,
- wantError: someOtherError,
+ wantError: someOtherError, // expect error...
defaultValue: true,
- wantValue: false,
+ wantValue: true, // ...AND default value if the handler fails.
+ wantMetrics: []metrics.TestState{
+ {Name: "$os_syspolicy_errors", Value: 1},
+ {Name: "$os_syspolicy_FlushDNSOnSessionUnlock_error", Value: 1},
+ },
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
+ h := metrics.NewTestHandler(t)
+ metrics.SetHooksForTest(t, h.AddMetric, h.SetMetric)
SetHandlerForTest(t, &testHandler{
t: t,
key: tt.key,
@@ -214,12 +272,21 @@ func TestGetBoolean(t *testing.T) {
err: tt.handlerError,
})
value, err := GetBoolean(tt.key, tt.defaultValue)
- if err != tt.wantError {
+ if !errorsMatchForTest(err, tt.wantError) {
t.Errorf("err=%q, want %q", err, tt.wantError)
}
if value != tt.wantValue {
t.Errorf("value=%v, want %v", value, tt.wantValue)
}
+ wantMetrics := tt.wantMetrics
+ if !metrics.ShouldReport() {
+ // Check that metrics are not reported on platforms
+ // where they shouldn't be reported.
+ // As of 2024-08-02, syspolicy only reports metrics
+ // on Windows and Android.
+ wantMetrics = nil
+ }
+ h.MustEqual(wantMetrics...)
})
}
}
@@ -232,42 +299,61 @@ func TestGetPreferenceOption(t *testing.T) {
handlerError error
wantValue PreferenceOption
wantError error
+ wantMetrics []metrics.TestState
}{
{
name: "always by policy",
key: EnableIncomingConnections,
handlerValue: "always",
- wantValue: alwaysByPolicy,
+ wantValue: setting.AlwaysByPolicy,
+ wantMetrics: []metrics.TestState{
+ {Name: "$os_syspolicy_any", Value: 1},
+ {Name: "$os_syspolicy_AllowIncomingConnections", Value: 1},
+ },
},
{
name: "never by policy",
key: EnableIncomingConnections,
handlerValue: "never",
- wantValue: neverByPolicy,
+ wantValue: setting.NeverByPolicy,
+ wantMetrics: []metrics.TestState{
+ {Name: "$os_syspolicy_any", Value: 1},
+ {Name: "$os_syspolicy_AllowIncomingConnections", Value: 1},
+ },
},
{
name: "use default",
key: EnableIncomingConnections,
handlerValue: "",
- wantValue: showChoiceByPolicy,
+ wantValue: setting.ShowChoiceByPolicy,
+ wantMetrics: []metrics.TestState{
+ {Name: "$os_syspolicy_any", Value: 1},
+ {Name: "$os_syspolicy_AllowIncomingConnections", Value: 1},
+ },
},
{
name: "read non-existing value",
key: EnableIncomingConnections,
- handlerError: ErrNoSuchKey,
- wantValue: showChoiceByPolicy,
+ handlerError: ErrNotConfigured,
+ wantValue: setting.ShowChoiceByPolicy,
},
{
name: "other error is returned",
key: EnableIncomingConnections,
handlerError: someOtherError,
- wantValue: showChoiceByPolicy,
+ wantValue: setting.ShowChoiceByPolicy,
wantError: someOtherError,
+ wantMetrics: []metrics.TestState{
+ {Name: "$os_syspolicy_errors", Value: 1},
+ {Name: "$os_syspolicy_AllowIncomingConnections_error", Value: 1},
+ },
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
+ h := metrics.NewTestHandler(t)
+ metrics.SetHooksForTest(t, h.AddMetric, h.SetMetric)
SetHandlerForTest(t, &testHandler{
t: t,
key: tt.key,
@@ -275,12 +361,21 @@ func TestGetPreferenceOption(t *testing.T) {
err: tt.handlerError,
})
option, err := GetPreferenceOption(tt.key)
- if err != tt.wantError {
+ if !errorsMatchForTest(err, tt.wantError) {
t.Errorf("err=%q, want %q", err, tt.wantError)
}
if option != tt.wantValue {
t.Errorf("option=%v, want %v", option, tt.wantValue)
}
+ wantMetrics := tt.wantMetrics
+ if !metrics.ShouldReport() {
+ // Check that metrics are not reported on platforms
+ // where they shouldn't be reported.
+ // As of 2024-08-02, syspolicy only reports metrics
+ // on Windows and Android.
+ wantMetrics = nil
+ }
+ h.MustEqual(wantMetrics...)
})
}
}
@@ -293,38 +388,53 @@ func TestGetVisibility(t *testing.T) {
handlerError error
wantValue Visibility
wantError error
+ wantMetrics []metrics.TestState
}{
{
name: "hidden by policy",
key: AdminConsoleVisibility,
handlerValue: "hide",
- wantValue: hiddenByPolicy,
+ wantValue: setting.HiddenByPolicy,
+ wantMetrics: []metrics.TestState{
+ {Name: "$os_syspolicy_any", Value: 1},
+ {Name: "$os_syspolicy_AdminConsole", Value: 1},
+ },
},
{
name: "visibility default",
key: AdminConsoleVisibility,
handlerValue: "show",
- wantValue: visibleByPolicy,
+ wantValue: setting.VisibleByPolicy,
+ wantMetrics: []metrics.TestState{
+ {Name: "$os_syspolicy_any", Value: 1},
+ {Name: "$os_syspolicy_AdminConsole", Value: 1},
+ },
},
{
name: "read non-existing value",
key: AdminConsoleVisibility,
handlerValue: "show",
- handlerError: ErrNoSuchKey,
- wantValue: visibleByPolicy,
+ handlerError: ErrNotConfigured,
+ wantValue: setting.VisibleByPolicy,
},
{
name: "other error is returned",
key: AdminConsoleVisibility,
handlerValue: "show",
handlerError: someOtherError,
- wantValue: visibleByPolicy,
+ wantValue: setting.VisibleByPolicy,
wantError: someOtherError,
+ wantMetrics: []metrics.TestState{
+ {Name: "$os_syspolicy_errors", Value: 1},
+ {Name: "$os_syspolicy_AdminConsole_error", Value: 1},
+ },
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
+ h := metrics.NewTestHandler(t)
+ metrics.SetHooksForTest(t, h.AddMetric, h.SetMetric)
SetHandlerForTest(t, &testHandler{
t: t,
key: tt.key,
@@ -332,12 +442,21 @@ func TestGetVisibility(t *testing.T) {
err: tt.handlerError,
})
visibility, err := GetVisibility(tt.key)
- if err != tt.wantError {
+ if !errorsMatchForTest(err, tt.wantError) {
t.Errorf("err=%q, want %q", err, tt.wantError)
}
if visibility != tt.wantValue {
t.Errorf("visibility=%v, want %v", visibility, tt.wantValue)
}
+ wantMetrics := tt.wantMetrics
+ if !metrics.ShouldReport() {
+ // Check that metrics are not reported on platforms
+ // where they shouldn't be reported.
+ // As of 2024-08-02, syspolicy only reports metrics
+ // on Windows and Android.
+ wantMetrics = nil
+ }
+ h.MustEqual(wantMetrics...)
})
}
}
@@ -351,6 +470,7 @@ func TestGetDuration(t *testing.T) {
defaultValue time.Duration
wantValue time.Duration
wantError error
+ wantMetrics []metrics.TestState
}{
{
name: "read existing value",
@@ -358,25 +478,34 @@ func TestGetDuration(t *testing.T) {
handlerValue: "2h",
wantValue: 2 * time.Hour,
defaultValue: 24 * time.Hour,
+ wantMetrics: []metrics.TestState{
+ {Name: "$os_syspolicy_any", Value: 1},
+ {Name: "$os_syspolicy_KeyExpirationNotice", Value: 1},
+ },
},
{
name: "invalid duration value",
key: KeyExpirationNoticeTime,
handlerValue: "-20",
wantValue: 24 * time.Hour,
+ wantError: errors.New(`time: missing unit in duration "-20"`),
defaultValue: 24 * time.Hour,
+ wantMetrics: []metrics.TestState{
+ {Name: "$os_syspolicy_errors", Value: 1},
+ {Name: "$os_syspolicy_KeyExpirationNotice_error", Value: 1},
+ },
},
{
name: "read non-existing value",
key: KeyExpirationNoticeTime,
- handlerError: ErrNoSuchKey,
+ handlerError: ErrNotConfigured,
wantValue: 24 * time.Hour,
defaultValue: 24 * time.Hour,
},
{
name: "read non-existing value different default",
key: KeyExpirationNoticeTime,
- handlerError: ErrNoSuchKey,
+ handlerError: ErrNotConfigured,
wantValue: 0 * time.Second,
defaultValue: 0 * time.Second,
},
@@ -387,11 +516,17 @@ func TestGetDuration(t *testing.T) {
wantValue: 24 * time.Hour,
wantError: someOtherError,
defaultValue: 24 * time.Hour,
+ wantMetrics: []metrics.TestState{
+ {Name: "$os_syspolicy_errors", Value: 1},
+ {Name: "$os_syspolicy_KeyExpirationNotice_error", Value: 1},
+ },
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
+ h := metrics.NewTestHandler(t)
+ metrics.SetHooksForTest(t, h.AddMetric, h.SetMetric)
SetHandlerForTest(t, &testHandler{
t: t,
key: tt.key,
@@ -399,12 +534,21 @@ func TestGetDuration(t *testing.T) {
err: tt.handlerError,
})
duration, err := GetDuration(tt.key, tt.defaultValue)
- if err != tt.wantError {
+ if fmt.Sprint(err) != fmt.Sprint(tt.wantError) {
t.Errorf("err=%q, want %q", err, tt.wantError)
}
if duration != tt.wantValue {
t.Errorf("duration=%v, want %v", duration, tt.wantValue)
}
+ wantMetrics := tt.wantMetrics
+ if !metrics.ShouldReport() {
+ // Check that metrics are not reported on platforms
+ // where they shouldn't be reported.
+ // As of 2024-08-02, syspolicy only reports metrics
+ // on Windows and Android.
+ wantMetrics = nil
+ }
+ h.MustEqual(wantMetrics...)
})
}
}
@@ -418,23 +562,28 @@ func TestGetStringArray(t *testing.T) {
defaultValue []string
wantValue []string
wantError error
+ wantMetrics []metrics.TestState
}{
{
name: "read existing value",
key: AllowedSuggestedExitNodes,
handlerValue: []string{"foo", "bar"},
wantValue: []string{"foo", "bar"},
+ wantMetrics: []metrics.TestState{
+ {Name: "$os_syspolicy_any", Value: 1},
+ {Name: "$os_syspolicy_AllowedSuggestedExitNodes", Value: 1},
+ },
},
{
name: "read non-existing value",
key: AllowedSuggestedExitNodes,
- handlerError: ErrNoSuchKey,
+ handlerError: ErrNotConfigured,
wantError: nil,
},
{
name: "read non-existing value, non nil default",
key: AllowedSuggestedExitNodes,
- handlerError: ErrNoSuchKey,
+ handlerError: ErrNotConfigured,
defaultValue: []string{"foo", "bar"},
wantValue: []string{"foo", "bar"},
wantError: nil,
@@ -444,11 +593,17 @@ func TestGetStringArray(t *testing.T) {
key: AllowedSuggestedExitNodes,
handlerError: someOtherError,
wantError: someOtherError,
+ wantMetrics: []metrics.TestState{
+ {Name: "$os_syspolicy_errors", Value: 1},
+ {Name: "$os_syspolicy_AllowedSuggestedExitNodes_error", Value: 1},
+ },
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
+ h := metrics.NewTestHandler(t)
+ metrics.SetHooksForTest(t, h.AddMetric, h.SetMetric)
SetHandlerForTest(t, &testHandler{
t: t,
key: tt.key,
@@ -456,16 +611,47 @@ func TestGetStringArray(t *testing.T) {
err: tt.handlerError,
})
value, err := GetStringArray(tt.key, tt.defaultValue)
- if err != tt.wantError {
+ if !errorsMatchForTest(err, tt.wantError) {
t.Errorf("err=%q, want %q", err, tt.wantError)
}
if !slices.Equal(tt.wantValue, value) {
t.Errorf("value=%v, want %v", value, tt.wantValue)
}
+ wantMetrics := tt.wantMetrics
+ if !metrics.ShouldReport() {
+ // Check that metrics are not reported on platforms
+ // where they shouldn't be reported.
+ // As of 2024-08-02, syspolicy only reports metrics
+ // on Windows and Android.
+ wantMetrics = nil
+ }
+ h.MustEqual(wantMetrics...)
})
}
}
+func BenchmarkGetString(b *testing.B) {
+ loggerx.SetForTest(b, logger.Discard, logger.Discard)
+ setWellKnownSettingsForTest(b)
+
+ store := source.NewTestStore(b)
+ wantControlURL := "https://login.tailscale.com"
+ store.SetStrings(source.TestSetting[string]{Key: ControlURL, Value: wantControlURL})
+
+ _, err := rsop.RegisterStoreForTest(b, "Test Store", setting.DeviceScope, store)
+ if err != nil {
+ b.Fatal(err)
+ }
+
+ b.ResetTimer()
+ for i := 0; i < b.N; i++ {
+ gotControlURL, _ := GetString(ControlURL, "https://controlplane.tailscale.com")
+ if gotControlURL != wantControlURL {
+ b.Fatalf("got %v; want %v", gotControlURL, wantControlURL)
+ }
+ }
+}
+
func TestSelectControlURL(t *testing.T) {
tests := []struct {
reg, disk, want string
@@ -497,3 +683,13 @@ func TestSelectControlURL(t *testing.T) {
}
}
}
+
+func errorsMatchForTest(got, want error) bool {
+ if got == nil && want == nil {
+ return true
+ }
+ if got == nil || want == nil {
+ return false
+ }
+ return errors.Is(got, want) || got.Error() == want.Error()
+}
diff --git a/util/syspolicy/syspolicy_windows.go b/util/syspolicy/syspolicy_windows.go
new file mode 100644
index 000000000..d17fa10b5
--- /dev/null
+++ b/util/syspolicy/syspolicy_windows.go
@@ -0,0 +1,93 @@
+// Copyright (c) Tailscale Inc & AUTHORS
+// SPDX-License-Identifier: BSD-3-Clause
+
+package syspolicy
+
+import (
+ "errors"
+ "fmt"
+ "os/user"
+
+ "tailscale.com/util/syspolicy/internal"
+ "tailscale.com/util/syspolicy/internal/lazyinit"
+ "tailscale.com/util/syspolicy/rsop"
+ "tailscale.com/util/syspolicy/setting"
+ "tailscale.com/util/syspolicy/source"
+ "tailscale.com/util/testenv"
+)
+
+func init() {
+ // On Windows, we should automatically register the Registry-based policy
+ // store for the device. If we are running in a user's security context
+ // (e.g., we're the GUI), we should also register the Registry policy store for
+ // the user. In the future, we should register (and unregister) user policy
+ // stores whenever a user connects to the local backend. This ensures the
+ // backend is aware of the user's policy settings and can send them to the
+ // GUI/CLI/Web clients on demand or whenever they change.
+ //
+ // Other platforms, such as macOS, iOS and Android, should register their
+ // platform-specific policy stores via [RegisterStore] (or [RegisterHandler]
+ // until they implement the [Store] interface).
+ //
+ // External code, such as the ipnlocal package, may choose to register
+ // additional policy stores, such as config files and policies received from
+ // the control plane.
+ lazyinit.Defer(func() error {
+ // Do not register or use default policy stores during tests.
+ // Each test should set up its own necessary configurations.
+ if testenv.InTest() {
+ return nil
+ }
+ return configureSyspolicy(nil)
+ })
+}
+
+// configureSyspolicy configures syspolicy for use on Windows,
+// either in test or regular builds depending on whether tb has a non-nil value.
+func configureSyspolicy(tb internal.TB) error {
+ const localSystemSID = "S-1-5-18"
+ // Always create and register a machine policy store that reads
+ // policy settings from the HKEY_LOCAL_MACHINE registry hive.
+ machineStore, err := source.NewMachinePlatformPolicyStore()
+ if err != nil {
+ return fmt.Errorf("failed to create the machine policy store: %v", err)
+ }
+ if tb == nil {
+ _, err = rsop.RegisterStore("Platform", setting.DeviceScope, machineStore)
+ } else {
+ _, err = rsop.RegisterStoreForTest(tb, "Platform", setting.DeviceScope, machineStore)
+ }
+ if err != nil {
+ return err
+ }
+ // Check whether the current process is running as Local System or not.
+ u, err := user.Current()
+ if err != nil {
+ return err
+ }
+ if u.Uid == localSystemSID {
+ return nil
+ }
+ // If it's not a Local System's process (e.g., the GUI and not the tailscaled service),
+ // we should create and use a policy store for the current user that reads
+ // policy settings from that user's registry hive (HKEY_CURRENT_USER).
+ userStore, err := source.NewUserPlatformPolicyStore(0)
+ if err != nil {
+ return fmt.Errorf("failed to create the current user's policy store: %v", err)
+ }
+ if tb == nil {
+ _, err = rsop.RegisterStore("Platform", setting.CurrentUserScope, userStore)
+ } else {
+ _, err = rsop.RegisterStoreForTest(tb, "Platform", setting.CurrentUserScope, userStore)
+ }
+ if err != nil {
+ return err
+ }
+ // And also set [CurrentUserScope] as the [CurrentScope], so [GetString],
+ // [GetVisibility] and similar functions would be returning a merged result
+ // of the machine's and user's policies.
+ if !setting.SetCurrentScope(setting.CurrentUserScope) {
+ return errors.New("current scope already set")
+ }
+ return nil
+}
diff --git a/util/winutil/gp/policylock_windows.go b/util/winutil/gp/policylock_windows.go
index f92c534bb..95453aa16 100644
--- a/util/winutil/gp/policylock_windows.go
+++ b/util/winutil/gp/policylock_windows.go
@@ -189,6 +189,7 @@ func (l *PolicyLock) lockSlow() (err error) {
select {
case resultCh <- policyLockResult{handle, err}:
// lockSlow has received the result.
+ break send_result
default:
select {
case <-closing: