summaryrefslogtreecommitdiffhomepage
path: root/ssh/tailssh/tailssh.go
diff options
context:
space:
mode:
Diffstat (limited to 'ssh/tailssh/tailssh.go')
-rw-r--r--ssh/tailssh/tailssh.go64
1 files changed, 23 insertions, 41 deletions
diff --git a/ssh/tailssh/tailssh.go b/ssh/tailssh/tailssh.go
index d8dea7da2..b1df339f3 100644
--- a/ssh/tailssh/tailssh.go
+++ b/ssh/tailssh/tailssh.go
@@ -193,8 +193,8 @@ func (srv *server) OnPolicyChange() {
defer srv.mu.Unlock()
for c := range srv.activeConns {
ci, lu := c.getInfoAndLocalUser()
- if ci == nil || lu == nil {
- // c.info or c.localUser are nil when the connection hasn't been
+ if !ci.isSet || lu.Username == "" {
+ // c.info or c.localUser are empty when the connection hasn't been
// authenticated yet. We will continue here, but the connection will
// be checked once it is authenticated. If it no longer conforms
// with the SSH access policy at that point, it will be terminated.
@@ -250,9 +250,9 @@ type conn struct {
// srv.mu should be acquired prior to mu.
// It is safe to just acquire mu, but unsafe to
// acquire mu and then srv.mu.
- mu sync.Mutex // protects the following
- info *sshConnInfo // set by setInfo
- localUser *userMeta // set by clientAuth
+ mu sync.Mutex // protects the following
+ info sshConnInfo // set by setInfo
+ localUser userMeta // set by clientAuth
sessions []*sshSession
}
@@ -267,19 +267,19 @@ func (c *conn) vlogf(format string, args ...any) {
}
}
-func (c *conn) getInfo() *sshConnInfo {
+func (c *conn) getInfo() sshConnInfo {
c.mu.Lock()
defer c.mu.Unlock()
return c.info
}
-func (c *conn) getLocalUser() *userMeta {
+func (c *conn) getLocalUser() userMeta {
c.mu.Lock()
defer c.mu.Unlock()
return c.localUser
}
-func (c *conn) getInfoAndLocalUser() (*sshConnInfo, *userMeta) {
+func (c *conn) getInfoAndLocalUser() (sshConnInfo, userMeta) {
c.mu.Lock()
defer c.mu.Unlock()
return c.info, c.localUser
@@ -356,11 +356,7 @@ func (c *conn) clientAuth(cm gossh.ConnMetadata) (perms *gossh.Permissions, retE
// do nothing
case rejectedUser:
ci := c.getInfo()
- if ci != nil {
- return nil, c.errBanner(fmt.Sprintf("tailnet policy does not permit you to SSH as user %q", ci.sshUser), nil)
- } else {
- return nil, c.errBanner(fmt.Sprintf("tailnet policy does not permit you to SSH"), nil)
- }
+ return nil, c.errBanner(fmt.Sprintf("tailnet policy does not permit you to SSH as user %q", ci.sshUser), nil)
case rejected, noPolicy:
return nil, c.errBanner("tailnet policy does not permit you to SSH to this node", fmt.Errorf("failed to evaluate policy, result: %s", result))
default:
@@ -610,7 +606,12 @@ func toIPPort(a net.Addr) (ipp netip.AddrPort) {
// connInfo populates the sshConnInfo from the provided arguments,
// validating only that they represent a known Tailscale identity.
func (c *conn) setInfo(cm gossh.ConnMetadata) error {
- ci := &sshConnInfo{
+ c.mu.Lock()
+ defer c.mu.Unlock()
+ if c.info.isSet {
+ return nil
+ }
+ ci := sshConnInfo{
sshUser: strings.TrimSuffix(cm.User(), forcePasswordSuffix),
src: toIPPort(cm.RemoteAddr()),
dst: toIPPort(cm.LocalAddr()),
@@ -627,14 +628,9 @@ func (c *conn) setInfo(cm gossh.ConnMetadata) error {
}
ci.node = node
ci.uprof = uprof
-
- c.mu.Lock()
- defer c.mu.Unlock()
- if c.info != nil {
- return nil
- }
c.idH = string(cm.SessionID())
c.info = ci
+ c.info.isSet = true
c.logf("handling conn: %v", ci.String())
return nil
}
@@ -767,7 +763,7 @@ func (c *conn) isStillValid() bool {
return false
}
lu := c.getLocalUser()
- return lu != nil && lu.Username == localUser
+ return lu.Username == localUser
}
// checkStillValid checks that the conn is still valid per the latest SSHPolicy.
@@ -842,11 +838,7 @@ func (ss *sshSession) killProcessOnContextDone() {
}
}
ci := ss.conn.getInfo()
- if ci != nil {
- ss.logf("terminating SSH session from %v: %v", ci.src.Addr(), err)
- } else {
- ss.logf("terminating SSH session: %v", err)
- }
+ ss.logf("terminating SSH session from %v: %v", ci.src.Addr(), err)
// We don't need to Process.Wait here, sshSession.run() does
// the waiting regardless of termination reason.
@@ -884,7 +876,7 @@ var errSessionDone = errors.New("session is done")
// handleSSHAgentForwarding starts a Unix socket listener and in the background
// forwards agent connections between the listener and the ssh.Session.
// On success, it assigns ss.agentListener.
-func (ss *sshSession) handleSSHAgentForwarding(s ssh.Session, lu *userMeta) error {
+func (ss *sshSession) handleSSHAgentForwarding(s ssh.Session, lu userMeta) error {
if !ssh.AgentRequested(ss) || !ss.conn.finalAction.AllowAgentForwarding {
return nil
}
@@ -1120,6 +1112,10 @@ func (ss *sshSession) shouldRecord() bool {
}
type sshConnInfo struct {
+ // isSet indicates whether the fields have been populated as part of
+ // authenticating the connection.
+ isSet bool
+
// sshUser is the requested local SSH username ("root", "alice", etc).
sshUser string
@@ -1182,10 +1178,6 @@ func (c *conn) matchRule(r *tailcfg.SSHRule) (a *tailcfg.SSHAction, localUser st
return nil, "", nil, errInvalidConn
}
ci := c.getInfo()
- if ci == nil {
- c.logf("invalid connection state")
- return nil, "", nil, errInvalidConn
- }
if r == nil {
return nil, "", nil, errNilRule
}
@@ -1237,9 +1229,6 @@ func (c *conn) anyPrincipalMatches(ps []*tailcfg.SSHPrincipal) bool {
// that match the Tailscale identity match (Node, NodeIP, UserLogin, Any).
func (c *conn) principalMatchesTailscaleIdentity(p *tailcfg.SSHPrincipal) bool {
ci := c.getInfo()
- if ci == nil {
- return false
- }
if p.Any {
return true
}
@@ -1386,9 +1375,6 @@ func (ss *sshSession) startNewRecording() (_ *recording, err error) {
}
ci, lu := ss.conn.getInfoAndLocalUser()
- if ci == nil || lu == nil {
- return nil, errors.New("recording: missing connection metadata")
- }
ch := sessionrecording.CastHeader{
Version: 2,
Width: w.Width,
@@ -1440,10 +1426,6 @@ func (ss *sshSession) startNewRecording() (_ *recording, err error) {
// an SSH session is a defined EventType.
func (ss *sshSession) notifyControl(ctx context.Context, nodeKey key.NodePublic, notifyType tailcfg.SSHEventType, attempts []*tailcfg.SSHRecordingAttempt, url string) {
ci, lu := ss.conn.getInfoAndLocalUser()
- if ci == nil || lu == nil {
- ss.logf("notifyControl: missing connection metadata")
- return
- }
re := tailcfg.SSHEventNotifyRequest{
EventType: notifyType,
ConnectionID: ss.conn.connID,