diff options
Diffstat (limited to 'ssh/tailssh/tailssh.go')
| -rw-r--r-- | ssh/tailssh/tailssh.go | 64 |
1 files changed, 23 insertions, 41 deletions
diff --git a/ssh/tailssh/tailssh.go b/ssh/tailssh/tailssh.go index d8dea7da2..b1df339f3 100644 --- a/ssh/tailssh/tailssh.go +++ b/ssh/tailssh/tailssh.go @@ -193,8 +193,8 @@ func (srv *server) OnPolicyChange() { defer srv.mu.Unlock() for c := range srv.activeConns { ci, lu := c.getInfoAndLocalUser() - if ci == nil || lu == nil { - // c.info or c.localUser are nil when the connection hasn't been + if !ci.isSet || lu.Username == "" { + // c.info or c.localUser are empty when the connection hasn't been // authenticated yet. We will continue here, but the connection will // be checked once it is authenticated. If it no longer conforms // with the SSH access policy at that point, it will be terminated. @@ -250,9 +250,9 @@ type conn struct { // srv.mu should be acquired prior to mu. // It is safe to just acquire mu, but unsafe to // acquire mu and then srv.mu. - mu sync.Mutex // protects the following - info *sshConnInfo // set by setInfo - localUser *userMeta // set by clientAuth + mu sync.Mutex // protects the following + info sshConnInfo // set by setInfo + localUser userMeta // set by clientAuth sessions []*sshSession } @@ -267,19 +267,19 @@ func (c *conn) vlogf(format string, args ...any) { } } -func (c *conn) getInfo() *sshConnInfo { +func (c *conn) getInfo() sshConnInfo { c.mu.Lock() defer c.mu.Unlock() return c.info } -func (c *conn) getLocalUser() *userMeta { +func (c *conn) getLocalUser() userMeta { c.mu.Lock() defer c.mu.Unlock() return c.localUser } -func (c *conn) getInfoAndLocalUser() (*sshConnInfo, *userMeta) { +func (c *conn) getInfoAndLocalUser() (sshConnInfo, userMeta) { c.mu.Lock() defer c.mu.Unlock() return c.info, c.localUser @@ -356,11 +356,7 @@ func (c *conn) clientAuth(cm gossh.ConnMetadata) (perms *gossh.Permissions, retE // do nothing case rejectedUser: ci := c.getInfo() - if ci != nil { - return nil, c.errBanner(fmt.Sprintf("tailnet policy does not permit you to SSH as user %q", ci.sshUser), nil) - } else { - return nil, c.errBanner(fmt.Sprintf("tailnet policy does not permit you to SSH"), nil) - } + return nil, c.errBanner(fmt.Sprintf("tailnet policy does not permit you to SSH as user %q", ci.sshUser), nil) case rejected, noPolicy: return nil, c.errBanner("tailnet policy does not permit you to SSH to this node", fmt.Errorf("failed to evaluate policy, result: %s", result)) default: @@ -610,7 +606,12 @@ func toIPPort(a net.Addr) (ipp netip.AddrPort) { // connInfo populates the sshConnInfo from the provided arguments, // validating only that they represent a known Tailscale identity. func (c *conn) setInfo(cm gossh.ConnMetadata) error { - ci := &sshConnInfo{ + c.mu.Lock() + defer c.mu.Unlock() + if c.info.isSet { + return nil + } + ci := sshConnInfo{ sshUser: strings.TrimSuffix(cm.User(), forcePasswordSuffix), src: toIPPort(cm.RemoteAddr()), dst: toIPPort(cm.LocalAddr()), @@ -627,14 +628,9 @@ func (c *conn) setInfo(cm gossh.ConnMetadata) error { } ci.node = node ci.uprof = uprof - - c.mu.Lock() - defer c.mu.Unlock() - if c.info != nil { - return nil - } c.idH = string(cm.SessionID()) c.info = ci + c.info.isSet = true c.logf("handling conn: %v", ci.String()) return nil } @@ -767,7 +763,7 @@ func (c *conn) isStillValid() bool { return false } lu := c.getLocalUser() - return lu != nil && lu.Username == localUser + return lu.Username == localUser } // checkStillValid checks that the conn is still valid per the latest SSHPolicy. @@ -842,11 +838,7 @@ func (ss *sshSession) killProcessOnContextDone() { } } ci := ss.conn.getInfo() - if ci != nil { - ss.logf("terminating SSH session from %v: %v", ci.src.Addr(), err) - } else { - ss.logf("terminating SSH session: %v", err) - } + ss.logf("terminating SSH session from %v: %v", ci.src.Addr(), err) // We don't need to Process.Wait here, sshSession.run() does // the waiting regardless of termination reason. @@ -884,7 +876,7 @@ var errSessionDone = errors.New("session is done") // handleSSHAgentForwarding starts a Unix socket listener and in the background // forwards agent connections between the listener and the ssh.Session. // On success, it assigns ss.agentListener. -func (ss *sshSession) handleSSHAgentForwarding(s ssh.Session, lu *userMeta) error { +func (ss *sshSession) handleSSHAgentForwarding(s ssh.Session, lu userMeta) error { if !ssh.AgentRequested(ss) || !ss.conn.finalAction.AllowAgentForwarding { return nil } @@ -1120,6 +1112,10 @@ func (ss *sshSession) shouldRecord() bool { } type sshConnInfo struct { + // isSet indicates whether the fields have been populated as part of + // authenticating the connection. + isSet bool + // sshUser is the requested local SSH username ("root", "alice", etc). sshUser string @@ -1182,10 +1178,6 @@ func (c *conn) matchRule(r *tailcfg.SSHRule) (a *tailcfg.SSHAction, localUser st return nil, "", nil, errInvalidConn } ci := c.getInfo() - if ci == nil { - c.logf("invalid connection state") - return nil, "", nil, errInvalidConn - } if r == nil { return nil, "", nil, errNilRule } @@ -1237,9 +1229,6 @@ func (c *conn) anyPrincipalMatches(ps []*tailcfg.SSHPrincipal) bool { // that match the Tailscale identity match (Node, NodeIP, UserLogin, Any). func (c *conn) principalMatchesTailscaleIdentity(p *tailcfg.SSHPrincipal) bool { ci := c.getInfo() - if ci == nil { - return false - } if p.Any { return true } @@ -1386,9 +1375,6 @@ func (ss *sshSession) startNewRecording() (_ *recording, err error) { } ci, lu := ss.conn.getInfoAndLocalUser() - if ci == nil || lu == nil { - return nil, errors.New("recording: missing connection metadata") - } ch := sessionrecording.CastHeader{ Version: 2, Width: w.Width, @@ -1440,10 +1426,6 @@ func (ss *sshSession) startNewRecording() (_ *recording, err error) { // an SSH session is a defined EventType. func (ss *sshSession) notifyControl(ctx context.Context, nodeKey key.NodePublic, notifyType tailcfg.SSHEventType, attempts []*tailcfg.SSHRecordingAttempt, url string) { ci, lu := ss.conn.getInfoAndLocalUser() - if ci == nil || lu == nil { - ss.logf("notifyControl: missing connection metadata") - return - } re := tailcfg.SSHEventNotifyRequest{ EventType: notifyType, ConnectionID: ss.conn.connID, |
