summaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGesa Stupperich <gesa@tailscale.com>2026-02-10 11:21:24 +0000
committerGesa Stupperich <gesa@tailscale.com>2026-02-11 21:26:59 +0000
commitc37f648ad27c00abcc0dc8958c8e9722ff9f9d10 (patch)
tree1683075c8ff6d384eef98fa504f649beff1b4acc
parentdb52827a83553122dadc36e4607e9b7a84abb2ec (diff)
downloadtailscale-gesa/ssh-crash-local-user.tar.xz
tailscale-gesa/ssh-crash-local-user.zip
ssh/tailssh: store c.info and c.localUser as valuesgesa/ssh-crash-local-user
This converts the info and localUser fields on the conn from pointers to values. I consider this an overall improvement since both structs are small and it makes access safer in cases when they've not yet been set. Updates tailscale/corp#36268 Signed-off-by: Gesa Stupperich <gesa@tailscale.com>
-rw-r--r--ssh/tailssh/auditd_linux.go2
-rw-r--r--ssh/tailssh/incubator.go2
-rw-r--r--ssh/tailssh/incubator_plan9.go2
-rw-r--r--ssh/tailssh/tailssh.go64
-rw-r--r--ssh/tailssh/tailssh_test.go167
-rw-r--r--ssh/tailssh/user.go8
6 files changed, 106 insertions, 139 deletions
diff --git a/ssh/tailssh/auditd_linux.go b/ssh/tailssh/auditd_linux.go
index bddb901d5..a0295ca37 100644
--- a/ssh/tailssh/auditd_linux.go
+++ b/ssh/tailssh/auditd_linux.go
@@ -123,7 +123,7 @@ func sendAuditMessage(logf logger.Logf, msgType uint16, message string) {
// logSSHLogin logs an SSH login event to auditd with whois information.
func logSSHLogin(logf logger.Logf, c *conn) {
- if c == nil || c.info == nil || c.localUser == nil {
+ if c == nil {
return
}
diff --git a/ssh/tailssh/incubator.go b/ssh/tailssh/incubator.go
index b414ce3fb..b31330177 100644
--- a/ssh/tailssh/incubator.go
+++ b/ssh/tailssh/incubator.go
@@ -1099,7 +1099,7 @@ func (ss *sshSession) startWithStdPipes() (err error) {
return ss.cmd.Start()
}
-func envForUser(u *userMeta) []string {
+func envForUser(u userMeta) []string {
return []string{
fmt.Sprintf("SHELL=%s", u.LoginShell()),
fmt.Sprintf("USER=%s", u.Username),
diff --git a/ssh/tailssh/incubator_plan9.go b/ssh/tailssh/incubator_plan9.go
index 69112635f..a9b9a1163 100644
--- a/ssh/tailssh/incubator_plan9.go
+++ b/ssh/tailssh/incubator_plan9.go
@@ -400,7 +400,7 @@ func (ss *sshSession) startWithStdPipes() (err error) {
return ss.cmd.Start()
}
-func envForUser(u *userMeta) []string {
+func envForUser(u userMeta) []string {
return []string{
fmt.Sprintf("user=%s", u.Username),
fmt.Sprintf("home=%s", u.HomeDir),
diff --git a/ssh/tailssh/tailssh.go b/ssh/tailssh/tailssh.go
index d8dea7da2..b1df339f3 100644
--- a/ssh/tailssh/tailssh.go
+++ b/ssh/tailssh/tailssh.go
@@ -193,8 +193,8 @@ func (srv *server) OnPolicyChange() {
defer srv.mu.Unlock()
for c := range srv.activeConns {
ci, lu := c.getInfoAndLocalUser()
- if ci == nil || lu == nil {
- // c.info or c.localUser are nil when the connection hasn't been
+ if !ci.isSet || lu.Username == "" {
+ // c.info or c.localUser are empty when the connection hasn't been
// authenticated yet. We will continue here, but the connection will
// be checked once it is authenticated. If it no longer conforms
// with the SSH access policy at that point, it will be terminated.
@@ -250,9 +250,9 @@ type conn struct {
// srv.mu should be acquired prior to mu.
// It is safe to just acquire mu, but unsafe to
// acquire mu and then srv.mu.
- mu sync.Mutex // protects the following
- info *sshConnInfo // set by setInfo
- localUser *userMeta // set by clientAuth
+ mu sync.Mutex // protects the following
+ info sshConnInfo // set by setInfo
+ localUser userMeta // set by clientAuth
sessions []*sshSession
}
@@ -267,19 +267,19 @@ func (c *conn) vlogf(format string, args ...any) {
}
}
-func (c *conn) getInfo() *sshConnInfo {
+func (c *conn) getInfo() sshConnInfo {
c.mu.Lock()
defer c.mu.Unlock()
return c.info
}
-func (c *conn) getLocalUser() *userMeta {
+func (c *conn) getLocalUser() userMeta {
c.mu.Lock()
defer c.mu.Unlock()
return c.localUser
}
-func (c *conn) getInfoAndLocalUser() (*sshConnInfo, *userMeta) {
+func (c *conn) getInfoAndLocalUser() (sshConnInfo, userMeta) {
c.mu.Lock()
defer c.mu.Unlock()
return c.info, c.localUser
@@ -356,11 +356,7 @@ func (c *conn) clientAuth(cm gossh.ConnMetadata) (perms *gossh.Permissions, retE
// do nothing
case rejectedUser:
ci := c.getInfo()
- if ci != nil {
- return nil, c.errBanner(fmt.Sprintf("tailnet policy does not permit you to SSH as user %q", ci.sshUser), nil)
- } else {
- return nil, c.errBanner(fmt.Sprintf("tailnet policy does not permit you to SSH"), nil)
- }
+ return nil, c.errBanner(fmt.Sprintf("tailnet policy does not permit you to SSH as user %q", ci.sshUser), nil)
case rejected, noPolicy:
return nil, c.errBanner("tailnet policy does not permit you to SSH to this node", fmt.Errorf("failed to evaluate policy, result: %s", result))
default:
@@ -610,7 +606,12 @@ func toIPPort(a net.Addr) (ipp netip.AddrPort) {
// connInfo populates the sshConnInfo from the provided arguments,
// validating only that they represent a known Tailscale identity.
func (c *conn) setInfo(cm gossh.ConnMetadata) error {
- ci := &sshConnInfo{
+ c.mu.Lock()
+ defer c.mu.Unlock()
+ if c.info.isSet {
+ return nil
+ }
+ ci := sshConnInfo{
sshUser: strings.TrimSuffix(cm.User(), forcePasswordSuffix),
src: toIPPort(cm.RemoteAddr()),
dst: toIPPort(cm.LocalAddr()),
@@ -627,14 +628,9 @@ func (c *conn) setInfo(cm gossh.ConnMetadata) error {
}
ci.node = node
ci.uprof = uprof
-
- c.mu.Lock()
- defer c.mu.Unlock()
- if c.info != nil {
- return nil
- }
c.idH = string(cm.SessionID())
c.info = ci
+ c.info.isSet = true
c.logf("handling conn: %v", ci.String())
return nil
}
@@ -767,7 +763,7 @@ func (c *conn) isStillValid() bool {
return false
}
lu := c.getLocalUser()
- return lu != nil && lu.Username == localUser
+ return lu.Username == localUser
}
// checkStillValid checks that the conn is still valid per the latest SSHPolicy.
@@ -842,11 +838,7 @@ func (ss *sshSession) killProcessOnContextDone() {
}
}
ci := ss.conn.getInfo()
- if ci != nil {
- ss.logf("terminating SSH session from %v: %v", ci.src.Addr(), err)
- } else {
- ss.logf("terminating SSH session: %v", err)
- }
+ ss.logf("terminating SSH session from %v: %v", ci.src.Addr(), err)
// We don't need to Process.Wait here, sshSession.run() does
// the waiting regardless of termination reason.
@@ -884,7 +876,7 @@ var errSessionDone = errors.New("session is done")
// handleSSHAgentForwarding starts a Unix socket listener and in the background
// forwards agent connections between the listener and the ssh.Session.
// On success, it assigns ss.agentListener.
-func (ss *sshSession) handleSSHAgentForwarding(s ssh.Session, lu *userMeta) error {
+func (ss *sshSession) handleSSHAgentForwarding(s ssh.Session, lu userMeta) error {
if !ssh.AgentRequested(ss) || !ss.conn.finalAction.AllowAgentForwarding {
return nil
}
@@ -1120,6 +1112,10 @@ func (ss *sshSession) shouldRecord() bool {
}
type sshConnInfo struct {
+ // isSet indicates whether the fields have been populated as part of
+ // authenticating the connection.
+ isSet bool
+
// sshUser is the requested local SSH username ("root", "alice", etc).
sshUser string
@@ -1182,10 +1178,6 @@ func (c *conn) matchRule(r *tailcfg.SSHRule) (a *tailcfg.SSHAction, localUser st
return nil, "", nil, errInvalidConn
}
ci := c.getInfo()
- if ci == nil {
- c.logf("invalid connection state")
- return nil, "", nil, errInvalidConn
- }
if r == nil {
return nil, "", nil, errNilRule
}
@@ -1237,9 +1229,6 @@ func (c *conn) anyPrincipalMatches(ps []*tailcfg.SSHPrincipal) bool {
// that match the Tailscale identity match (Node, NodeIP, UserLogin, Any).
func (c *conn) principalMatchesTailscaleIdentity(p *tailcfg.SSHPrincipal) bool {
ci := c.getInfo()
- if ci == nil {
- return false
- }
if p.Any {
return true
}
@@ -1386,9 +1375,6 @@ func (ss *sshSession) startNewRecording() (_ *recording, err error) {
}
ci, lu := ss.conn.getInfoAndLocalUser()
- if ci == nil || lu == nil {
- return nil, errors.New("recording: missing connection metadata")
- }
ch := sessionrecording.CastHeader{
Version: 2,
Width: w.Width,
@@ -1440,10 +1426,6 @@ func (ss *sshSession) startNewRecording() (_ *recording, err error) {
// an SSH session is a defined EventType.
func (ss *sshSession) notifyControl(ctx context.Context, nodeKey key.NodePublic, notifyType tailcfg.SSHEventType, attempts []*tailcfg.SSHRecordingAttempt, url string) {
ci, lu := ss.conn.getInfoAndLocalUser()
- if ci == nil || lu == nil {
- ss.logf("notifyControl: missing connection metadata")
- return
- }
re := tailcfg.SSHEventNotifyRequest{
EventType: notifyType,
ConnectionID: ss.conn.connID,
diff --git a/ssh/tailssh/tailssh_test.go b/ssh/tailssh/tailssh_test.go
index 581c8be82..c750370a2 100644
--- a/ssh/tailssh/tailssh_test.go
+++ b/ssh/tailssh/tailssh_test.go
@@ -64,31 +64,20 @@ func TestMatchRule(t *testing.T) {
tests := []struct {
name string
rule *tailcfg.SSHRule
- ci *sshConnInfo
+ ci sshConnInfo
wantErr error
wantUser string
wantAcceptEnv []string
}{
{
- name: "invalid-conn",
- rule: &tailcfg.SSHRule{
- Action: someAction,
- Principals: []*tailcfg.SSHPrincipal{{Any: true}},
- SSHUsers: map[string]string{
- "*": "ubuntu",
- },
- },
- wantErr: errInvalidConn,
- },
- {
name: "nil-rule",
- ci: &sshConnInfo{},
+ ci: sshConnInfo{},
rule: nil,
wantErr: errNilRule,
},
{
name: "nil-action",
- ci: &sshConnInfo{},
+ ci: sshConnInfo{},
rule: &tailcfg.SSHRule{},
wantErr: errNilAction,
},
@@ -98,7 +87,7 @@ func TestMatchRule(t *testing.T) {
Action: someAction,
RuleExpires: ptr.To(time.Unix(100, 0)),
},
- ci: &sshConnInfo{},
+ ci: sshConnInfo{},
wantErr: errRuleExpired,
},
{
@@ -108,7 +97,7 @@ func TestMatchRule(t *testing.T) {
SSHUsers: map[string]string{
"*": "ubuntu",
}},
- ci: &sshConnInfo{},
+ ci: sshConnInfo{},
wantErr: errPrincipalMatch,
},
{
@@ -117,7 +106,7 @@ func TestMatchRule(t *testing.T) {
Action: someAction,
Principals: []*tailcfg.SSHPrincipal{{Any: true}},
},
- ci: &sshConnInfo{sshUser: "alice"},
+ ci: sshConnInfo{sshUser: "alice"},
wantErr: errUserMatch,
},
{
@@ -129,7 +118,7 @@ func TestMatchRule(t *testing.T) {
"*": "ubuntu",
},
},
- ci: &sshConnInfo{sshUser: "alice"},
+ ci: sshConnInfo{sshUser: "alice"},
wantUser: "ubuntu",
},
{
@@ -144,7 +133,7 @@ func TestMatchRule(t *testing.T) {
"*": "ubuntu",
},
},
- ci: &sshConnInfo{sshUser: "alice"},
+ ci: sshConnInfo{sshUser: "alice"},
wantUser: "ubuntu",
},
{
@@ -157,7 +146,7 @@ func TestMatchRule(t *testing.T) {
"alice": "thealice",
},
},
- ci: &sshConnInfo{sshUser: "alice"},
+ ci: sshConnInfo{sshUser: "alice"},
wantUser: "thealice",
},
{
@@ -171,7 +160,7 @@ func TestMatchRule(t *testing.T) {
},
AcceptEnv: []string{"EXAMPLE", "?_?", "TEST_*"},
},
- ci: &sshConnInfo{sshUser: "alice"},
+ ci: sshConnInfo{sshUser: "alice"},
wantUser: "thealice",
wantAcceptEnv: []string{"EXAMPLE", "?_?", "TEST_*"},
},
@@ -181,7 +170,7 @@ func TestMatchRule(t *testing.T) {
Principals: []*tailcfg.SSHPrincipal{{Any: true}},
Action: &tailcfg.SSHAction{Reject: true},
},
- ci: &sshConnInfo{sshUser: "alice"},
+ ci: sshConnInfo{sshUser: "alice"},
},
{
name: "match-principal-node-ip",
@@ -190,7 +179,7 @@ func TestMatchRule(t *testing.T) {
Principals: []*tailcfg.SSHPrincipal{{NodeIP: "1.2.3.4"}},
SSHUsers: map[string]string{"*": "ubuntu"},
},
- ci: &sshConnInfo{src: netip.MustParseAddrPort("1.2.3.4:30343")},
+ ci: sshConnInfo{src: netip.MustParseAddrPort("1.2.3.4:30343")},
wantUser: "ubuntu",
},
{
@@ -200,7 +189,7 @@ func TestMatchRule(t *testing.T) {
Principals: []*tailcfg.SSHPrincipal{{Node: "some-node-ID"}},
SSHUsers: map[string]string{"*": "ubuntu"},
},
- ci: &sshConnInfo{node: (&tailcfg.Node{StableID: "some-node-ID"}).View()},
+ ci: sshConnInfo{node: (&tailcfg.Node{StableID: "some-node-ID"}).View()},
wantUser: "ubuntu",
},
{
@@ -210,7 +199,7 @@ func TestMatchRule(t *testing.T) {
Principals: []*tailcfg.SSHPrincipal{{UserLogin: "foo@bar.com"}},
SSHUsers: map[string]string{"*": "ubuntu"},
},
- ci: &sshConnInfo{uprof: tailcfg.UserProfile{LoginName: "foo@bar.com"}},
+ ci: sshConnInfo{uprof: tailcfg.UserProfile{LoginName: "foo@bar.com"}},
wantUser: "ubuntu",
},
{
@@ -222,7 +211,7 @@ func TestMatchRule(t *testing.T) {
"*": "=",
},
},
- ci: &sshConnInfo{sshUser: "alice"},
+ ci: sshConnInfo{sshUser: "alice"},
wantUser: "alice",
},
}
@@ -254,7 +243,7 @@ func TestEvalSSHPolicy(t *testing.T) {
tests := []struct {
name string
policy *tailcfg.SSHPolicy
- ci *sshConnInfo
+ ci sshConnInfo
wantResult evalResult
wantUser string
wantAcceptEnv []string
@@ -298,7 +287,7 @@ func TestEvalSSHPolicy(t *testing.T) {
},
},
},
- ci: &sshConnInfo{sshUser: "alice"},
+ ci: sshConnInfo{sshUser: "alice"},
wantUser: "thealice",
wantAcceptEnv: []string{"EXAMPLE", "?_?", "TEST_*"},
wantResult: accepted,
@@ -308,7 +297,7 @@ func TestEvalSSHPolicy(t *testing.T) {
policy: &tailcfg.SSHPolicy{
Rules: []*tailcfg.SSHRule{},
},
- ci: &sshConnInfo{sshUser: "alice"},
+ ci: sshConnInfo{sshUser: "alice"},
wantUser: "",
wantAcceptEnv: nil,
wantResult: rejected,
@@ -349,7 +338,7 @@ func TestEvalSSHPolicy(t *testing.T) {
},
},
},
- ci: &sshConnInfo{sshUser: "alice"},
+ ci: sshConnInfo{sshUser: "alice"},
wantUser: "",
wantAcceptEnv: nil,
wantResult: rejectedUser,
@@ -1100,7 +1089,7 @@ func TestSSH(t *testing.T) {
t.Fatal(err)
}
sc.localUser = um
- sc.info = &sshConnInfo{
+ sc.info = sshConnInfo{
sshUser: "test",
src: netip.MustParseAddrPort("1.2.3.4:32342"),
dst: netip.MustParseAddrPort("1.2.3.5:22"),
@@ -1318,76 +1307,72 @@ func TestStdOsUserUserAssumptions(t *testing.T) {
}
}
-func TestOnPolicyChangeHandlesNilLocalUser(t *testing.T) {
- synctest.Test(t, func(t *testing.T) {
- srv := &server{
- logf: tstest.WhileTestRunningLogger(t),
- lb: &localState{
- sshEnabled: true,
- matchingRule: newSSHRule(&tailcfg.SSHAction{Accept: true}),
- },
- }
- c := &conn{
- srv: srv,
- info: &sshConnInfo{sshUser: "alice"},
- }
- srv.activeConns = map[*conn]bool{c: true}
-
- srv.OnPolicyChange()
-
- synctest.Wait()
- })
-}
-
-func TestRaceWriteAndReadConnInfoAndLocalUser(t *testing.T) {
- synctest.Test(t, func(t *testing.T) {
- srv := &server{
- logf: tstest.WhileTestRunningLogger(t),
- lb: &localState{
- sshEnabled: true,
- matchingRule: newSSHRule(&tailcfg.SSHAction{Accept: true}),
- },
- }
- c := &conn{
- srv: srv,
- info: &sshConnInfo{sshUser: "alice"},
- }
- srv.activeConns = map[*conn]bool{c: true}
+func TestOnPolicyChangeDefersValidationOnEmptyLocalUser(t *testing.T) {
+ tests := []struct {
+ name string
+ sshRule *tailcfg.SSHRule
+ wantCancelOnValidation bool
+ }{
+ {
+ name: "defer-then-accept-when-allowed",
+ sshRule: newSSHRule(&tailcfg.SSHAction{Accept: true}),
+ wantCancelOnValidation: false,
+ },
+ {
+ name: "defer-then-reject-when-not-allowed",
+ sshRule: newSSHRule(&tailcfg.SSHAction{Reject: true}),
+ wantCancelOnValidation: true,
+ },
+ }
- fakeClientAuth := func() {
- c.mu.Lock()
- c.info = &sshConnInfo{sshUser: "alice"}
- c.mu.Unlock()
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
- c.mu.Lock()
- c.localUser = &userMeta{User: user.User{Username: currentUser}}
- c.mu.Unlock()
- }
+ synctest.Test(t, func(t *testing.T) {
+ srv := &server{
+ logf: tstest.WhileTestRunningLogger(t),
+ lb: &localState{
+ sshEnabled: true,
+ matchingRule: tt.sshRule,
+ },
+ }
+ c := &conn{
+ srv: srv,
+ info: sshConnInfo{sshUser: "alice"},
+ }
+ srv.activeConns = map[*conn]bool{c: true}
+ ctx, cancel := context.WithCancelCause(context.Background())
+ ss := &sshSession{ctx: ctx, cancelCtx: cancel}
+ c.sessions = []*sshSession{ss}
- // Simulate a race between clientAuth() writing and OnPolicyChange reading a connection's info and localUser.
- done := make(chan struct{})
- go func() {
- for i := 0; i < 100; i++ {
+ srv.OnPolicyChange()
+ synctest.Wait()
select {
- case <-done:
- return
+ case <-ctx.Done():
+ t.Fatal("expected deferral of cancellation decision while localUser unset but session got canceled")
default:
- fakeClientAuth()
}
- }
- }()
- go func() {
- for i := 0; i < 100; i++ {
+ c.mu.Lock()
+ c.info.isSet = true
+ c.localUser = userMeta{User: user.User{Username: currentUser}}
+ c.mu.Unlock()
+
+ srv.OnPolicyChange()
+ synctest.Wait()
select {
- case <-done:
- return
+ case <-ctx.Done():
+ if !tt.wantCancelOnValidation {
+ t.Fatal("valid session shouldn't have been canceled")
+ }
default:
- srv.OnPolicyChange()
+ if tt.wantCancelOnValidation {
+ t.Fatal("invalid session should have been canceled but it wasn't")
+ }
}
- }
- }()
- })
+ })
+ })
+ }
}
func mockRecordingServer(t *testing.T, handleRecord http.HandlerFunc) *httptest.Server {
diff --git a/ssh/tailssh/user.go b/ssh/tailssh/user.go
index 7da6bb4eb..c84c93821 100644
--- a/ssh/tailssh/user.go
+++ b/ssh/tailssh/user.go
@@ -36,15 +36,15 @@ func (u *userMeta) GroupIds() ([]string, error) {
return osuser.GetGroupIds(&u.User)
}
-// userLookup is like os/user.Lookup but it returns a *userMeta wrapper
+// userLookup is like os/user.Lookup but it returns a userMeta wrapper
// around a *user.User with extra fields.
-func userLookup(username string) (*userMeta, error) {
+func userLookup(username string) (userMeta, error) {
u, s, err := osuser.LookupByUsernameWithShell(username)
if err != nil {
- return nil, err
+ return userMeta{}, err
}
- return &userMeta{User: *u, loginShellCached: s}, nil
+ return userMeta{User: *u, loginShellCached: s}, nil
}
func (u *userMeta) LoginShell() string {