summaryrefslogtreecommitdiffhomepage
path: root/ipn/ipnlocal/diskcache.go
diff options
context:
space:
mode:
Diffstat (limited to 'ipn/ipnlocal/diskcache.go')
-rw-r--r--ipn/ipnlocal/diskcache.go289
1 files changed, 289 insertions, 0 deletions
diff --git a/ipn/ipnlocal/diskcache.go b/ipn/ipnlocal/diskcache.go
new file mode 100644
index 000000000..734b702c3
--- /dev/null
+++ b/ipn/ipnlocal/diskcache.go
@@ -0,0 +1,289 @@
+// Copyright (c) Tailscale Inc & AUTHORS
+// SPDX-License-Identifier: BSD-3-Clause
+
+package ipnlocal
+
+import (
+ "bytes"
+ "cmp"
+ "encoding/json"
+ "errors"
+ "fmt"
+ "log"
+ "os"
+ "path/filepath"
+ "slices"
+ "strings"
+ "time"
+
+ "tailscale.com/feature/buildfeatures"
+ "tailscale.com/tailcfg"
+ "tailscale.com/tka"
+ "tailscale.com/types/key"
+ "tailscale.com/types/netmap"
+ "tailscale.com/util/mak"
+ "tailscale.com/util/set"
+)
+
+// diskCache is the state netmap caching to disk.
+type diskCache struct {
+ // all fields guarded by LocalBackend.mu
+
+ dir string // active directory to write to
+ wantBase set.Set[string] // base names we want to have on disk
+ lastWrote map[string]lastWrote // base name => contents written
+}
+
+type lastWrote struct {
+ baseDir string
+ contents []byte
+ at time.Time
+}
+
+func validBasename(name string) bool {
+ if len(name) == 0 || len(name) > 255 {
+ return false
+ }
+ if strings.ContainsAny(name, "\\/\x00") {
+ return false
+ }
+ return true
+}
+
+func (dc *diskCache) writeJSON(baseName string, v any) error {
+ if !validBasename(baseName) {
+ return fmt.Errorf("invalid baseName %q", baseName)
+ }
+ j, err := json.Marshal(v)
+ if err != nil {
+ return fmt.Errorf("errror JSON marshalling %q: %w", baseName, err)
+ }
+ last, ok := dc.lastWrote[baseName]
+ if ok && last.baseDir == dc.dir && bytes.Equal(j, last.contents) {
+ // Avoid disk writes
+ return nil
+ }
+ err = os.WriteFile(filepath.Join(dc.dir, baseName), j, 0600)
+ if err != nil {
+ return err
+ }
+ dc.wantBase.Make()
+ dc.wantBase.Add(baseName)
+ mak.Set(&dc.lastWrote, baseName, lastWrote{
+ baseDir: dc.dir,
+ contents: j,
+ at: time.Now(),
+ })
+ return nil
+}
+
+func (dc *diskCache) removeUnwantedFiles() error {
+ ents, err := os.ReadDir(dc.dir)
+ if err != nil {
+ return err
+ }
+ for _, de := range ents {
+ baseName := de.Name()
+ if !dc.wantBase.Contains(baseName) {
+ if err := os.Remove(filepath.Join(dc.dir, baseName)); err != nil {
+ return err
+ }
+ }
+ }
+ return nil
+}
+
+// netmapJSON are some misc small, low-churn netmap.NetworkMap fields we
+// serialize to disk together.
+//
+// (We never write a whole NetworkMap to disk; it's not considered a stable format)
+type netmapJSON struct {
+ MachineKey key.MachinePublic
+ CollectServices bool `json:",omitzero"`
+ DisplayMessages map[tailcfg.DisplayMessageID]tailcfg.DisplayMessage `json:",omitempty"`
+ TKAEnabled bool `json:",omitzero"`
+ TKAHead tka.AUMHash `json:",omitzero"`
+ Domain string
+ DomainAuditLogID string `json:",omitzero"`
+}
+
+func (b *LocalBackend) writeNetmapToDiskLocked(nm *netmap.NetworkMap) error {
+ if !buildfeatures.HasCacheNetMap || nm == nil || nm.Cached {
+ return nil
+ }
+ b.logf("writing netmap to disk cache")
+
+ selfUID := nm.User()
+ if selfUID == 0 {
+ return errors.New("no user in netmap")
+ }
+ prof, ok := nm.UserProfiles[selfUID]
+ if !ok {
+ return errors.New("no profile for current user in netmap")
+ }
+ root := b.varRoot
+ if root == "" {
+ return errors.New("no varRoot")
+ }
+
+ dc := &b.diskCache
+ // TODO(bradfitz): the (ID integer, LoginName string) tuple is not sufficiently
+ // globally unique. It doesn't include the control plane server URL. We should
+ // make each profile have a local UUID.
+ dc.dir = filepath.Join(root, fmt.Sprintf("nm-%d-%s", prof.ID(), prof.LoginName()))
+
+ if err := os.MkdirAll(dc.dir, 0700); err != nil {
+ return err
+ }
+
+ dc.wantBase = nil
+
+ misc := &netmapJSON{
+ MachineKey: nm.MachineKey,
+ CollectServices: nm.CollectServices,
+ DisplayMessages: nm.DisplayMessages,
+ TKAEnabled: nm.TKAEnabled,
+ TKAHead: nm.TKAHead,
+ Domain: nm.Domain,
+ DomainAuditLogID: nm.DomainAuditLogID,
+ }
+ if err := dc.writeJSON("misc", misc); err != nil {
+ return err
+ }
+
+ if buildfeatures.HasSSH && nm.SSHPolicy != nil {
+ if err := dc.writeJSON("ssh", nm.SSHPolicy); err != nil {
+ return err
+ }
+ }
+ if err := dc.writeJSON("dns", nm.DNS); err != nil {
+ return err
+ }
+ if err := dc.writeJSON("derpmap", nm.DERPMap); err != nil {
+ return err
+ }
+ if err := dc.writeJSON("self", nm.SelfNode); err != nil {
+ return err
+ }
+ for _, p := range nm.Peers {
+ if err := dc.writeJSON("peer-"+string(p.StableID()), p); err != nil {
+ return err
+ }
+ }
+ for uid, p := range nm.UserProfiles {
+ if err := dc.writeJSON(fmt.Sprintf("user-%d", uid), p); err != nil {
+ return err
+ }
+ }
+
+ if err := dc.removeUnwantedFiles(); err != nil {
+ return fmt.Errorf("cleaning old files from netmap disk cache: %w", err)
+ }
+
+ return nil
+}
+
+func (b *LocalBackend) loadDiskCache(prof tailcfg.UserProfile) (_ *netmap.NetworkMap, ok bool) {
+ if !buildfeatures.HasCacheNetMap {
+ return nil, false
+ }
+ root := b.varRoot
+ if root == "" {
+ return nil, false
+ }
+
+ dir := filepath.Join(root, fmt.Sprintf("nm-%d-%s", prof.ID, prof.LoginName))
+ ents, err := os.ReadDir(dir)
+ if err != nil {
+ if os.IsNotExist(err) {
+ return nil, false
+ }
+ b.logf("loading netmap from disk cache: reading dir %q: %v", dir, err)
+ return nil, false
+ }
+ if len(ents) == 0 {
+ return nil, false
+ }
+
+ nm := &netmap.NetworkMap{Cached: true}
+ dc := diskCache{dir: dir}
+ for _, de := range ents {
+ if err := dc.readFile(nm, de.Name()); err != nil {
+ b.logf("loading netmap from disk cache: reading file %q: %v", de.Name(), err)
+ return nil, false
+ }
+ }
+ slices.SortFunc(nm.Peers, func(a, b tailcfg.NodeView) int {
+ return cmp.Compare(a.ID(), b.ID())
+ })
+ return nm, true
+}
+
+func (dc *diskCache) readFile(nm *netmap.NetworkMap, base string) error {
+ j, err := os.ReadFile(filepath.Join(dc.dir, base))
+ if err != nil {
+ return err
+ }
+ if strings.HasPrefix(base, "peer-") {
+ var p tailcfg.Node
+ if err := setField(&p, j); err != nil {
+ return err
+ }
+ nm.Peers = append(nm.Peers, p.View())
+ return nil
+ }
+ if strings.HasPrefix(base, "user-") {
+ var up tailcfg.UserProfile
+ if err := setField(&up, j); err != nil {
+ return err
+ }
+ mak.Set(&nm.UserProfiles, up.ID, up.View())
+ return nil
+ }
+
+ switch base {
+ case "derpmap":
+ return setField(&nm.DERPMap, j)
+ case "dns":
+ return setField(&nm.DNS, j)
+ case "ssh":
+ return setField(&nm.SSHPolicy, j)
+ case "self":
+ var n *tailcfg.Node
+ if err := setField(&n, j); err != nil {
+ return err
+ }
+ nm.SelfNode = n.View()
+ nm.NodeKey = n.Key
+
+ capSet := set.Set[tailcfg.NodeCapability]{}
+ for _, c := range n.Capabilities {
+ capSet.Add(c)
+ }
+ for c := range n.CapMap {
+ capSet.Add(c)
+ }
+ nm.AllCaps = capSet
+
+ case "misc":
+ misc := &netmapJSON{}
+ if err := setField(misc, j); err != nil {
+ return err
+ }
+ nm.MachineKey = misc.MachineKey
+ nm.CollectServices = misc.CollectServices
+ nm.DisplayMessages = misc.DisplayMessages
+ nm.TKAEnabled = misc.TKAEnabled
+ nm.TKAHead = misc.TKAHead
+ nm.Domain = misc.Domain
+ nm.DomainAuditLogID = misc.DomainAuditLogID
+ return nil
+ default:
+ log.Printf("unknown netmap disk cache file %q; ignoring", base)
+ }
+ return nil
+}
+
+func setField[T any](ptr *T, j []byte) error {
+ return json.Unmarshal(j, ptr)
+}