diff options
| author | kari-ts <kari@tailscale.com> | 2025-03-19 11:28:04 -0700 |
|---|---|---|
| committer | kari-ts <kari@tailscale.com> | 2025-04-04 14:24:56 -0700 |
| commit | 6d5c7b11913e09b061e863411ad488dc44a13870 (patch) | |
| tree | 9e1789b5080ae4a92523611e49920dcb1102604b /ipn | |
| parent | ca50599c95e0a4cb7b4aab179e866e202f10c0c4 (diff) | |
| parent | 3a2c92f08eac8cd8f50356ff288e40a28636ee42 (diff) | |
| download | tailscale-kari/taildropsaf.tar.xz tailscale-kari/taildropsaf.zip | |
TO DO:kari/taildropsaf
-check if Context.getExternalFilesDirs works as is for private dir
Diffstat (limited to 'ipn')
| -rw-r--r-- | ipn/auditlog/auditlog.go | 466 | ||||
| -rw-r--r-- | ipn/auditlog/auditlog_test.go | 481 | ||||
| -rw-r--r-- | ipn/ipnauth/actor.go | 7 | ||||
| -rw-r--r-- | ipn/ipnauth/policy.go | 10 | ||||
| -rw-r--r-- | ipn/ipnlocal/cert.go | 103 | ||||
| -rw-r--r-- | ipn/ipnlocal/cert_test.go | 185 | ||||
| -rw-r--r-- | ipn/ipnlocal/local.go | 132 | ||||
| -rw-r--r-- | ipn/ipnlocal/local_test.go | 59 | ||||
| -rw-r--r-- | ipn/ipnstate/ipnstate.go | 8 | ||||
| -rw-r--r-- | ipn/store/awsstore/store_aws.go | 111 | ||||
| -rw-r--r-- | ipn/store/awsstore/store_aws_stub.go | 18 | ||||
| -rw-r--r-- | ipn/store/awsstore/store_aws_test.go | 61 | ||||
| -rw-r--r-- | ipn/store/kubestore/store_kube.go | 317 | ||||
| -rw-r--r-- | ipn/store/kubestore/store_kube_test.go | 723 | ||||
| -rw-r--r-- | ipn/store/store_aws.go | 10 |
15 files changed, 2545 insertions, 146 deletions
diff --git a/ipn/auditlog/auditlog.go b/ipn/auditlog/auditlog.go new file mode 100644 index 000000000..30f39211f --- /dev/null +++ b/ipn/auditlog/auditlog.go @@ -0,0 +1,466 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +// Package auditlog provides a mechanism for logging audit events. +package auditlog + +import ( + "context" + "encoding/json" + "errors" + "fmt" + "sort" + "sync" + "time" + + "tailscale.com/ipn" + "tailscale.com/tailcfg" + "tailscale.com/types/logger" + "tailscale.com/util/rands" + "tailscale.com/util/set" +) + +// transaction represents an audit log that has not yet been sent to the control plane. +type transaction struct { + // EventID is the unique identifier for the event being logged. + // This is used on the client side only and is not sent to control. + EventID string `json:",omitempty"` + // Retries is the number of times the logger has attempted to send this log. + // This is used on the client side only and is not sent to control. + Retries int `json:",omitempty"` + + // Action is the action to be logged. It must correspond to a known action in the control plane. + Action tailcfg.ClientAuditAction `json:",omitempty"` + // Details is an opaque string specific to the action being logged. Empty strings may not + // be valid depending on the action being logged. + Details string `json:",omitempty"` + // TimeStamp is the time at which the audit log was generated on the node. + TimeStamp time.Time `json:",omitzero"` +} + +// Transport provides a means for a client to send audit logs to a consumer (typically the control plane). +type Transport interface { + // SendAuditLog sends an audit log to a consumer of audit logs. + // Errors should be checked with [IsRetryableError] for retryability. + SendAuditLog(context.Context, tailcfg.AuditLogRequest) error +} + +// LogStore provides a means for a [Logger] to persist logs to disk or memory. +type LogStore interface { + // Save saves the given data to a persistent store. Save will overwrite existing data + // for the given key. + save(key ipn.ProfileID, txns []*transaction) error + + // Load retrieves the data from a persistent store. Returns a nil slice and + // no error if no data exists for the given key. + load(key ipn.ProfileID) ([]*transaction, error) +} + +// Opts contains the configuration options for a [Logger]. +type Opts struct { + // RetryLimit is the maximum number of attempts the logger will make to send a log before giving up. + RetryLimit int + // Store is the persistent store used to save logs to disk. Must be non-nil. + Store LogStore + // Logf is the logger used to log messages from the audit logger. Must be non-nil. + Logf logger.Logf +} + +// IsRetryableError returns true if the given error is retryable +// See [controlclient.apiResponseError]. Potentially retryable errors implement the Retryable() method. +func IsRetryableError(err error) bool { + var retryable interface{ Retryable() bool } + return errors.As(err, &retryable) && retryable.Retryable() +} + +type backoffOpts struct { + min, max time.Duration + multiplier float64 +} + +// .5, 1, 2, 4, 8, 10, 10, 10, 10, 10... +var defaultBackoffOpts = backoffOpts{ + min: time.Millisecond * 500, + max: 10 * time.Second, + multiplier: 2, +} + +// Logger provides a queue-based mechanism for submitting audit logs to the control plane - or +// another suitable consumer. Logs are stored to disk and retried until they are successfully sent, +// or until they permanently fail. +// +// Each individual profile/controlclient tuple should construct and manage a unique [Logger] instance. +type Logger struct { + logf logger.Logf + retryLimit int // the maximum number of attempts to send a log before giving up. + flusher chan struct{} // channel used to signal a flush operation. + done chan struct{} // closed when the flush worker exits. + ctx context.Context // canceled when the logger is stopped. + ctxCancel context.CancelFunc // cancels ctx. + backoffOpts // backoff settings for retry operations. + + // mu protects the fields below. + mu sync.Mutex + store LogStore // persistent storage for unsent logs. + profileID ipn.ProfileID // empty if [Logger.SetProfileID] has not been called. + transport Transport // nil until [Logger.Start] is called. +} + +// NewLogger creates a new [Logger] with the given options. +func NewLogger(opts Opts) *Logger { + ctx, cancel := context.WithCancel(context.Background()) + + al := &Logger{ + retryLimit: opts.RetryLimit, + logf: logger.WithPrefix(opts.Logf, "auditlog: "), + store: opts.Store, + flusher: make(chan struct{}, 1), + done: make(chan struct{}), + ctx: ctx, + ctxCancel: cancel, + backoffOpts: defaultBackoffOpts, + } + al.logf("created") + return al +} + +// FlushAndStop synchronously flushes all pending logs and stops the audit logger. +// This will block until a final flush operation completes or context is done. +// If the logger is already stopped, this will return immediately. All unsent +// logs will be persisted to the store. +func (al *Logger) FlushAndStop(ctx context.Context) { + al.stop() + al.flush(ctx) +} + +// SetProfileID sets the profileID for the logger. This must be called before any logs can be enqueued. +// The profileID of a logger cannot be changed once set. +func (al *Logger) SetProfileID(profileID ipn.ProfileID) error { + al.mu.Lock() + defer al.mu.Unlock() + if al.profileID != "" { + return errors.New("profileID already set") + } + + al.profileID = profileID + return nil +} + +// Start starts the audit logger with the given transport. +// It returns an error if the logger is already started. +func (al *Logger) Start(t Transport) error { + al.mu.Lock() + defer al.mu.Unlock() + + if al.transport != nil { + return errors.New("already started") + } + + al.transport = t + pending, err := al.storedCountLocked() + if err != nil { + al.logf("[unexpected] failed to restore logs: %v", err) + } + go al.flushWorker() + if pending > 0 { + al.flushAsync() + } + return nil +} + +// ErrAuditLogStorageFailure is returned when the logger fails to persist logs to the store. +var ErrAuditLogStorageFailure = errors.New("audit log storage failure") + +// Enqueue queues an audit log to be sent to the control plane (or another suitable consumer/transport). +// This will return an error if the underlying store fails to save the log or we fail to generate a unique +// eventID for the log. +func (al *Logger) Enqueue(action tailcfg.ClientAuditAction, details string) error { + txn := &transaction{ + Action: action, + Details: details, + TimeStamp: time.Now(), + } + // Generate a suitably random eventID for the transaction. + txn.EventID = fmt.Sprint(txn.TimeStamp, rands.HexString(16)) + return al.enqueue(txn) +} + +// flushAsync requests an asynchronous flush. +// It is a no-op if a flush is already pending. +func (al *Logger) flushAsync() { + select { + case al.flusher <- struct{}{}: + default: + } +} + +func (al *Logger) flushWorker() { + defer close(al.done) + + var retryDelay time.Duration + retry := time.NewTimer(0) + retry.Stop() + + for { + select { + case <-al.ctx.Done(): + return + case <-al.flusher: + err := al.flush(al.ctx) + switch { + case errors.Is(err, context.Canceled): + // The logger was stopped, no need to retry. + return + case err != nil: + retryDelay = max(al.backoffOpts.min, min(retryDelay*time.Duration(al.backoffOpts.multiplier), al.backoffOpts.max)) + al.logf("retrying after %v, %v", retryDelay, err) + retry.Reset(retryDelay) + default: + retryDelay = 0 + retry.Stop() + } + case <-retry.C: + al.flushAsync() + } + } +} + +// flush attempts to send all pending logs to the control plane. +// l.mu must not be held. +func (al *Logger) flush(ctx context.Context) error { + al.mu.Lock() + pending, err := al.store.load(al.profileID) + t := al.transport + al.mu.Unlock() + + if err != nil { + // This will catch nil profileIDs + return fmt.Errorf("failed to restore pending logs: %w", err) + } + if len(pending) == 0 { + return nil + } + if t == nil { + return errors.New("no transport") + } + + complete, unsent := al.sendToTransport(ctx, pending, t) + al.markTransactionsDone(complete) + + al.mu.Lock() + defer al.mu.Unlock() + if err = al.appendToStoreLocked(unsent); err != nil { + al.logf("[unexpected] failed to persist logs: %v", err) + } + + if len(unsent) != 0 { + return fmt.Errorf("failed to send %d logs", len(unsent)) + } + + if len(complete) != 0 { + al.logf("complete %d audit log transactions", len(complete)) + } + return nil +} + +// sendToTransport sends all pending logs to the control plane. Returns a pair of slices +// containing the logs that were successfully sent (or failed permanently) and those that were not. +// +// This may require multiple round trips to the control plane and can be a long running transaction. +func (al *Logger) sendToTransport(ctx context.Context, pending []*transaction, t Transport) (complete []*transaction, unsent []*transaction) { + for i, txn := range pending { + req := tailcfg.AuditLogRequest{ + Action: tailcfg.ClientAuditAction(txn.Action), + Details: txn.Details, + Timestamp: txn.TimeStamp, + } + + if err := t.SendAuditLog(ctx, req); err != nil { + switch { + case errors.Is(err, context.Canceled) || errors.Is(err, context.DeadlineExceeded): + // The contex is done. All further attempts will fail. + unsent = append(unsent, pending[i:]...) + return complete, unsent + case IsRetryableError(err) && txn.Retries+1 < al.retryLimit: + // We permit a maximum number of retries for each log. All retriable + // errors should be transient and we should be able to send the log eventually, but + // we don't want logs to be persisted indefinitely. + txn.Retries++ + unsent = append(unsent, txn) + default: + complete = append(complete, txn) + al.logf("failed permanently: %v", err) + } + } else { + // No error - we're done. + complete = append(complete, txn) + } + } + + return complete, unsent +} + +func (al *Logger) stop() { + al.mu.Lock() + t := al.transport + al.mu.Unlock() + + if t == nil { + // No transport means no worker goroutine and done will not be + // closed if we cancel the context. + return + } + + al.ctxCancel() + <-al.done + al.logf("stopped for profileID: %v", al.profileID) +} + +// appendToStoreLocked persists logs to the store. This will deduplicate +// logs so it is safe to call this with the same logs multiple time, to +// requeue failed transactions for example. +// +// l.mu must be held. +func (al *Logger) appendToStoreLocked(txns []*transaction) error { + if len(txns) == 0 { + return nil + } + + if al.profileID == "" { + return errors.New("no logId set") + } + + persisted, err := al.store.load(al.profileID) + if err != nil { + al.logf("[unexpected] append failed to restore logs: %v", err) + } + + // The order is important here. We want the latest transactions first, which will + // ensure when we dedup, the new transactions are seen and the older transactions + // are discarded. + txnsOut := append(txns, persisted...) + txnsOut = deduplicateAndSort(txnsOut) + + return al.store.save(al.profileID, txnsOut) +} + +// storedCountLocked returns the number of logs persisted to the store. +// al.mu must be held. +func (al *Logger) storedCountLocked() (int, error) { + persisted, err := al.store.load(al.profileID) + return len(persisted), err +} + +// markTransactionsDone removes logs from the store that are complete (sent or failed permanently). +// al.mu must not be held. +func (al *Logger) markTransactionsDone(sent []*transaction) { + al.mu.Lock() + defer al.mu.Unlock() + + ids := set.Set[string]{} + for _, txn := range sent { + ids.Add(txn.EventID) + } + + persisted, err := al.store.load(al.profileID) + if err != nil { + al.logf("[unexpected] markTransactionsDone failed to restore logs: %v", err) + } + var unsent []*transaction + for _, txn := range persisted { + if !ids.Contains(txn.EventID) { + unsent = append(unsent, txn) + } + } + al.store.save(al.profileID, unsent) +} + +// deduplicateAndSort removes duplicate logs from the given slice and sorts them by timestamp. +// The first log entry in the slice will be retained, subsequent logs with the same EventID will be discarded. +func deduplicateAndSort(txns []*transaction) []*transaction { + seen := set.Set[string]{} + deduped := make([]*transaction, 0, len(txns)) + for _, txn := range txns { + if !seen.Contains(txn.EventID) { + deduped = append(deduped, txn) + seen.Add(txn.EventID) + } + } + // Sort logs by timestamp - oldest to newest. This will put the oldest logs at + // the front of the queue. + sort.Slice(deduped, func(i, j int) bool { + return deduped[i].TimeStamp.Before(deduped[j].TimeStamp) + }) + return deduped +} + +func (al *Logger) enqueue(txn *transaction) error { + al.mu.Lock() + defer al.mu.Unlock() + + if err := al.appendToStoreLocked([]*transaction{txn}); err != nil { + return fmt.Errorf("%w: %w", ErrAuditLogStorageFailure, err) + } + + // If a.transport is nil if the logger is stopped. + if al.transport != nil { + al.flushAsync() + } + + return nil +} + +var _ LogStore = (*logStateStore)(nil) + +// logStateStore is a concrete implementation of [LogStore] +// using [ipn.StateStore] as the underlying storage. +type logStateStore struct { + store ipn.StateStore +} + +// NewLogStore creates a new LogStateStore with the given [ipn.StateStore]. +func NewLogStore(store ipn.StateStore) LogStore { + return &logStateStore{ + store: store, + } +} + +func (s *logStateStore) generateKey(key ipn.ProfileID) string { + return "auditlog-" + string(key) +} + +// Save saves the given logs to an [ipn.StateStore]. This overwrites +// any existing entries for the given key. +func (s *logStateStore) save(key ipn.ProfileID, txns []*transaction) error { + if key == "" { + return errors.New("empty key") + } + + data, err := json.Marshal(txns) + if err != nil { + return err + } + k := ipn.StateKey(s.generateKey(key)) + return s.store.WriteState(k, data) +} + +// Load retrieves the logs from an [ipn.StateStore]. +func (s *logStateStore) load(key ipn.ProfileID) ([]*transaction, error) { + if key == "" { + return nil, errors.New("empty key") + } + + k := ipn.StateKey(s.generateKey(key)) + data, err := s.store.ReadState(k) + + switch { + case errors.Is(err, ipn.ErrStateNotExist): + return nil, nil + case err != nil: + return nil, err + } + + var txns []*transaction + err = json.Unmarshal(data, &txns) + return txns, err +} diff --git a/ipn/auditlog/auditlog_test.go b/ipn/auditlog/auditlog_test.go new file mode 100644 index 000000000..3d3bf95cb --- /dev/null +++ b/ipn/auditlog/auditlog_test.go @@ -0,0 +1,481 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +package auditlog + +import ( + "context" + "errors" + "fmt" + "sync" + "testing" + "time" + + qt "github.com/frankban/quicktest" + "tailscale.com/ipn/store/mem" + "tailscale.com/tailcfg" + "tailscale.com/tstest" +) + +// loggerForTest creates an auditLogger for you and cleans it up +// (and ensures no goroutines are leaked) when the test is done. +func loggerForTest(t *testing.T, opts Opts) *Logger { + t.Helper() + tstest.ResourceCheck(t) + + if opts.Logf == nil { + opts.Logf = t.Logf + } + + if opts.Store == nil { + t.Fatalf("opts.Store must be set") + } + + a := NewLogger(opts) + + t.Cleanup(func() { + a.FlushAndStop(context.Background()) + }) + return a +} + +func TestNonRetryableErrors(t *testing.T) { + errorTests := []struct { + desc string + err error + want bool + }{ + {"DeadlineExceeded", context.DeadlineExceeded, false}, + {"Canceled", context.Canceled, false}, + {"Canceled wrapped", fmt.Errorf("%w: %w", context.Canceled, errors.New("ctx cancelled")), false}, + {"Random error", errors.New("random error"), false}, + } + + for _, tt := range errorTests { + t.Run(tt.desc, func(t *testing.T) { + if IsRetryableError(tt.err) != tt.want { + t.Fatalf("retriable: got %v, want %v", !tt.want, tt.want) + } + }) + } +} + +// TestEnqueueAndFlush enqueues n logs and flushes them. +// We expect all logs to be flushed and for no +// logs to remain in the store once FlushAndStop returns. +func TestEnqueueAndFlush(t *testing.T) { + c := qt.New(t) + mockTransport := newMockTransport(nil) + al := loggerForTest(t, Opts{ + RetryLimit: 200, + Logf: t.Logf, + Store: NewLogStore(&mem.Store{}), + }) + + c.Assert(al.SetProfileID("test"), qt.IsNil) + c.Assert(al.Start(mockTransport), qt.IsNil) + + wantSent := 10 + + for i := range wantSent { + err := al.Enqueue(tailcfg.AuditNodeDisconnect, fmt.Sprintf("log %d", i)) + c.Assert(err, qt.IsNil) + } + + al.FlushAndStop(context.Background()) + + al.mu.Lock() + defer al.mu.Unlock() + gotStored, err := al.storedCountLocked() + c.Assert(err, qt.IsNil) + + if wantStored := 0; gotStored != wantStored { + t.Fatalf("stored: got %d, want %d", gotStored, wantStored) + } + + if gotSent := mockTransport.sentCount(); gotSent != wantSent { + t.Fatalf("sent: got %d, want %d", gotSent, wantSent) + } +} + +// TestEnqueueAndFlushWithFlushCancel calls FlushAndCancel with a cancelled +// context. We expect nothing to be sent and all logs to be stored. +func TestEnqueueAndFlushWithFlushCancel(t *testing.T) { + c := qt.New(t) + mockTransport := newMockTransport(&retriableError) + al := loggerForTest(t, Opts{ + RetryLimit: 200, + Logf: t.Logf, + Store: NewLogStore(&mem.Store{}), + }) + + c.Assert(al.SetProfileID("test"), qt.IsNil) + c.Assert(al.Start(mockTransport), qt.IsNil) + + for i := range 10 { + err := al.Enqueue(tailcfg.AuditNodeDisconnect, fmt.Sprintf("log %d", i)) + c.Assert(err, qt.IsNil) + } + + // Cancel the context before calling FlushAndStop - nothing should get sent. + // This mimics a timeout before flush() has a chance to execute. + ctx, cancel := context.WithCancel(context.Background()) + cancel() + + al.FlushAndStop(ctx) + + al.mu.Lock() + defer al.mu.Unlock() + gotStored, err := al.storedCountLocked() + c.Assert(err, qt.IsNil) + + if wantStored := 10; gotStored != wantStored { + t.Fatalf("stored: got %d, want %d", gotStored, wantStored) + } + + if gotSent, wantSent := mockTransport.sentCount(), 0; gotSent != wantSent { + t.Fatalf("sent: got %d, want %d", gotSent, wantSent) + } +} + +// TestDeduplicateAndSort tests that the most recent log is kept when deduplicating logs +func TestDeduplicateAndSort(t *testing.T) { + c := qt.New(t) + al := loggerForTest(t, Opts{ + RetryLimit: 100, + Logf: t.Logf, + Store: NewLogStore(&mem.Store{}), + }) + + c.Assert(al.SetProfileID("test"), qt.IsNil) + + logs := []*transaction{ + {EventID: "1", Details: "log 1", TimeStamp: time.Now().Add(-time.Minute * 1), Retries: 1}, + } + + al.mu.Lock() + defer al.mu.Unlock() + al.appendToStoreLocked(logs) + + // Update the transaction and re-append it + logs[0].Retries = 2 + al.appendToStoreLocked(logs) + + fromStore, err := al.store.load("test") + c.Assert(err, qt.IsNil) + + // We should see only one transaction + if wantStored, gotStored := len(logs), len(fromStore); gotStored != wantStored { + t.Fatalf("stored: got %d, want %d", gotStored, wantStored) + } + + // We should see the latest transaction + if wantRetryCount, gotRetryCount := 2, fromStore[0].Retries; gotRetryCount != wantRetryCount { + t.Fatalf("reties: got %d, want %d", gotRetryCount, wantRetryCount) + } +} + +func TestChangeProfileId(t *testing.T) { + c := qt.New(t) + al := loggerForTest(t, Opts{ + RetryLimit: 100, + Logf: t.Logf, + Store: NewLogStore(&mem.Store{}), + }) + c.Assert(al.SetProfileID("test"), qt.IsNil) + + // Changing a profile ID must fail + c.Assert(al.SetProfileID("test"), qt.IsNotNil) +} + +// TestSendOnRestore pushes a n logs to the persistent store, and ensures they +// are sent as soon as Start is called then checks to ensure the sent logs no +// longer exist in the store. +func TestSendOnRestore(t *testing.T) { + c := qt.New(t) + mockTransport := newMockTransport(nil) + al := loggerForTest(t, Opts{ + RetryLimit: 100, + Logf: t.Logf, + Store: NewLogStore(&mem.Store{}), + }) + al.SetProfileID("test") + + wantTotal := 10 + + for range 10 { + al.Enqueue(tailcfg.AuditNodeDisconnect, "log") + } + + c.Assert(al.Start(mockTransport), qt.IsNil) + + al.FlushAndStop(context.Background()) + + al.mu.Lock() + defer al.mu.Unlock() + gotStored, err := al.storedCountLocked() + c.Assert(err, qt.IsNil) + + if wantStored := 0; gotStored != wantStored { + t.Fatalf("stored: got %d, want %d", gotStored, wantStored) + } + + if gotSent, wantSent := mockTransport.sentCount(), wantTotal; gotSent != wantSent { + t.Fatalf("sent: got %d, want %d", gotSent, wantSent) + } +} + +// TestFailureExhaustion enqueues n logs, with the transport in a failable state. +// We then set it to a non-failing state, call FlushAndStop and expect all logs to be sent. +func TestFailureExhaustion(t *testing.T) { + c := qt.New(t) + mockTransport := newMockTransport(&retriableError) + + al := loggerForTest(t, Opts{ + RetryLimit: 1, + Logf: t.Logf, + Store: NewLogStore(&mem.Store{}), + }) + + c.Assert(al.SetProfileID("test"), qt.IsNil) + c.Assert(al.Start(mockTransport), qt.IsNil) + + for range 10 { + err := al.Enqueue(tailcfg.AuditNodeDisconnect, "log") + c.Assert(err, qt.IsNil) + } + + al.FlushAndStop(context.Background()) + al.mu.Lock() + defer al.mu.Unlock() + gotStored, err := al.storedCountLocked() + c.Assert(err, qt.IsNil) + + if wantStored := 0; gotStored != wantStored { + t.Fatalf("stored: got %d, want %d", gotStored, wantStored) + } + + if gotSent, wantSent := mockTransport.sentCount(), 0; gotSent != wantSent { + t.Fatalf("sent: got %d, want %d", gotSent, wantSent) + } +} + +// TestEnqueueAndFailNoRetry enqueues a set of logs, all of which will fail and are not +// retriable. We then call FlushAndStop and expect all to be unsent. +func TestEnqueueAndFailNoRetry(t *testing.T) { + c := qt.New(t) + mockTransport := newMockTransport(&nonRetriableError) + + al := loggerForTest(t, Opts{ + RetryLimit: 100, + Logf: t.Logf, + Store: NewLogStore(&mem.Store{}), + }) + + c.Assert(al.SetProfileID("test"), qt.IsNil) + c.Assert(al.Start(mockTransport), qt.IsNil) + + for i := range 10 { + err := al.Enqueue(tailcfg.AuditNodeDisconnect, fmt.Sprintf("log %d", i)) + c.Assert(err, qt.IsNil) + } + + al.FlushAndStop(context.Background()) + al.mu.Lock() + defer al.mu.Unlock() + gotStored, err := al.storedCountLocked() + c.Assert(err, qt.IsNil) + + if wantStored := 0; gotStored != wantStored { + t.Fatalf("stored: got %d, want %d", gotStored, wantStored) + } + + if gotSent, wantSent := mockTransport.sentCount(), 0; gotSent != wantSent { + t.Fatalf("sent: got %d, want %d", gotSent, wantSent) + } +} + +// TestEnqueueAndRetry enqueues a set of logs, all of which will fail and are retriable. +// Mid-test, we set the transport to not-fail and expect the queue to flush properly +// We set the backoff parameters to 0 seconds so retries are immediate. +func TestEnqueueAndRetry(t *testing.T) { + c := qt.New(t) + mockTransport := newMockTransport(&retriableError) + + al := loggerForTest(t, Opts{ + RetryLimit: 100, + Logf: t.Logf, + Store: NewLogStore(&mem.Store{}), + }) + + al.backoffOpts = backoffOpts{ + min: 1 * time.Millisecond, + max: 4 * time.Millisecond, + multiplier: 2.0, + } + + c.Assert(al.SetProfileID("test"), qt.IsNil) + c.Assert(al.Start(mockTransport), qt.IsNil) + + err := al.Enqueue(tailcfg.AuditNodeDisconnect, fmt.Sprintf("log 1")) + c.Assert(err, qt.IsNil) + + // This will wait for at least 2 retries + gotRetried, wantRetried := mockTransport.waitForSendAttemptsToReach(3), true + if gotRetried != wantRetried { + t.Fatalf("retried: got %v, want %v", gotRetried, wantRetried) + } + + mockTransport.setErrorCondition(nil) + + al.FlushAndStop(context.Background()) + al.mu.Lock() + defer al.mu.Unlock() + + gotStored, err := al.storedCountLocked() + c.Assert(err, qt.IsNil) + + if wantStored := 0; gotStored != wantStored { + t.Fatalf("stored: got %d, want %d", gotStored, wantStored) + } + + if gotSent, wantSent := mockTransport.sentCount(), 1; gotSent != wantSent { + t.Fatalf("sent: got %d, want %d", gotSent, wantSent) + } +} + +// TestEnqueueBeforeSetProfileID tests that logs enqueued before SetProfileId are not sent +func TestEnqueueBeforeSetProfileID(t *testing.T) { + c := qt.New(t) + al := loggerForTest(t, Opts{ + RetryLimit: 100, + Logf: t.Logf, + Store: NewLogStore(&mem.Store{}), + }) + + err := al.Enqueue(tailcfg.AuditNodeDisconnect, "log") + c.Assert(err, qt.IsNotNil) + al.FlushAndStop(context.Background()) + + al.mu.Lock() + defer al.mu.Unlock() + gotStored, err := al.storedCountLocked() + c.Assert(err, qt.IsNotNil) + + if wantStored := 0; gotStored != wantStored { + t.Fatalf("stored: got %d, want %d", gotStored, wantStored) + } +} + +// TestLogStoring tests that audit logs are persisted sorted by timestamp, oldest to newest +func TestLogSorting(t *testing.T) { + c := qt.New(t) + mockStore := NewLogStore(&mem.Store{}) + + logs := []*transaction{ + {EventID: "1", Details: "log 3", TimeStamp: time.Now().Add(-time.Minute * 1)}, + {EventID: "1", Details: "log 3", TimeStamp: time.Now().Add(-time.Minute * 2)}, + {EventID: "2", Details: "log 2", TimeStamp: time.Now().Add(-time.Minute * 3)}, + {EventID: "3", Details: "log 1", TimeStamp: time.Now().Add(-time.Minute * 4)}, + } + + wantLogs := []transaction{ + {Details: "log 1"}, + {Details: "log 2"}, + {Details: "log 3"}, + } + + mockStore.save("test", logs) + + gotLogs, err := mockStore.load("test") + c.Assert(err, qt.IsNil) + gotLogs = deduplicateAndSort(gotLogs) + + for i := range gotLogs { + if want, got := wantLogs[i].Details, gotLogs[i].Details; want != got { + t.Fatalf("Details: got %v, want %v", got, want) + } + } +} + +// mock implementations for testing + +// newMockTransport returns a mock transport for testing +// If err is no nil, SendAuditLog will return this error if the send is attempted +// before the context is cancelled. +func newMockTransport(err error) *mockAuditLogTransport { + return &mockAuditLogTransport{ + err: err, + attempts: make(chan int, 1), + } +} + +type mockAuditLogTransport struct { + attempts chan int // channel to notify of send attempts + + mu sync.Mutex + sendAttmpts int // number of attempts to send logs + sendCount int // number of logs sent by the transport + err error // error to return when sending logs +} + +// waitForSendAttemptsToReach blocks until the number of send attempts reaches n +// This should be use only in tests where the transport is expected to retry sending logs +func (t *mockAuditLogTransport) waitForSendAttemptsToReach(n int) bool { + for attempts := range t.attempts { + if attempts >= n { + return true + } + } + return false +} + +func (t *mockAuditLogTransport) setErrorCondition(err error) { + t.mu.Lock() + defer t.mu.Unlock() + t.err = err +} + +func (t *mockAuditLogTransport) sentCount() int { + t.mu.Lock() + defer t.mu.Unlock() + return t.sendCount +} + +func (t *mockAuditLogTransport) SendAuditLog(ctx context.Context, _ tailcfg.AuditLogRequest) (err error) { + t.mu.Lock() + t.sendAttmpts += 1 + defer func() { + a := t.sendAttmpts + t.mu.Unlock() + select { + case t.attempts <- a: + default: + } + }() + + select { + case <-ctx.Done(): + return ctx.Err() + default: + } + + if t.err != nil { + return t.err + } + t.sendCount += 1 + return nil +} + +var ( + retriableError = mockError{errors.New("retriable error")} + nonRetriableError = mockError{errors.New("permanent failure error")} +) + +type mockError struct { + error +} + +func (e mockError) Retryable() bool { + return e == retriableError +} diff --git a/ipn/ipnauth/actor.go b/ipn/ipnauth/actor.go index 8a0e77645..108bdd341 100644 --- a/ipn/ipnauth/actor.go +++ b/ipn/ipnauth/actor.go @@ -10,12 +10,11 @@ import ( "tailscale.com/client/tailscale/apitype" "tailscale.com/ipn" + "tailscale.com/tailcfg" ) // AuditLogFunc is any function that can be used to log audit actions performed by an [Actor]. -// -// TODO(nickkhyl,barnstar): define a named string type for the action (in tailcfg?) and use it here. -type AuditLogFunc func(action, details string) +type AuditLogFunc func(action tailcfg.ClientAuditAction, details string) error // Actor is any actor using the [ipnlocal.LocalBackend]. // @@ -45,7 +44,7 @@ type Actor interface { // // If the auditLogger is non-nil, it is used to write details about the action // to the audit log when required by the policy. - CheckProfileAccess(profile ipn.LoginProfileView, requestedAccess ProfileAccess, auditLogger AuditLogFunc) error + CheckProfileAccess(profile ipn.LoginProfileView, requestedAccess ProfileAccess, auditLogFn AuditLogFunc) error // IsLocalSystem reports whether the actor is the Windows' Local System account. // diff --git a/ipn/ipnauth/policy.go b/ipn/ipnauth/policy.go index f09be0fcb..aa4ec4100 100644 --- a/ipn/ipnauth/policy.go +++ b/ipn/ipnauth/policy.go @@ -9,6 +9,7 @@ import ( "tailscale.com/client/tailscale/apitype" "tailscale.com/ipn" + "tailscale.com/tailcfg" "tailscale.com/util/syspolicy" ) @@ -48,7 +49,7 @@ func (a actorWithPolicyChecks) CheckProfileAccess(profile ipn.LoginProfileView, // // TODO(nickkhyl): unexport it when we move [ipn.Actor] implementations from [ipnserver] // and corp to this package. -func CheckDisconnectPolicy(actor Actor, profile ipn.LoginProfileView, reason string, auditLogger AuditLogFunc) error { +func CheckDisconnectPolicy(actor Actor, profile ipn.LoginProfileView, reason string, auditFn AuditLogFunc) error { if alwaysOn, _ := syspolicy.GetBoolean(syspolicy.AlwaysOn, false); !alwaysOn { return nil } @@ -58,15 +59,16 @@ func CheckDisconnectPolicy(actor Actor, profile ipn.LoginProfileView, reason str if reason == "" { return errors.New("disconnect not allowed: reason required") } - if auditLogger != nil { + if auditFn != nil { var details string if username, _ := actor.Username(); username != "" { // best-effort; we don't have it on all platforms details = fmt.Sprintf("%q is being disconnected by %q: %v", profile.Name(), username, reason) } else { details = fmt.Sprintf("%q is being disconnected: %v", profile.Name(), reason) } - // TODO(nickkhyl,barnstar): use a const for DISCONNECT_NODE. - auditLogger("DISCONNECT_NODE", details) + if err := auditFn(tailcfg.AuditNodeDisconnect, details); err != nil { + return err + } } return nil } diff --git a/ipn/ipnlocal/cert.go b/ipn/ipnlocal/cert.go index cfa4fe1ba..111dc5a2d 100644 --- a/ipn/ipnlocal/cert.go +++ b/ipn/ipnlocal/cert.go @@ -119,6 +119,9 @@ func (b *LocalBackend) GetCertPEMWithValidity(ctx context.Context, domain string } if pair, err := getCertPEMCached(cs, domain, now); err == nil { + if envknob.IsCertShareReadOnlyMode() { + return pair, nil + } // If we got here, we have a valid unexpired cert. // Check whether we should start an async renewal. shouldRenew, err := b.shouldStartDomainRenewal(cs, domain, now, pair, minValidity) @@ -134,7 +137,7 @@ func (b *LocalBackend) GetCertPEMWithValidity(ctx context.Context, domain string if minValidity == 0 { logf("starting async renewal") // Start renewal in the background, return current valid cert. - go b.getCertPEM(context.Background(), cs, logf, traceACME, domain, now, minValidity) + b.goTracker.Go(func() { getCertPEM(context.Background(), b, cs, logf, traceACME, domain, now, minValidity) }) return pair, nil } // If the caller requested a specific validity duration, fall through @@ -142,7 +145,11 @@ func (b *LocalBackend) GetCertPEMWithValidity(ctx context.Context, domain string logf("starting sync renewal") } - pair, err := b.getCertPEM(ctx, cs, logf, traceACME, domain, now, minValidity) + if envknob.IsCertShareReadOnlyMode() { + return nil, fmt.Errorf("retrieving cached TLS certificate failed and cert store is configured in read-only mode, not attempting to issue a new certificate: %w", err) + } + + pair, err := getCertPEM(ctx, b, cs, logf, traceACME, domain, now, minValidity) if err != nil { logf("getCertPEM: %v", err) return nil, err @@ -250,15 +257,13 @@ type certStore interface { // for now. If they're expired, it returns errCertExpired. // If they don't exist, it returns ipn.ErrStateNotExist. Read(domain string, now time.Time) (*TLSCertKeyPair, error) - // WriteCert writes the cert for domain. - WriteCert(domain string, cert []byte) error - // WriteKey writes the key for domain. - WriteKey(domain string, key []byte) error // ACMEKey returns the value previously stored via WriteACMEKey. // It is a PEM encoded ECDSA key. ACMEKey() ([]byte, error) // WriteACMEKey stores the provided PEM encoded ECDSA key. WriteACMEKey([]byte) error + // WriteTLSCertAndKey writes the cert and key for domain. + WriteTLSCertAndKey(domain string, cert, key []byte) error } var errCertExpired = errors.New("cert expired") @@ -344,6 +349,13 @@ func (f certFileStore) WriteKey(domain string, key []byte) error { return atomicfile.WriteFile(keyFile(f.dir, domain), key, 0600) } +func (f certFileStore) WriteTLSCertAndKey(domain string, cert, key []byte) error { + if err := f.WriteKey(domain, key); err != nil { + return err + } + return f.WriteCert(domain, cert) +} + // certStateStore implements certStore by storing the cert & key files in an ipn.StateStore. type certStateStore struct { ipn.StateStore @@ -353,7 +365,29 @@ type certStateStore struct { testRoots *x509.CertPool } +// TLSCertKeyReader is an interface implemented by state stores where it makes +// sense to read the TLS cert and key in a single operation that can be +// distinguished from generic state value reads. Currently this is only implemented +// by the kubestore.Store, which, in some cases, need to read cert and key from a +// non-cached TLS Secret. +type TLSCertKeyReader interface { + ReadTLSCertAndKey(domain string) ([]byte, []byte, error) +} + func (s certStateStore) Read(domain string, now time.Time) (*TLSCertKeyPair, error) { + // If we're using a store that supports atomic reads, use that + if kr, ok := s.StateStore.(TLSCertKeyReader); ok { + cert, key, err := kr.ReadTLSCertAndKey(domain) + if err != nil { + return nil, err + } + if !validCertPEM(domain, key, cert, s.testRoots, now) { + return nil, errCertExpired + } + return &TLSCertKeyPair{CertPEM: cert, KeyPEM: key, Cached: true}, nil + } + + // Otherwise fall back to separate reads certPEM, err := s.ReadState(ipn.StateKey(domain + ".crt")) if err != nil { return nil, err @@ -384,6 +418,27 @@ func (s certStateStore) WriteACMEKey(key []byte) error { return ipn.WriteState(s.StateStore, ipn.StateKey(acmePEMName), key) } +// TLSCertKeyWriter is an interface implemented by state stores that can write the TLS +// cert and key in a single atomic operation. Currently this is only implemented +// by the kubestore.StoreKube. +type TLSCertKeyWriter interface { + WriteTLSCertAndKey(domain string, cert, key []byte) error +} + +// WriteTLSCertAndKey writes the TLS cert and key for domain to the current +// LocalBackend's StateStore. +func (s certStateStore) WriteTLSCertAndKey(domain string, cert, key []byte) error { + // If we're using a store that supports atomic writes, use that. + if aw, ok := s.StateStore.(TLSCertKeyWriter); ok { + return aw.WriteTLSCertAndKey(domain, cert, key) + } + // Otherwise fall back to separate writes for cert and key. + if err := s.WriteKey(domain, key); err != nil { + return err + } + return s.WriteCert(domain, cert) +} + // TLSCertKeyPair is a TLS public and private key, and whether they were obtained // from cache or freshly obtained. type TLSCertKeyPair struct { @@ -420,7 +475,9 @@ func getCertPEMCached(cs certStore, domain string, now time.Time) (p *TLSCertKey return cs.Read(domain, now) } -func (b *LocalBackend) getCertPEM(ctx context.Context, cs certStore, logf logger.Logf, traceACME func(any), domain string, now time.Time, minValidity time.Duration) (*TLSCertKeyPair, error) { +// getCertPem checks if a cert needs to be renewed and if so, renews it. +// It can be overridden in tests. +var getCertPEM = func(ctx context.Context, b *LocalBackend, cs certStore, logf logger.Logf, traceACME func(any), domain string, now time.Time, minValidity time.Duration) (*TLSCertKeyPair, error) { acmeMu.Lock() defer acmeMu.Unlock() @@ -445,6 +502,10 @@ func (b *LocalBackend) getCertPEM(ctx context.Context, cs certStore, logf logger return nil, err } + if !isDefaultDirectoryURL(ac.DirectoryURL) { + logf("acme: using Directory URL %q", ac.DirectoryURL) + } + a, err := ac.GetReg(ctx, "" /* pre-RFC param */) switch { case err == nil: @@ -546,9 +607,6 @@ func (b *LocalBackend) getCertPEM(ctx context.Context, cs certStore, logf logger if err := encodeECDSAKey(&privPEM, certPrivKey); err != nil { return nil, err } - if err := cs.WriteKey(domain, privPEM.Bytes()); err != nil { - return nil, err - } csr, err := certRequest(certPrivKey, domain, nil) if err != nil { @@ -570,7 +628,7 @@ func (b *LocalBackend) getCertPEM(ctx context.Context, cs certStore, logf logger return nil, err } } - if err := cs.WriteCert(domain, certPEM.Bytes()); err != nil { + if err := cs.WriteTLSCertAndKey(domain, certPEM.Bytes(), privPEM.Bytes()); err != nil { return nil, err } b.domainRenewed(domain) @@ -714,7 +772,28 @@ func validateLeaf(leaf *x509.Certificate, intermediates *x509.CertPool, domain s // binary's baked-in roots (LetsEncrypt). See tailscale/tailscale#14690. return validateLeaf(leaf, intermediates, domain, now, bakedroots.Get()) } - return err == nil + + if err == nil { + return true + } + + // When pointed at a non-prod ACME server, we don't expect to have the CA + // in our system or baked-in roots. Verify only throws UnknownAuthorityError + // after first checking the leaf cert's expiry, hostnames etc, so we know + // that the only reason for an error is to do with constructing a full chain. + // Allow this error so that cert caching still works in testing environments. + if errors.As(err, &x509.UnknownAuthorityError{}) { + acmeURL := envknob.String("TS_DEBUG_ACME_DIRECTORY_URL") + if !isDefaultDirectoryURL(acmeURL) { + return true + } + } + + return false +} + +func isDefaultDirectoryURL(u string) bool { + return u == "" || u == acme.LetsEncryptURL } // validLookingCertDomain reports whether name looks like a valid domain name that diff --git a/ipn/ipnlocal/cert_test.go b/ipn/ipnlocal/cert_test.go index 21741ca95..e2398f670 100644 --- a/ipn/ipnlocal/cert_test.go +++ b/ipn/ipnlocal/cert_test.go @@ -6,6 +6,7 @@ package ipnlocal import ( + "context" "crypto/ecdsa" "crypto/elliptic" "crypto/rand" @@ -14,11 +15,17 @@ import ( "embed" "encoding/pem" "math/big" + "os" + "path/filepath" "testing" "time" "github.com/google/go-cmp/cmp" + "tailscale.com/envknob" "tailscale.com/ipn/store/mem" + "tailscale.com/tstest" + "tailscale.com/types/logger" + "tailscale.com/util/must" ) func TestValidLookingCertDomain(t *testing.T) { @@ -47,10 +54,10 @@ var certTestFS embed.FS func TestCertStoreRoundTrip(t *testing.T) { const testDomain = "example.com" - // Use a fixed verification timestamp so validity doesn't fall off when the - // cert expires. If you update the test data below, this may also need to be - // updated. + // Use fixed verification timestamps so validity doesn't change over time. + // If you update the test data below, these may also need to be updated. testNow := time.Date(2023, time.February, 10, 0, 0, 0, 0, time.UTC) + testExpired := time.Date(2026, time.February, 10, 0, 0, 0, 0, time.UTC) // To re-generate a root certificate and domain certificate for testing, // use: @@ -78,21 +85,23 @@ func TestCertStoreRoundTrip(t *testing.T) { } tests := []struct { - name string - store certStore + name string + store certStore + debugACMEURL bool }{ - {"FileStore", certFileStore{dir: t.TempDir(), testRoots: roots}}, - {"StateStore", certStateStore{StateStore: new(mem.Store), testRoots: roots}}, + {"FileStore", certFileStore{dir: t.TempDir(), testRoots: roots}, false}, + {"FileStore_UnknownCA", certFileStore{dir: t.TempDir()}, true}, + {"StateStore", certStateStore{StateStore: new(mem.Store), testRoots: roots}, false}, + {"StateStore_UnknownCA", certStateStore{StateStore: new(mem.Store)}, true}, } for _, test := range tests { t.Run(test.name, func(t *testing.T) { - if err := test.store.WriteCert(testDomain, testCert); err != nil { - t.Fatalf("WriteCert: unexpected error: %v", err) + if test.debugACMEURL { + t.Setenv("TS_DEBUG_ACME_DIRECTORY_URL", "https://acme-staging-v02.api.letsencrypt.org/directory") } - if err := test.store.WriteKey(testDomain, testKey); err != nil { - t.Fatalf("WriteKey: unexpected error: %v", err) + if err := test.store.WriteTLSCertAndKey(testDomain, testCert, testKey); err != nil { + t.Fatalf("WriteTLSCertAndKey: unexpected error: %v", err) } - kp, err := test.store.Read(testDomain, testNow) if err != nil { t.Fatalf("Read: unexpected error: %v", err) @@ -103,6 +112,10 @@ func TestCertStoreRoundTrip(t *testing.T) { if diff := cmp.Diff(kp.KeyPEM, testKey); diff != "" { t.Errorf("Key (-got, +want):\n%s", diff) } + unexpected, err := test.store.Read(testDomain, testExpired) + if err != errCertExpired { + t.Fatalf("Read: expected expiry error: %v", string(unexpected.CertPEM)) + } }) } } @@ -215,3 +228,151 @@ func TestDebugACMEDirectoryURL(t *testing.T) { }) } } + +func TestGetCertPEMWithValidity(t *testing.T) { + const testDomain = "example.com" + b := &LocalBackend{ + store: &mem.Store{}, + varRoot: t.TempDir(), + ctx: context.Background(), + logf: t.Logf, + } + certDir, err := b.certDir() + if err != nil { + t.Fatalf("certDir error: %v", err) + } + if _, err := b.getCertStore(); err != nil { + t.Fatalf("getCertStore error: %v", err) + } + testRoot, err := certTestFS.ReadFile("testdata/rootCA.pem") + if err != nil { + t.Fatal(err) + } + roots := x509.NewCertPool() + if !roots.AppendCertsFromPEM(testRoot) { + t.Fatal("Unable to add test CA to the cert pool") + } + testX509Roots = roots + defer func() { testX509Roots = nil }() + tests := []struct { + name string + now time.Time + // storeCerts is true if the test cert and key should be written to store. + storeCerts bool + readOnlyMode bool // TS_READ_ONLY_CERTS env var + wantAsyncRenewal bool // async issuance should be started + wantIssuance bool // sync issuance should be started + wantErr bool + }{ + { + name: "valid_no_renewal", + now: time.Date(2023, time.February, 20, 0, 0, 0, 0, time.UTC), + storeCerts: true, + wantAsyncRenewal: false, + wantIssuance: false, + wantErr: false, + }, + { + name: "issuance_needed", + now: time.Date(2023, time.February, 20, 0, 0, 0, 0, time.UTC), + storeCerts: false, + wantAsyncRenewal: false, + wantIssuance: true, + wantErr: false, + }, + { + name: "renewal_needed", + now: time.Date(2025, time.May, 1, 0, 0, 0, 0, time.UTC), + storeCerts: true, + wantAsyncRenewal: true, + wantIssuance: false, + wantErr: false, + }, + { + name: "renewal_needed_read_only_mode", + now: time.Date(2025, time.May, 1, 0, 0, 0, 0, time.UTC), + storeCerts: true, + readOnlyMode: true, + wantAsyncRenewal: false, + wantIssuance: false, + wantErr: false, + }, + { + name: "no_certs_read_only_mode", + now: time.Date(2025, time.May, 1, 0, 0, 0, 0, time.UTC), + storeCerts: false, + readOnlyMode: true, + wantAsyncRenewal: false, + wantIssuance: false, + wantErr: true, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + + if tt.readOnlyMode { + envknob.Setenv("TS_CERT_SHARE_MODE", "ro") + } + + os.RemoveAll(certDir) + if tt.storeCerts { + os.MkdirAll(certDir, 0755) + if err := os.WriteFile(filepath.Join(certDir, "example.com.crt"), + must.Get(os.ReadFile("testdata/example.com.pem")), 0644); err != nil { + t.Fatal(err) + } + if err := os.WriteFile(filepath.Join(certDir, "example.com.key"), + must.Get(os.ReadFile("testdata/example.com-key.pem")), 0644); err != nil { + t.Fatal(err) + } + } + + b.clock = tstest.NewClock(tstest.ClockOpts{Start: tt.now}) + + allDone := make(chan bool, 1) + defer b.goTracker.AddDoneCallback(func() { + b.mu.Lock() + defer b.mu.Unlock() + if b.goTracker.RunningGoroutines() > 0 { + return + } + select { + case allDone <- true: + default: + } + })() + + // Set to true if get getCertPEM is called. GetCertPEM can be called in a goroutine for async + // renewal or in the main goroutine if issuance is required to obtain valid TLS credentials. + getCertPemWasCalled := false + getCertPEM = func(ctx context.Context, b *LocalBackend, cs certStore, logf logger.Logf, traceACME func(any), domain string, now time.Time, minValidity time.Duration) (*TLSCertKeyPair, error) { + getCertPemWasCalled = true + return nil, nil + } + prevGoRoutines := b.goTracker.StartedGoroutines() + _, err = b.GetCertPEMWithValidity(context.Background(), testDomain, 0) + if (err != nil) != tt.wantErr { + t.Errorf("b.GetCertPemWithValidity got err %v, wants error: '%v'", err, tt.wantErr) + } + // GetCertPEMWithValidity calls getCertPEM in a goroutine if async renewal is needed. That's the + // only goroutine it starts, so this can be used to test if async renewal was started. + gotAsyncRenewal := b.goTracker.StartedGoroutines()-prevGoRoutines != 0 + if gotAsyncRenewal { + select { + case <-time.After(5 * time.Second): + t.Fatal("timed out waiting for goroutines to finish") + case <-allDone: + } + } + // Verify that async renewal was triggered if expected. + if tt.wantAsyncRenewal != gotAsyncRenewal { + t.Fatalf("wants getCertPem to be called async: %v, got called %v", tt.wantAsyncRenewal, gotAsyncRenewal) + } + // Verify that (non-async) issuance was started if expected. + gotIssuance := getCertPemWasCalled && !gotAsyncRenewal + if tt.wantIssuance != gotIssuance { + t.Errorf("wants getCertPem to be called: %v, got called %v", tt.wantIssuance, gotIssuance) + } + }) + } +} diff --git a/ipn/ipnlocal/local.go b/ipn/ipnlocal/local.go index 5c367c876..cb7d06407 100644 --- a/ipn/ipnlocal/local.go +++ b/ipn/ipnlocal/local.go @@ -57,10 +57,12 @@ import ( "tailscale.com/health/healthmsg" "tailscale.com/hostinfo" "tailscale.com/ipn" + "tailscale.com/ipn/auditlog" "tailscale.com/ipn/conffile" "tailscale.com/ipn/ipnauth" "tailscale.com/ipn/ipnstate" "tailscale.com/ipn/policy" + memstore "tailscale.com/ipn/store/mem" "tailscale.com/log/sockstatlog" "tailscale.com/logpolicy" "tailscale.com/net/captivedetection" @@ -406,8 +408,8 @@ type LocalBackend struct { // outgoingFiles keeps track of Taildrop outgoing files keyed to their OutgoingFile.ID outgoingFiles map[string]*ipn.OutgoingFile - // getSafFd gets the Storage Access Framework file descriptor for writing Taildrop files to - GetSafFd func(filename string) int32 + // FileOps abstracts platform-specific file operations needed for file transfers. + FileOps taildrop.FileOps // lastSuggestedExitNode stores the last suggested exit node suggestion to // avoid unnecessary churn between multiple equally-good options. @@ -453,6 +455,12 @@ type LocalBackend struct { // Each callback is called exactly once in unspecified order and without b.mu held. // Returned errors are logged but otherwise ignored and do not affect the shutdown process. shutdownCbs set.HandleSet[func() error] + + // auditLogger, if non-nil, manages audit logging for the backend. + // + // It queues, persists, and sends audit logs + // to the control client. auditLogger has the same lifespan as b.cc. + auditLogger *auditlog.Logger } // HealthTracker returns the health tracker for the backend. @@ -621,19 +629,6 @@ func NewLocalBackend(logf logger.Logf, logID logid.PublicID, sys *tsd.System, lo } } - // initialize Taildrive shares from saved state - fs, ok := b.sys.DriveForRemote.GetOK() - if ok { - currentShares := b.pm.prefs.DriveShares() - if currentShares.Len() > 0 { - var shares []*drive.Share - for _, share := range currentShares.All() { - shares = append(shares, share.AsStruct()) - } - fs.SetShares(shares) - } - } - for name, newFn := range registeredExtensions { ext, err := newFn(logf, sys) if err != nil { @@ -813,6 +808,13 @@ func (b *LocalBackend) SetDirectFileRoot(dir string) { b.directFileRoot = dir } +// SetFileOps sets the +func (b *LocalBackend) SetFileOps(fileOps taildrop.FileOps) { + b.mu.Lock() + defer b.mu.Unlock() + b.FileOps = fileOps +} + // ReloadConfig reloads the backend's config from disk. // // It returns (false, nil) if not running in declarative mode, (true, nil) on @@ -1695,6 +1697,15 @@ func (b *LocalBackend) SetControlClientStatus(c controlclient.Client, st control b.logf("Failed to save new controlclient state: %v", err) } } + + // Update the audit logger with the current profile ID. + if b.auditLogger != nil && prefsChanged { + pid := b.pm.CurrentProfile().ID() + if err := b.auditLogger.SetProfileID(pid); err != nil { + b.logf("Failed to set profile ID in audit logger: %v", err) + } + } + // initTKALocked is dependent on CurrentProfile.ID, which is initialized // (for new profiles) on the first call to b.pm.SetPrefs. if err := b.initTKALocked(); err != nil { @@ -2402,6 +2413,27 @@ func (b *LocalBackend) Start(opts ipn.Options) error { debugFlags = append([]string{"netstack"}, debugFlags...) } + var auditLogShutdown func() + // Audit logging is only available if the client has set up a proper persistent + // store for the logs in sys. + store, ok := b.sys.AuditLogStore.GetOK() + if !ok { + b.logf("auditlog: [unexpected] no persistent audit log storage configured. using memory store.") + store = auditlog.NewLogStore(&memstore.Store{}) + } + + al := auditlog.NewLogger(auditlog.Opts{ + Logf: b.logf, + RetryLimit: 32, + Store: store, + }) + b.auditLogger = al + auditLogShutdown = func() { + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + al.FlushAndStop(ctx) + } + // TODO(apenwarr): The only way to change the ServerURL is to // re-run b.Start, because this is the only place we create a // new controlclient. EditPrefs allows you to overwrite ServerURL, @@ -2427,6 +2459,7 @@ func (b *LocalBackend) Start(opts ipn.Options) error { C2NHandler: http.HandlerFunc(b.handleC2N), DialPlan: &b.dialPlan, // pointer because it can't be copied ControlKnobs: b.sys.ControlKnobs(), + Shutdown: auditLogShutdown, // Don't warn about broken Linux IP forwarding when // netstack is being used. @@ -2461,6 +2494,16 @@ func (b *LocalBackend) Start(opts ipn.Options) error { b.logf("Backend: logs: be:%v fe:%v", blid, opts.FrontendLogID) b.sendToLocked(ipn.Notify{Prefs: &prefs}, allClients) + // initialize Taildrive shares from saved state + if fs, ok := b.sys.DriveForRemote.GetOK(); ok { + currentShares := b.pm.CurrentPrefs().DriveShares() + var shares []*drive.Share + for _, share := range currentShares.All() { + shares = append(shares, share.AsStruct()) + } + fs.SetShares(shares) + } + if !loggedOut && (b.hasNodeKeyLocked() || confWantRunning) { // If we know that we're either logged in or meant to be // running, tell the controlclient that it should also assume @@ -4269,6 +4312,21 @@ func (b *LocalBackend) MaybeClearAppConnector(mp *ipn.MaskedPrefs) error { return err } +var errNoAuditLogger = errors.New("no audit logger configured") + +func (b *LocalBackend) getAuditLoggerLocked() ipnauth.AuditLogFunc { + logger := b.auditLogger + return func(action tailcfg.ClientAuditAction, details string) error { + if logger == nil { + return errNoAuditLogger + } + if err := logger.Enqueue(action, details); err != nil { + return fmt.Errorf("failed to enqueue audit log %v %q: %w", action, details, err) + } + return nil + } +} + // EditPrefs applies the changes in mp to the current prefs, // acting as the tailscaled itself rather than a specific user. func (b *LocalBackend) EditPrefs(mp *ipn.MaskedPrefs) (ipn.PrefsView, error) { @@ -4294,9 +4352,8 @@ func (b *LocalBackend) EditPrefsAs(mp *ipn.MaskedPrefs, actor ipnauth.Actor) (ip unlock := b.lockAndGetUnlock() defer unlock() if mp.WantRunningSet && !mp.WantRunning && b.pm.CurrentPrefs().WantRunning() { - // TODO(barnstar,nickkhyl): replace loggerFn with the actual audit logger. - loggerFn := func(action, details string) { b.logf("[audit]: %s: %s", action, details) } - if err := actor.CheckProfileAccess(b.pm.CurrentProfile(), ipnauth.Disconnect, loggerFn); err != nil { + if err := actor.CheckProfileAccess(b.pm.CurrentProfile(), ipnauth.Disconnect, b.getAuditLoggerLocked()); err != nil { + b.logf("check profile access failed: %v", err) return ipn.PrefsView{}, err } @@ -5296,7 +5353,6 @@ func (b *LocalBackend) initPeerAPIListener() { if fileRoot == "" { b.logf("peerapi starting without Taildrop directory configured") } - ps := &peerAPIServer{ b: b, taildrop: taildrop.ManagerOptions{ @@ -5306,7 +5362,7 @@ func (b *LocalBackend) initPeerAPIListener() { Dir: fileRoot, DirectFileMode: b.directFileRoot != "", SendFileNotify: b.sendFileNotify, - }.New(b.getSafFd), + }.New(b.FileOps), } if dm, ok := b.sys.DNSManager.GetOK(); ok { ps.resolver = dm.Resolver() @@ -5880,6 +5936,15 @@ func (b *LocalBackend) requestEngineStatusAndWait() { func (b *LocalBackend) setControlClientLocked(cc controlclient.Client) { b.cc = cc b.ccAuto, _ = cc.(*controlclient.Auto) + if b.auditLogger != nil { + if err := b.auditLogger.SetProfileID(b.pm.CurrentProfile().ID()); err != nil { + b.logf("audit logger set profile ID failure: %v", err) + } + + if err := b.auditLogger.Start(b.ccAuto); err != nil { + b.logf("audit logger start failure: %v", err) + } + } } // resetControlClientLocked sets b.cc to nil and returns the old value. If the @@ -6712,7 +6777,7 @@ func (b *LocalBackend) FileTargets() ([]*apitype.FileTarget, error) { } func (b *LocalBackend) taildropTargetStatus(p tailcfg.NodeView) ipnstate.TaildropTargetStatus { - if b.netMap == nil || b.state != ipn.Running { + if b.state != ipn.Running { return ipnstate.TaildropTargetIpnStateNotRunning } if b.netMap == nil { @@ -8225,15 +8290,13 @@ func (b *LocalBackend) vipServiceHash(services []*tailcfg.VIPService) string { func (b *LocalBackend) vipServicesFromPrefsLocked(prefs ipn.PrefsView) []*tailcfg.VIPService { // keyed by service name var services map[tailcfg.ServiceName]*tailcfg.VIPService - if !b.serveConfig.Valid() { - return nil - } - - for svc, config := range b.serveConfig.Services().All() { - mak.Set(&services, svc, &tailcfg.VIPService{ - Name: svc, - Ports: config.ServicePortRange(), - }) + if b.serveConfig.Valid() { + for svc, config := range b.serveConfig.Services().All() { + mak.Set(&services, svc, &tailcfg.VIPService{ + Name: svc, + Ports: config.ServicePortRange(), + }) + } } for _, s := range prefs.AdvertiseServices().All() { @@ -8246,7 +8309,14 @@ func (b *LocalBackend) vipServicesFromPrefsLocked(prefs ipn.PrefsView) []*tailcf services[sn].Active = true } - return slicesx.MapValues(services) + servicesList := slicesx.MapValues(services) + // [slicesx.MapValues] provides the values in an indeterminate order, but since we'll + // be hashing a representation of this list later we want it to be in a consistent + // order. + slices.SortFunc(servicesList, func(a, b *tailcfg.VIPService) int { + return strings.Compare(a.Name.String(), b.Name.String()) + }) + return servicesList } var ( diff --git a/ipn/ipnlocal/local_test.go b/ipn/ipnlocal/local_test.go index 35977e679..aa9137275 100644 --- a/ipn/ipnlocal/local_test.go +++ b/ipn/ipnlocal/local_test.go @@ -44,6 +44,7 @@ import ( "tailscale.com/tsd" "tailscale.com/tstest" "tailscale.com/types/dnstype" + "tailscale.com/types/ipproto" "tailscale.com/types/key" "tailscale.com/types/logger" "tailscale.com/types/logid" @@ -60,6 +61,7 @@ import ( "tailscale.com/util/syspolicy/source" "tailscale.com/wgengine" "tailscale.com/wgengine/filter" + "tailscale.com/wgengine/filter/filtertype" "tailscale.com/wgengine/wgcfg" ) @@ -5206,3 +5208,60 @@ func TestUpdateIngressLocked(t *testing.T) { }) } } + +// TestSrcCapPacketFilter tests that LocalBackend handles packet filters with +// SrcCaps instead of Srcs (IPs) +func TestSrcCapPacketFilter(t *testing.T) { + lb := newLocalBackendWithTestControl(t, false, func(tb testing.TB, opts controlclient.Options) controlclient.Client { + return newClient(tb, opts) + }) + if err := lb.Start(ipn.Options{}); err != nil { + t.Fatalf("(*LocalBackend).Start(): %v", err) + } + + var k key.NodePublic + must.Do(k.UnmarshalText([]byte("nodekey:5c8f86d5fc70d924e55f02446165a5dae8f822994ad26bcf4b08fd841f9bf261"))) + + controlClient := lb.cc.(*mockControl) + controlClient.send(nil, "", false, &netmap.NetworkMap{ + SelfNode: (&tailcfg.Node{ + Addresses: []netip.Prefix{netip.MustParsePrefix("1.1.1.1/32")}, + }).View(), + Peers: []tailcfg.NodeView{ + (&tailcfg.Node{ + Addresses: []netip.Prefix{netip.MustParsePrefix("2.2.2.2/32")}, + ID: 2, + Key: k, + CapMap: tailcfg.NodeCapMap{"cap-X": nil}, // node 2 has cap + }).View(), + (&tailcfg.Node{ + Addresses: []netip.Prefix{netip.MustParsePrefix("3.3.3.3/32")}, + ID: 3, + Key: k, + CapMap: tailcfg.NodeCapMap{}, // node 3 does not have the cap + }).View(), + }, + PacketFilter: []filtertype.Match{{ + IPProto: views.SliceOf([]ipproto.Proto{ipproto.TCP}), + SrcCaps: []tailcfg.NodeCapability{"cap-X"}, // cap in packet filter rule + Dsts: []filtertype.NetPortRange{{ + Net: netip.MustParsePrefix("1.1.1.1/32"), + Ports: filtertype.PortRange{ + First: 22, + Last: 22, + }, + }}, + }}, + }) + + f := lb.GetFilterForTest() + res := f.Check(netip.MustParseAddr("2.2.2.2"), netip.MustParseAddr("1.1.1.1"), 22, ipproto.TCP) + if res != filter.Accept { + t.Errorf("Check(2.2.2.2, ...) = %s, want %s", res, filter.Accept) + } + + res = f.Check(netip.MustParseAddr("3.3.3.3"), netip.MustParseAddr("1.1.1.1"), 22, ipproto.TCP) + if !res.IsDrop() { + t.Error("IsDrop() for node without cap = false, want true") + } +} diff --git a/ipn/ipnstate/ipnstate.go b/ipn/ipnstate/ipnstate.go index bc1ba615d..89c6d7e24 100644 --- a/ipn/ipnstate/ipnstate.go +++ b/ipn/ipnstate/ipnstate.go @@ -216,6 +216,11 @@ type PeerStatusLite struct { } // PeerStatus describes a peer node and its current state. +// WARNING: The fields in PeerStatus are merged by the AddPeer method in the StatusBuilder. +// When adding a new field to PeerStatus, you must update AddPeer to handle merging +// the new field. The AddPeer function is responsible for combining multiple updates +// to the same peer, and any new field that is not merged properly may lead to +// inconsistencies or lost data in the peer status. type PeerStatus struct { ID tailcfg.StableNodeID PublicKey key.NodePublic @@ -533,6 +538,9 @@ func (sb *StatusBuilder) AddPeer(peer key.NodePublic, st *PeerStatus) { if v := st.Capabilities; v != nil { e.Capabilities = v } + if v := st.TaildropTarget; v != TaildropTargetUnknown { + e.TaildropTarget = v + } e.Location = st.Location } diff --git a/ipn/store/awsstore/store_aws.go b/ipn/store/awsstore/store_aws.go index 0fb78d45a..40bbbf037 100644 --- a/ipn/store/awsstore/store_aws.go +++ b/ipn/store/awsstore/store_aws.go @@ -10,7 +10,9 @@ import ( "context" "errors" "fmt" + "net/url" "regexp" + "strings" "github.com/aws/aws-sdk-go-v2/aws" "github.com/aws/aws-sdk-go-v2/aws/arn" @@ -28,6 +30,14 @@ const ( var parameterNameRx = regexp.MustCompile(parameterNameRxStr) +// Option defines a functional option type for configuring awsStore. +type Option func(*storeOptions) + +// storeOptions holds optional settings for creating a new awsStore. +type storeOptions struct { + kmsKey string +} + // awsSSMClient is an interface allowing us to mock the couple of // API calls we are leveraging with the AWSStore provider type awsSSMClient interface { @@ -46,6 +56,10 @@ type awsStore struct { ssmClient awsSSMClient ssmARN arn.ARN + // kmsKey is optional. If empty, the parameter is stored in plaintext. + // If non-empty, the parameter is encrypted with this KMS key. + kmsKey string + memory mem.Store } @@ -57,30 +71,80 @@ type awsStore struct { // Tailscaled to only only store new state in-memory and // restarting Tailscaled can fail until you delete your state // from the AWS Parameter Store. -func New(_ logger.Logf, ssmARN string) (ipn.StateStore, error) { - return newStore(ssmARN, nil) +// +// If you want to specify an optional KMS key, +// pass one or more Option objects, e.g. awsstore.WithKeyID("alias/my-key"). +func New(_ logger.Logf, ssmARN string, opts ...Option) (ipn.StateStore, error) { + // Apply all options to an empty storeOptions + var so storeOptions + for _, opt := range opts { + opt(&so) + } + + return newStore(ssmARN, so, nil) +} + +// WithKeyID sets the KMS key to be used for encryption. It can be +// a KeyID, an alias ("alias/my-key"), or a full ARN. +// +// If kmsKey is empty, the Option is a no-op. +func WithKeyID(kmsKey string) Option { + return func(o *storeOptions) { + o.kmsKey = kmsKey + } +} + +// ParseARNAndOpts parses an ARN and optional URL-encoded parameters +// from arg. +func ParseARNAndOpts(arg string) (ssmARN string, opts []Option, err error) { + ssmARN = arg + + // Support optional ?url-encoded-parameters. + if s, q, ok := strings.Cut(arg, "?"); ok { + ssmARN = s + q, err := url.ParseQuery(q) + if err != nil { + return "", nil, err + } + + for k := range q { + switch k { + default: + return "", nil, fmt.Errorf("unknown arn option parameter %q", k) + case "kmsKey": + // We allow an ARN, a key ID, or an alias name for kmsKeyID. + // If it doesn't look like an ARN and doesn't have a '/', + // prepend "alias/" for KMS alias references. + kmsKey := q.Get(k) + if kmsKey != "" && + !strings.Contains(kmsKey, "/") && + !strings.HasPrefix(kmsKey, "arn:") { + kmsKey = "alias/" + kmsKey + } + if kmsKey != "" { + opts = append(opts, WithKeyID(kmsKey)) + } + } + } + } + return ssmARN, opts, nil } // newStore is NewStore, but for tests. If client is non-nil, it's // used instead of making one. -func newStore(ssmARN string, client awsSSMClient) (ipn.StateStore, error) { +func newStore(ssmARN string, so storeOptions, client awsSSMClient) (ipn.StateStore, error) { s := &awsStore{ ssmClient: client, + kmsKey: so.kmsKey, } var err error - - // Parse the ARN if s.ssmARN, err = arn.Parse(ssmARN); err != nil { return nil, fmt.Errorf("unable to parse the ARN correctly: %v", err) } - - // Validate the ARN corresponds to the SSM service if s.ssmARN.Service != "ssm" { return nil, fmt.Errorf("invalid service %q, expected 'ssm'", s.ssmARN.Service) } - - // Validate the ARN corresponds to a parameter store resource if !parameterNameRx.MatchString(s.ssmARN.Resource) { return nil, fmt.Errorf("invalid resource %q, expected to match %v", s.ssmARN.Resource, parameterNameRxStr) } @@ -96,12 +160,11 @@ func newStore(ssmARN string, client awsSSMClient) (ipn.StateStore, error) { s.ssmClient = ssm.NewFromConfig(cfg) } - // Hydrate cache with the potentially current state + // Preload existing state, if any if err := s.LoadState(); err != nil { return nil, err } return s, nil - } // LoadState attempts to read the state from AWS SSM parameter store key. @@ -172,15 +235,21 @@ func (s *awsStore) persistState() error { // which is free. However, if it exceeds 4kb it switches the parameter to advanced tiering // doubling the capacity to 8kb per the following docs: // https://aws.amazon.com/about-aws/whats-new/2019/08/aws-systems-manager-parameter-store-announces-intelligent-tiering-to-enable-automatic-parameter-tier-selection/ - _, err = s.ssmClient.PutParameter( - context.TODO(), - &ssm.PutParameterInput{ - Name: aws.String(s.ParameterName()), - Value: aws.String(string(bs)), - Overwrite: aws.Bool(true), - Tier: ssmTypes.ParameterTierIntelligentTiering, - Type: ssmTypes.ParameterTypeSecureString, - }, - ) + in := &ssm.PutParameterInput{ + Name: aws.String(s.ParameterName()), + Value: aws.String(string(bs)), + Overwrite: aws.Bool(true), + Tier: ssmTypes.ParameterTierIntelligentTiering, + Type: ssmTypes.ParameterTypeSecureString, + } + + // If kmsKey is specified, encrypt with that key + // NOTE: this input allows any alias, keyID or ARN + // If this isn't specified, AWS will use the default KMS key + if s.kmsKey != "" { + in.KeyId = aws.String(s.kmsKey) + } + + _, err = s.ssmClient.PutParameter(context.TODO(), in) return err } diff --git a/ipn/store/awsstore/store_aws_stub.go b/ipn/store/awsstore/store_aws_stub.go deleted file mode 100644 index 8d2156ce9..000000000 --- a/ipn/store/awsstore/store_aws_stub.go +++ /dev/null @@ -1,18 +0,0 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -//go:build !linux || ts_omit_aws - -package awsstore - -import ( - "fmt" - "runtime" - - "tailscale.com/ipn" - "tailscale.com/types/logger" -) - -func New(logger.Logf, string) (ipn.StateStore, error) { - return nil, fmt.Errorf("AWS store is not supported on %v", runtime.GOOS) -} diff --git a/ipn/store/awsstore/store_aws_test.go b/ipn/store/awsstore/store_aws_test.go index f6c8fedb3..3382635a7 100644 --- a/ipn/store/awsstore/store_aws_test.go +++ b/ipn/store/awsstore/store_aws_test.go @@ -1,7 +1,7 @@ // Copyright (c) Tailscale Inc & AUTHORS // SPDX-License-Identifier: BSD-3-Clause -//go:build linux +//go:build linux && !ts_omit_aws package awsstore @@ -65,7 +65,11 @@ func TestNewAWSStore(t *testing.T) { Resource: "parameter/foo", } - s, err := newStore(storeParameterARN.String(), mc) + opts := storeOptions{ + kmsKey: "arn:aws:kms:eu-west-1:123456789:key/MyCustomKey", + } + + s, err := newStore(storeParameterARN.String(), opts, mc) if err != nil { t.Fatalf("creating aws store failed: %v", err) } @@ -73,7 +77,7 @@ func TestNewAWSStore(t *testing.T) { // Build a brand new file store and check that both IDs written // above are still there. - s2, err := newStore(storeParameterARN.String(), mc) + s2, err := newStore(storeParameterARN.String(), opts, mc) if err != nil { t.Fatalf("creating second aws store failed: %v", err) } @@ -162,3 +166,54 @@ func testStoreSemantics(t *testing.T, store ipn.StateStore) { } } } + +func TestParseARNAndOpts(t *testing.T) { + tests := []struct { + name string + arg string + wantARN string + wantKey string + }{ + { + name: "no-key", + arg: "arn:aws:ssm:us-east-1:123456789012:parameter/myTailscaleParam", + wantARN: "arn:aws:ssm:us-east-1:123456789012:parameter/myTailscaleParam", + }, + { + name: "custom-key", + arg: "arn:aws:ssm:us-east-1:123456789012:parameter/myTailscaleParam?kmsKey=alias/MyCustomKey", + wantARN: "arn:aws:ssm:us-east-1:123456789012:parameter/myTailscaleParam", + wantKey: "alias/MyCustomKey", + }, + { + name: "bare-name", + arg: "arn:aws:ssm:us-east-1:123456789012:parameter/myTailscaleParam?kmsKey=Bare", + wantARN: "arn:aws:ssm:us-east-1:123456789012:parameter/myTailscaleParam", + wantKey: "alias/Bare", + }, + { + name: "arn-arg", + arg: "arn:aws:ssm:us-east-1:123456789012:parameter/myTailscaleParam?kmsKey=arn:foo", + wantARN: "arn:aws:ssm:us-east-1:123456789012:parameter/myTailscaleParam", + wantKey: "arn:foo", + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + arn, opts, err := ParseARNAndOpts(tt.arg) + if err != nil { + t.Fatalf("New: %v", err) + } + if arn != tt.wantARN { + t.Errorf("ARN = %q; want %q", arn, tt.wantARN) + } + var got storeOptions + for _, opt := range opts { + opt(&got) + } + if got.kmsKey != tt.wantKey { + t.Errorf("kmsKey = %q; want %q", got.kmsKey, tt.wantKey) + } + }) + } +} diff --git a/ipn/store/kubestore/store_kube.go b/ipn/store/kubestore/store_kube.go index 462e6d434..ed37f06c2 100644 --- a/ipn/store/kubestore/store_kube.go +++ b/ipn/store/kubestore/store_kube.go @@ -13,11 +13,15 @@ import ( "strings" "time" + "tailscale.com/envknob" "tailscale.com/ipn" "tailscale.com/ipn/store/mem" "tailscale.com/kube/kubeapi" "tailscale.com/kube/kubeclient" + "tailscale.com/kube/kubetypes" "tailscale.com/types/logger" + "tailscale.com/util/dnsname" + "tailscale.com/util/mak" ) const ( @@ -31,21 +35,37 @@ const ( reasonTailscaleStateLoadFailed = "TailscaleStateLoadFailed" eventTypeWarning = "Warning" eventTypeNormal = "Normal" + + keyTLSCert = "tls.crt" + keyTLSKey = "tls.key" ) // Store is an ipn.StateStore that uses a Kubernetes Secret for persistence. type Store struct { - client kubeclient.Client - canPatch bool - secretName string + client kubeclient.Client + canPatch bool + secretName string // state Secret + certShareMode string // 'ro', 'rw', or empty + podName string - // memory holds the latest tailscale state. Writes write state to a kube Secret and memory, Reads read from - // memory. + // memory holds the latest tailscale state. Writes write state to a kube + // Secret and memory, Reads read from memory. memory mem.Store } -// New returns a new Store that persists to the named Secret. -func New(_ logger.Logf, secretName string) (*Store, error) { +// New returns a new Store that persists state to Kubernets Secret(s). +// Tailscale state is stored in a Secret named by the secretName parameter. +// TLS certs are stored and retrieved from state Secret or separate Secrets +// named after TLS endpoints if running in cert share mode. +func New(logf logger.Logf, secretName string) (*Store, error) { + c, err := newClient() + if err != nil { + return nil, err + } + return newWithClient(logf, c, secretName) +} + +func newClient() (kubeclient.Client, error) { c, err := kubeclient.New("tailscale-state-store") if err != nil { return nil, err @@ -54,6 +74,10 @@ func New(_ logger.Logf, secretName string) (*Store, error) { // Derive the API server address from the environment variables c.SetURL(fmt.Sprintf("https://%s:%s", os.Getenv("KUBERNETES_SERVICE_HOST"), os.Getenv("KUBERNETES_SERVICE_PORT_HTTPS"))) } + return c, nil +} + +func newWithClient(logf logger.Logf, c kubeclient.Client, secretName string) (*Store, error) { canPatch, _, err := c.CheckSecretPermissions(context.Background(), secretName) if err != nil { return nil, err @@ -62,11 +86,30 @@ func New(_ logger.Logf, secretName string) (*Store, error) { client: c, canPatch: canPatch, secretName: secretName, + podName: os.Getenv("POD_NAME"), + } + if envknob.IsCertShareReadWriteMode() { + s.certShareMode = "rw" + } else if envknob.IsCertShareReadOnlyMode() { + s.certShareMode = "ro" } + // Load latest state from kube Secret if it already exists. if err := s.loadState(); err != nil && err != ipn.ErrStateNotExist { return nil, fmt.Errorf("error loading state from kube Secret: %w", err) } + // If we are in cert share mode, pre-load existing shared certs. + if s.certShareMode == "rw" || s.certShareMode == "ro" { + sel := s.certSecretSelector() + if err := s.loadCerts(context.Background(), sel); err != nil { + // We will attempt to again retrieve the certs from Secrets when a request for an HTTPS endpoint + // is received. + log.Printf("[unexpected] error loading TLS certs: %v", err) + } + } + if s.certShareMode == "ro" { + go s.runCertReload(context.Background(), logf) + } return s, nil } @@ -83,11 +126,101 @@ func (s *Store) ReadState(id ipn.StateKey) ([]byte, error) { // WriteState implements the StateStore interface. func (s *Store) WriteState(id ipn.StateKey, bs []byte) (err error) { - ctx, cancel := context.WithTimeout(context.Background(), timeout) defer func() { if err == nil { s.memory.WriteState(ipn.StateKey(sanitizeKey(id)), bs) } + }() + return s.updateSecret(map[string][]byte{string(id): bs}, s.secretName) +} + +// WriteTLSCertAndKey writes a TLS cert and key to domain.crt, domain.key fields +// of a Tailscale Kubernetes node's state Secret. +func (s *Store) WriteTLSCertAndKey(domain string, cert, key []byte) (err error) { + if s.certShareMode == "ro" { + log.Printf("[unexpected] TLS cert and key write in read-only mode") + } + if err := dnsname.ValidHostname(domain); err != nil { + return fmt.Errorf("invalid domain name %q: %w", domain, err) + } + defer func() { + // TODO(irbekrm): a read between these two separate writes would + // get a mismatched cert and key. Allow writing both cert and + // key to the memory store in a single, lock-protected operation. + if err == nil { + s.memory.WriteState(ipn.StateKey(domain+".crt"), cert) + s.memory.WriteState(ipn.StateKey(domain+".key"), key) + } + }() + secretName := s.secretName + data := map[string][]byte{ + domain + ".crt": cert, + domain + ".key": key, + } + // If we run in cert share mode, cert and key for a DNS name are written + // to a separate Secret. + if s.certShareMode == "rw" { + secretName = domain + data = map[string][]byte{ + keyTLSCert: cert, + keyTLSKey: key, + } + } + return s.updateSecret(data, secretName) +} + +// ReadTLSCertAndKey reads a TLS cert and key from memory or from a +// domain-specific Secret. It first checks the in-memory store, if not found in +// memory and running cert store in read-only mode, looks up a Secret. +func (s *Store) ReadTLSCertAndKey(domain string) (cert, key []byte, err error) { + if err := dnsname.ValidHostname(domain); err != nil { + return nil, nil, fmt.Errorf("invalid domain name %q: %w", domain, err) + } + certKey := domain + ".crt" + keyKey := domain + ".key" + + cert, err = s.memory.ReadState(ipn.StateKey(certKey)) + if err == nil { + key, err = s.memory.ReadState(ipn.StateKey(keyKey)) + if err == nil { + return cert, key, nil + } + } + if s.certShareMode != "ro" { + return nil, nil, ipn.ErrStateNotExist + } + // If we are in cert share read only mode, it is possible that a write + // replica just issued the TLS cert for this DNS name and it has not + // been loaded to store yet, so check the Secret. + + ctx, cancel := context.WithTimeout(context.Background(), timeout) + defer cancel() + + secret, err := s.client.GetSecret(ctx, domain) + if err != nil { + if kubeclient.IsNotFoundErr(err) { + // TODO(irbekrm): we should return a more specific error + // that wraps ipn.ErrStateNotExist here. + return nil, nil, ipn.ErrStateNotExist + } + return nil, nil, fmt.Errorf("getting TLS Secret %q: %w", domain, err) + } + cert = secret.Data[keyTLSCert] + key = secret.Data[keyTLSKey] + if len(cert) == 0 || len(key) == 0 { + return nil, nil, ipn.ErrStateNotExist + } + // TODO(irbekrm): a read between these two separate writes would + // get a mismatched cert and key. Allow writing both cert and + // key to the memory store in a single lock-protected operation. + s.memory.WriteState(ipn.StateKey(certKey), cert) + s.memory.WriteState(ipn.StateKey(keyKey), key) + return cert, key, nil +} + +func (s *Store) updateSecret(data map[string][]byte, secretName string) (err error) { + ctx, cancel := context.WithTimeout(context.Background(), timeout) + defer func() { if err != nil { if err := s.client.Event(ctx, eventTypeWarning, reasonTailscaleStateUpdateFailed, err.Error()); err != nil { log.Printf("kubestore: error creating tailscaled state update Event: %v", err) @@ -99,56 +232,69 @@ func (s *Store) WriteState(id ipn.StateKey, bs []byte) (err error) { } cancel() }() - - secret, err := s.client.GetSecret(ctx, s.secretName) + secret, err := s.client.GetSecret(ctx, secretName) if err != nil { - if kubeclient.IsNotFoundErr(err) { + // If the Secret does not exist, create it with the required data. + if kubeclient.IsNotFoundErr(err) && s.canCreateSecret(secretName) { return s.client.CreateSecret(ctx, &kubeapi.Secret{ TypeMeta: kubeapi.TypeMeta{ APIVersion: "v1", Kind: "Secret", }, ObjectMeta: kubeapi.ObjectMeta{ - Name: s.secretName, - }, - Data: map[string][]byte{ - sanitizeKey(id): bs, + Name: secretName, }, + Data: func(m map[string][]byte) map[string][]byte { + d := make(map[string][]byte, len(m)) + for key, val := range m { + d[sanitizeKey(key)] = val + } + return d + }(data), }) } - return err + return fmt.Errorf("error getting Secret %s: %w", secretName, err) } - if s.canPatch { - if len(secret.Data) == 0 { // if user has pre-created a blank Secret - m := []kubeclient.JSONPatch{ + if s.canPatchSecret(secretName) { + var m []kubeclient.JSONPatch + // If the user has pre-created a Secret with no data, we need to ensure the top level /data field. + if len(secret.Data) == 0 { + m = []kubeclient.JSONPatch{ { - Op: "add", - Path: "/data", - Value: map[string][]byte{sanitizeKey(id): bs}, + Op: "add", + Path: "/data", + Value: func(m map[string][]byte) map[string][]byte { + d := make(map[string][]byte, len(m)) + for key, val := range m { + d[sanitizeKey(key)] = val + } + return d + }(data), }, } - if err := s.client.JSONPatchResource(ctx, s.secretName, kubeclient.TypeSecrets, m); err != nil { - return fmt.Errorf("error patching Secret %s with a /data field: %v", s.secretName, err) + // If the Secret has data, patch it with the new data. + } else { + for key, val := range data { + m = append(m, kubeclient.JSONPatch{ + Op: "add", + Path: "/data/" + sanitizeKey(key), + Value: val, + }) } - return nil - } - m := []kubeclient.JSONPatch{ - { - Op: "add", - Path: "/data/" + sanitizeKey(id), - Value: bs, - }, } - if err := s.client.JSONPatchResource(ctx, s.secretName, kubeclient.TypeSecrets, m); err != nil { - return fmt.Errorf("error patching Secret %s with /data/%s field: %v", s.secretName, sanitizeKey(id), err) + if err := s.client.JSONPatchResource(ctx, secretName, kubeclient.TypeSecrets, m); err != nil { + return fmt.Errorf("error patching Secret %s: %w", secretName, err) } return nil } - secret.Data[sanitizeKey(id)] = bs + // No patch permissions, use UPDATE instead. + for key, val := range data { + mak.Set(&secret.Data, sanitizeKey(key), val) + } if err := s.client.UpdateSecret(ctx, secret); err != nil { - return err + return fmt.Errorf("error updating Secret %s: %w", s.secretName, err) } - return err + return nil } func (s *Store) loadState() (err error) { @@ -172,9 +318,100 @@ func (s *Store) loadState() (err error) { return nil } -func sanitizeKey(k ipn.StateKey) string { - // The only valid characters in a Kubernetes secret key are alphanumeric, -, - // _, and . +// runCertReload relists and reloads all TLS certs for endpoints shared by this +// node from Secrets other than the state Secret to ensure that renewed certs get eventually loaded. +// It is not critical to reload a cert immediately after +// renewal, so a daily check is acceptable. +// Currently (3/2025) this is only used for the shared HA Ingress certs on 'read' replicas. +// Note that if shared certs are not found in memory on an HTTPS request, we +// do a Secret lookup, so this mechanism does not need to ensure that newly +// added Ingresses' certs get loaded. +func (s *Store) runCertReload(ctx context.Context, logf logger.Logf) { + ticker := time.NewTicker(time.Hour * 24) + defer ticker.Stop() + for { + select { + case <-ctx.Done(): + return + case <-ticker.C: + sel := s.certSecretSelector() + if err := s.loadCerts(ctx, sel); err != nil { + logf("[unexpected] error reloading TLS certs: %v", err) + } + } + } +} + +// loadCerts lists all Secrets matching the provided selector and loads TLS +// certs and keys from those. +func (s *Store) loadCerts(ctx context.Context, sel map[string]string) error { + ss, err := s.client.ListSecrets(ctx, sel) + if err != nil { + return fmt.Errorf("error listing TLS Secrets: %w", err) + } + for _, secret := range ss.Items { + if !hasTLSData(&secret) { + continue + } + // Only load secrets that have valid domain names (ending in .ts.net) + if !strings.HasSuffix(secret.Name, ".ts.net") { + continue + } + s.memory.WriteState(ipn.StateKey(secret.Name)+".crt", secret.Data[keyTLSCert]) + s.memory.WriteState(ipn.StateKey(secret.Name)+".key", secret.Data[keyTLSKey]) + } + return nil +} + +// canCreateSecret returns true if this node should be allowed to create the given +// Secret in its namespace. +func (s *Store) canCreateSecret(secret string) bool { + // Only allow creating the state Secret (and not TLS Secrets). + return secret == s.secretName +} + +// canPatchSecret returns true if this node should be allowed to patch the given +// Secret. +func (s *Store) canPatchSecret(secret string) bool { + // For backwards compatibility reasons, setups where the proxies are not + // given PATCH permissions for state Secrets are allowed. For TLS + // Secrets, we should always have PATCH permissions. + if secret == s.secretName { + return s.canPatch + } + return true +} + +// certSecretSelector returns a label selector that can be used to list all +// Secrets that aren't Tailscale state Secrets and contain TLS certificates for +// HTTPS endpoints that this node serves. +// Currently (3/2025) this only applies to the Kubernetes Operator's ingress +// ProxyGroup. +func (s *Store) certSecretSelector() map[string]string { + if s.podName == "" { + return map[string]string{} + } + p := strings.LastIndex(s.podName, "-") + if p == -1 { + return map[string]string{} + } + pgName := s.podName[:p] + return map[string]string{ + kubetypes.LabelSecretType: "certs", + kubetypes.LabelManaged: "true", + "tailscale.com/proxy-group": pgName, + } +} + +// hasTLSData returns true if the provided Secret contains non-empty TLS cert and key. +func hasTLSData(s *kubeapi.Secret) bool { + return len(s.Data[keyTLSCert]) != 0 && len(s.Data[keyTLSKey]) != 0 +} + +// sanitizeKey converts any value that can be converted to a string into a valid Kubernetes Secret key. +// Valid characters are alphanumeric, -, _, and . +// https://kubernetes.io/docs/concepts/configuration/secret/#restriction-names-data. +func sanitizeKey[T ~string](k T) string { return strings.Map(func(r rune) rune { if r >= 'a' && r <= 'z' || r >= 'A' && r <= 'Z' || r >= '0' && r <= '9' || r == '-' || r == '_' || r == '.' { return r diff --git a/ipn/store/kubestore/store_kube_test.go b/ipn/store/kubestore/store_kube_test.go new file mode 100644 index 000000000..2ed16e77b --- /dev/null +++ b/ipn/store/kubestore/store_kube_test.go @@ -0,0 +1,723 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +package kubestore + +import ( + "bytes" + "context" + "encoding/json" + "fmt" + "strings" + "testing" + + "github.com/google/go-cmp/cmp" + "tailscale.com/envknob" + "tailscale.com/ipn" + "tailscale.com/ipn/store/mem" + "tailscale.com/kube/kubeapi" + "tailscale.com/kube/kubeclient" +) + +func TestWriteState(t *testing.T) { + tests := []struct { + name string + initial map[string][]byte + key ipn.StateKey + value []byte + wantData map[string][]byte + allowPatch bool + }{ + { + name: "basic_write", + initial: map[string][]byte{ + "existing": []byte("old"), + }, + key: "foo", + value: []byte("bar"), + wantData: map[string][]byte{ + "existing": []byte("old"), + "foo": []byte("bar"), + }, + allowPatch: true, + }, + { + name: "update_existing", + initial: map[string][]byte{ + "foo": []byte("old"), + }, + key: "foo", + value: []byte("new"), + wantData: map[string][]byte{ + "foo": []byte("new"), + }, + allowPatch: true, + }, + { + name: "create_new_secret", + key: "foo", + value: []byte("bar"), + wantData: map[string][]byte{ + "foo": []byte("bar"), + }, + allowPatch: true, + }, + { + name: "patch_denied", + initial: map[string][]byte{ + "foo": []byte("old"), + }, + key: "foo", + value: []byte("new"), + wantData: map[string][]byte{ + "foo": []byte("new"), + }, + allowPatch: false, + }, + { + name: "sanitize_key", + initial: map[string][]byte{ + "clean-key": []byte("old"), + }, + key: "dirty@key", + value: []byte("new"), + wantData: map[string][]byte{ + "clean-key": []byte("old"), + "dirty_key": []byte("new"), + }, + allowPatch: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + secret := tt.initial // track current state + client := &kubeclient.FakeClient{ + GetSecretImpl: func(ctx context.Context, name string) (*kubeapi.Secret, error) { + if secret == nil { + return nil, &kubeapi.Status{Code: 404} + } + return &kubeapi.Secret{Data: secret}, nil + }, + CheckSecretPermissionsImpl: func(ctx context.Context, name string) (bool, bool, error) { + return tt.allowPatch, true, nil + }, + CreateSecretImpl: func(ctx context.Context, s *kubeapi.Secret) error { + secret = s.Data + return nil + }, + UpdateSecretImpl: func(ctx context.Context, s *kubeapi.Secret) error { + secret = s.Data + return nil + }, + JSONPatchResourceImpl: func(ctx context.Context, name, resourceType string, patches []kubeclient.JSONPatch) error { + if !tt.allowPatch { + return &kubeapi.Status{Reason: "Forbidden"} + } + if secret == nil { + secret = make(map[string][]byte) + } + for _, p := range patches { + if p.Op == "add" && p.Path == "/data" { + secret = p.Value.(map[string][]byte) + } else if p.Op == "add" && strings.HasPrefix(p.Path, "/data/") { + key := strings.TrimPrefix(p.Path, "/data/") + secret[key] = p.Value.([]byte) + } + } + return nil + }, + } + + s := &Store{ + client: client, + canPatch: tt.allowPatch, + secretName: "ts-state", + memory: mem.Store{}, + } + + err := s.WriteState(tt.key, tt.value) + if err != nil { + t.Errorf("WriteState() error = %v", err) + return + } + + // Verify secret data + if diff := cmp.Diff(secret, tt.wantData); diff != "" { + t.Errorf("secret data mismatch (-got +want):\n%s", diff) + } + + // Verify memory store was updated + got, err := s.memory.ReadState(ipn.StateKey(sanitizeKey(string(tt.key)))) + if err != nil { + t.Errorf("reading from memory store: %v", err) + } + if !cmp.Equal(got, tt.value) { + t.Errorf("memory store key %q = %v, want %v", tt.key, got, tt.value) + } + }) + } +} + +func TestWriteTLSCertAndKey(t *testing.T) { + const ( + testDomain = "my-app.tailnetxyz.ts.net" + testCert = "fake-cert" + testKey = "fake-key" + ) + + tests := []struct { + name string + initial map[string][]byte // pre-existing cert and key + certShareMode string + allowPatch bool // whether client can patch the Secret + wantSecretName string // name of the Secret where cert and key should be written + wantSecretData map[string][]byte + wantMemoryStore map[ipn.StateKey][]byte + }{ + { + name: "basic_write", + initial: map[string][]byte{ + "existing": []byte("old"), + }, + allowPatch: true, + wantSecretName: "ts-state", + wantSecretData: map[string][]byte{ + "existing": []byte("old"), + "my-app.tailnetxyz.ts.net.crt": []byte(testCert), + "my-app.tailnetxyz.ts.net.key": []byte(testKey), + }, + wantMemoryStore: map[ipn.StateKey][]byte{ + "my-app.tailnetxyz.ts.net.crt": []byte(testCert), + "my-app.tailnetxyz.ts.net.key": []byte(testKey), + }, + }, + { + name: "cert_share_mode_write", + certShareMode: "rw", + allowPatch: true, + wantSecretName: "my-app.tailnetxyz.ts.net", + wantSecretData: map[string][]byte{ + "tls.crt": []byte(testCert), + "tls.key": []byte(testKey), + }, + wantMemoryStore: map[ipn.StateKey][]byte{ + "my-app.tailnetxyz.ts.net.crt": []byte(testCert), + "my-app.tailnetxyz.ts.net.key": []byte(testKey), + }, + }, + { + name: "cert_share_mode_write_update_existing", + initial: map[string][]byte{ + "tls.crt": []byte("old-cert"), + "tls.key": []byte("old-key"), + }, + certShareMode: "rw", + allowPatch: true, + wantSecretName: "my-app.tailnetxyz.ts.net", + wantSecretData: map[string][]byte{ + "tls.crt": []byte(testCert), + "tls.key": []byte(testKey), + }, + wantMemoryStore: map[ipn.StateKey][]byte{ + "my-app.tailnetxyz.ts.net.crt": []byte(testCert), + "my-app.tailnetxyz.ts.net.key": []byte(testKey), + }, + }, + { + name: "update_existing", + initial: map[string][]byte{ + "my-app.tailnetxyz.ts.net.crt": []byte("old-cert"), + "my-app.tailnetxyz.ts.net.key": []byte("old-key"), + }, + certShareMode: "", + allowPatch: true, + wantSecretName: "ts-state", + wantSecretData: map[string][]byte{ + "my-app.tailnetxyz.ts.net.crt": []byte(testCert), + "my-app.tailnetxyz.ts.net.key": []byte(testKey), + }, + wantMemoryStore: map[ipn.StateKey][]byte{ + "my-app.tailnetxyz.ts.net.crt": []byte(testCert), + "my-app.tailnetxyz.ts.net.key": []byte(testKey), + }, + }, + { + name: "patch_denied", + certShareMode: "", + allowPatch: false, + wantSecretName: "ts-state", + wantSecretData: map[string][]byte{ + "my-app.tailnetxyz.ts.net.crt": []byte(testCert), + "my-app.tailnetxyz.ts.net.key": []byte(testKey), + }, + wantMemoryStore: map[ipn.StateKey][]byte{ + "my-app.tailnetxyz.ts.net.crt": []byte(testCert), + "my-app.tailnetxyz.ts.net.key": []byte(testKey), + }, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + + // Set POD_NAME for testing selectors + envknob.Setenv("POD_NAME", "ingress-proxies-1") + defer envknob.Setenv("POD_NAME", "") + + secret := tt.initial // track current state + client := &kubeclient.FakeClient{ + GetSecretImpl: func(ctx context.Context, name string) (*kubeapi.Secret, error) { + if secret == nil { + return nil, &kubeapi.Status{Code: 404} + } + return &kubeapi.Secret{Data: secret}, nil + }, + CheckSecretPermissionsImpl: func(ctx context.Context, name string) (bool, bool, error) { + return tt.allowPatch, true, nil + }, + CreateSecretImpl: func(ctx context.Context, s *kubeapi.Secret) error { + if s.Name != tt.wantSecretName { + t.Errorf("CreateSecret called with wrong name, got %q, want %q", s.Name, tt.wantSecretName) + } + secret = s.Data + return nil + }, + UpdateSecretImpl: func(ctx context.Context, s *kubeapi.Secret) error { + if s.Name != tt.wantSecretName { + t.Errorf("UpdateSecret called with wrong name, got %q, want %q", s.Name, tt.wantSecretName) + } + secret = s.Data + return nil + }, + JSONPatchResourceImpl: func(ctx context.Context, name, resourceType string, patches []kubeclient.JSONPatch) error { + if !tt.allowPatch { + return &kubeapi.Status{Reason: "Forbidden"} + } + if name != tt.wantSecretName { + t.Errorf("JSONPatchResource called with wrong name, got %q, want %q", name, tt.wantSecretName) + } + if secret == nil { + secret = make(map[string][]byte) + } + for _, p := range patches { + if p.Op == "add" && p.Path == "/data" { + secret = p.Value.(map[string][]byte) + } else if p.Op == "add" && strings.HasPrefix(p.Path, "/data/") { + key := strings.TrimPrefix(p.Path, "/data/") + secret[key] = p.Value.([]byte) + } + } + return nil + }, + } + + s := &Store{ + client: client, + canPatch: tt.allowPatch, + secretName: tt.wantSecretName, + certShareMode: tt.certShareMode, + memory: mem.Store{}, + } + + err := s.WriteTLSCertAndKey(testDomain, []byte(testCert), []byte(testKey)) + if err != nil { + t.Errorf("WriteTLSCertAndKey() error = '%v'", err) + return + } + + // Verify secret data + if diff := cmp.Diff(secret, tt.wantSecretData); diff != "" { + t.Errorf("secret data mismatch (-got +want):\n%s", diff) + } + + // Verify memory store was updated + for key, want := range tt.wantMemoryStore { + got, err := s.memory.ReadState(key) + if err != nil { + t.Errorf("reading from memory store: %v", err) + continue + } + if !cmp.Equal(got, want) { + t.Errorf("memory store key %q = %v, want %v", key, got, want) + } + } + }) + } +} + +func TestReadTLSCertAndKey(t *testing.T) { + const ( + testDomain = "my-app.tailnetxyz.ts.net" + testCert = "fake-cert" + testKey = "fake-key" + ) + + tests := []struct { + name string + memoryStore map[ipn.StateKey][]byte // pre-existing memory store state + certShareMode string + domain string + secretData map[string][]byte // data to return from mock GetSecret + secretGetErr error // error to return from mock GetSecret + wantCert []byte + wantKey []byte + wantErr error + // what should end up in memory store after the store is created + wantMemoryStore map[ipn.StateKey][]byte + }{ + { + name: "found", + memoryStore: map[ipn.StateKey][]byte{ + "my-app.tailnetxyz.ts.net.crt": []byte(testCert), + "my-app.tailnetxyz.ts.net.key": []byte(testKey), + }, + domain: testDomain, + wantCert: []byte(testCert), + wantKey: []byte(testKey), + wantMemoryStore: map[ipn.StateKey][]byte{ + "my-app.tailnetxyz.ts.net.crt": []byte(testCert), + "my-app.tailnetxyz.ts.net.key": []byte(testKey), + }, + }, + { + name: "not_found", + domain: testDomain, + wantErr: ipn.ErrStateNotExist, + }, + { + name: "cert_share_ro_mode_found_in_secret", + certShareMode: "ro", + domain: testDomain, + secretData: map[string][]byte{ + "tls.crt": []byte(testCert), + "tls.key": []byte(testKey), + }, + wantCert: []byte(testCert), + wantKey: []byte(testKey), + wantMemoryStore: map[ipn.StateKey][]byte{ + "my-app.tailnetxyz.ts.net.crt": []byte(testCert), + "my-app.tailnetxyz.ts.net.key": []byte(testKey), + }, + }, + { + name: "cert_share_ro_mode_found_in_memory", + certShareMode: "ro", + memoryStore: map[ipn.StateKey][]byte{ + "my-app.tailnetxyz.ts.net.crt": []byte(testCert), + "my-app.tailnetxyz.ts.net.key": []byte(testKey), + }, + domain: testDomain, + wantCert: []byte(testCert), + wantKey: []byte(testKey), + wantMemoryStore: map[ipn.StateKey][]byte{ + "my-app.tailnetxyz.ts.net.crt": []byte(testCert), + "my-app.tailnetxyz.ts.net.key": []byte(testKey), + }, + }, + { + name: "cert_share_ro_mode_not_found", + certShareMode: "ro", + domain: testDomain, + secretGetErr: &kubeapi.Status{Code: 404}, + wantErr: ipn.ErrStateNotExist, + }, + { + name: "cert_share_ro_mode_empty_cert_in_secret", + certShareMode: "ro", + domain: testDomain, + secretData: map[string][]byte{ + "tls.crt": {}, + "tls.key": []byte(testKey), + }, + wantErr: ipn.ErrStateNotExist, + }, + { + name: "cert_share_ro_mode_kube_api_error", + certShareMode: "ro", + domain: testDomain, + secretGetErr: fmt.Errorf("api error"), + wantErr: fmt.Errorf("getting TLS Secret %q: api error", sanitizeKey(testDomain)), + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + + client := &kubeclient.FakeClient{ + GetSecretImpl: func(ctx context.Context, name string) (*kubeapi.Secret, error) { + if tt.secretGetErr != nil { + return nil, tt.secretGetErr + } + return &kubeapi.Secret{Data: tt.secretData}, nil + }, + } + + s := &Store{ + client: client, + secretName: "ts-state", + certShareMode: tt.certShareMode, + memory: mem.Store{}, + } + + // Initialize memory store + for k, v := range tt.memoryStore { + s.memory.WriteState(k, v) + } + + gotCert, gotKey, err := s.ReadTLSCertAndKey(tt.domain) + if tt.wantErr != nil { + if err == nil { + t.Errorf("ReadTLSCertAndKey() error = nil, want error containing %v", tt.wantErr) + return + } + if !strings.Contains(err.Error(), tt.wantErr.Error()) { + t.Errorf("ReadTLSCertAndKey() error = %v, want error containing %v", err, tt.wantErr) + } + return + } + if err != nil { + t.Errorf("ReadTLSCertAndKey() unexpected error: %v", err) + return + } + + if !bytes.Equal(gotCert, tt.wantCert) { + t.Errorf("ReadTLSCertAndKey() gotCert = %v, want %v", gotCert, tt.wantCert) + } + if !bytes.Equal(gotKey, tt.wantKey) { + t.Errorf("ReadTLSCertAndKey() gotKey = %v, want %v", gotKey, tt.wantKey) + } + + // Verify memory store contents after operation + if tt.wantMemoryStore != nil { + for key, want := range tt.wantMemoryStore { + got, err := s.memory.ReadState(key) + if err != nil { + t.Errorf("reading from memory store: %v", err) + continue + } + if !bytes.Equal(got, want) { + t.Errorf("memory store key %q = %v, want %v", key, got, want) + } + } + } + }) + } +} + +func TestNewWithClient(t *testing.T) { + const ( + secretName = "ts-state" + testCert = "fake-cert" + testKey = "fake-key" + ) + + certSecretsLabels := map[string]string{ + "tailscale.com/secret-type": "certs", + "tailscale.com/managed": "true", + "tailscale.com/proxy-group": "ingress-proxies", + } + + // Helper function to create Secret objects for testing + makeSecret := func(name string, labels map[string]string, certSuffix string) kubeapi.Secret { + return kubeapi.Secret{ + ObjectMeta: kubeapi.ObjectMeta{ + Name: name, + Labels: labels, + }, + Data: map[string][]byte{ + "tls.crt": []byte(testCert + certSuffix), + "tls.key": []byte(testKey + certSuffix), + }, + } + } + + tests := []struct { + name string + stateSecretContents map[string][]byte // data in state Secret + TLSSecrets []kubeapi.Secret // list of TLS cert Secrets + certMode string + secretGetErr error // error to return from GetSecret + secretsListErr error // error to return from ListSecrets + wantMemoryStoreContents map[ipn.StateKey][]byte + wantErr error + }{ + { + name: "empty_state_secret", + stateSecretContents: map[string][]byte{}, + wantMemoryStoreContents: map[ipn.StateKey][]byte{}, + }, + { + name: "state_secret_not_found", + secretGetErr: &kubeapi.Status{Code: 404}, + wantMemoryStoreContents: map[ipn.StateKey][]byte{}, + }, + { + name: "state_secret_get_error", + secretGetErr: fmt.Errorf("some error"), + wantErr: fmt.Errorf("error loading state from kube Secret: some error"), + }, + { + name: "load_existing_state", + stateSecretContents: map[string][]byte{ + "foo": []byte("bar"), + "baz": []byte("qux"), + }, + wantMemoryStoreContents: map[ipn.StateKey][]byte{ + "foo": []byte("bar"), + "baz": []byte("qux"), + }, + }, + { + name: "load_select_certs_in_read_only_mode", + certMode: "ro", + stateSecretContents: map[string][]byte{ + "foo": []byte("bar"), + }, + TLSSecrets: []kubeapi.Secret{ + makeSecret("app1.tailnetxyz.ts.net", certSecretsLabels, "1"), + makeSecret("app2.tailnetxyz.ts.net", certSecretsLabels, "2"), + makeSecret("some-other-secret", nil, "3"), + makeSecret("app3.other-proxies.ts.net", map[string]string{ + "tailscale.com/secret-type": "certs", + "tailscale.com/managed": "true", + "tailscale.com/proxy-group": "some-other-proxygroup", + }, "4"), + }, + wantMemoryStoreContents: map[ipn.StateKey][]byte{ + "foo": []byte("bar"), + "app1.tailnetxyz.ts.net.crt": []byte(testCert + "1"), + "app1.tailnetxyz.ts.net.key": []byte(testKey + "1"), + "app2.tailnetxyz.ts.net.crt": []byte(testCert + "2"), + "app2.tailnetxyz.ts.net.key": []byte(testKey + "2"), + }, + }, + { + name: "load_select_certs_in_read_write_mode", + certMode: "rw", + stateSecretContents: map[string][]byte{ + "foo": []byte("bar"), + }, + TLSSecrets: []kubeapi.Secret{ + makeSecret("app1.tailnetxyz.ts.net", certSecretsLabels, "1"), + makeSecret("app2.tailnetxyz.ts.net", certSecretsLabels, "2"), + makeSecret("some-other-secret", nil, "3"), + makeSecret("app3.other-proxies.ts.net", map[string]string{ + "tailscale.com/secret-type": "certs", + "tailscale.com/managed": "true", + "tailscale.com/proxy-group": "some-other-proxygroup", + }, "4"), + }, + wantMemoryStoreContents: map[ipn.StateKey][]byte{ + "foo": []byte("bar"), + "app1.tailnetxyz.ts.net.crt": []byte(testCert + "1"), + "app1.tailnetxyz.ts.net.key": []byte(testKey + "1"), + "app2.tailnetxyz.ts.net.crt": []byte(testCert + "2"), + "app2.tailnetxyz.ts.net.key": []byte(testKey + "2"), + }, + }, + { + name: "list_cert_secrets_fails", + certMode: "ro", + stateSecretContents: map[string][]byte{ + "foo": []byte("bar"), + }, + secretsListErr: fmt.Errorf("list error"), + // The error is logged but not returned, and state is still loaded + wantMemoryStoreContents: map[ipn.StateKey][]byte{ + "foo": []byte("bar"), + }, + }, + { + name: "cert_secrets_not_loaded_when_not_in_share_mode", + certMode: "", + stateSecretContents: map[string][]byte{ + "foo": []byte("bar"), + }, + TLSSecrets: []kubeapi.Secret{ + makeSecret("app1.tailnetxyz.ts.net", certSecretsLabels, "1"), + }, + wantMemoryStoreContents: map[ipn.StateKey][]byte{ + "foo": []byte("bar"), + }, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + envknob.Setenv("TS_CERT_SHARE_MODE", tt.certMode) + + t.Setenv("POD_NAME", "ingress-proxies-1") + + client := &kubeclient.FakeClient{ + GetSecretImpl: func(ctx context.Context, name string) (*kubeapi.Secret, error) { + if tt.secretGetErr != nil { + return nil, tt.secretGetErr + } + if name == secretName { + return &kubeapi.Secret{Data: tt.stateSecretContents}, nil + } + return nil, &kubeapi.Status{Code: 404} + }, + CheckSecretPermissionsImpl: func(ctx context.Context, name string) (bool, bool, error) { + return true, true, nil + }, + ListSecretsImpl: func(ctx context.Context, selector map[string]string) (*kubeapi.SecretList, error) { + if tt.secretsListErr != nil { + return nil, tt.secretsListErr + } + var matchingSecrets []kubeapi.Secret + for _, secret := range tt.TLSSecrets { + matches := true + for k, v := range selector { + if secret.Labels[k] != v { + matches = false + break + } + } + if matches { + matchingSecrets = append(matchingSecrets, secret) + } + } + return &kubeapi.SecretList{Items: matchingSecrets}, nil + }, + } + + s, err := newWithClient(t.Logf, client, secretName) + if tt.wantErr != nil { + if err == nil { + t.Errorf("NewWithClient() error = nil, want error containing %v", tt.wantErr) + return + } + if !strings.Contains(err.Error(), tt.wantErr.Error()) { + t.Errorf("NewWithClient() error = %v, want error containing %v", err, tt.wantErr) + } + return + } + + if err != nil { + t.Errorf("NewWithClient() unexpected error: %v", err) + return + } + + // Verify memory store contents + gotJSON, err := s.memory.ExportToJSON() + if err != nil { + t.Errorf("ExportToJSON failed: %v", err) + return + } + var got map[ipn.StateKey][]byte + if err := json.Unmarshal(gotJSON, &got); err != nil { + t.Errorf("failed to unmarshal memory store JSON: %v", err) + return + } + want := tt.wantMemoryStoreContents + if want == nil { + want = map[ipn.StateKey][]byte{} + } + if diff := cmp.Diff(got, want); diff != "" { + t.Errorf("memory store contents mismatch (-got +want):\n%s", diff) + } + }) + } +} diff --git a/ipn/store/store_aws.go b/ipn/store/store_aws.go index e164f9de7..d39e84319 100644 --- a/ipn/store/store_aws.go +++ b/ipn/store/store_aws.go @@ -6,7 +6,9 @@ package store import ( + "tailscale.com/ipn" "tailscale.com/ipn/store/awsstore" + "tailscale.com/types/logger" ) func init() { @@ -14,5 +16,11 @@ func init() { } func registerAWSStore() { - Register("arn:", awsstore.New) + Register("arn:", func(logf logger.Logf, arg string) (ipn.StateStore, error) { + ssmARN, opts, err := awsstore.ParseARNAndOpts(arg) + if err != nil { + return nil, err + } + return awsstore.New(logf, ssmARN, opts...) + }) } |
