summaryrefslogtreecommitdiffhomepage
path: root/ipn
diff options
context:
space:
mode:
authorkari-ts <kari@tailscale.com>2025-03-19 11:28:04 -0700
committerkari-ts <kari@tailscale.com>2025-04-04 14:24:56 -0700
commit6d5c7b11913e09b061e863411ad488dc44a13870 (patch)
tree9e1789b5080ae4a92523611e49920dcb1102604b /ipn
parentca50599c95e0a4cb7b4aab179e866e202f10c0c4 (diff)
parent3a2c92f08eac8cd8f50356ff288e40a28636ee42 (diff)
downloadtailscale-kari/taildropsaf.tar.xz
tailscale-kari/taildropsaf.zip
-check if Context.getExternalFilesDirs works as is for private dir
Diffstat (limited to 'ipn')
-rw-r--r--ipn/auditlog/auditlog.go466
-rw-r--r--ipn/auditlog/auditlog_test.go481
-rw-r--r--ipn/ipnauth/actor.go7
-rw-r--r--ipn/ipnauth/policy.go10
-rw-r--r--ipn/ipnlocal/cert.go103
-rw-r--r--ipn/ipnlocal/cert_test.go185
-rw-r--r--ipn/ipnlocal/local.go132
-rw-r--r--ipn/ipnlocal/local_test.go59
-rw-r--r--ipn/ipnstate/ipnstate.go8
-rw-r--r--ipn/store/awsstore/store_aws.go111
-rw-r--r--ipn/store/awsstore/store_aws_stub.go18
-rw-r--r--ipn/store/awsstore/store_aws_test.go61
-rw-r--r--ipn/store/kubestore/store_kube.go317
-rw-r--r--ipn/store/kubestore/store_kube_test.go723
-rw-r--r--ipn/store/store_aws.go10
15 files changed, 2545 insertions, 146 deletions
diff --git a/ipn/auditlog/auditlog.go b/ipn/auditlog/auditlog.go
new file mode 100644
index 000000000..30f39211f
--- /dev/null
+++ b/ipn/auditlog/auditlog.go
@@ -0,0 +1,466 @@
+// Copyright (c) Tailscale Inc & AUTHORS
+// SPDX-License-Identifier: BSD-3-Clause
+
+// Package auditlog provides a mechanism for logging audit events.
+package auditlog
+
+import (
+ "context"
+ "encoding/json"
+ "errors"
+ "fmt"
+ "sort"
+ "sync"
+ "time"
+
+ "tailscale.com/ipn"
+ "tailscale.com/tailcfg"
+ "tailscale.com/types/logger"
+ "tailscale.com/util/rands"
+ "tailscale.com/util/set"
+)
+
+// transaction represents an audit log that has not yet been sent to the control plane.
+type transaction struct {
+ // EventID is the unique identifier for the event being logged.
+ // This is used on the client side only and is not sent to control.
+ EventID string `json:",omitempty"`
+ // Retries is the number of times the logger has attempted to send this log.
+ // This is used on the client side only and is not sent to control.
+ Retries int `json:",omitempty"`
+
+ // Action is the action to be logged. It must correspond to a known action in the control plane.
+ Action tailcfg.ClientAuditAction `json:",omitempty"`
+ // Details is an opaque string specific to the action being logged. Empty strings may not
+ // be valid depending on the action being logged.
+ Details string `json:",omitempty"`
+ // TimeStamp is the time at which the audit log was generated on the node.
+ TimeStamp time.Time `json:",omitzero"`
+}
+
+// Transport provides a means for a client to send audit logs to a consumer (typically the control plane).
+type Transport interface {
+ // SendAuditLog sends an audit log to a consumer of audit logs.
+ // Errors should be checked with [IsRetryableError] for retryability.
+ SendAuditLog(context.Context, tailcfg.AuditLogRequest) error
+}
+
+// LogStore provides a means for a [Logger] to persist logs to disk or memory.
+type LogStore interface {
+ // Save saves the given data to a persistent store. Save will overwrite existing data
+ // for the given key.
+ save(key ipn.ProfileID, txns []*transaction) error
+
+ // Load retrieves the data from a persistent store. Returns a nil slice and
+ // no error if no data exists for the given key.
+ load(key ipn.ProfileID) ([]*transaction, error)
+}
+
+// Opts contains the configuration options for a [Logger].
+type Opts struct {
+ // RetryLimit is the maximum number of attempts the logger will make to send a log before giving up.
+ RetryLimit int
+ // Store is the persistent store used to save logs to disk. Must be non-nil.
+ Store LogStore
+ // Logf is the logger used to log messages from the audit logger. Must be non-nil.
+ Logf logger.Logf
+}
+
+// IsRetryableError returns true if the given error is retryable
+// See [controlclient.apiResponseError]. Potentially retryable errors implement the Retryable() method.
+func IsRetryableError(err error) bool {
+ var retryable interface{ Retryable() bool }
+ return errors.As(err, &retryable) && retryable.Retryable()
+}
+
+type backoffOpts struct {
+ min, max time.Duration
+ multiplier float64
+}
+
+// .5, 1, 2, 4, 8, 10, 10, 10, 10, 10...
+var defaultBackoffOpts = backoffOpts{
+ min: time.Millisecond * 500,
+ max: 10 * time.Second,
+ multiplier: 2,
+}
+
+// Logger provides a queue-based mechanism for submitting audit logs to the control plane - or
+// another suitable consumer. Logs are stored to disk and retried until they are successfully sent,
+// or until they permanently fail.
+//
+// Each individual profile/controlclient tuple should construct and manage a unique [Logger] instance.
+type Logger struct {
+ logf logger.Logf
+ retryLimit int // the maximum number of attempts to send a log before giving up.
+ flusher chan struct{} // channel used to signal a flush operation.
+ done chan struct{} // closed when the flush worker exits.
+ ctx context.Context // canceled when the logger is stopped.
+ ctxCancel context.CancelFunc // cancels ctx.
+ backoffOpts // backoff settings for retry operations.
+
+ // mu protects the fields below.
+ mu sync.Mutex
+ store LogStore // persistent storage for unsent logs.
+ profileID ipn.ProfileID // empty if [Logger.SetProfileID] has not been called.
+ transport Transport // nil until [Logger.Start] is called.
+}
+
+// NewLogger creates a new [Logger] with the given options.
+func NewLogger(opts Opts) *Logger {
+ ctx, cancel := context.WithCancel(context.Background())
+
+ al := &Logger{
+ retryLimit: opts.RetryLimit,
+ logf: logger.WithPrefix(opts.Logf, "auditlog: "),
+ store: opts.Store,
+ flusher: make(chan struct{}, 1),
+ done: make(chan struct{}),
+ ctx: ctx,
+ ctxCancel: cancel,
+ backoffOpts: defaultBackoffOpts,
+ }
+ al.logf("created")
+ return al
+}
+
+// FlushAndStop synchronously flushes all pending logs and stops the audit logger.
+// This will block until a final flush operation completes or context is done.
+// If the logger is already stopped, this will return immediately. All unsent
+// logs will be persisted to the store.
+func (al *Logger) FlushAndStop(ctx context.Context) {
+ al.stop()
+ al.flush(ctx)
+}
+
+// SetProfileID sets the profileID for the logger. This must be called before any logs can be enqueued.
+// The profileID of a logger cannot be changed once set.
+func (al *Logger) SetProfileID(profileID ipn.ProfileID) error {
+ al.mu.Lock()
+ defer al.mu.Unlock()
+ if al.profileID != "" {
+ return errors.New("profileID already set")
+ }
+
+ al.profileID = profileID
+ return nil
+}
+
+// Start starts the audit logger with the given transport.
+// It returns an error if the logger is already started.
+func (al *Logger) Start(t Transport) error {
+ al.mu.Lock()
+ defer al.mu.Unlock()
+
+ if al.transport != nil {
+ return errors.New("already started")
+ }
+
+ al.transport = t
+ pending, err := al.storedCountLocked()
+ if err != nil {
+ al.logf("[unexpected] failed to restore logs: %v", err)
+ }
+ go al.flushWorker()
+ if pending > 0 {
+ al.flushAsync()
+ }
+ return nil
+}
+
+// ErrAuditLogStorageFailure is returned when the logger fails to persist logs to the store.
+var ErrAuditLogStorageFailure = errors.New("audit log storage failure")
+
+// Enqueue queues an audit log to be sent to the control plane (or another suitable consumer/transport).
+// This will return an error if the underlying store fails to save the log or we fail to generate a unique
+// eventID for the log.
+func (al *Logger) Enqueue(action tailcfg.ClientAuditAction, details string) error {
+ txn := &transaction{
+ Action: action,
+ Details: details,
+ TimeStamp: time.Now(),
+ }
+ // Generate a suitably random eventID for the transaction.
+ txn.EventID = fmt.Sprint(txn.TimeStamp, rands.HexString(16))
+ return al.enqueue(txn)
+}
+
+// flushAsync requests an asynchronous flush.
+// It is a no-op if a flush is already pending.
+func (al *Logger) flushAsync() {
+ select {
+ case al.flusher <- struct{}{}:
+ default:
+ }
+}
+
+func (al *Logger) flushWorker() {
+ defer close(al.done)
+
+ var retryDelay time.Duration
+ retry := time.NewTimer(0)
+ retry.Stop()
+
+ for {
+ select {
+ case <-al.ctx.Done():
+ return
+ case <-al.flusher:
+ err := al.flush(al.ctx)
+ switch {
+ case errors.Is(err, context.Canceled):
+ // The logger was stopped, no need to retry.
+ return
+ case err != nil:
+ retryDelay = max(al.backoffOpts.min, min(retryDelay*time.Duration(al.backoffOpts.multiplier), al.backoffOpts.max))
+ al.logf("retrying after %v, %v", retryDelay, err)
+ retry.Reset(retryDelay)
+ default:
+ retryDelay = 0
+ retry.Stop()
+ }
+ case <-retry.C:
+ al.flushAsync()
+ }
+ }
+}
+
+// flush attempts to send all pending logs to the control plane.
+// l.mu must not be held.
+func (al *Logger) flush(ctx context.Context) error {
+ al.mu.Lock()
+ pending, err := al.store.load(al.profileID)
+ t := al.transport
+ al.mu.Unlock()
+
+ if err != nil {
+ // This will catch nil profileIDs
+ return fmt.Errorf("failed to restore pending logs: %w", err)
+ }
+ if len(pending) == 0 {
+ return nil
+ }
+ if t == nil {
+ return errors.New("no transport")
+ }
+
+ complete, unsent := al.sendToTransport(ctx, pending, t)
+ al.markTransactionsDone(complete)
+
+ al.mu.Lock()
+ defer al.mu.Unlock()
+ if err = al.appendToStoreLocked(unsent); err != nil {
+ al.logf("[unexpected] failed to persist logs: %v", err)
+ }
+
+ if len(unsent) != 0 {
+ return fmt.Errorf("failed to send %d logs", len(unsent))
+ }
+
+ if len(complete) != 0 {
+ al.logf("complete %d audit log transactions", len(complete))
+ }
+ return nil
+}
+
+// sendToTransport sends all pending logs to the control plane. Returns a pair of slices
+// containing the logs that were successfully sent (or failed permanently) and those that were not.
+//
+// This may require multiple round trips to the control plane and can be a long running transaction.
+func (al *Logger) sendToTransport(ctx context.Context, pending []*transaction, t Transport) (complete []*transaction, unsent []*transaction) {
+ for i, txn := range pending {
+ req := tailcfg.AuditLogRequest{
+ Action: tailcfg.ClientAuditAction(txn.Action),
+ Details: txn.Details,
+ Timestamp: txn.TimeStamp,
+ }
+
+ if err := t.SendAuditLog(ctx, req); err != nil {
+ switch {
+ case errors.Is(err, context.Canceled) || errors.Is(err, context.DeadlineExceeded):
+ // The contex is done. All further attempts will fail.
+ unsent = append(unsent, pending[i:]...)
+ return complete, unsent
+ case IsRetryableError(err) && txn.Retries+1 < al.retryLimit:
+ // We permit a maximum number of retries for each log. All retriable
+ // errors should be transient and we should be able to send the log eventually, but
+ // we don't want logs to be persisted indefinitely.
+ txn.Retries++
+ unsent = append(unsent, txn)
+ default:
+ complete = append(complete, txn)
+ al.logf("failed permanently: %v", err)
+ }
+ } else {
+ // No error - we're done.
+ complete = append(complete, txn)
+ }
+ }
+
+ return complete, unsent
+}
+
+func (al *Logger) stop() {
+ al.mu.Lock()
+ t := al.transport
+ al.mu.Unlock()
+
+ if t == nil {
+ // No transport means no worker goroutine and done will not be
+ // closed if we cancel the context.
+ return
+ }
+
+ al.ctxCancel()
+ <-al.done
+ al.logf("stopped for profileID: %v", al.profileID)
+}
+
+// appendToStoreLocked persists logs to the store. This will deduplicate
+// logs so it is safe to call this with the same logs multiple time, to
+// requeue failed transactions for example.
+//
+// l.mu must be held.
+func (al *Logger) appendToStoreLocked(txns []*transaction) error {
+ if len(txns) == 0 {
+ return nil
+ }
+
+ if al.profileID == "" {
+ return errors.New("no logId set")
+ }
+
+ persisted, err := al.store.load(al.profileID)
+ if err != nil {
+ al.logf("[unexpected] append failed to restore logs: %v", err)
+ }
+
+ // The order is important here. We want the latest transactions first, which will
+ // ensure when we dedup, the new transactions are seen and the older transactions
+ // are discarded.
+ txnsOut := append(txns, persisted...)
+ txnsOut = deduplicateAndSort(txnsOut)
+
+ return al.store.save(al.profileID, txnsOut)
+}
+
+// storedCountLocked returns the number of logs persisted to the store.
+// al.mu must be held.
+func (al *Logger) storedCountLocked() (int, error) {
+ persisted, err := al.store.load(al.profileID)
+ return len(persisted), err
+}
+
+// markTransactionsDone removes logs from the store that are complete (sent or failed permanently).
+// al.mu must not be held.
+func (al *Logger) markTransactionsDone(sent []*transaction) {
+ al.mu.Lock()
+ defer al.mu.Unlock()
+
+ ids := set.Set[string]{}
+ for _, txn := range sent {
+ ids.Add(txn.EventID)
+ }
+
+ persisted, err := al.store.load(al.profileID)
+ if err != nil {
+ al.logf("[unexpected] markTransactionsDone failed to restore logs: %v", err)
+ }
+ var unsent []*transaction
+ for _, txn := range persisted {
+ if !ids.Contains(txn.EventID) {
+ unsent = append(unsent, txn)
+ }
+ }
+ al.store.save(al.profileID, unsent)
+}
+
+// deduplicateAndSort removes duplicate logs from the given slice and sorts them by timestamp.
+// The first log entry in the slice will be retained, subsequent logs with the same EventID will be discarded.
+func deduplicateAndSort(txns []*transaction) []*transaction {
+ seen := set.Set[string]{}
+ deduped := make([]*transaction, 0, len(txns))
+ for _, txn := range txns {
+ if !seen.Contains(txn.EventID) {
+ deduped = append(deduped, txn)
+ seen.Add(txn.EventID)
+ }
+ }
+ // Sort logs by timestamp - oldest to newest. This will put the oldest logs at
+ // the front of the queue.
+ sort.Slice(deduped, func(i, j int) bool {
+ return deduped[i].TimeStamp.Before(deduped[j].TimeStamp)
+ })
+ return deduped
+}
+
+func (al *Logger) enqueue(txn *transaction) error {
+ al.mu.Lock()
+ defer al.mu.Unlock()
+
+ if err := al.appendToStoreLocked([]*transaction{txn}); err != nil {
+ return fmt.Errorf("%w: %w", ErrAuditLogStorageFailure, err)
+ }
+
+ // If a.transport is nil if the logger is stopped.
+ if al.transport != nil {
+ al.flushAsync()
+ }
+
+ return nil
+}
+
+var _ LogStore = (*logStateStore)(nil)
+
+// logStateStore is a concrete implementation of [LogStore]
+// using [ipn.StateStore] as the underlying storage.
+type logStateStore struct {
+ store ipn.StateStore
+}
+
+// NewLogStore creates a new LogStateStore with the given [ipn.StateStore].
+func NewLogStore(store ipn.StateStore) LogStore {
+ return &logStateStore{
+ store: store,
+ }
+}
+
+func (s *logStateStore) generateKey(key ipn.ProfileID) string {
+ return "auditlog-" + string(key)
+}
+
+// Save saves the given logs to an [ipn.StateStore]. This overwrites
+// any existing entries for the given key.
+func (s *logStateStore) save(key ipn.ProfileID, txns []*transaction) error {
+ if key == "" {
+ return errors.New("empty key")
+ }
+
+ data, err := json.Marshal(txns)
+ if err != nil {
+ return err
+ }
+ k := ipn.StateKey(s.generateKey(key))
+ return s.store.WriteState(k, data)
+}
+
+// Load retrieves the logs from an [ipn.StateStore].
+func (s *logStateStore) load(key ipn.ProfileID) ([]*transaction, error) {
+ if key == "" {
+ return nil, errors.New("empty key")
+ }
+
+ k := ipn.StateKey(s.generateKey(key))
+ data, err := s.store.ReadState(k)
+
+ switch {
+ case errors.Is(err, ipn.ErrStateNotExist):
+ return nil, nil
+ case err != nil:
+ return nil, err
+ }
+
+ var txns []*transaction
+ err = json.Unmarshal(data, &txns)
+ return txns, err
+}
diff --git a/ipn/auditlog/auditlog_test.go b/ipn/auditlog/auditlog_test.go
new file mode 100644
index 000000000..3d3bf95cb
--- /dev/null
+++ b/ipn/auditlog/auditlog_test.go
@@ -0,0 +1,481 @@
+// Copyright (c) Tailscale Inc & AUTHORS
+// SPDX-License-Identifier: BSD-3-Clause
+
+package auditlog
+
+import (
+ "context"
+ "errors"
+ "fmt"
+ "sync"
+ "testing"
+ "time"
+
+ qt "github.com/frankban/quicktest"
+ "tailscale.com/ipn/store/mem"
+ "tailscale.com/tailcfg"
+ "tailscale.com/tstest"
+)
+
+// loggerForTest creates an auditLogger for you and cleans it up
+// (and ensures no goroutines are leaked) when the test is done.
+func loggerForTest(t *testing.T, opts Opts) *Logger {
+ t.Helper()
+ tstest.ResourceCheck(t)
+
+ if opts.Logf == nil {
+ opts.Logf = t.Logf
+ }
+
+ if opts.Store == nil {
+ t.Fatalf("opts.Store must be set")
+ }
+
+ a := NewLogger(opts)
+
+ t.Cleanup(func() {
+ a.FlushAndStop(context.Background())
+ })
+ return a
+}
+
+func TestNonRetryableErrors(t *testing.T) {
+ errorTests := []struct {
+ desc string
+ err error
+ want bool
+ }{
+ {"DeadlineExceeded", context.DeadlineExceeded, false},
+ {"Canceled", context.Canceled, false},
+ {"Canceled wrapped", fmt.Errorf("%w: %w", context.Canceled, errors.New("ctx cancelled")), false},
+ {"Random error", errors.New("random error"), false},
+ }
+
+ for _, tt := range errorTests {
+ t.Run(tt.desc, func(t *testing.T) {
+ if IsRetryableError(tt.err) != tt.want {
+ t.Fatalf("retriable: got %v, want %v", !tt.want, tt.want)
+ }
+ })
+ }
+}
+
+// TestEnqueueAndFlush enqueues n logs and flushes them.
+// We expect all logs to be flushed and for no
+// logs to remain in the store once FlushAndStop returns.
+func TestEnqueueAndFlush(t *testing.T) {
+ c := qt.New(t)
+ mockTransport := newMockTransport(nil)
+ al := loggerForTest(t, Opts{
+ RetryLimit: 200,
+ Logf: t.Logf,
+ Store: NewLogStore(&mem.Store{}),
+ })
+
+ c.Assert(al.SetProfileID("test"), qt.IsNil)
+ c.Assert(al.Start(mockTransport), qt.IsNil)
+
+ wantSent := 10
+
+ for i := range wantSent {
+ err := al.Enqueue(tailcfg.AuditNodeDisconnect, fmt.Sprintf("log %d", i))
+ c.Assert(err, qt.IsNil)
+ }
+
+ al.FlushAndStop(context.Background())
+
+ al.mu.Lock()
+ defer al.mu.Unlock()
+ gotStored, err := al.storedCountLocked()
+ c.Assert(err, qt.IsNil)
+
+ if wantStored := 0; gotStored != wantStored {
+ t.Fatalf("stored: got %d, want %d", gotStored, wantStored)
+ }
+
+ if gotSent := mockTransport.sentCount(); gotSent != wantSent {
+ t.Fatalf("sent: got %d, want %d", gotSent, wantSent)
+ }
+}
+
+// TestEnqueueAndFlushWithFlushCancel calls FlushAndCancel with a cancelled
+// context. We expect nothing to be sent and all logs to be stored.
+func TestEnqueueAndFlushWithFlushCancel(t *testing.T) {
+ c := qt.New(t)
+ mockTransport := newMockTransport(&retriableError)
+ al := loggerForTest(t, Opts{
+ RetryLimit: 200,
+ Logf: t.Logf,
+ Store: NewLogStore(&mem.Store{}),
+ })
+
+ c.Assert(al.SetProfileID("test"), qt.IsNil)
+ c.Assert(al.Start(mockTransport), qt.IsNil)
+
+ for i := range 10 {
+ err := al.Enqueue(tailcfg.AuditNodeDisconnect, fmt.Sprintf("log %d", i))
+ c.Assert(err, qt.IsNil)
+ }
+
+ // Cancel the context before calling FlushAndStop - nothing should get sent.
+ // This mimics a timeout before flush() has a chance to execute.
+ ctx, cancel := context.WithCancel(context.Background())
+ cancel()
+
+ al.FlushAndStop(ctx)
+
+ al.mu.Lock()
+ defer al.mu.Unlock()
+ gotStored, err := al.storedCountLocked()
+ c.Assert(err, qt.IsNil)
+
+ if wantStored := 10; gotStored != wantStored {
+ t.Fatalf("stored: got %d, want %d", gotStored, wantStored)
+ }
+
+ if gotSent, wantSent := mockTransport.sentCount(), 0; gotSent != wantSent {
+ t.Fatalf("sent: got %d, want %d", gotSent, wantSent)
+ }
+}
+
+// TestDeduplicateAndSort tests that the most recent log is kept when deduplicating logs
+func TestDeduplicateAndSort(t *testing.T) {
+ c := qt.New(t)
+ al := loggerForTest(t, Opts{
+ RetryLimit: 100,
+ Logf: t.Logf,
+ Store: NewLogStore(&mem.Store{}),
+ })
+
+ c.Assert(al.SetProfileID("test"), qt.IsNil)
+
+ logs := []*transaction{
+ {EventID: "1", Details: "log 1", TimeStamp: time.Now().Add(-time.Minute * 1), Retries: 1},
+ }
+
+ al.mu.Lock()
+ defer al.mu.Unlock()
+ al.appendToStoreLocked(logs)
+
+ // Update the transaction and re-append it
+ logs[0].Retries = 2
+ al.appendToStoreLocked(logs)
+
+ fromStore, err := al.store.load("test")
+ c.Assert(err, qt.IsNil)
+
+ // We should see only one transaction
+ if wantStored, gotStored := len(logs), len(fromStore); gotStored != wantStored {
+ t.Fatalf("stored: got %d, want %d", gotStored, wantStored)
+ }
+
+ // We should see the latest transaction
+ if wantRetryCount, gotRetryCount := 2, fromStore[0].Retries; gotRetryCount != wantRetryCount {
+ t.Fatalf("reties: got %d, want %d", gotRetryCount, wantRetryCount)
+ }
+}
+
+func TestChangeProfileId(t *testing.T) {
+ c := qt.New(t)
+ al := loggerForTest(t, Opts{
+ RetryLimit: 100,
+ Logf: t.Logf,
+ Store: NewLogStore(&mem.Store{}),
+ })
+ c.Assert(al.SetProfileID("test"), qt.IsNil)
+
+ // Changing a profile ID must fail
+ c.Assert(al.SetProfileID("test"), qt.IsNotNil)
+}
+
+// TestSendOnRestore pushes a n logs to the persistent store, and ensures they
+// are sent as soon as Start is called then checks to ensure the sent logs no
+// longer exist in the store.
+func TestSendOnRestore(t *testing.T) {
+ c := qt.New(t)
+ mockTransport := newMockTransport(nil)
+ al := loggerForTest(t, Opts{
+ RetryLimit: 100,
+ Logf: t.Logf,
+ Store: NewLogStore(&mem.Store{}),
+ })
+ al.SetProfileID("test")
+
+ wantTotal := 10
+
+ for range 10 {
+ al.Enqueue(tailcfg.AuditNodeDisconnect, "log")
+ }
+
+ c.Assert(al.Start(mockTransport), qt.IsNil)
+
+ al.FlushAndStop(context.Background())
+
+ al.mu.Lock()
+ defer al.mu.Unlock()
+ gotStored, err := al.storedCountLocked()
+ c.Assert(err, qt.IsNil)
+
+ if wantStored := 0; gotStored != wantStored {
+ t.Fatalf("stored: got %d, want %d", gotStored, wantStored)
+ }
+
+ if gotSent, wantSent := mockTransport.sentCount(), wantTotal; gotSent != wantSent {
+ t.Fatalf("sent: got %d, want %d", gotSent, wantSent)
+ }
+}
+
+// TestFailureExhaustion enqueues n logs, with the transport in a failable state.
+// We then set it to a non-failing state, call FlushAndStop and expect all logs to be sent.
+func TestFailureExhaustion(t *testing.T) {
+ c := qt.New(t)
+ mockTransport := newMockTransport(&retriableError)
+
+ al := loggerForTest(t, Opts{
+ RetryLimit: 1,
+ Logf: t.Logf,
+ Store: NewLogStore(&mem.Store{}),
+ })
+
+ c.Assert(al.SetProfileID("test"), qt.IsNil)
+ c.Assert(al.Start(mockTransport), qt.IsNil)
+
+ for range 10 {
+ err := al.Enqueue(tailcfg.AuditNodeDisconnect, "log")
+ c.Assert(err, qt.IsNil)
+ }
+
+ al.FlushAndStop(context.Background())
+ al.mu.Lock()
+ defer al.mu.Unlock()
+ gotStored, err := al.storedCountLocked()
+ c.Assert(err, qt.IsNil)
+
+ if wantStored := 0; gotStored != wantStored {
+ t.Fatalf("stored: got %d, want %d", gotStored, wantStored)
+ }
+
+ if gotSent, wantSent := mockTransport.sentCount(), 0; gotSent != wantSent {
+ t.Fatalf("sent: got %d, want %d", gotSent, wantSent)
+ }
+}
+
+// TestEnqueueAndFailNoRetry enqueues a set of logs, all of which will fail and are not
+// retriable. We then call FlushAndStop and expect all to be unsent.
+func TestEnqueueAndFailNoRetry(t *testing.T) {
+ c := qt.New(t)
+ mockTransport := newMockTransport(&nonRetriableError)
+
+ al := loggerForTest(t, Opts{
+ RetryLimit: 100,
+ Logf: t.Logf,
+ Store: NewLogStore(&mem.Store{}),
+ })
+
+ c.Assert(al.SetProfileID("test"), qt.IsNil)
+ c.Assert(al.Start(mockTransport), qt.IsNil)
+
+ for i := range 10 {
+ err := al.Enqueue(tailcfg.AuditNodeDisconnect, fmt.Sprintf("log %d", i))
+ c.Assert(err, qt.IsNil)
+ }
+
+ al.FlushAndStop(context.Background())
+ al.mu.Lock()
+ defer al.mu.Unlock()
+ gotStored, err := al.storedCountLocked()
+ c.Assert(err, qt.IsNil)
+
+ if wantStored := 0; gotStored != wantStored {
+ t.Fatalf("stored: got %d, want %d", gotStored, wantStored)
+ }
+
+ if gotSent, wantSent := mockTransport.sentCount(), 0; gotSent != wantSent {
+ t.Fatalf("sent: got %d, want %d", gotSent, wantSent)
+ }
+}
+
+// TestEnqueueAndRetry enqueues a set of logs, all of which will fail and are retriable.
+// Mid-test, we set the transport to not-fail and expect the queue to flush properly
+// We set the backoff parameters to 0 seconds so retries are immediate.
+func TestEnqueueAndRetry(t *testing.T) {
+ c := qt.New(t)
+ mockTransport := newMockTransport(&retriableError)
+
+ al := loggerForTest(t, Opts{
+ RetryLimit: 100,
+ Logf: t.Logf,
+ Store: NewLogStore(&mem.Store{}),
+ })
+
+ al.backoffOpts = backoffOpts{
+ min: 1 * time.Millisecond,
+ max: 4 * time.Millisecond,
+ multiplier: 2.0,
+ }
+
+ c.Assert(al.SetProfileID("test"), qt.IsNil)
+ c.Assert(al.Start(mockTransport), qt.IsNil)
+
+ err := al.Enqueue(tailcfg.AuditNodeDisconnect, fmt.Sprintf("log 1"))
+ c.Assert(err, qt.IsNil)
+
+ // This will wait for at least 2 retries
+ gotRetried, wantRetried := mockTransport.waitForSendAttemptsToReach(3), true
+ if gotRetried != wantRetried {
+ t.Fatalf("retried: got %v, want %v", gotRetried, wantRetried)
+ }
+
+ mockTransport.setErrorCondition(nil)
+
+ al.FlushAndStop(context.Background())
+ al.mu.Lock()
+ defer al.mu.Unlock()
+
+ gotStored, err := al.storedCountLocked()
+ c.Assert(err, qt.IsNil)
+
+ if wantStored := 0; gotStored != wantStored {
+ t.Fatalf("stored: got %d, want %d", gotStored, wantStored)
+ }
+
+ if gotSent, wantSent := mockTransport.sentCount(), 1; gotSent != wantSent {
+ t.Fatalf("sent: got %d, want %d", gotSent, wantSent)
+ }
+}
+
+// TestEnqueueBeforeSetProfileID tests that logs enqueued before SetProfileId are not sent
+func TestEnqueueBeforeSetProfileID(t *testing.T) {
+ c := qt.New(t)
+ al := loggerForTest(t, Opts{
+ RetryLimit: 100,
+ Logf: t.Logf,
+ Store: NewLogStore(&mem.Store{}),
+ })
+
+ err := al.Enqueue(tailcfg.AuditNodeDisconnect, "log")
+ c.Assert(err, qt.IsNotNil)
+ al.FlushAndStop(context.Background())
+
+ al.mu.Lock()
+ defer al.mu.Unlock()
+ gotStored, err := al.storedCountLocked()
+ c.Assert(err, qt.IsNotNil)
+
+ if wantStored := 0; gotStored != wantStored {
+ t.Fatalf("stored: got %d, want %d", gotStored, wantStored)
+ }
+}
+
+// TestLogStoring tests that audit logs are persisted sorted by timestamp, oldest to newest
+func TestLogSorting(t *testing.T) {
+ c := qt.New(t)
+ mockStore := NewLogStore(&mem.Store{})
+
+ logs := []*transaction{
+ {EventID: "1", Details: "log 3", TimeStamp: time.Now().Add(-time.Minute * 1)},
+ {EventID: "1", Details: "log 3", TimeStamp: time.Now().Add(-time.Minute * 2)},
+ {EventID: "2", Details: "log 2", TimeStamp: time.Now().Add(-time.Minute * 3)},
+ {EventID: "3", Details: "log 1", TimeStamp: time.Now().Add(-time.Minute * 4)},
+ }
+
+ wantLogs := []transaction{
+ {Details: "log 1"},
+ {Details: "log 2"},
+ {Details: "log 3"},
+ }
+
+ mockStore.save("test", logs)
+
+ gotLogs, err := mockStore.load("test")
+ c.Assert(err, qt.IsNil)
+ gotLogs = deduplicateAndSort(gotLogs)
+
+ for i := range gotLogs {
+ if want, got := wantLogs[i].Details, gotLogs[i].Details; want != got {
+ t.Fatalf("Details: got %v, want %v", got, want)
+ }
+ }
+}
+
+// mock implementations for testing
+
+// newMockTransport returns a mock transport for testing
+// If err is no nil, SendAuditLog will return this error if the send is attempted
+// before the context is cancelled.
+func newMockTransport(err error) *mockAuditLogTransport {
+ return &mockAuditLogTransport{
+ err: err,
+ attempts: make(chan int, 1),
+ }
+}
+
+type mockAuditLogTransport struct {
+ attempts chan int // channel to notify of send attempts
+
+ mu sync.Mutex
+ sendAttmpts int // number of attempts to send logs
+ sendCount int // number of logs sent by the transport
+ err error // error to return when sending logs
+}
+
+// waitForSendAttemptsToReach blocks until the number of send attempts reaches n
+// This should be use only in tests where the transport is expected to retry sending logs
+func (t *mockAuditLogTransport) waitForSendAttemptsToReach(n int) bool {
+ for attempts := range t.attempts {
+ if attempts >= n {
+ return true
+ }
+ }
+ return false
+}
+
+func (t *mockAuditLogTransport) setErrorCondition(err error) {
+ t.mu.Lock()
+ defer t.mu.Unlock()
+ t.err = err
+}
+
+func (t *mockAuditLogTransport) sentCount() int {
+ t.mu.Lock()
+ defer t.mu.Unlock()
+ return t.sendCount
+}
+
+func (t *mockAuditLogTransport) SendAuditLog(ctx context.Context, _ tailcfg.AuditLogRequest) (err error) {
+ t.mu.Lock()
+ t.sendAttmpts += 1
+ defer func() {
+ a := t.sendAttmpts
+ t.mu.Unlock()
+ select {
+ case t.attempts <- a:
+ default:
+ }
+ }()
+
+ select {
+ case <-ctx.Done():
+ return ctx.Err()
+ default:
+ }
+
+ if t.err != nil {
+ return t.err
+ }
+ t.sendCount += 1
+ return nil
+}
+
+var (
+ retriableError = mockError{errors.New("retriable error")}
+ nonRetriableError = mockError{errors.New("permanent failure error")}
+)
+
+type mockError struct {
+ error
+}
+
+func (e mockError) Retryable() bool {
+ return e == retriableError
+}
diff --git a/ipn/ipnauth/actor.go b/ipn/ipnauth/actor.go
index 8a0e77645..108bdd341 100644
--- a/ipn/ipnauth/actor.go
+++ b/ipn/ipnauth/actor.go
@@ -10,12 +10,11 @@ import (
"tailscale.com/client/tailscale/apitype"
"tailscale.com/ipn"
+ "tailscale.com/tailcfg"
)
// AuditLogFunc is any function that can be used to log audit actions performed by an [Actor].
-//
-// TODO(nickkhyl,barnstar): define a named string type for the action (in tailcfg?) and use it here.
-type AuditLogFunc func(action, details string)
+type AuditLogFunc func(action tailcfg.ClientAuditAction, details string) error
// Actor is any actor using the [ipnlocal.LocalBackend].
//
@@ -45,7 +44,7 @@ type Actor interface {
//
// If the auditLogger is non-nil, it is used to write details about the action
// to the audit log when required by the policy.
- CheckProfileAccess(profile ipn.LoginProfileView, requestedAccess ProfileAccess, auditLogger AuditLogFunc) error
+ CheckProfileAccess(profile ipn.LoginProfileView, requestedAccess ProfileAccess, auditLogFn AuditLogFunc) error
// IsLocalSystem reports whether the actor is the Windows' Local System account.
//
diff --git a/ipn/ipnauth/policy.go b/ipn/ipnauth/policy.go
index f09be0fcb..aa4ec4100 100644
--- a/ipn/ipnauth/policy.go
+++ b/ipn/ipnauth/policy.go
@@ -9,6 +9,7 @@ import (
"tailscale.com/client/tailscale/apitype"
"tailscale.com/ipn"
+ "tailscale.com/tailcfg"
"tailscale.com/util/syspolicy"
)
@@ -48,7 +49,7 @@ func (a actorWithPolicyChecks) CheckProfileAccess(profile ipn.LoginProfileView,
//
// TODO(nickkhyl): unexport it when we move [ipn.Actor] implementations from [ipnserver]
// and corp to this package.
-func CheckDisconnectPolicy(actor Actor, profile ipn.LoginProfileView, reason string, auditLogger AuditLogFunc) error {
+func CheckDisconnectPolicy(actor Actor, profile ipn.LoginProfileView, reason string, auditFn AuditLogFunc) error {
if alwaysOn, _ := syspolicy.GetBoolean(syspolicy.AlwaysOn, false); !alwaysOn {
return nil
}
@@ -58,15 +59,16 @@ func CheckDisconnectPolicy(actor Actor, profile ipn.LoginProfileView, reason str
if reason == "" {
return errors.New("disconnect not allowed: reason required")
}
- if auditLogger != nil {
+ if auditFn != nil {
var details string
if username, _ := actor.Username(); username != "" { // best-effort; we don't have it on all platforms
details = fmt.Sprintf("%q is being disconnected by %q: %v", profile.Name(), username, reason)
} else {
details = fmt.Sprintf("%q is being disconnected: %v", profile.Name(), reason)
}
- // TODO(nickkhyl,barnstar): use a const for DISCONNECT_NODE.
- auditLogger("DISCONNECT_NODE", details)
+ if err := auditFn(tailcfg.AuditNodeDisconnect, details); err != nil {
+ return err
+ }
}
return nil
}
diff --git a/ipn/ipnlocal/cert.go b/ipn/ipnlocal/cert.go
index cfa4fe1ba..111dc5a2d 100644
--- a/ipn/ipnlocal/cert.go
+++ b/ipn/ipnlocal/cert.go
@@ -119,6 +119,9 @@ func (b *LocalBackend) GetCertPEMWithValidity(ctx context.Context, domain string
}
if pair, err := getCertPEMCached(cs, domain, now); err == nil {
+ if envknob.IsCertShareReadOnlyMode() {
+ return pair, nil
+ }
// If we got here, we have a valid unexpired cert.
// Check whether we should start an async renewal.
shouldRenew, err := b.shouldStartDomainRenewal(cs, domain, now, pair, minValidity)
@@ -134,7 +137,7 @@ func (b *LocalBackend) GetCertPEMWithValidity(ctx context.Context, domain string
if minValidity == 0 {
logf("starting async renewal")
// Start renewal in the background, return current valid cert.
- go b.getCertPEM(context.Background(), cs, logf, traceACME, domain, now, minValidity)
+ b.goTracker.Go(func() { getCertPEM(context.Background(), b, cs, logf, traceACME, domain, now, minValidity) })
return pair, nil
}
// If the caller requested a specific validity duration, fall through
@@ -142,7 +145,11 @@ func (b *LocalBackend) GetCertPEMWithValidity(ctx context.Context, domain string
logf("starting sync renewal")
}
- pair, err := b.getCertPEM(ctx, cs, logf, traceACME, domain, now, minValidity)
+ if envknob.IsCertShareReadOnlyMode() {
+ return nil, fmt.Errorf("retrieving cached TLS certificate failed and cert store is configured in read-only mode, not attempting to issue a new certificate: %w", err)
+ }
+
+ pair, err := getCertPEM(ctx, b, cs, logf, traceACME, domain, now, minValidity)
if err != nil {
logf("getCertPEM: %v", err)
return nil, err
@@ -250,15 +257,13 @@ type certStore interface {
// for now. If they're expired, it returns errCertExpired.
// If they don't exist, it returns ipn.ErrStateNotExist.
Read(domain string, now time.Time) (*TLSCertKeyPair, error)
- // WriteCert writes the cert for domain.
- WriteCert(domain string, cert []byte) error
- // WriteKey writes the key for domain.
- WriteKey(domain string, key []byte) error
// ACMEKey returns the value previously stored via WriteACMEKey.
// It is a PEM encoded ECDSA key.
ACMEKey() ([]byte, error)
// WriteACMEKey stores the provided PEM encoded ECDSA key.
WriteACMEKey([]byte) error
+ // WriteTLSCertAndKey writes the cert and key for domain.
+ WriteTLSCertAndKey(domain string, cert, key []byte) error
}
var errCertExpired = errors.New("cert expired")
@@ -344,6 +349,13 @@ func (f certFileStore) WriteKey(domain string, key []byte) error {
return atomicfile.WriteFile(keyFile(f.dir, domain), key, 0600)
}
+func (f certFileStore) WriteTLSCertAndKey(domain string, cert, key []byte) error {
+ if err := f.WriteKey(domain, key); err != nil {
+ return err
+ }
+ return f.WriteCert(domain, cert)
+}
+
// certStateStore implements certStore by storing the cert & key files in an ipn.StateStore.
type certStateStore struct {
ipn.StateStore
@@ -353,7 +365,29 @@ type certStateStore struct {
testRoots *x509.CertPool
}
+// TLSCertKeyReader is an interface implemented by state stores where it makes
+// sense to read the TLS cert and key in a single operation that can be
+// distinguished from generic state value reads. Currently this is only implemented
+// by the kubestore.Store, which, in some cases, need to read cert and key from a
+// non-cached TLS Secret.
+type TLSCertKeyReader interface {
+ ReadTLSCertAndKey(domain string) ([]byte, []byte, error)
+}
+
func (s certStateStore) Read(domain string, now time.Time) (*TLSCertKeyPair, error) {
+ // If we're using a store that supports atomic reads, use that
+ if kr, ok := s.StateStore.(TLSCertKeyReader); ok {
+ cert, key, err := kr.ReadTLSCertAndKey(domain)
+ if err != nil {
+ return nil, err
+ }
+ if !validCertPEM(domain, key, cert, s.testRoots, now) {
+ return nil, errCertExpired
+ }
+ return &TLSCertKeyPair{CertPEM: cert, KeyPEM: key, Cached: true}, nil
+ }
+
+ // Otherwise fall back to separate reads
certPEM, err := s.ReadState(ipn.StateKey(domain + ".crt"))
if err != nil {
return nil, err
@@ -384,6 +418,27 @@ func (s certStateStore) WriteACMEKey(key []byte) error {
return ipn.WriteState(s.StateStore, ipn.StateKey(acmePEMName), key)
}
+// TLSCertKeyWriter is an interface implemented by state stores that can write the TLS
+// cert and key in a single atomic operation. Currently this is only implemented
+// by the kubestore.StoreKube.
+type TLSCertKeyWriter interface {
+ WriteTLSCertAndKey(domain string, cert, key []byte) error
+}
+
+// WriteTLSCertAndKey writes the TLS cert and key for domain to the current
+// LocalBackend's StateStore.
+func (s certStateStore) WriteTLSCertAndKey(domain string, cert, key []byte) error {
+ // If we're using a store that supports atomic writes, use that.
+ if aw, ok := s.StateStore.(TLSCertKeyWriter); ok {
+ return aw.WriteTLSCertAndKey(domain, cert, key)
+ }
+ // Otherwise fall back to separate writes for cert and key.
+ if err := s.WriteKey(domain, key); err != nil {
+ return err
+ }
+ return s.WriteCert(domain, cert)
+}
+
// TLSCertKeyPair is a TLS public and private key, and whether they were obtained
// from cache or freshly obtained.
type TLSCertKeyPair struct {
@@ -420,7 +475,9 @@ func getCertPEMCached(cs certStore, domain string, now time.Time) (p *TLSCertKey
return cs.Read(domain, now)
}
-func (b *LocalBackend) getCertPEM(ctx context.Context, cs certStore, logf logger.Logf, traceACME func(any), domain string, now time.Time, minValidity time.Duration) (*TLSCertKeyPair, error) {
+// getCertPem checks if a cert needs to be renewed and if so, renews it.
+// It can be overridden in tests.
+var getCertPEM = func(ctx context.Context, b *LocalBackend, cs certStore, logf logger.Logf, traceACME func(any), domain string, now time.Time, minValidity time.Duration) (*TLSCertKeyPair, error) {
acmeMu.Lock()
defer acmeMu.Unlock()
@@ -445,6 +502,10 @@ func (b *LocalBackend) getCertPEM(ctx context.Context, cs certStore, logf logger
return nil, err
}
+ if !isDefaultDirectoryURL(ac.DirectoryURL) {
+ logf("acme: using Directory URL %q", ac.DirectoryURL)
+ }
+
a, err := ac.GetReg(ctx, "" /* pre-RFC param */)
switch {
case err == nil:
@@ -546,9 +607,6 @@ func (b *LocalBackend) getCertPEM(ctx context.Context, cs certStore, logf logger
if err := encodeECDSAKey(&privPEM, certPrivKey); err != nil {
return nil, err
}
- if err := cs.WriteKey(domain, privPEM.Bytes()); err != nil {
- return nil, err
- }
csr, err := certRequest(certPrivKey, domain, nil)
if err != nil {
@@ -570,7 +628,7 @@ func (b *LocalBackend) getCertPEM(ctx context.Context, cs certStore, logf logger
return nil, err
}
}
- if err := cs.WriteCert(domain, certPEM.Bytes()); err != nil {
+ if err := cs.WriteTLSCertAndKey(domain, certPEM.Bytes(), privPEM.Bytes()); err != nil {
return nil, err
}
b.domainRenewed(domain)
@@ -714,7 +772,28 @@ func validateLeaf(leaf *x509.Certificate, intermediates *x509.CertPool, domain s
// binary's baked-in roots (LetsEncrypt). See tailscale/tailscale#14690.
return validateLeaf(leaf, intermediates, domain, now, bakedroots.Get())
}
- return err == nil
+
+ if err == nil {
+ return true
+ }
+
+ // When pointed at a non-prod ACME server, we don't expect to have the CA
+ // in our system or baked-in roots. Verify only throws UnknownAuthorityError
+ // after first checking the leaf cert's expiry, hostnames etc, so we know
+ // that the only reason for an error is to do with constructing a full chain.
+ // Allow this error so that cert caching still works in testing environments.
+ if errors.As(err, &x509.UnknownAuthorityError{}) {
+ acmeURL := envknob.String("TS_DEBUG_ACME_DIRECTORY_URL")
+ if !isDefaultDirectoryURL(acmeURL) {
+ return true
+ }
+ }
+
+ return false
+}
+
+func isDefaultDirectoryURL(u string) bool {
+ return u == "" || u == acme.LetsEncryptURL
}
// validLookingCertDomain reports whether name looks like a valid domain name that
diff --git a/ipn/ipnlocal/cert_test.go b/ipn/ipnlocal/cert_test.go
index 21741ca95..e2398f670 100644
--- a/ipn/ipnlocal/cert_test.go
+++ b/ipn/ipnlocal/cert_test.go
@@ -6,6 +6,7 @@
package ipnlocal
import (
+ "context"
"crypto/ecdsa"
"crypto/elliptic"
"crypto/rand"
@@ -14,11 +15,17 @@ import (
"embed"
"encoding/pem"
"math/big"
+ "os"
+ "path/filepath"
"testing"
"time"
"github.com/google/go-cmp/cmp"
+ "tailscale.com/envknob"
"tailscale.com/ipn/store/mem"
+ "tailscale.com/tstest"
+ "tailscale.com/types/logger"
+ "tailscale.com/util/must"
)
func TestValidLookingCertDomain(t *testing.T) {
@@ -47,10 +54,10 @@ var certTestFS embed.FS
func TestCertStoreRoundTrip(t *testing.T) {
const testDomain = "example.com"
- // Use a fixed verification timestamp so validity doesn't fall off when the
- // cert expires. If you update the test data below, this may also need to be
- // updated.
+ // Use fixed verification timestamps so validity doesn't change over time.
+ // If you update the test data below, these may also need to be updated.
testNow := time.Date(2023, time.February, 10, 0, 0, 0, 0, time.UTC)
+ testExpired := time.Date(2026, time.February, 10, 0, 0, 0, 0, time.UTC)
// To re-generate a root certificate and domain certificate for testing,
// use:
@@ -78,21 +85,23 @@ func TestCertStoreRoundTrip(t *testing.T) {
}
tests := []struct {
- name string
- store certStore
+ name string
+ store certStore
+ debugACMEURL bool
}{
- {"FileStore", certFileStore{dir: t.TempDir(), testRoots: roots}},
- {"StateStore", certStateStore{StateStore: new(mem.Store), testRoots: roots}},
+ {"FileStore", certFileStore{dir: t.TempDir(), testRoots: roots}, false},
+ {"FileStore_UnknownCA", certFileStore{dir: t.TempDir()}, true},
+ {"StateStore", certStateStore{StateStore: new(mem.Store), testRoots: roots}, false},
+ {"StateStore_UnknownCA", certStateStore{StateStore: new(mem.Store)}, true},
}
for _, test := range tests {
t.Run(test.name, func(t *testing.T) {
- if err := test.store.WriteCert(testDomain, testCert); err != nil {
- t.Fatalf("WriteCert: unexpected error: %v", err)
+ if test.debugACMEURL {
+ t.Setenv("TS_DEBUG_ACME_DIRECTORY_URL", "https://acme-staging-v02.api.letsencrypt.org/directory")
}
- if err := test.store.WriteKey(testDomain, testKey); err != nil {
- t.Fatalf("WriteKey: unexpected error: %v", err)
+ if err := test.store.WriteTLSCertAndKey(testDomain, testCert, testKey); err != nil {
+ t.Fatalf("WriteTLSCertAndKey: unexpected error: %v", err)
}
-
kp, err := test.store.Read(testDomain, testNow)
if err != nil {
t.Fatalf("Read: unexpected error: %v", err)
@@ -103,6 +112,10 @@ func TestCertStoreRoundTrip(t *testing.T) {
if diff := cmp.Diff(kp.KeyPEM, testKey); diff != "" {
t.Errorf("Key (-got, +want):\n%s", diff)
}
+ unexpected, err := test.store.Read(testDomain, testExpired)
+ if err != errCertExpired {
+ t.Fatalf("Read: expected expiry error: %v", string(unexpected.CertPEM))
+ }
})
}
}
@@ -215,3 +228,151 @@ func TestDebugACMEDirectoryURL(t *testing.T) {
})
}
}
+
+func TestGetCertPEMWithValidity(t *testing.T) {
+ const testDomain = "example.com"
+ b := &LocalBackend{
+ store: &mem.Store{},
+ varRoot: t.TempDir(),
+ ctx: context.Background(),
+ logf: t.Logf,
+ }
+ certDir, err := b.certDir()
+ if err != nil {
+ t.Fatalf("certDir error: %v", err)
+ }
+ if _, err := b.getCertStore(); err != nil {
+ t.Fatalf("getCertStore error: %v", err)
+ }
+ testRoot, err := certTestFS.ReadFile("testdata/rootCA.pem")
+ if err != nil {
+ t.Fatal(err)
+ }
+ roots := x509.NewCertPool()
+ if !roots.AppendCertsFromPEM(testRoot) {
+ t.Fatal("Unable to add test CA to the cert pool")
+ }
+ testX509Roots = roots
+ defer func() { testX509Roots = nil }()
+ tests := []struct {
+ name string
+ now time.Time
+ // storeCerts is true if the test cert and key should be written to store.
+ storeCerts bool
+ readOnlyMode bool // TS_READ_ONLY_CERTS env var
+ wantAsyncRenewal bool // async issuance should be started
+ wantIssuance bool // sync issuance should be started
+ wantErr bool
+ }{
+ {
+ name: "valid_no_renewal",
+ now: time.Date(2023, time.February, 20, 0, 0, 0, 0, time.UTC),
+ storeCerts: true,
+ wantAsyncRenewal: false,
+ wantIssuance: false,
+ wantErr: false,
+ },
+ {
+ name: "issuance_needed",
+ now: time.Date(2023, time.February, 20, 0, 0, 0, 0, time.UTC),
+ storeCerts: false,
+ wantAsyncRenewal: false,
+ wantIssuance: true,
+ wantErr: false,
+ },
+ {
+ name: "renewal_needed",
+ now: time.Date(2025, time.May, 1, 0, 0, 0, 0, time.UTC),
+ storeCerts: true,
+ wantAsyncRenewal: true,
+ wantIssuance: false,
+ wantErr: false,
+ },
+ {
+ name: "renewal_needed_read_only_mode",
+ now: time.Date(2025, time.May, 1, 0, 0, 0, 0, time.UTC),
+ storeCerts: true,
+ readOnlyMode: true,
+ wantAsyncRenewal: false,
+ wantIssuance: false,
+ wantErr: false,
+ },
+ {
+ name: "no_certs_read_only_mode",
+ now: time.Date(2025, time.May, 1, 0, 0, 0, 0, time.UTC),
+ storeCerts: false,
+ readOnlyMode: true,
+ wantAsyncRenewal: false,
+ wantIssuance: false,
+ wantErr: true,
+ },
+ }
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+
+ if tt.readOnlyMode {
+ envknob.Setenv("TS_CERT_SHARE_MODE", "ro")
+ }
+
+ os.RemoveAll(certDir)
+ if tt.storeCerts {
+ os.MkdirAll(certDir, 0755)
+ if err := os.WriteFile(filepath.Join(certDir, "example.com.crt"),
+ must.Get(os.ReadFile("testdata/example.com.pem")), 0644); err != nil {
+ t.Fatal(err)
+ }
+ if err := os.WriteFile(filepath.Join(certDir, "example.com.key"),
+ must.Get(os.ReadFile("testdata/example.com-key.pem")), 0644); err != nil {
+ t.Fatal(err)
+ }
+ }
+
+ b.clock = tstest.NewClock(tstest.ClockOpts{Start: tt.now})
+
+ allDone := make(chan bool, 1)
+ defer b.goTracker.AddDoneCallback(func() {
+ b.mu.Lock()
+ defer b.mu.Unlock()
+ if b.goTracker.RunningGoroutines() > 0 {
+ return
+ }
+ select {
+ case allDone <- true:
+ default:
+ }
+ })()
+
+ // Set to true if get getCertPEM is called. GetCertPEM can be called in a goroutine for async
+ // renewal or in the main goroutine if issuance is required to obtain valid TLS credentials.
+ getCertPemWasCalled := false
+ getCertPEM = func(ctx context.Context, b *LocalBackend, cs certStore, logf logger.Logf, traceACME func(any), domain string, now time.Time, minValidity time.Duration) (*TLSCertKeyPair, error) {
+ getCertPemWasCalled = true
+ return nil, nil
+ }
+ prevGoRoutines := b.goTracker.StartedGoroutines()
+ _, err = b.GetCertPEMWithValidity(context.Background(), testDomain, 0)
+ if (err != nil) != tt.wantErr {
+ t.Errorf("b.GetCertPemWithValidity got err %v, wants error: '%v'", err, tt.wantErr)
+ }
+ // GetCertPEMWithValidity calls getCertPEM in a goroutine if async renewal is needed. That's the
+ // only goroutine it starts, so this can be used to test if async renewal was started.
+ gotAsyncRenewal := b.goTracker.StartedGoroutines()-prevGoRoutines != 0
+ if gotAsyncRenewal {
+ select {
+ case <-time.After(5 * time.Second):
+ t.Fatal("timed out waiting for goroutines to finish")
+ case <-allDone:
+ }
+ }
+ // Verify that async renewal was triggered if expected.
+ if tt.wantAsyncRenewal != gotAsyncRenewal {
+ t.Fatalf("wants getCertPem to be called async: %v, got called %v", tt.wantAsyncRenewal, gotAsyncRenewal)
+ }
+ // Verify that (non-async) issuance was started if expected.
+ gotIssuance := getCertPemWasCalled && !gotAsyncRenewal
+ if tt.wantIssuance != gotIssuance {
+ t.Errorf("wants getCertPem to be called: %v, got called %v", tt.wantIssuance, gotIssuance)
+ }
+ })
+ }
+}
diff --git a/ipn/ipnlocal/local.go b/ipn/ipnlocal/local.go
index 5c367c876..cb7d06407 100644
--- a/ipn/ipnlocal/local.go
+++ b/ipn/ipnlocal/local.go
@@ -57,10 +57,12 @@ import (
"tailscale.com/health/healthmsg"
"tailscale.com/hostinfo"
"tailscale.com/ipn"
+ "tailscale.com/ipn/auditlog"
"tailscale.com/ipn/conffile"
"tailscale.com/ipn/ipnauth"
"tailscale.com/ipn/ipnstate"
"tailscale.com/ipn/policy"
+ memstore "tailscale.com/ipn/store/mem"
"tailscale.com/log/sockstatlog"
"tailscale.com/logpolicy"
"tailscale.com/net/captivedetection"
@@ -406,8 +408,8 @@ type LocalBackend struct {
// outgoingFiles keeps track of Taildrop outgoing files keyed to their OutgoingFile.ID
outgoingFiles map[string]*ipn.OutgoingFile
- // getSafFd gets the Storage Access Framework file descriptor for writing Taildrop files to
- GetSafFd func(filename string) int32
+ // FileOps abstracts platform-specific file operations needed for file transfers.
+ FileOps taildrop.FileOps
// lastSuggestedExitNode stores the last suggested exit node suggestion to
// avoid unnecessary churn between multiple equally-good options.
@@ -453,6 +455,12 @@ type LocalBackend struct {
// Each callback is called exactly once in unspecified order and without b.mu held.
// Returned errors are logged but otherwise ignored and do not affect the shutdown process.
shutdownCbs set.HandleSet[func() error]
+
+ // auditLogger, if non-nil, manages audit logging for the backend.
+ //
+ // It queues, persists, and sends audit logs
+ // to the control client. auditLogger has the same lifespan as b.cc.
+ auditLogger *auditlog.Logger
}
// HealthTracker returns the health tracker for the backend.
@@ -621,19 +629,6 @@ func NewLocalBackend(logf logger.Logf, logID logid.PublicID, sys *tsd.System, lo
}
}
- // initialize Taildrive shares from saved state
- fs, ok := b.sys.DriveForRemote.GetOK()
- if ok {
- currentShares := b.pm.prefs.DriveShares()
- if currentShares.Len() > 0 {
- var shares []*drive.Share
- for _, share := range currentShares.All() {
- shares = append(shares, share.AsStruct())
- }
- fs.SetShares(shares)
- }
- }
-
for name, newFn := range registeredExtensions {
ext, err := newFn(logf, sys)
if err != nil {
@@ -813,6 +808,13 @@ func (b *LocalBackend) SetDirectFileRoot(dir string) {
b.directFileRoot = dir
}
+// SetFileOps sets the
+func (b *LocalBackend) SetFileOps(fileOps taildrop.FileOps) {
+ b.mu.Lock()
+ defer b.mu.Unlock()
+ b.FileOps = fileOps
+}
+
// ReloadConfig reloads the backend's config from disk.
//
// It returns (false, nil) if not running in declarative mode, (true, nil) on
@@ -1695,6 +1697,15 @@ func (b *LocalBackend) SetControlClientStatus(c controlclient.Client, st control
b.logf("Failed to save new controlclient state: %v", err)
}
}
+
+ // Update the audit logger with the current profile ID.
+ if b.auditLogger != nil && prefsChanged {
+ pid := b.pm.CurrentProfile().ID()
+ if err := b.auditLogger.SetProfileID(pid); err != nil {
+ b.logf("Failed to set profile ID in audit logger: %v", err)
+ }
+ }
+
// initTKALocked is dependent on CurrentProfile.ID, which is initialized
// (for new profiles) on the first call to b.pm.SetPrefs.
if err := b.initTKALocked(); err != nil {
@@ -2402,6 +2413,27 @@ func (b *LocalBackend) Start(opts ipn.Options) error {
debugFlags = append([]string{"netstack"}, debugFlags...)
}
+ var auditLogShutdown func()
+ // Audit logging is only available if the client has set up a proper persistent
+ // store for the logs in sys.
+ store, ok := b.sys.AuditLogStore.GetOK()
+ if !ok {
+ b.logf("auditlog: [unexpected] no persistent audit log storage configured. using memory store.")
+ store = auditlog.NewLogStore(&memstore.Store{})
+ }
+
+ al := auditlog.NewLogger(auditlog.Opts{
+ Logf: b.logf,
+ RetryLimit: 32,
+ Store: store,
+ })
+ b.auditLogger = al
+ auditLogShutdown = func() {
+ ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
+ defer cancel()
+ al.FlushAndStop(ctx)
+ }
+
// TODO(apenwarr): The only way to change the ServerURL is to
// re-run b.Start, because this is the only place we create a
// new controlclient. EditPrefs allows you to overwrite ServerURL,
@@ -2427,6 +2459,7 @@ func (b *LocalBackend) Start(opts ipn.Options) error {
C2NHandler: http.HandlerFunc(b.handleC2N),
DialPlan: &b.dialPlan, // pointer because it can't be copied
ControlKnobs: b.sys.ControlKnobs(),
+ Shutdown: auditLogShutdown,
// Don't warn about broken Linux IP forwarding when
// netstack is being used.
@@ -2461,6 +2494,16 @@ func (b *LocalBackend) Start(opts ipn.Options) error {
b.logf("Backend: logs: be:%v fe:%v", blid, opts.FrontendLogID)
b.sendToLocked(ipn.Notify{Prefs: &prefs}, allClients)
+ // initialize Taildrive shares from saved state
+ if fs, ok := b.sys.DriveForRemote.GetOK(); ok {
+ currentShares := b.pm.CurrentPrefs().DriveShares()
+ var shares []*drive.Share
+ for _, share := range currentShares.All() {
+ shares = append(shares, share.AsStruct())
+ }
+ fs.SetShares(shares)
+ }
+
if !loggedOut && (b.hasNodeKeyLocked() || confWantRunning) {
// If we know that we're either logged in or meant to be
// running, tell the controlclient that it should also assume
@@ -4269,6 +4312,21 @@ func (b *LocalBackend) MaybeClearAppConnector(mp *ipn.MaskedPrefs) error {
return err
}
+var errNoAuditLogger = errors.New("no audit logger configured")
+
+func (b *LocalBackend) getAuditLoggerLocked() ipnauth.AuditLogFunc {
+ logger := b.auditLogger
+ return func(action tailcfg.ClientAuditAction, details string) error {
+ if logger == nil {
+ return errNoAuditLogger
+ }
+ if err := logger.Enqueue(action, details); err != nil {
+ return fmt.Errorf("failed to enqueue audit log %v %q: %w", action, details, err)
+ }
+ return nil
+ }
+}
+
// EditPrefs applies the changes in mp to the current prefs,
// acting as the tailscaled itself rather than a specific user.
func (b *LocalBackend) EditPrefs(mp *ipn.MaskedPrefs) (ipn.PrefsView, error) {
@@ -4294,9 +4352,8 @@ func (b *LocalBackend) EditPrefsAs(mp *ipn.MaskedPrefs, actor ipnauth.Actor) (ip
unlock := b.lockAndGetUnlock()
defer unlock()
if mp.WantRunningSet && !mp.WantRunning && b.pm.CurrentPrefs().WantRunning() {
- // TODO(barnstar,nickkhyl): replace loggerFn with the actual audit logger.
- loggerFn := func(action, details string) { b.logf("[audit]: %s: %s", action, details) }
- if err := actor.CheckProfileAccess(b.pm.CurrentProfile(), ipnauth.Disconnect, loggerFn); err != nil {
+ if err := actor.CheckProfileAccess(b.pm.CurrentProfile(), ipnauth.Disconnect, b.getAuditLoggerLocked()); err != nil {
+ b.logf("check profile access failed: %v", err)
return ipn.PrefsView{}, err
}
@@ -5296,7 +5353,6 @@ func (b *LocalBackend) initPeerAPIListener() {
if fileRoot == "" {
b.logf("peerapi starting without Taildrop directory configured")
}
-
ps := &peerAPIServer{
b: b,
taildrop: taildrop.ManagerOptions{
@@ -5306,7 +5362,7 @@ func (b *LocalBackend) initPeerAPIListener() {
Dir: fileRoot,
DirectFileMode: b.directFileRoot != "",
SendFileNotify: b.sendFileNotify,
- }.New(b.getSafFd),
+ }.New(b.FileOps),
}
if dm, ok := b.sys.DNSManager.GetOK(); ok {
ps.resolver = dm.Resolver()
@@ -5880,6 +5936,15 @@ func (b *LocalBackend) requestEngineStatusAndWait() {
func (b *LocalBackend) setControlClientLocked(cc controlclient.Client) {
b.cc = cc
b.ccAuto, _ = cc.(*controlclient.Auto)
+ if b.auditLogger != nil {
+ if err := b.auditLogger.SetProfileID(b.pm.CurrentProfile().ID()); err != nil {
+ b.logf("audit logger set profile ID failure: %v", err)
+ }
+
+ if err := b.auditLogger.Start(b.ccAuto); err != nil {
+ b.logf("audit logger start failure: %v", err)
+ }
+ }
}
// resetControlClientLocked sets b.cc to nil and returns the old value. If the
@@ -6712,7 +6777,7 @@ func (b *LocalBackend) FileTargets() ([]*apitype.FileTarget, error) {
}
func (b *LocalBackend) taildropTargetStatus(p tailcfg.NodeView) ipnstate.TaildropTargetStatus {
- if b.netMap == nil || b.state != ipn.Running {
+ if b.state != ipn.Running {
return ipnstate.TaildropTargetIpnStateNotRunning
}
if b.netMap == nil {
@@ -8225,15 +8290,13 @@ func (b *LocalBackend) vipServiceHash(services []*tailcfg.VIPService) string {
func (b *LocalBackend) vipServicesFromPrefsLocked(prefs ipn.PrefsView) []*tailcfg.VIPService {
// keyed by service name
var services map[tailcfg.ServiceName]*tailcfg.VIPService
- if !b.serveConfig.Valid() {
- return nil
- }
-
- for svc, config := range b.serveConfig.Services().All() {
- mak.Set(&services, svc, &tailcfg.VIPService{
- Name: svc,
- Ports: config.ServicePortRange(),
- })
+ if b.serveConfig.Valid() {
+ for svc, config := range b.serveConfig.Services().All() {
+ mak.Set(&services, svc, &tailcfg.VIPService{
+ Name: svc,
+ Ports: config.ServicePortRange(),
+ })
+ }
}
for _, s := range prefs.AdvertiseServices().All() {
@@ -8246,7 +8309,14 @@ func (b *LocalBackend) vipServicesFromPrefsLocked(prefs ipn.PrefsView) []*tailcf
services[sn].Active = true
}
- return slicesx.MapValues(services)
+ servicesList := slicesx.MapValues(services)
+ // [slicesx.MapValues] provides the values in an indeterminate order, but since we'll
+ // be hashing a representation of this list later we want it to be in a consistent
+ // order.
+ slices.SortFunc(servicesList, func(a, b *tailcfg.VIPService) int {
+ return strings.Compare(a.Name.String(), b.Name.String())
+ })
+ return servicesList
}
var (
diff --git a/ipn/ipnlocal/local_test.go b/ipn/ipnlocal/local_test.go
index 35977e679..aa9137275 100644
--- a/ipn/ipnlocal/local_test.go
+++ b/ipn/ipnlocal/local_test.go
@@ -44,6 +44,7 @@ import (
"tailscale.com/tsd"
"tailscale.com/tstest"
"tailscale.com/types/dnstype"
+ "tailscale.com/types/ipproto"
"tailscale.com/types/key"
"tailscale.com/types/logger"
"tailscale.com/types/logid"
@@ -60,6 +61,7 @@ import (
"tailscale.com/util/syspolicy/source"
"tailscale.com/wgengine"
"tailscale.com/wgengine/filter"
+ "tailscale.com/wgengine/filter/filtertype"
"tailscale.com/wgengine/wgcfg"
)
@@ -5206,3 +5208,60 @@ func TestUpdateIngressLocked(t *testing.T) {
})
}
}
+
+// TestSrcCapPacketFilter tests that LocalBackend handles packet filters with
+// SrcCaps instead of Srcs (IPs)
+func TestSrcCapPacketFilter(t *testing.T) {
+ lb := newLocalBackendWithTestControl(t, false, func(tb testing.TB, opts controlclient.Options) controlclient.Client {
+ return newClient(tb, opts)
+ })
+ if err := lb.Start(ipn.Options{}); err != nil {
+ t.Fatalf("(*LocalBackend).Start(): %v", err)
+ }
+
+ var k key.NodePublic
+ must.Do(k.UnmarshalText([]byte("nodekey:5c8f86d5fc70d924e55f02446165a5dae8f822994ad26bcf4b08fd841f9bf261")))
+
+ controlClient := lb.cc.(*mockControl)
+ controlClient.send(nil, "", false, &netmap.NetworkMap{
+ SelfNode: (&tailcfg.Node{
+ Addresses: []netip.Prefix{netip.MustParsePrefix("1.1.1.1/32")},
+ }).View(),
+ Peers: []tailcfg.NodeView{
+ (&tailcfg.Node{
+ Addresses: []netip.Prefix{netip.MustParsePrefix("2.2.2.2/32")},
+ ID: 2,
+ Key: k,
+ CapMap: tailcfg.NodeCapMap{"cap-X": nil}, // node 2 has cap
+ }).View(),
+ (&tailcfg.Node{
+ Addresses: []netip.Prefix{netip.MustParsePrefix("3.3.3.3/32")},
+ ID: 3,
+ Key: k,
+ CapMap: tailcfg.NodeCapMap{}, // node 3 does not have the cap
+ }).View(),
+ },
+ PacketFilter: []filtertype.Match{{
+ IPProto: views.SliceOf([]ipproto.Proto{ipproto.TCP}),
+ SrcCaps: []tailcfg.NodeCapability{"cap-X"}, // cap in packet filter rule
+ Dsts: []filtertype.NetPortRange{{
+ Net: netip.MustParsePrefix("1.1.1.1/32"),
+ Ports: filtertype.PortRange{
+ First: 22,
+ Last: 22,
+ },
+ }},
+ }},
+ })
+
+ f := lb.GetFilterForTest()
+ res := f.Check(netip.MustParseAddr("2.2.2.2"), netip.MustParseAddr("1.1.1.1"), 22, ipproto.TCP)
+ if res != filter.Accept {
+ t.Errorf("Check(2.2.2.2, ...) = %s, want %s", res, filter.Accept)
+ }
+
+ res = f.Check(netip.MustParseAddr("3.3.3.3"), netip.MustParseAddr("1.1.1.1"), 22, ipproto.TCP)
+ if !res.IsDrop() {
+ t.Error("IsDrop() for node without cap = false, want true")
+ }
+}
diff --git a/ipn/ipnstate/ipnstate.go b/ipn/ipnstate/ipnstate.go
index bc1ba615d..89c6d7e24 100644
--- a/ipn/ipnstate/ipnstate.go
+++ b/ipn/ipnstate/ipnstate.go
@@ -216,6 +216,11 @@ type PeerStatusLite struct {
}
// PeerStatus describes a peer node and its current state.
+// WARNING: The fields in PeerStatus are merged by the AddPeer method in the StatusBuilder.
+// When adding a new field to PeerStatus, you must update AddPeer to handle merging
+// the new field. The AddPeer function is responsible for combining multiple updates
+// to the same peer, and any new field that is not merged properly may lead to
+// inconsistencies or lost data in the peer status.
type PeerStatus struct {
ID tailcfg.StableNodeID
PublicKey key.NodePublic
@@ -533,6 +538,9 @@ func (sb *StatusBuilder) AddPeer(peer key.NodePublic, st *PeerStatus) {
if v := st.Capabilities; v != nil {
e.Capabilities = v
}
+ if v := st.TaildropTarget; v != TaildropTargetUnknown {
+ e.TaildropTarget = v
+ }
e.Location = st.Location
}
diff --git a/ipn/store/awsstore/store_aws.go b/ipn/store/awsstore/store_aws.go
index 0fb78d45a..40bbbf037 100644
--- a/ipn/store/awsstore/store_aws.go
+++ b/ipn/store/awsstore/store_aws.go
@@ -10,7 +10,9 @@ import (
"context"
"errors"
"fmt"
+ "net/url"
"regexp"
+ "strings"
"github.com/aws/aws-sdk-go-v2/aws"
"github.com/aws/aws-sdk-go-v2/aws/arn"
@@ -28,6 +30,14 @@ const (
var parameterNameRx = regexp.MustCompile(parameterNameRxStr)
+// Option defines a functional option type for configuring awsStore.
+type Option func(*storeOptions)
+
+// storeOptions holds optional settings for creating a new awsStore.
+type storeOptions struct {
+ kmsKey string
+}
+
// awsSSMClient is an interface allowing us to mock the couple of
// API calls we are leveraging with the AWSStore provider
type awsSSMClient interface {
@@ -46,6 +56,10 @@ type awsStore struct {
ssmClient awsSSMClient
ssmARN arn.ARN
+ // kmsKey is optional. If empty, the parameter is stored in plaintext.
+ // If non-empty, the parameter is encrypted with this KMS key.
+ kmsKey string
+
memory mem.Store
}
@@ -57,30 +71,80 @@ type awsStore struct {
// Tailscaled to only only store new state in-memory and
// restarting Tailscaled can fail until you delete your state
// from the AWS Parameter Store.
-func New(_ logger.Logf, ssmARN string) (ipn.StateStore, error) {
- return newStore(ssmARN, nil)
+//
+// If you want to specify an optional KMS key,
+// pass one or more Option objects, e.g. awsstore.WithKeyID("alias/my-key").
+func New(_ logger.Logf, ssmARN string, opts ...Option) (ipn.StateStore, error) {
+ // Apply all options to an empty storeOptions
+ var so storeOptions
+ for _, opt := range opts {
+ opt(&so)
+ }
+
+ return newStore(ssmARN, so, nil)
+}
+
+// WithKeyID sets the KMS key to be used for encryption. It can be
+// a KeyID, an alias ("alias/my-key"), or a full ARN.
+//
+// If kmsKey is empty, the Option is a no-op.
+func WithKeyID(kmsKey string) Option {
+ return func(o *storeOptions) {
+ o.kmsKey = kmsKey
+ }
+}
+
+// ParseARNAndOpts parses an ARN and optional URL-encoded parameters
+// from arg.
+func ParseARNAndOpts(arg string) (ssmARN string, opts []Option, err error) {
+ ssmARN = arg
+
+ // Support optional ?url-encoded-parameters.
+ if s, q, ok := strings.Cut(arg, "?"); ok {
+ ssmARN = s
+ q, err := url.ParseQuery(q)
+ if err != nil {
+ return "", nil, err
+ }
+
+ for k := range q {
+ switch k {
+ default:
+ return "", nil, fmt.Errorf("unknown arn option parameter %q", k)
+ case "kmsKey":
+ // We allow an ARN, a key ID, or an alias name for kmsKeyID.
+ // If it doesn't look like an ARN and doesn't have a '/',
+ // prepend "alias/" for KMS alias references.
+ kmsKey := q.Get(k)
+ if kmsKey != "" &&
+ !strings.Contains(kmsKey, "/") &&
+ !strings.HasPrefix(kmsKey, "arn:") {
+ kmsKey = "alias/" + kmsKey
+ }
+ if kmsKey != "" {
+ opts = append(opts, WithKeyID(kmsKey))
+ }
+ }
+ }
+ }
+ return ssmARN, opts, nil
}
// newStore is NewStore, but for tests. If client is non-nil, it's
// used instead of making one.
-func newStore(ssmARN string, client awsSSMClient) (ipn.StateStore, error) {
+func newStore(ssmARN string, so storeOptions, client awsSSMClient) (ipn.StateStore, error) {
s := &awsStore{
ssmClient: client,
+ kmsKey: so.kmsKey,
}
var err error
-
- // Parse the ARN
if s.ssmARN, err = arn.Parse(ssmARN); err != nil {
return nil, fmt.Errorf("unable to parse the ARN correctly: %v", err)
}
-
- // Validate the ARN corresponds to the SSM service
if s.ssmARN.Service != "ssm" {
return nil, fmt.Errorf("invalid service %q, expected 'ssm'", s.ssmARN.Service)
}
-
- // Validate the ARN corresponds to a parameter store resource
if !parameterNameRx.MatchString(s.ssmARN.Resource) {
return nil, fmt.Errorf("invalid resource %q, expected to match %v", s.ssmARN.Resource, parameterNameRxStr)
}
@@ -96,12 +160,11 @@ func newStore(ssmARN string, client awsSSMClient) (ipn.StateStore, error) {
s.ssmClient = ssm.NewFromConfig(cfg)
}
- // Hydrate cache with the potentially current state
+ // Preload existing state, if any
if err := s.LoadState(); err != nil {
return nil, err
}
return s, nil
-
}
// LoadState attempts to read the state from AWS SSM parameter store key.
@@ -172,15 +235,21 @@ func (s *awsStore) persistState() error {
// which is free. However, if it exceeds 4kb it switches the parameter to advanced tiering
// doubling the capacity to 8kb per the following docs:
// https://aws.amazon.com/about-aws/whats-new/2019/08/aws-systems-manager-parameter-store-announces-intelligent-tiering-to-enable-automatic-parameter-tier-selection/
- _, err = s.ssmClient.PutParameter(
- context.TODO(),
- &ssm.PutParameterInput{
- Name: aws.String(s.ParameterName()),
- Value: aws.String(string(bs)),
- Overwrite: aws.Bool(true),
- Tier: ssmTypes.ParameterTierIntelligentTiering,
- Type: ssmTypes.ParameterTypeSecureString,
- },
- )
+ in := &ssm.PutParameterInput{
+ Name: aws.String(s.ParameterName()),
+ Value: aws.String(string(bs)),
+ Overwrite: aws.Bool(true),
+ Tier: ssmTypes.ParameterTierIntelligentTiering,
+ Type: ssmTypes.ParameterTypeSecureString,
+ }
+
+ // If kmsKey is specified, encrypt with that key
+ // NOTE: this input allows any alias, keyID or ARN
+ // If this isn't specified, AWS will use the default KMS key
+ if s.kmsKey != "" {
+ in.KeyId = aws.String(s.kmsKey)
+ }
+
+ _, err = s.ssmClient.PutParameter(context.TODO(), in)
return err
}
diff --git a/ipn/store/awsstore/store_aws_stub.go b/ipn/store/awsstore/store_aws_stub.go
deleted file mode 100644
index 8d2156ce9..000000000
--- a/ipn/store/awsstore/store_aws_stub.go
+++ /dev/null
@@ -1,18 +0,0 @@
-// Copyright (c) Tailscale Inc & AUTHORS
-// SPDX-License-Identifier: BSD-3-Clause
-
-//go:build !linux || ts_omit_aws
-
-package awsstore
-
-import (
- "fmt"
- "runtime"
-
- "tailscale.com/ipn"
- "tailscale.com/types/logger"
-)
-
-func New(logger.Logf, string) (ipn.StateStore, error) {
- return nil, fmt.Errorf("AWS store is not supported on %v", runtime.GOOS)
-}
diff --git a/ipn/store/awsstore/store_aws_test.go b/ipn/store/awsstore/store_aws_test.go
index f6c8fedb3..3382635a7 100644
--- a/ipn/store/awsstore/store_aws_test.go
+++ b/ipn/store/awsstore/store_aws_test.go
@@ -1,7 +1,7 @@
// Copyright (c) Tailscale Inc & AUTHORS
// SPDX-License-Identifier: BSD-3-Clause
-//go:build linux
+//go:build linux && !ts_omit_aws
package awsstore
@@ -65,7 +65,11 @@ func TestNewAWSStore(t *testing.T) {
Resource: "parameter/foo",
}
- s, err := newStore(storeParameterARN.String(), mc)
+ opts := storeOptions{
+ kmsKey: "arn:aws:kms:eu-west-1:123456789:key/MyCustomKey",
+ }
+
+ s, err := newStore(storeParameterARN.String(), opts, mc)
if err != nil {
t.Fatalf("creating aws store failed: %v", err)
}
@@ -73,7 +77,7 @@ func TestNewAWSStore(t *testing.T) {
// Build a brand new file store and check that both IDs written
// above are still there.
- s2, err := newStore(storeParameterARN.String(), mc)
+ s2, err := newStore(storeParameterARN.String(), opts, mc)
if err != nil {
t.Fatalf("creating second aws store failed: %v", err)
}
@@ -162,3 +166,54 @@ func testStoreSemantics(t *testing.T, store ipn.StateStore) {
}
}
}
+
+func TestParseARNAndOpts(t *testing.T) {
+ tests := []struct {
+ name string
+ arg string
+ wantARN string
+ wantKey string
+ }{
+ {
+ name: "no-key",
+ arg: "arn:aws:ssm:us-east-1:123456789012:parameter/myTailscaleParam",
+ wantARN: "arn:aws:ssm:us-east-1:123456789012:parameter/myTailscaleParam",
+ },
+ {
+ name: "custom-key",
+ arg: "arn:aws:ssm:us-east-1:123456789012:parameter/myTailscaleParam?kmsKey=alias/MyCustomKey",
+ wantARN: "arn:aws:ssm:us-east-1:123456789012:parameter/myTailscaleParam",
+ wantKey: "alias/MyCustomKey",
+ },
+ {
+ name: "bare-name",
+ arg: "arn:aws:ssm:us-east-1:123456789012:parameter/myTailscaleParam?kmsKey=Bare",
+ wantARN: "arn:aws:ssm:us-east-1:123456789012:parameter/myTailscaleParam",
+ wantKey: "alias/Bare",
+ },
+ {
+ name: "arn-arg",
+ arg: "arn:aws:ssm:us-east-1:123456789012:parameter/myTailscaleParam?kmsKey=arn:foo",
+ wantARN: "arn:aws:ssm:us-east-1:123456789012:parameter/myTailscaleParam",
+ wantKey: "arn:foo",
+ },
+ }
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ arn, opts, err := ParseARNAndOpts(tt.arg)
+ if err != nil {
+ t.Fatalf("New: %v", err)
+ }
+ if arn != tt.wantARN {
+ t.Errorf("ARN = %q; want %q", arn, tt.wantARN)
+ }
+ var got storeOptions
+ for _, opt := range opts {
+ opt(&got)
+ }
+ if got.kmsKey != tt.wantKey {
+ t.Errorf("kmsKey = %q; want %q", got.kmsKey, tt.wantKey)
+ }
+ })
+ }
+}
diff --git a/ipn/store/kubestore/store_kube.go b/ipn/store/kubestore/store_kube.go
index 462e6d434..ed37f06c2 100644
--- a/ipn/store/kubestore/store_kube.go
+++ b/ipn/store/kubestore/store_kube.go
@@ -13,11 +13,15 @@ import (
"strings"
"time"
+ "tailscale.com/envknob"
"tailscale.com/ipn"
"tailscale.com/ipn/store/mem"
"tailscale.com/kube/kubeapi"
"tailscale.com/kube/kubeclient"
+ "tailscale.com/kube/kubetypes"
"tailscale.com/types/logger"
+ "tailscale.com/util/dnsname"
+ "tailscale.com/util/mak"
)
const (
@@ -31,21 +35,37 @@ const (
reasonTailscaleStateLoadFailed = "TailscaleStateLoadFailed"
eventTypeWarning = "Warning"
eventTypeNormal = "Normal"
+
+ keyTLSCert = "tls.crt"
+ keyTLSKey = "tls.key"
)
// Store is an ipn.StateStore that uses a Kubernetes Secret for persistence.
type Store struct {
- client kubeclient.Client
- canPatch bool
- secretName string
+ client kubeclient.Client
+ canPatch bool
+ secretName string // state Secret
+ certShareMode string // 'ro', 'rw', or empty
+ podName string
- // memory holds the latest tailscale state. Writes write state to a kube Secret and memory, Reads read from
- // memory.
+ // memory holds the latest tailscale state. Writes write state to a kube
+ // Secret and memory, Reads read from memory.
memory mem.Store
}
-// New returns a new Store that persists to the named Secret.
-func New(_ logger.Logf, secretName string) (*Store, error) {
+// New returns a new Store that persists state to Kubernets Secret(s).
+// Tailscale state is stored in a Secret named by the secretName parameter.
+// TLS certs are stored and retrieved from state Secret or separate Secrets
+// named after TLS endpoints if running in cert share mode.
+func New(logf logger.Logf, secretName string) (*Store, error) {
+ c, err := newClient()
+ if err != nil {
+ return nil, err
+ }
+ return newWithClient(logf, c, secretName)
+}
+
+func newClient() (kubeclient.Client, error) {
c, err := kubeclient.New("tailscale-state-store")
if err != nil {
return nil, err
@@ -54,6 +74,10 @@ func New(_ logger.Logf, secretName string) (*Store, error) {
// Derive the API server address from the environment variables
c.SetURL(fmt.Sprintf("https://%s:%s", os.Getenv("KUBERNETES_SERVICE_HOST"), os.Getenv("KUBERNETES_SERVICE_PORT_HTTPS")))
}
+ return c, nil
+}
+
+func newWithClient(logf logger.Logf, c kubeclient.Client, secretName string) (*Store, error) {
canPatch, _, err := c.CheckSecretPermissions(context.Background(), secretName)
if err != nil {
return nil, err
@@ -62,11 +86,30 @@ func New(_ logger.Logf, secretName string) (*Store, error) {
client: c,
canPatch: canPatch,
secretName: secretName,
+ podName: os.Getenv("POD_NAME"),
+ }
+ if envknob.IsCertShareReadWriteMode() {
+ s.certShareMode = "rw"
+ } else if envknob.IsCertShareReadOnlyMode() {
+ s.certShareMode = "ro"
}
+
// Load latest state from kube Secret if it already exists.
if err := s.loadState(); err != nil && err != ipn.ErrStateNotExist {
return nil, fmt.Errorf("error loading state from kube Secret: %w", err)
}
+ // If we are in cert share mode, pre-load existing shared certs.
+ if s.certShareMode == "rw" || s.certShareMode == "ro" {
+ sel := s.certSecretSelector()
+ if err := s.loadCerts(context.Background(), sel); err != nil {
+ // We will attempt to again retrieve the certs from Secrets when a request for an HTTPS endpoint
+ // is received.
+ log.Printf("[unexpected] error loading TLS certs: %v", err)
+ }
+ }
+ if s.certShareMode == "ro" {
+ go s.runCertReload(context.Background(), logf)
+ }
return s, nil
}
@@ -83,11 +126,101 @@ func (s *Store) ReadState(id ipn.StateKey) ([]byte, error) {
// WriteState implements the StateStore interface.
func (s *Store) WriteState(id ipn.StateKey, bs []byte) (err error) {
- ctx, cancel := context.WithTimeout(context.Background(), timeout)
defer func() {
if err == nil {
s.memory.WriteState(ipn.StateKey(sanitizeKey(id)), bs)
}
+ }()
+ return s.updateSecret(map[string][]byte{string(id): bs}, s.secretName)
+}
+
+// WriteTLSCertAndKey writes a TLS cert and key to domain.crt, domain.key fields
+// of a Tailscale Kubernetes node's state Secret.
+func (s *Store) WriteTLSCertAndKey(domain string, cert, key []byte) (err error) {
+ if s.certShareMode == "ro" {
+ log.Printf("[unexpected] TLS cert and key write in read-only mode")
+ }
+ if err := dnsname.ValidHostname(domain); err != nil {
+ return fmt.Errorf("invalid domain name %q: %w", domain, err)
+ }
+ defer func() {
+ // TODO(irbekrm): a read between these two separate writes would
+ // get a mismatched cert and key. Allow writing both cert and
+ // key to the memory store in a single, lock-protected operation.
+ if err == nil {
+ s.memory.WriteState(ipn.StateKey(domain+".crt"), cert)
+ s.memory.WriteState(ipn.StateKey(domain+".key"), key)
+ }
+ }()
+ secretName := s.secretName
+ data := map[string][]byte{
+ domain + ".crt": cert,
+ domain + ".key": key,
+ }
+ // If we run in cert share mode, cert and key for a DNS name are written
+ // to a separate Secret.
+ if s.certShareMode == "rw" {
+ secretName = domain
+ data = map[string][]byte{
+ keyTLSCert: cert,
+ keyTLSKey: key,
+ }
+ }
+ return s.updateSecret(data, secretName)
+}
+
+// ReadTLSCertAndKey reads a TLS cert and key from memory or from a
+// domain-specific Secret. It first checks the in-memory store, if not found in
+// memory and running cert store in read-only mode, looks up a Secret.
+func (s *Store) ReadTLSCertAndKey(domain string) (cert, key []byte, err error) {
+ if err := dnsname.ValidHostname(domain); err != nil {
+ return nil, nil, fmt.Errorf("invalid domain name %q: %w", domain, err)
+ }
+ certKey := domain + ".crt"
+ keyKey := domain + ".key"
+
+ cert, err = s.memory.ReadState(ipn.StateKey(certKey))
+ if err == nil {
+ key, err = s.memory.ReadState(ipn.StateKey(keyKey))
+ if err == nil {
+ return cert, key, nil
+ }
+ }
+ if s.certShareMode != "ro" {
+ return nil, nil, ipn.ErrStateNotExist
+ }
+ // If we are in cert share read only mode, it is possible that a write
+ // replica just issued the TLS cert for this DNS name and it has not
+ // been loaded to store yet, so check the Secret.
+
+ ctx, cancel := context.WithTimeout(context.Background(), timeout)
+ defer cancel()
+
+ secret, err := s.client.GetSecret(ctx, domain)
+ if err != nil {
+ if kubeclient.IsNotFoundErr(err) {
+ // TODO(irbekrm): we should return a more specific error
+ // that wraps ipn.ErrStateNotExist here.
+ return nil, nil, ipn.ErrStateNotExist
+ }
+ return nil, nil, fmt.Errorf("getting TLS Secret %q: %w", domain, err)
+ }
+ cert = secret.Data[keyTLSCert]
+ key = secret.Data[keyTLSKey]
+ if len(cert) == 0 || len(key) == 0 {
+ return nil, nil, ipn.ErrStateNotExist
+ }
+ // TODO(irbekrm): a read between these two separate writes would
+ // get a mismatched cert and key. Allow writing both cert and
+ // key to the memory store in a single lock-protected operation.
+ s.memory.WriteState(ipn.StateKey(certKey), cert)
+ s.memory.WriteState(ipn.StateKey(keyKey), key)
+ return cert, key, nil
+}
+
+func (s *Store) updateSecret(data map[string][]byte, secretName string) (err error) {
+ ctx, cancel := context.WithTimeout(context.Background(), timeout)
+ defer func() {
if err != nil {
if err := s.client.Event(ctx, eventTypeWarning, reasonTailscaleStateUpdateFailed, err.Error()); err != nil {
log.Printf("kubestore: error creating tailscaled state update Event: %v", err)
@@ -99,56 +232,69 @@ func (s *Store) WriteState(id ipn.StateKey, bs []byte) (err error) {
}
cancel()
}()
-
- secret, err := s.client.GetSecret(ctx, s.secretName)
+ secret, err := s.client.GetSecret(ctx, secretName)
if err != nil {
- if kubeclient.IsNotFoundErr(err) {
+ // If the Secret does not exist, create it with the required data.
+ if kubeclient.IsNotFoundErr(err) && s.canCreateSecret(secretName) {
return s.client.CreateSecret(ctx, &kubeapi.Secret{
TypeMeta: kubeapi.TypeMeta{
APIVersion: "v1",
Kind: "Secret",
},
ObjectMeta: kubeapi.ObjectMeta{
- Name: s.secretName,
- },
- Data: map[string][]byte{
- sanitizeKey(id): bs,
+ Name: secretName,
},
+ Data: func(m map[string][]byte) map[string][]byte {
+ d := make(map[string][]byte, len(m))
+ for key, val := range m {
+ d[sanitizeKey(key)] = val
+ }
+ return d
+ }(data),
})
}
- return err
+ return fmt.Errorf("error getting Secret %s: %w", secretName, err)
}
- if s.canPatch {
- if len(secret.Data) == 0 { // if user has pre-created a blank Secret
- m := []kubeclient.JSONPatch{
+ if s.canPatchSecret(secretName) {
+ var m []kubeclient.JSONPatch
+ // If the user has pre-created a Secret with no data, we need to ensure the top level /data field.
+ if len(secret.Data) == 0 {
+ m = []kubeclient.JSONPatch{
{
- Op: "add",
- Path: "/data",
- Value: map[string][]byte{sanitizeKey(id): bs},
+ Op: "add",
+ Path: "/data",
+ Value: func(m map[string][]byte) map[string][]byte {
+ d := make(map[string][]byte, len(m))
+ for key, val := range m {
+ d[sanitizeKey(key)] = val
+ }
+ return d
+ }(data),
},
}
- if err := s.client.JSONPatchResource(ctx, s.secretName, kubeclient.TypeSecrets, m); err != nil {
- return fmt.Errorf("error patching Secret %s with a /data field: %v", s.secretName, err)
+ // If the Secret has data, patch it with the new data.
+ } else {
+ for key, val := range data {
+ m = append(m, kubeclient.JSONPatch{
+ Op: "add",
+ Path: "/data/" + sanitizeKey(key),
+ Value: val,
+ })
}
- return nil
- }
- m := []kubeclient.JSONPatch{
- {
- Op: "add",
- Path: "/data/" + sanitizeKey(id),
- Value: bs,
- },
}
- if err := s.client.JSONPatchResource(ctx, s.secretName, kubeclient.TypeSecrets, m); err != nil {
- return fmt.Errorf("error patching Secret %s with /data/%s field: %v", s.secretName, sanitizeKey(id), err)
+ if err := s.client.JSONPatchResource(ctx, secretName, kubeclient.TypeSecrets, m); err != nil {
+ return fmt.Errorf("error patching Secret %s: %w", secretName, err)
}
return nil
}
- secret.Data[sanitizeKey(id)] = bs
+ // No patch permissions, use UPDATE instead.
+ for key, val := range data {
+ mak.Set(&secret.Data, sanitizeKey(key), val)
+ }
if err := s.client.UpdateSecret(ctx, secret); err != nil {
- return err
+ return fmt.Errorf("error updating Secret %s: %w", s.secretName, err)
}
- return err
+ return nil
}
func (s *Store) loadState() (err error) {
@@ -172,9 +318,100 @@ func (s *Store) loadState() (err error) {
return nil
}
-func sanitizeKey(k ipn.StateKey) string {
- // The only valid characters in a Kubernetes secret key are alphanumeric, -,
- // _, and .
+// runCertReload relists and reloads all TLS certs for endpoints shared by this
+// node from Secrets other than the state Secret to ensure that renewed certs get eventually loaded.
+// It is not critical to reload a cert immediately after
+// renewal, so a daily check is acceptable.
+// Currently (3/2025) this is only used for the shared HA Ingress certs on 'read' replicas.
+// Note that if shared certs are not found in memory on an HTTPS request, we
+// do a Secret lookup, so this mechanism does not need to ensure that newly
+// added Ingresses' certs get loaded.
+func (s *Store) runCertReload(ctx context.Context, logf logger.Logf) {
+ ticker := time.NewTicker(time.Hour * 24)
+ defer ticker.Stop()
+ for {
+ select {
+ case <-ctx.Done():
+ return
+ case <-ticker.C:
+ sel := s.certSecretSelector()
+ if err := s.loadCerts(ctx, sel); err != nil {
+ logf("[unexpected] error reloading TLS certs: %v", err)
+ }
+ }
+ }
+}
+
+// loadCerts lists all Secrets matching the provided selector and loads TLS
+// certs and keys from those.
+func (s *Store) loadCerts(ctx context.Context, sel map[string]string) error {
+ ss, err := s.client.ListSecrets(ctx, sel)
+ if err != nil {
+ return fmt.Errorf("error listing TLS Secrets: %w", err)
+ }
+ for _, secret := range ss.Items {
+ if !hasTLSData(&secret) {
+ continue
+ }
+ // Only load secrets that have valid domain names (ending in .ts.net)
+ if !strings.HasSuffix(secret.Name, ".ts.net") {
+ continue
+ }
+ s.memory.WriteState(ipn.StateKey(secret.Name)+".crt", secret.Data[keyTLSCert])
+ s.memory.WriteState(ipn.StateKey(secret.Name)+".key", secret.Data[keyTLSKey])
+ }
+ return nil
+}
+
+// canCreateSecret returns true if this node should be allowed to create the given
+// Secret in its namespace.
+func (s *Store) canCreateSecret(secret string) bool {
+ // Only allow creating the state Secret (and not TLS Secrets).
+ return secret == s.secretName
+}
+
+// canPatchSecret returns true if this node should be allowed to patch the given
+// Secret.
+func (s *Store) canPatchSecret(secret string) bool {
+ // For backwards compatibility reasons, setups where the proxies are not
+ // given PATCH permissions for state Secrets are allowed. For TLS
+ // Secrets, we should always have PATCH permissions.
+ if secret == s.secretName {
+ return s.canPatch
+ }
+ return true
+}
+
+// certSecretSelector returns a label selector that can be used to list all
+// Secrets that aren't Tailscale state Secrets and contain TLS certificates for
+// HTTPS endpoints that this node serves.
+// Currently (3/2025) this only applies to the Kubernetes Operator's ingress
+// ProxyGroup.
+func (s *Store) certSecretSelector() map[string]string {
+ if s.podName == "" {
+ return map[string]string{}
+ }
+ p := strings.LastIndex(s.podName, "-")
+ if p == -1 {
+ return map[string]string{}
+ }
+ pgName := s.podName[:p]
+ return map[string]string{
+ kubetypes.LabelSecretType: "certs",
+ kubetypes.LabelManaged: "true",
+ "tailscale.com/proxy-group": pgName,
+ }
+}
+
+// hasTLSData returns true if the provided Secret contains non-empty TLS cert and key.
+func hasTLSData(s *kubeapi.Secret) bool {
+ return len(s.Data[keyTLSCert]) != 0 && len(s.Data[keyTLSKey]) != 0
+}
+
+// sanitizeKey converts any value that can be converted to a string into a valid Kubernetes Secret key.
+// Valid characters are alphanumeric, -, _, and .
+// https://kubernetes.io/docs/concepts/configuration/secret/#restriction-names-data.
+func sanitizeKey[T ~string](k T) string {
return strings.Map(func(r rune) rune {
if r >= 'a' && r <= 'z' || r >= 'A' && r <= 'Z' || r >= '0' && r <= '9' || r == '-' || r == '_' || r == '.' {
return r
diff --git a/ipn/store/kubestore/store_kube_test.go b/ipn/store/kubestore/store_kube_test.go
new file mode 100644
index 000000000..2ed16e77b
--- /dev/null
+++ b/ipn/store/kubestore/store_kube_test.go
@@ -0,0 +1,723 @@
+// Copyright (c) Tailscale Inc & AUTHORS
+// SPDX-License-Identifier: BSD-3-Clause
+
+package kubestore
+
+import (
+ "bytes"
+ "context"
+ "encoding/json"
+ "fmt"
+ "strings"
+ "testing"
+
+ "github.com/google/go-cmp/cmp"
+ "tailscale.com/envknob"
+ "tailscale.com/ipn"
+ "tailscale.com/ipn/store/mem"
+ "tailscale.com/kube/kubeapi"
+ "tailscale.com/kube/kubeclient"
+)
+
+func TestWriteState(t *testing.T) {
+ tests := []struct {
+ name string
+ initial map[string][]byte
+ key ipn.StateKey
+ value []byte
+ wantData map[string][]byte
+ allowPatch bool
+ }{
+ {
+ name: "basic_write",
+ initial: map[string][]byte{
+ "existing": []byte("old"),
+ },
+ key: "foo",
+ value: []byte("bar"),
+ wantData: map[string][]byte{
+ "existing": []byte("old"),
+ "foo": []byte("bar"),
+ },
+ allowPatch: true,
+ },
+ {
+ name: "update_existing",
+ initial: map[string][]byte{
+ "foo": []byte("old"),
+ },
+ key: "foo",
+ value: []byte("new"),
+ wantData: map[string][]byte{
+ "foo": []byte("new"),
+ },
+ allowPatch: true,
+ },
+ {
+ name: "create_new_secret",
+ key: "foo",
+ value: []byte("bar"),
+ wantData: map[string][]byte{
+ "foo": []byte("bar"),
+ },
+ allowPatch: true,
+ },
+ {
+ name: "patch_denied",
+ initial: map[string][]byte{
+ "foo": []byte("old"),
+ },
+ key: "foo",
+ value: []byte("new"),
+ wantData: map[string][]byte{
+ "foo": []byte("new"),
+ },
+ allowPatch: false,
+ },
+ {
+ name: "sanitize_key",
+ initial: map[string][]byte{
+ "clean-key": []byte("old"),
+ },
+ key: "dirty@key",
+ value: []byte("new"),
+ wantData: map[string][]byte{
+ "clean-key": []byte("old"),
+ "dirty_key": []byte("new"),
+ },
+ allowPatch: true,
+ },
+ }
+
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ secret := tt.initial // track current state
+ client := &kubeclient.FakeClient{
+ GetSecretImpl: func(ctx context.Context, name string) (*kubeapi.Secret, error) {
+ if secret == nil {
+ return nil, &kubeapi.Status{Code: 404}
+ }
+ return &kubeapi.Secret{Data: secret}, nil
+ },
+ CheckSecretPermissionsImpl: func(ctx context.Context, name string) (bool, bool, error) {
+ return tt.allowPatch, true, nil
+ },
+ CreateSecretImpl: func(ctx context.Context, s *kubeapi.Secret) error {
+ secret = s.Data
+ return nil
+ },
+ UpdateSecretImpl: func(ctx context.Context, s *kubeapi.Secret) error {
+ secret = s.Data
+ return nil
+ },
+ JSONPatchResourceImpl: func(ctx context.Context, name, resourceType string, patches []kubeclient.JSONPatch) error {
+ if !tt.allowPatch {
+ return &kubeapi.Status{Reason: "Forbidden"}
+ }
+ if secret == nil {
+ secret = make(map[string][]byte)
+ }
+ for _, p := range patches {
+ if p.Op == "add" && p.Path == "/data" {
+ secret = p.Value.(map[string][]byte)
+ } else if p.Op == "add" && strings.HasPrefix(p.Path, "/data/") {
+ key := strings.TrimPrefix(p.Path, "/data/")
+ secret[key] = p.Value.([]byte)
+ }
+ }
+ return nil
+ },
+ }
+
+ s := &Store{
+ client: client,
+ canPatch: tt.allowPatch,
+ secretName: "ts-state",
+ memory: mem.Store{},
+ }
+
+ err := s.WriteState(tt.key, tt.value)
+ if err != nil {
+ t.Errorf("WriteState() error = %v", err)
+ return
+ }
+
+ // Verify secret data
+ if diff := cmp.Diff(secret, tt.wantData); diff != "" {
+ t.Errorf("secret data mismatch (-got +want):\n%s", diff)
+ }
+
+ // Verify memory store was updated
+ got, err := s.memory.ReadState(ipn.StateKey(sanitizeKey(string(tt.key))))
+ if err != nil {
+ t.Errorf("reading from memory store: %v", err)
+ }
+ if !cmp.Equal(got, tt.value) {
+ t.Errorf("memory store key %q = %v, want %v", tt.key, got, tt.value)
+ }
+ })
+ }
+}
+
+func TestWriteTLSCertAndKey(t *testing.T) {
+ const (
+ testDomain = "my-app.tailnetxyz.ts.net"
+ testCert = "fake-cert"
+ testKey = "fake-key"
+ )
+
+ tests := []struct {
+ name string
+ initial map[string][]byte // pre-existing cert and key
+ certShareMode string
+ allowPatch bool // whether client can patch the Secret
+ wantSecretName string // name of the Secret where cert and key should be written
+ wantSecretData map[string][]byte
+ wantMemoryStore map[ipn.StateKey][]byte
+ }{
+ {
+ name: "basic_write",
+ initial: map[string][]byte{
+ "existing": []byte("old"),
+ },
+ allowPatch: true,
+ wantSecretName: "ts-state",
+ wantSecretData: map[string][]byte{
+ "existing": []byte("old"),
+ "my-app.tailnetxyz.ts.net.crt": []byte(testCert),
+ "my-app.tailnetxyz.ts.net.key": []byte(testKey),
+ },
+ wantMemoryStore: map[ipn.StateKey][]byte{
+ "my-app.tailnetxyz.ts.net.crt": []byte(testCert),
+ "my-app.tailnetxyz.ts.net.key": []byte(testKey),
+ },
+ },
+ {
+ name: "cert_share_mode_write",
+ certShareMode: "rw",
+ allowPatch: true,
+ wantSecretName: "my-app.tailnetxyz.ts.net",
+ wantSecretData: map[string][]byte{
+ "tls.crt": []byte(testCert),
+ "tls.key": []byte(testKey),
+ },
+ wantMemoryStore: map[ipn.StateKey][]byte{
+ "my-app.tailnetxyz.ts.net.crt": []byte(testCert),
+ "my-app.tailnetxyz.ts.net.key": []byte(testKey),
+ },
+ },
+ {
+ name: "cert_share_mode_write_update_existing",
+ initial: map[string][]byte{
+ "tls.crt": []byte("old-cert"),
+ "tls.key": []byte("old-key"),
+ },
+ certShareMode: "rw",
+ allowPatch: true,
+ wantSecretName: "my-app.tailnetxyz.ts.net",
+ wantSecretData: map[string][]byte{
+ "tls.crt": []byte(testCert),
+ "tls.key": []byte(testKey),
+ },
+ wantMemoryStore: map[ipn.StateKey][]byte{
+ "my-app.tailnetxyz.ts.net.crt": []byte(testCert),
+ "my-app.tailnetxyz.ts.net.key": []byte(testKey),
+ },
+ },
+ {
+ name: "update_existing",
+ initial: map[string][]byte{
+ "my-app.tailnetxyz.ts.net.crt": []byte("old-cert"),
+ "my-app.tailnetxyz.ts.net.key": []byte("old-key"),
+ },
+ certShareMode: "",
+ allowPatch: true,
+ wantSecretName: "ts-state",
+ wantSecretData: map[string][]byte{
+ "my-app.tailnetxyz.ts.net.crt": []byte(testCert),
+ "my-app.tailnetxyz.ts.net.key": []byte(testKey),
+ },
+ wantMemoryStore: map[ipn.StateKey][]byte{
+ "my-app.tailnetxyz.ts.net.crt": []byte(testCert),
+ "my-app.tailnetxyz.ts.net.key": []byte(testKey),
+ },
+ },
+ {
+ name: "patch_denied",
+ certShareMode: "",
+ allowPatch: false,
+ wantSecretName: "ts-state",
+ wantSecretData: map[string][]byte{
+ "my-app.tailnetxyz.ts.net.crt": []byte(testCert),
+ "my-app.tailnetxyz.ts.net.key": []byte(testKey),
+ },
+ wantMemoryStore: map[ipn.StateKey][]byte{
+ "my-app.tailnetxyz.ts.net.crt": []byte(testCert),
+ "my-app.tailnetxyz.ts.net.key": []byte(testKey),
+ },
+ },
+ }
+
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+
+ // Set POD_NAME for testing selectors
+ envknob.Setenv("POD_NAME", "ingress-proxies-1")
+ defer envknob.Setenv("POD_NAME", "")
+
+ secret := tt.initial // track current state
+ client := &kubeclient.FakeClient{
+ GetSecretImpl: func(ctx context.Context, name string) (*kubeapi.Secret, error) {
+ if secret == nil {
+ return nil, &kubeapi.Status{Code: 404}
+ }
+ return &kubeapi.Secret{Data: secret}, nil
+ },
+ CheckSecretPermissionsImpl: func(ctx context.Context, name string) (bool, bool, error) {
+ return tt.allowPatch, true, nil
+ },
+ CreateSecretImpl: func(ctx context.Context, s *kubeapi.Secret) error {
+ if s.Name != tt.wantSecretName {
+ t.Errorf("CreateSecret called with wrong name, got %q, want %q", s.Name, tt.wantSecretName)
+ }
+ secret = s.Data
+ return nil
+ },
+ UpdateSecretImpl: func(ctx context.Context, s *kubeapi.Secret) error {
+ if s.Name != tt.wantSecretName {
+ t.Errorf("UpdateSecret called with wrong name, got %q, want %q", s.Name, tt.wantSecretName)
+ }
+ secret = s.Data
+ return nil
+ },
+ JSONPatchResourceImpl: func(ctx context.Context, name, resourceType string, patches []kubeclient.JSONPatch) error {
+ if !tt.allowPatch {
+ return &kubeapi.Status{Reason: "Forbidden"}
+ }
+ if name != tt.wantSecretName {
+ t.Errorf("JSONPatchResource called with wrong name, got %q, want %q", name, tt.wantSecretName)
+ }
+ if secret == nil {
+ secret = make(map[string][]byte)
+ }
+ for _, p := range patches {
+ if p.Op == "add" && p.Path == "/data" {
+ secret = p.Value.(map[string][]byte)
+ } else if p.Op == "add" && strings.HasPrefix(p.Path, "/data/") {
+ key := strings.TrimPrefix(p.Path, "/data/")
+ secret[key] = p.Value.([]byte)
+ }
+ }
+ return nil
+ },
+ }
+
+ s := &Store{
+ client: client,
+ canPatch: tt.allowPatch,
+ secretName: tt.wantSecretName,
+ certShareMode: tt.certShareMode,
+ memory: mem.Store{},
+ }
+
+ err := s.WriteTLSCertAndKey(testDomain, []byte(testCert), []byte(testKey))
+ if err != nil {
+ t.Errorf("WriteTLSCertAndKey() error = '%v'", err)
+ return
+ }
+
+ // Verify secret data
+ if diff := cmp.Diff(secret, tt.wantSecretData); diff != "" {
+ t.Errorf("secret data mismatch (-got +want):\n%s", diff)
+ }
+
+ // Verify memory store was updated
+ for key, want := range tt.wantMemoryStore {
+ got, err := s.memory.ReadState(key)
+ if err != nil {
+ t.Errorf("reading from memory store: %v", err)
+ continue
+ }
+ if !cmp.Equal(got, want) {
+ t.Errorf("memory store key %q = %v, want %v", key, got, want)
+ }
+ }
+ })
+ }
+}
+
+func TestReadTLSCertAndKey(t *testing.T) {
+ const (
+ testDomain = "my-app.tailnetxyz.ts.net"
+ testCert = "fake-cert"
+ testKey = "fake-key"
+ )
+
+ tests := []struct {
+ name string
+ memoryStore map[ipn.StateKey][]byte // pre-existing memory store state
+ certShareMode string
+ domain string
+ secretData map[string][]byte // data to return from mock GetSecret
+ secretGetErr error // error to return from mock GetSecret
+ wantCert []byte
+ wantKey []byte
+ wantErr error
+ // what should end up in memory store after the store is created
+ wantMemoryStore map[ipn.StateKey][]byte
+ }{
+ {
+ name: "found",
+ memoryStore: map[ipn.StateKey][]byte{
+ "my-app.tailnetxyz.ts.net.crt": []byte(testCert),
+ "my-app.tailnetxyz.ts.net.key": []byte(testKey),
+ },
+ domain: testDomain,
+ wantCert: []byte(testCert),
+ wantKey: []byte(testKey),
+ wantMemoryStore: map[ipn.StateKey][]byte{
+ "my-app.tailnetxyz.ts.net.crt": []byte(testCert),
+ "my-app.tailnetxyz.ts.net.key": []byte(testKey),
+ },
+ },
+ {
+ name: "not_found",
+ domain: testDomain,
+ wantErr: ipn.ErrStateNotExist,
+ },
+ {
+ name: "cert_share_ro_mode_found_in_secret",
+ certShareMode: "ro",
+ domain: testDomain,
+ secretData: map[string][]byte{
+ "tls.crt": []byte(testCert),
+ "tls.key": []byte(testKey),
+ },
+ wantCert: []byte(testCert),
+ wantKey: []byte(testKey),
+ wantMemoryStore: map[ipn.StateKey][]byte{
+ "my-app.tailnetxyz.ts.net.crt": []byte(testCert),
+ "my-app.tailnetxyz.ts.net.key": []byte(testKey),
+ },
+ },
+ {
+ name: "cert_share_ro_mode_found_in_memory",
+ certShareMode: "ro",
+ memoryStore: map[ipn.StateKey][]byte{
+ "my-app.tailnetxyz.ts.net.crt": []byte(testCert),
+ "my-app.tailnetxyz.ts.net.key": []byte(testKey),
+ },
+ domain: testDomain,
+ wantCert: []byte(testCert),
+ wantKey: []byte(testKey),
+ wantMemoryStore: map[ipn.StateKey][]byte{
+ "my-app.tailnetxyz.ts.net.crt": []byte(testCert),
+ "my-app.tailnetxyz.ts.net.key": []byte(testKey),
+ },
+ },
+ {
+ name: "cert_share_ro_mode_not_found",
+ certShareMode: "ro",
+ domain: testDomain,
+ secretGetErr: &kubeapi.Status{Code: 404},
+ wantErr: ipn.ErrStateNotExist,
+ },
+ {
+ name: "cert_share_ro_mode_empty_cert_in_secret",
+ certShareMode: "ro",
+ domain: testDomain,
+ secretData: map[string][]byte{
+ "tls.crt": {},
+ "tls.key": []byte(testKey),
+ },
+ wantErr: ipn.ErrStateNotExist,
+ },
+ {
+ name: "cert_share_ro_mode_kube_api_error",
+ certShareMode: "ro",
+ domain: testDomain,
+ secretGetErr: fmt.Errorf("api error"),
+ wantErr: fmt.Errorf("getting TLS Secret %q: api error", sanitizeKey(testDomain)),
+ },
+ }
+
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+
+ client := &kubeclient.FakeClient{
+ GetSecretImpl: func(ctx context.Context, name string) (*kubeapi.Secret, error) {
+ if tt.secretGetErr != nil {
+ return nil, tt.secretGetErr
+ }
+ return &kubeapi.Secret{Data: tt.secretData}, nil
+ },
+ }
+
+ s := &Store{
+ client: client,
+ secretName: "ts-state",
+ certShareMode: tt.certShareMode,
+ memory: mem.Store{},
+ }
+
+ // Initialize memory store
+ for k, v := range tt.memoryStore {
+ s.memory.WriteState(k, v)
+ }
+
+ gotCert, gotKey, err := s.ReadTLSCertAndKey(tt.domain)
+ if tt.wantErr != nil {
+ if err == nil {
+ t.Errorf("ReadTLSCertAndKey() error = nil, want error containing %v", tt.wantErr)
+ return
+ }
+ if !strings.Contains(err.Error(), tt.wantErr.Error()) {
+ t.Errorf("ReadTLSCertAndKey() error = %v, want error containing %v", err, tt.wantErr)
+ }
+ return
+ }
+ if err != nil {
+ t.Errorf("ReadTLSCertAndKey() unexpected error: %v", err)
+ return
+ }
+
+ if !bytes.Equal(gotCert, tt.wantCert) {
+ t.Errorf("ReadTLSCertAndKey() gotCert = %v, want %v", gotCert, tt.wantCert)
+ }
+ if !bytes.Equal(gotKey, tt.wantKey) {
+ t.Errorf("ReadTLSCertAndKey() gotKey = %v, want %v", gotKey, tt.wantKey)
+ }
+
+ // Verify memory store contents after operation
+ if tt.wantMemoryStore != nil {
+ for key, want := range tt.wantMemoryStore {
+ got, err := s.memory.ReadState(key)
+ if err != nil {
+ t.Errorf("reading from memory store: %v", err)
+ continue
+ }
+ if !bytes.Equal(got, want) {
+ t.Errorf("memory store key %q = %v, want %v", key, got, want)
+ }
+ }
+ }
+ })
+ }
+}
+
+func TestNewWithClient(t *testing.T) {
+ const (
+ secretName = "ts-state"
+ testCert = "fake-cert"
+ testKey = "fake-key"
+ )
+
+ certSecretsLabels := map[string]string{
+ "tailscale.com/secret-type": "certs",
+ "tailscale.com/managed": "true",
+ "tailscale.com/proxy-group": "ingress-proxies",
+ }
+
+ // Helper function to create Secret objects for testing
+ makeSecret := func(name string, labels map[string]string, certSuffix string) kubeapi.Secret {
+ return kubeapi.Secret{
+ ObjectMeta: kubeapi.ObjectMeta{
+ Name: name,
+ Labels: labels,
+ },
+ Data: map[string][]byte{
+ "tls.crt": []byte(testCert + certSuffix),
+ "tls.key": []byte(testKey + certSuffix),
+ },
+ }
+ }
+
+ tests := []struct {
+ name string
+ stateSecretContents map[string][]byte // data in state Secret
+ TLSSecrets []kubeapi.Secret // list of TLS cert Secrets
+ certMode string
+ secretGetErr error // error to return from GetSecret
+ secretsListErr error // error to return from ListSecrets
+ wantMemoryStoreContents map[ipn.StateKey][]byte
+ wantErr error
+ }{
+ {
+ name: "empty_state_secret",
+ stateSecretContents: map[string][]byte{},
+ wantMemoryStoreContents: map[ipn.StateKey][]byte{},
+ },
+ {
+ name: "state_secret_not_found",
+ secretGetErr: &kubeapi.Status{Code: 404},
+ wantMemoryStoreContents: map[ipn.StateKey][]byte{},
+ },
+ {
+ name: "state_secret_get_error",
+ secretGetErr: fmt.Errorf("some error"),
+ wantErr: fmt.Errorf("error loading state from kube Secret: some error"),
+ },
+ {
+ name: "load_existing_state",
+ stateSecretContents: map[string][]byte{
+ "foo": []byte("bar"),
+ "baz": []byte("qux"),
+ },
+ wantMemoryStoreContents: map[ipn.StateKey][]byte{
+ "foo": []byte("bar"),
+ "baz": []byte("qux"),
+ },
+ },
+ {
+ name: "load_select_certs_in_read_only_mode",
+ certMode: "ro",
+ stateSecretContents: map[string][]byte{
+ "foo": []byte("bar"),
+ },
+ TLSSecrets: []kubeapi.Secret{
+ makeSecret("app1.tailnetxyz.ts.net", certSecretsLabels, "1"),
+ makeSecret("app2.tailnetxyz.ts.net", certSecretsLabels, "2"),
+ makeSecret("some-other-secret", nil, "3"),
+ makeSecret("app3.other-proxies.ts.net", map[string]string{
+ "tailscale.com/secret-type": "certs",
+ "tailscale.com/managed": "true",
+ "tailscale.com/proxy-group": "some-other-proxygroup",
+ }, "4"),
+ },
+ wantMemoryStoreContents: map[ipn.StateKey][]byte{
+ "foo": []byte("bar"),
+ "app1.tailnetxyz.ts.net.crt": []byte(testCert + "1"),
+ "app1.tailnetxyz.ts.net.key": []byte(testKey + "1"),
+ "app2.tailnetxyz.ts.net.crt": []byte(testCert + "2"),
+ "app2.tailnetxyz.ts.net.key": []byte(testKey + "2"),
+ },
+ },
+ {
+ name: "load_select_certs_in_read_write_mode",
+ certMode: "rw",
+ stateSecretContents: map[string][]byte{
+ "foo": []byte("bar"),
+ },
+ TLSSecrets: []kubeapi.Secret{
+ makeSecret("app1.tailnetxyz.ts.net", certSecretsLabels, "1"),
+ makeSecret("app2.tailnetxyz.ts.net", certSecretsLabels, "2"),
+ makeSecret("some-other-secret", nil, "3"),
+ makeSecret("app3.other-proxies.ts.net", map[string]string{
+ "tailscale.com/secret-type": "certs",
+ "tailscale.com/managed": "true",
+ "tailscale.com/proxy-group": "some-other-proxygroup",
+ }, "4"),
+ },
+ wantMemoryStoreContents: map[ipn.StateKey][]byte{
+ "foo": []byte("bar"),
+ "app1.tailnetxyz.ts.net.crt": []byte(testCert + "1"),
+ "app1.tailnetxyz.ts.net.key": []byte(testKey + "1"),
+ "app2.tailnetxyz.ts.net.crt": []byte(testCert + "2"),
+ "app2.tailnetxyz.ts.net.key": []byte(testKey + "2"),
+ },
+ },
+ {
+ name: "list_cert_secrets_fails",
+ certMode: "ro",
+ stateSecretContents: map[string][]byte{
+ "foo": []byte("bar"),
+ },
+ secretsListErr: fmt.Errorf("list error"),
+ // The error is logged but not returned, and state is still loaded
+ wantMemoryStoreContents: map[ipn.StateKey][]byte{
+ "foo": []byte("bar"),
+ },
+ },
+ {
+ name: "cert_secrets_not_loaded_when_not_in_share_mode",
+ certMode: "",
+ stateSecretContents: map[string][]byte{
+ "foo": []byte("bar"),
+ },
+ TLSSecrets: []kubeapi.Secret{
+ makeSecret("app1.tailnetxyz.ts.net", certSecretsLabels, "1"),
+ },
+ wantMemoryStoreContents: map[ipn.StateKey][]byte{
+ "foo": []byte("bar"),
+ },
+ },
+ }
+
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ envknob.Setenv("TS_CERT_SHARE_MODE", tt.certMode)
+
+ t.Setenv("POD_NAME", "ingress-proxies-1")
+
+ client := &kubeclient.FakeClient{
+ GetSecretImpl: func(ctx context.Context, name string) (*kubeapi.Secret, error) {
+ if tt.secretGetErr != nil {
+ return nil, tt.secretGetErr
+ }
+ if name == secretName {
+ return &kubeapi.Secret{Data: tt.stateSecretContents}, nil
+ }
+ return nil, &kubeapi.Status{Code: 404}
+ },
+ CheckSecretPermissionsImpl: func(ctx context.Context, name string) (bool, bool, error) {
+ return true, true, nil
+ },
+ ListSecretsImpl: func(ctx context.Context, selector map[string]string) (*kubeapi.SecretList, error) {
+ if tt.secretsListErr != nil {
+ return nil, tt.secretsListErr
+ }
+ var matchingSecrets []kubeapi.Secret
+ for _, secret := range tt.TLSSecrets {
+ matches := true
+ for k, v := range selector {
+ if secret.Labels[k] != v {
+ matches = false
+ break
+ }
+ }
+ if matches {
+ matchingSecrets = append(matchingSecrets, secret)
+ }
+ }
+ return &kubeapi.SecretList{Items: matchingSecrets}, nil
+ },
+ }
+
+ s, err := newWithClient(t.Logf, client, secretName)
+ if tt.wantErr != nil {
+ if err == nil {
+ t.Errorf("NewWithClient() error = nil, want error containing %v", tt.wantErr)
+ return
+ }
+ if !strings.Contains(err.Error(), tt.wantErr.Error()) {
+ t.Errorf("NewWithClient() error = %v, want error containing %v", err, tt.wantErr)
+ }
+ return
+ }
+
+ if err != nil {
+ t.Errorf("NewWithClient() unexpected error: %v", err)
+ return
+ }
+
+ // Verify memory store contents
+ gotJSON, err := s.memory.ExportToJSON()
+ if err != nil {
+ t.Errorf("ExportToJSON failed: %v", err)
+ return
+ }
+ var got map[ipn.StateKey][]byte
+ if err := json.Unmarshal(gotJSON, &got); err != nil {
+ t.Errorf("failed to unmarshal memory store JSON: %v", err)
+ return
+ }
+ want := tt.wantMemoryStoreContents
+ if want == nil {
+ want = map[ipn.StateKey][]byte{}
+ }
+ if diff := cmp.Diff(got, want); diff != "" {
+ t.Errorf("memory store contents mismatch (-got +want):\n%s", diff)
+ }
+ })
+ }
+}
diff --git a/ipn/store/store_aws.go b/ipn/store/store_aws.go
index e164f9de7..d39e84319 100644
--- a/ipn/store/store_aws.go
+++ b/ipn/store/store_aws.go
@@ -6,7 +6,9 @@
package store
import (
+ "tailscale.com/ipn"
"tailscale.com/ipn/store/awsstore"
+ "tailscale.com/types/logger"
)
func init() {
@@ -14,5 +16,11 @@ func init() {
}
func registerAWSStore() {
- Register("arn:", awsstore.New)
+ Register("arn:", func(logf logger.Logf, arg string) (ipn.StateStore, error) {
+ ssmARN, opts, err := awsstore.ParseARNAndOpts(arg)
+ if err != nil {
+ return nil, err
+ }
+ return awsstore.New(logf, ssmARN, opts...)
+ })
}