summaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
-rw-r--r--cmd/tailscale/cli/drive.go173
-rw-r--r--drive/drive_clone.go1
-rw-r--r--drive/drive_view.go6
-rw-r--r--drive/remote.go15
-rw-r--r--drive/remote_test.go16
-rw-r--r--drive/share_access.go235
-rw-r--r--drive/share_access_test.go303
-rw-r--r--ipn/ipnlocal/peerapi_drive.go3
8 files changed, 739 insertions, 13 deletions
diff --git a/cmd/tailscale/cli/drive.go b/cmd/tailscale/cli/drive.go
index 280ff3172..c290dd495 100644
--- a/cmd/tailscale/cli/drive.go
+++ b/cmd/tailscale/cli/drive.go
@@ -7,8 +7,10 @@ package cli
import (
"context"
+ "flag"
"fmt"
"path/filepath"
+ "sort"
"strings"
"github.com/peterbourgon/ff/v3/ffcli"
@@ -16,7 +18,7 @@ import (
)
const (
- driveShareUsage = "tailscale drive share <name> <path>"
+ driveShareUsage = "tailscale drive share [--users user1,user2 | --group groupname] <name> <path>"
driveRenameUsage = "tailscale drive rename <oldname> <newname>"
driveUnshareUsage = "tailscale drive unshare <name>"
driveListUsage = "tailscale drive list"
@@ -27,6 +29,10 @@ func init() {
}
func driveCmd() *ffcli.Command {
+ shareFlags := flag.NewFlagSet("share", flag.ContinueOnError)
+ usersFlag := shareFlags.String("users", "", "comma-separated list of users to share with (share name auto-generated)")
+ groupFlag := shareFlags.String("group", "", "group name to share with (share name auto-generated, only group members can access)")
+
return &ffcli.Command{
Name: "drive",
ShortHelp: "Share a directory with your tailnet",
@@ -42,8 +48,11 @@ func driveCmd() *ffcli.Command {
{
Name: "share",
ShortUsage: driveShareUsage,
- Exec: runDriveShare,
- ShortHelp: "[ALPHA] Create or modify a share",
+ FlagSet: shareFlags,
+ Exec: func(ctx context.Context, args []string) error {
+ return runDriveShare(ctx, args, *usersFlag, *groupFlag)
+ },
+ ShortHelp: "[ALPHA] Create or modify a share",
},
{
Name: "rename",
@@ -68,12 +77,54 @@ func driveCmd() *ffcli.Command {
}
// runDriveShare is the entry point for the "tailscale drive share" command.
-func runDriveShare(ctx context.Context, args []string) error {
- if len(args) != 2 {
- return fmt.Errorf("usage: %s", driveShareUsage)
+func runDriveShare(ctx context.Context, args []string, usersFlag, groupFlag string) error {
+ if usersFlag != "" && groupFlag != "" {
+ return fmt.Errorf("cannot specify both --users and --group")
}
- name, path := args[0], args[1]
+ var name, path string
+ var isGroup bool
+
+ switch {
+ case usersFlag != "":
+ // --users joe,rhea → name = "joe+rhea", path from args[0]
+ if len(args) != 1 {
+ return fmt.Errorf("usage: tailscale drive share --users user1,user2 <path>")
+ }
+ users := strings.Split(usersFlag, ",")
+ for i, u := range users {
+ users[i] = strings.TrimSpace(u)
+ if users[i] == "" {
+ return fmt.Errorf("empty username in --users flag")
+ }
+ }
+ if err := validateUsers(ctx, users); err != nil {
+ return err
+ }
+ sort.Strings(users)
+ name = strings.Join(users, "+")
+ path = args[0]
+
+ case groupFlag != "":
+ // --group eng → name = "eng", path from args[0]
+ if len(args) != 1 {
+ return fmt.Errorf("usage: tailscale drive share --group groupname <path>")
+ }
+ if err := validateGroup(ctx, groupFlag); err != nil {
+ return err
+ }
+ name = groupFlag
+ path = args[0]
+ isGroup = true
+
+ default:
+ // Traditional: <name> <path>
+ if len(args) != 2 {
+ return fmt.Errorf("usage: %s", driveShareUsage)
+ }
+ name = args[0]
+ path = args[1]
+ }
absolutePath, err := filepath.Abs(path)
if err != nil {
@@ -81,8 +132,9 @@ func runDriveShare(ctx context.Context, args []string) error {
}
err = localClient.DriveShareSet(ctx, &drive.Share{
- Name: name,
- Path: absolutePath,
+ Name: name,
+ Path: absolutePath,
+ IsGroup: isGroup,
})
if err == nil {
fmt.Printf("Sharing %q as %q\n", path, name)
@@ -90,6 +142,109 @@ func runDriveShare(ctx context.Context, args []string) error {
return err
}
+// validateUsers checks that all specified usernames exist in the tailnet and
+// resolves display names. It modifies users in place, replacing each entry
+// with its resolved display name (which may include a domain qualifier for
+// disambiguation). It returns an error if any user is unknown or ambiguous.
+func validateUsers(ctx context.Context, users []string) error {
+ status, err := localClient.Status(ctx)
+ if err != nil {
+ return fmt.Errorf("failed to get tailnet status: %w", err)
+ }
+
+ tailnetDomain := ""
+ if status.CurrentTailnet != nil {
+ tailnetDomain = status.CurrentTailnet.Name
+ }
+
+ // Build a map from short name to list of login names.
+ type userInfo struct {
+ loginName string
+ displayName string
+ }
+ shortToUsers := make(map[string][]userInfo)
+ for _, u := range status.User {
+ short := drive.LoginShortName(u.LoginName)
+ display := drive.LoginDisplayName(u.LoginName, tailnetDomain)
+ shortToUsers[short] = append(shortToUsers[short], userInfo{
+ loginName: u.LoginName,
+ displayName: display,
+ })
+ }
+
+ // Also build a lookup by display name for users specifying name(domain).
+ displayToUser := make(map[string]userInfo)
+ for _, infos := range shortToUsers {
+ for _, info := range infos {
+ displayToUser[info.displayName] = info
+ }
+ }
+
+ for i, u := range users {
+ // Check if user specified name(domain) form.
+ if strings.Contains(u, "(") && strings.Contains(u, ")") {
+ if _, ok := displayToUser[u]; !ok {
+ known := make([]string, 0)
+ for d := range displayToUser {
+ known = append(known, d)
+ }
+ sort.Strings(known)
+ return fmt.Errorf("unknown user %q\nvalid users: %s", u, strings.Join(known, ", "))
+ }
+ users[i] = u
+ continue
+ }
+
+ // Plain short name lookup.
+ matches, ok := shortToUsers[u]
+ if !ok || len(matches) == 0 {
+ known := make([]string, 0, len(shortToUsers))
+ for k := range shortToUsers {
+ known = append(known, k)
+ }
+ sort.Strings(known)
+ return fmt.Errorf("unknown user %q\nvalid users: %s", u, strings.Join(known, ", "))
+ }
+ if len(matches) == 1 {
+ users[i] = matches[0].displayName
+ continue
+ }
+ // Ambiguous: multiple users share the same short name.
+ options := make([]string, len(matches))
+ for j, m := range matches {
+ options[j] = m.displayName
+ }
+ sort.Strings(options)
+ return fmt.Errorf("ambiguous user %q, did you mean: %s?", u, strings.Join(options, " or "))
+ }
+ return nil
+}
+
+// validateGroup checks that the specified group exists in the tailnet.
+func validateGroup(ctx context.Context, group string) error {
+ status, err := localClient.Status(ctx)
+ if err != nil {
+ return fmt.Errorf("failed to get tailnet status: %w", err)
+ }
+
+ knownGroups := make(map[string]bool)
+ for _, u := range status.User {
+ for _, g := range u.Groups {
+ knownGroups[drive.GroupShortName(g)] = true
+ }
+ }
+
+ if !knownGroups[group] {
+ known := make([]string, 0, len(knownGroups))
+ for k := range knownGroups {
+ known = append(known, k)
+ }
+ sort.Strings(known)
+ return fmt.Errorf("unknown group: %s\nvalid groups: %s", group, strings.Join(known, ", "))
+ }
+ return nil
+}
+
// runDriveUnshare is the entry point for the "tailscale drive unshare" command.
func runDriveUnshare(ctx context.Context, args []string) error {
if len(args) != 1 {
diff --git a/drive/drive_clone.go b/drive/drive_clone.go
index 724ebc386..ec9945e92 100644
--- a/drive/drive_clone.go
+++ b/drive/drive_clone.go
@@ -23,6 +23,7 @@ var _ShareCloneNeedsRegeneration = Share(struct {
Path string
As string
BookmarkData []byte
+ IsGroup bool
}{})
// Clone duplicates src into dst and reports whether it succeeded.
diff --git a/drive/drive_view.go b/drive/drive_view.go
index 253a2955b..7c22ef6e6 100644
--- a/drive/drive_view.go
+++ b/drive/drive_view.go
@@ -105,10 +105,16 @@ func (v ShareView) BookmarkData() views.ByteSlice[[]byte] {
return views.ByteSliceOf(v.ж.BookmarkData)
}
+// IsGroup indicates that this share's name corresponds to a group
+// identity. When true, only members of the matching group can access
+// the share.
+func (v ShareView) IsGroup() bool { return v.ж.IsGroup }
+
// A compilation failure here means this code must be regenerated, with the command at the top of this file.
var _ShareViewNeedsRegeneration = Share(struct {
Name string
Path string
As string
BookmarkData []byte
+ IsGroup bool
}{})
diff --git a/drive/remote.go b/drive/remote.go
index 5f34d0023..d1f8388e8 100644
--- a/drive/remote.go
+++ b/drive/remote.go
@@ -17,7 +17,7 @@ var (
// for testing.
DisallowShareAs = false
ErrDriveNotEnabled = errors.New("Taildrive not enabled")
- ErrInvalidShareName = errors.New("Share names may only contain the letters a-z, underscore _, parentheses (), or spaces")
+ ErrInvalidShareName = errors.New("Share names may only contain the letters a-z, underscore _, parentheses (), plus +, or spaces")
)
// AllowShareAs reports whether sharing files as a specific user is allowed.
@@ -46,6 +46,11 @@ type Share struct {
// hold on to a security-scoped bookmark. That bookmark is stored here. See
// https://developer.apple.com/documentation/security/app_sandbox/accessing_files_from_the_macos_app_sandbox#4144043
BookmarkData []byte `json:"bookmarkData,omitempty"`
+
+ // IsGroup indicates that this share's name corresponds to a group
+ // identity. When true, only members of the matching group can access
+ // the share.
+ IsGroup bool `json:"isGroup,omitempty"`
}
func ShareViewsEqual(a, b ShareView) bool {
@@ -55,7 +60,7 @@ func ShareViewsEqual(a, b ShareView) bool {
if !a.Valid() || !b.Valid() {
return false
}
- return a.Name() == b.Name() && a.Path() == b.Path() && a.As() == b.As() && a.BookmarkData().Equal(b.ж.BookmarkData)
+ return a.Name() == b.Name() && a.Path() == b.Path() && a.As() == b.As() && a.BookmarkData().Equal(b.ж.BookmarkData) && a.IsGroup() == b.IsGroup()
}
func SharesEqual(a, b *Share) bool {
@@ -65,7 +70,7 @@ func SharesEqual(a, b *Share) bool {
if a == nil || b == nil {
return false
}
- return a.Name == b.Name && a.Path == b.Path && a.As == b.As && bytes.Equal(a.BookmarkData, b.BookmarkData)
+ return a.Name == b.Name && a.Path == b.Path && a.As == b.As && bytes.Equal(a.BookmarkData, b.BookmarkData) && a.IsGroup == b.IsGroup
}
func CompareShares(a, b *Share) int {
@@ -124,6 +129,8 @@ func NormalizeShareName(name string) (string, error) {
return "", ErrInvalidShareName
}
+ name = NormalizeShareNameOrder(name)
+
return name, nil
}
@@ -136,7 +143,7 @@ func validShareName(name string) bool {
continue
}
switch r {
- case '_', ' ', '(', ')':
+ case '_', ' ', '(', ')', '+':
continue
}
return false
diff --git a/drive/remote_test.go b/drive/remote_test.go
index c0de1723a..bd409140e 100644
--- a/drive/remote_test.go
+++ b/drive/remote_test.go
@@ -26,6 +26,22 @@ func TestNormalizeShareName(t *testing.T) {
name: "generally good except for .",
err: ErrInvalidShareName,
},
+ {
+ name: "c++",
+ want: "c++",
+ },
+ {
+ name: " my lib (c++) ",
+ want: "my lib (c++)",
+ },
+ {
+ name: "rhea+joe",
+ want: "joe+rhea",
+ },
+ {
+ name: "Charlie+Alice+Bob",
+ want: "alice+bob+charlie",
+ },
}
for _, tt := range tests {
t.Run(fmt.Sprintf("name %q", tt.name), func(t *testing.T) {
diff --git a/drive/share_access.go b/drive/share_access.go
new file mode 100644
index 000000000..882b968d8
--- /dev/null
+++ b/drive/share_access.go
@@ -0,0 +1,235 @@
+// Copyright (c) Tailscale Inc & contributors
+// SPDX-License-Identifier: BSD-3-Clause
+
+package drive
+
+import (
+ "sort"
+ "strings"
+
+ "tailscale.com/types/views"
+)
+
+// ParseShareAccessNames returns the list of user short names encoded in a
+// share name that uses '+' as a separator. Returns nil if the name is not
+// a multi-user share. A valid multi-user share must have all non-empty
+// segments and at least 2 segments (so "c++" with empty segments returns nil).
+func ParseShareAccessNames(shareName string) []string {
+ if !strings.Contains(shareName, "+") {
+ return nil
+ }
+ parts := strings.Split(shareName, "+")
+ if len(parts) < 2 {
+ return nil
+ }
+ for _, p := range parts {
+ if p == "" {
+ return nil
+ }
+ }
+ return parts
+}
+
+// NormalizeShareNameOrder sorts '+'-separated segments alphabetically.
+// Non-multi-user names are returned unchanged.
+func NormalizeShareNameOrder(name string) string {
+ parts := ParseShareAccessNames(name)
+ if parts == nil {
+ return name
+ }
+ sort.Strings(parts)
+ return strings.Join(parts, "+")
+}
+
+// IsShareAccessibleByUser checks if the given loginName's short name (the
+// part before '@') appears in the share's '+'-separated user list. Returns
+// true for non-multi-user shares (no name-based restriction).
+func IsShareAccessibleByUser(shareName, loginName string) bool {
+ parts := ParseShareAccessNames(shareName)
+ if parts == nil {
+ return true
+ }
+ short := LoginShortName(loginName)
+ domain := loginDomain(loginName)
+ for _, p := range parts {
+ segShort, segDomain := parseShareSegment(p)
+ if segShort != short {
+ continue
+ }
+ // If the segment has no domain qualifier, match on short name only
+ // (backward compat). If it has a domain, the login's domain must
+ // start with that label.
+ if segDomain == "" {
+ return true
+ }
+ if domain != "" && strings.HasPrefix(domain, segDomain) {
+ return true
+ }
+ }
+ return false
+}
+
+// FilterPermissionsByIdentity takes ACL-derived permissions and further
+// restricts them based on share name access control. For each share:
+// - Contains '+' with valid segments: peer's login short name must be listed
+// - Has IsGroup=true on the Share: peer must be in a matching group
+// - Otherwise: no name-based restriction (ACLs only)
+//
+// The wildcard "*" permission is preserved but only applies to shares the
+// peer can access based on name/group rules.
+func FilterPermissionsByIdentity(
+ aclPerms Permissions,
+ loginName string,
+ groups []string,
+ shares views.SliceView[*Share, ShareView],
+) Permissions {
+ // If there are no shares with name-based restrictions, return as-is.
+ hasRestricted := false
+ type shareInfo struct {
+ accessible bool
+ }
+ shareInfos := make(map[string]shareInfo, shares.Len())
+ for i := range shares.Len() {
+ s := shares.At(i)
+ name := s.Name()
+ info := shareInfo{accessible: true}
+ if s.IsGroup() {
+ hasRestricted = true
+ info.accessible = matchesGroup(name, groups)
+ } else if parts := ParseShareAccessNames(name); parts != nil {
+ hasRestricted = true
+ info.accessible = false
+ short := LoginShortName(loginName)
+ domain := loginDomain(loginName)
+ for _, p := range parts {
+ segShort, segDomain := parseShareSegment(p)
+ if segShort != short {
+ continue
+ }
+ if segDomain == "" {
+ info.accessible = true
+ break
+ }
+ if domain != "" && strings.HasPrefix(domain, segDomain) {
+ info.accessible = true
+ break
+ }
+ }
+ }
+ shareInfos[name] = info
+ }
+
+ if !hasRestricted {
+ return aclPerms
+ }
+
+ // Expand the wildcard into per-share permissions so we can selectively
+ // deny access. The Permissions.For method returns max(specific, wildcard),
+ // so the only way to deny a share under a wildcard is to remove the
+ // wildcard and grant each accessible share explicitly.
+ wildcardPerm := aclPerms[wildcardShare]
+
+ filtered := make(Permissions)
+
+ // Copy non-wildcard ACL entries for accessible shares.
+ for shareName, perm := range aclPerms {
+ if shareName == wildcardShare {
+ continue
+ }
+ info, ok := shareInfos[shareName]
+ if !ok {
+ // Share in ACL but not on this node; keep it.
+ filtered[shareName] = perm
+ continue
+ }
+ if info.accessible {
+ filtered[shareName] = perm
+ }
+ }
+
+ // If there was a wildcard, expand it to all accessible shares that
+ // don't already have an explicit (higher) permission.
+ if wildcardPerm > PermissionNone {
+ for name, info := range shareInfos {
+ if info.accessible {
+ if existing := filtered[name]; wildcardPerm > existing {
+ filtered[name] = wildcardPerm
+ }
+ }
+ }
+ }
+
+ return filtered
+}
+
+// LoginShortName extracts the short name from a login name.
+// "joe@example.com" → "joe"
+func LoginShortName(loginName string) string {
+ if i := strings.Index(loginName, "@"); i >= 0 {
+ return loginName[:i]
+ }
+ return loginName
+}
+
+// loginDomain extracts the domain part from a login name.
+// "alice@example.com" → "example.com"
+// "alice" → ""
+func loginDomain(loginName string) string {
+ if i := strings.Index(loginName, "@"); i >= 0 {
+ return loginName[i+1:]
+ }
+ return ""
+}
+
+// LoginDisplayName returns a display name for a login, suitable for use in
+// share names. If the login's domain matches tailnetDomain, only the short
+// name is returned (e.g. "alice"). Otherwise, the format "shortname(domain)"
+// is used (e.g. "alice(company)") where domain has its TLD stripped.
+func LoginDisplayName(loginName, tailnetDomain string) string {
+ short := LoginShortName(loginName)
+ domain := loginDomain(loginName)
+ if domain == "" || domain == tailnetDomain {
+ return short
+ }
+ // Strip TLD from domain for brevity: "company.com" → "company"
+ domainLabel := domain
+ if i := strings.Index(domainLabel, "."); i >= 0 {
+ domainLabel = domainLabel[:i]
+ }
+ return short + "(" + domainLabel + ")"
+}
+
+// parseShareSegment parses a share name segment that may contain a domain
+// qualifier. "alice(company)" returns ("alice", "company"). "alice" returns
+// ("alice", "").
+func parseShareSegment(segment string) (shortName, domain string) {
+ if i := strings.Index(segment, "("); i >= 0 {
+ if j := strings.Index(segment, ")"); j > i {
+ return segment[:i], segment[i+1 : j]
+ }
+ }
+ return segment, ""
+}
+
+// matchesGroup checks if the share name matches any of the peer's group
+// identifiers. Groups can be in the form "group:eng" or "eng@example.com".
+func matchesGroup(shareName string, groups []string) bool {
+ for _, g := range groups {
+ if GroupShortName(g) == shareName {
+ return true
+ }
+ }
+ return false
+}
+
+// GroupShortName extracts a short group name from a group identifier.
+// "group:eng" → "eng", "eng@example.com" → "eng"
+func GroupShortName(group string) string {
+ if strings.HasPrefix(group, "group:") {
+ return strings.TrimPrefix(group, "group:")
+ }
+ if i := strings.Index(group, "@"); i >= 0 {
+ return group[:i]
+ }
+ return group
+}
diff --git a/drive/share_access_test.go b/drive/share_access_test.go
new file mode 100644
index 000000000..91840cf10
--- /dev/null
+++ b/drive/share_access_test.go
@@ -0,0 +1,303 @@
+// Copyright (c) Tailscale Inc & contributors
+// SPDX-License-Identifier: BSD-3-Clause
+
+package drive
+
+import (
+ "testing"
+
+ "tailscale.com/types/views"
+)
+
+func TestParseShareAccessNames(t *testing.T) {
+ tests := []struct {
+ name string
+ want []string
+ }{
+ {"joe+rhea", []string{"joe", "rhea"}},
+ {"alice+joe+rhea", []string{"alice", "joe", "rhea"}},
+ {"c++", nil}, // empty segments
+ {"docs", nil}, // no '+'
+ {"+leading", nil}, // empty first segment
+ {"trailing+", nil}, // empty last segment
+ {"a++b", nil}, // empty middle segment
+ {"a+b", []string{"a", "b"}},
+ }
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ got := ParseShareAccessNames(tt.name)
+ if tt.want == nil {
+ if got != nil {
+ t.Errorf("ParseShareAccessNames(%q) = %v, want nil", tt.name, got)
+ }
+ return
+ }
+ if len(got) != len(tt.want) {
+ t.Errorf("ParseShareAccessNames(%q) = %v, want %v", tt.name, got, tt.want)
+ return
+ }
+ for i := range got {
+ if got[i] != tt.want[i] {
+ t.Errorf("ParseShareAccessNames(%q)[%d] = %q, want %q", tt.name, i, got[i], tt.want[i])
+ }
+ }
+ })
+ }
+}
+
+func TestNormalizeShareNameOrder(t *testing.T) {
+ tests := []struct {
+ name string
+ want string
+ }{
+ {"rhea+joe", "joe+rhea"},
+ {"charlie+alice+bob", "alice+bob+charlie"},
+ {"docs", "docs"},
+ {"c++", "c++"},
+ {"a+b", "a+b"}, // already sorted
+ }
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ got := NormalizeShareNameOrder(tt.name)
+ if got != tt.want {
+ t.Errorf("NormalizeShareNameOrder(%q) = %q, want %q", tt.name, got, tt.want)
+ }
+ })
+ }
+}
+
+func TestIsShareAccessibleByUser(t *testing.T) {
+ tests := []struct {
+ shareName string
+ loginName string
+ want bool
+ }{
+ {"joe+rhea", "joe@example.com", true},
+ {"joe+rhea", "rhea@example.com", true},
+ {"joe+rhea", "alice@example.com", false},
+ {"docs", "anyone@example.com", true}, // not a multi-user share
+ {"c++", "anyone@example.com", true}, // not a multi-user share (empty segments)
+ {"joe+rhea", "joe", true}, // no domain
+
+ // name(domain) format
+ {"alice(contractor)+bob", "alice@contractor.io", true},
+ {"alice(contractor)+bob", "alice@example.com", false}, // wrong domain
+ {"alice(contractor)+bob", "bob@example.com", true}, // bob has no domain qualifier
+ {"alice(contractor)+bob", "charlie@example.com", false}, // not listed
+ }
+ for _, tt := range tests {
+ t.Run(tt.shareName+"_"+tt.loginName, func(t *testing.T) {
+ got := IsShareAccessibleByUser(tt.shareName, tt.loginName)
+ if got != tt.want {
+ t.Errorf("IsShareAccessibleByUser(%q, %q) = %v, want %v", tt.shareName, tt.loginName, got, tt.want)
+ }
+ })
+ }
+}
+
+func TestLoginDisplayName(t *testing.T) {
+ tests := []struct {
+ loginName string
+ tailnetDomain string
+ want string
+ }{
+ {"alice@example.com", "example.com", "alice"}, // home domain
+ {"alice@contractor.io", "example.com", "alice(contractor)"}, // foreign domain
+ {"alice@example.com", "bob@gmail.com", "alice(example)"}, // shared domain tailnet
+ {"alice", "example.com", "alice"}, // no domain in login
+ {"alice@foo.bar.com", "example.com", "alice(foo)"}, // multi-part domain
+ }
+ for _, tt := range tests {
+ t.Run(tt.loginName+"_"+tt.tailnetDomain, func(t *testing.T) {
+ got := LoginDisplayName(tt.loginName, tt.tailnetDomain)
+ if got != tt.want {
+ t.Errorf("LoginDisplayName(%q, %q) = %q, want %q", tt.loginName, tt.tailnetDomain, got, tt.want)
+ }
+ })
+ }
+}
+
+func TestParseShareSegment(t *testing.T) {
+ tests := []struct {
+ input string
+ wantShort string
+ wantDomain string
+ }{
+ {"alice", "alice", ""},
+ {"alice(company)", "alice", "company"},
+ {"alice(contractor)", "alice", "contractor"},
+ {"bob", "bob", ""},
+ }
+ for _, tt := range tests {
+ t.Run(tt.input, func(t *testing.T) {
+ gotShort, gotDomain := parseShareSegment(tt.input)
+ if gotShort != tt.wantShort || gotDomain != tt.wantDomain {
+ t.Errorf("parseShareSegment(%q) = (%q, %q), want (%q, %q)", tt.input, gotShort, gotDomain, tt.wantShort, tt.wantDomain)
+ }
+ })
+ }
+}
+
+func TestLoginShortName(t *testing.T) {
+ tests := []struct {
+ input string
+ want string
+ }{
+ {"joe@example.com", "joe"},
+ {"joe", "joe"},
+ {"alice@foo.bar.com", "alice"},
+ }
+ for _, tt := range tests {
+ t.Run(tt.input, func(t *testing.T) {
+ got := LoginShortName(tt.input)
+ if got != tt.want {
+ t.Errorf("LoginShortName(%q) = %q, want %q", tt.input, got, tt.want)
+ }
+ })
+ }
+}
+
+func TestMatchesGroup(t *testing.T) {
+ tests := []struct {
+ shareName string
+ groups []string
+ want bool
+ }{
+ {"eng", []string{"group:eng"}, true},
+ {"eng", []string{"eng@example.com"}, true},
+ {"eng", []string{"group:design", "group:eng"}, true},
+ {"eng", []string{"group:design"}, false},
+ {"eng", []string{}, false},
+ {"design", []string{"engineering@example.com"}, false},
+ }
+ for _, tt := range tests {
+ t.Run(tt.shareName, func(t *testing.T) {
+ got := matchesGroup(tt.shareName, tt.groups)
+ if got != tt.want {
+ t.Errorf("matchesGroup(%q, %v) = %v, want %v", tt.shareName, tt.groups, got, tt.want)
+ }
+ })
+ }
+}
+
+func TestGroupShortName(t *testing.T) {
+ tests := []struct {
+ input string
+ want string
+ }{
+ {"group:eng", "eng"},
+ {"eng@example.com", "eng"},
+ {"eng", "eng"},
+ }
+ for _, tt := range tests {
+ t.Run(tt.input, func(t *testing.T) {
+ got := GroupShortName(tt.input)
+ if got != tt.want {
+ t.Errorf("GroupShortName(%q) = %q, want %q", tt.input, got, tt.want)
+ }
+ })
+ }
+}
+
+func TestFilterPermissionsByIdentity(t *testing.T) {
+ shares := views.SliceOfViews([]*Share{
+ {Name: "joe+rhea"},
+ {Name: "docs"},
+ {Name: "eng", IsGroup: true},
+ {Name: "alice+bob"},
+ })
+
+ t.Run("multi-user share access", func(t *testing.T) {
+ perms := Permissions{
+ "*": PermissionReadWrite,
+ }
+ filtered := FilterPermissionsByIdentity(perms, "joe@example.com", nil, shares)
+ // joe can access joe+rhea and docs, but not eng (group) or alice+bob
+ if filtered.For("joe+rhea") != PermissionReadWrite {
+ t.Error("joe should access joe+rhea")
+ }
+ if filtered.For("docs") != PermissionReadWrite {
+ t.Error("joe should access docs")
+ }
+ if filtered.For("eng") != PermissionNone {
+ t.Error("joe should not access eng (not in group)")
+ }
+ if filtered.For("alice+bob") != PermissionNone {
+ t.Error("joe should not access alice+bob")
+ }
+ })
+
+ t.Run("group share access", func(t *testing.T) {
+ perms := Permissions{
+ "*": PermissionReadOnly,
+ }
+ filtered := FilterPermissionsByIdentity(perms, "joe@example.com", []string{"group:eng"}, shares)
+ if filtered.For("eng") != PermissionReadOnly {
+ t.Error("joe in group:eng should access eng share")
+ }
+ })
+
+ t.Run("specific share permission without wildcard", func(t *testing.T) {
+ perms := Permissions{
+ "joe+rhea": PermissionReadWrite,
+ "alice+bob": PermissionReadOnly,
+ }
+ filtered := FilterPermissionsByIdentity(perms, "joe@example.com", nil, shares)
+ if filtered.For("joe+rhea") != PermissionReadWrite {
+ t.Error("joe should have rw to joe+rhea")
+ }
+ if filtered.For("alice+bob") != PermissionNone {
+ t.Error("joe should not access alice+bob")
+ }
+ })
+
+ t.Run("no restricted shares means no filtering", func(t *testing.T) {
+ perms := Permissions{
+ "*": PermissionReadWrite,
+ }
+ unrestricted := views.SliceOfViews([]*Share{
+ {Name: "docs"},
+ {Name: "photos"},
+ })
+ filtered := FilterPermissionsByIdentity(perms, "joe@example.com", nil, unrestricted)
+ if filtered.For("docs") != PermissionReadWrite {
+ t.Error("wildcard should pass through with no restricted shares")
+ }
+ })
+
+ t.Run("empty shares means no filtering", func(t *testing.T) {
+ perms := Permissions{
+ "*": PermissionReadWrite,
+ }
+ empty := views.SliceOfViews([]*Share{})
+ filtered := FilterPermissionsByIdentity(perms, "joe@example.com", nil, empty)
+ if filtered.For("anything") != PermissionReadWrite {
+ t.Error("wildcard should pass through with empty shares")
+ }
+ })
+
+ t.Run("name(domain) share access", func(t *testing.T) {
+ domainShares := views.SliceOfViews([]*Share{
+ {Name: "alice(contractor)+bob"},
+ {Name: "docs"},
+ })
+ perms := Permissions{
+ "*": PermissionReadWrite,
+ }
+ // alice@contractor.io should access alice(contractor)+bob
+ filtered := FilterPermissionsByIdentity(perms, "alice@contractor.io", nil, domainShares)
+ if filtered.For("alice(contractor)+bob") != PermissionReadWrite {
+ t.Error("alice@contractor.io should access alice(contractor)+bob")
+ }
+ // alice@example.com should NOT access alice(contractor)+bob
+ filtered = FilterPermissionsByIdentity(perms, "alice@example.com", nil, domainShares)
+ if filtered.For("alice(contractor)+bob") != PermissionNone {
+ t.Error("alice@example.com should not access alice(contractor)+bob")
+ }
+ // bob@example.com should access alice(contractor)+bob
+ filtered = FilterPermissionsByIdentity(perms, "bob@example.com", nil, domainShares)
+ if filtered.For("alice(contractor)+bob") != PermissionReadWrite {
+ t.Error("bob@example.com should access alice(contractor)+bob")
+ }
+ })
+}
diff --git a/ipn/ipnlocal/peerapi_drive.go b/ipn/ipnlocal/peerapi_drive.go
index d42843577..193106c30 100644
--- a/ipn/ipnlocal/peerapi_drive.go
+++ b/ipn/ipnlocal/peerapi_drive.go
@@ -53,6 +53,9 @@ func handleServeDrive(hi PeerAPIHandler, w http.ResponseWriter, r *http.Request)
return
}
+ shares := h.ps.b.DriveGetShares()
+ p = drive.FilterPermissionsByIdentity(p, h.peerUser.LoginName, h.peerUser.Groups, shares)
+
fs, ok := h.ps.b.sys.DriveForRemote.GetOK()
if !ok {
h.logf("taildrive: not supported on platform")