diff options
Diffstat (limited to 'ssh/tailssh')
| -rw-r--r-- | ssh/tailssh/tailssh.go | 115 |
1 files changed, 61 insertions, 54 deletions
diff --git a/ssh/tailssh/tailssh.go b/ssh/tailssh/tailssh.go index 7d12ab45f..07b0aa57a 100644 --- a/ssh/tailssh/tailssh.go +++ b/ssh/tailssh/tailssh.go @@ -25,7 +25,6 @@ import ( "strconv" "strings" "sync" - "sync/atomic" "syscall" "time" @@ -804,11 +803,18 @@ func (ss *sshSession) killProcessOnContextDone() { // We don't need to Process.Wait here, sshSession.run() does // the waiting regardless of termination reason. - // TODO(maisem): should this be a SIGTERM followed by a SIGKILL? - ss.cmd.Process.Kill() + // Send SIGHUP like a real terminal disconnect would. + // The process may ignore it or exit cleanly. + ss.cmd.Process.Signal(syscall.SIGHUP) }) } +// isNotFoundOrExecutable reports whether err is an error indicating +// the command could not be found or executed. +func isNotFoundOrExecutable(err error) bool { + return errors.Is(err, exec.ErrNotFound) || errors.Is(err, os.ErrNotExist) +} + // attachSession registers ss as an active session. func (c *conn) attachSession(ss *sshSession) { c.srv.sessionWaitGroup.Add(1) @@ -894,10 +900,11 @@ func (ss *sshSession) run() { metricActiveSessions.Add(1) defer metricActiveSessions.Add(-1) defer ss.cancelCtx(errSessionDone) + defer ss.Close() if attached := ss.conn.srv.attachSessionToConnIfNotShutdown(ss); !attached { fmt.Fprintf(ss, "Tailscale SSH is shutting down\r\n") - ss.Exit(1) + ss.Exit(255) return } defer ss.conn.detachSession(ss) @@ -919,7 +926,10 @@ func (ss *sshSession) run() { if lu.Uid != fmt.Sprint(euid) { ss.logf("can't switch to user %q from process euid %v", lu.Username, euid) fmt.Fprintf(ss, "can't switch user\r\n") - ss.Exit(1) + // Exit code 255 indicates SSH protocol/permission error. + // This matches OpenSSH behavior for fatal errors that prevent + // the session from starting. + ss.Exit(255) return } } @@ -948,7 +958,9 @@ func (ss *sshSession) run() { fmt.Fprintf(ss, "can't start new recording\r\n") } ss.logf("startNewRecording: %v", err) - ss.Exit(1) + // Exit code 254 for recording infrastructure failure. + // Distinct from 255 (SSH protocol error) and 1 (general command failure). + ss.Exit(254) return } ss.logf("startNewRecording: <nil>") @@ -961,95 +973,90 @@ func (ss *sshSession) run() { err := ss.launchProcess() if err != nil { logf("start failed: %v", err.Error()) + exitCode := 1 if errors.Is(err, context.Canceled) { err := context.Cause(ss.ctx) var uve userVisibleError if errors.As(err, &uve) { fmt.Fprintf(ss, "%s\r\n", uve) } + } else if isNotFoundOrExecutable(err) { + // Use exit code 127 for "command not found" per shell convention. + // This matches standard SSH behavior. + exitCode = 127 } - ss.Exit(1) + ss.Exit(exitCode) return } go ss.killProcessOnContextDone() - var processDone atomic.Bool + // Start goroutines to copy stdin/stdout/stderr. + var wg sync.WaitGroup + + wg.Add(1) go func() { + defer wg.Done() defer ss.wrStdin.Close() if _, err := io.Copy(rec.writer("i", ss.wrStdin), ss); err != nil { logf("stdin copy: %v", err) ss.cancelCtx(err) } }() - outputDone := make(chan struct{}) - var openOutputStreams atomic.Int32 - if ss.rdStderr != nil { - openOutputStreams.Store(2) - } else { - openOutputStreams.Store(1) - } + + wg.Add(1) go func() { + defer wg.Done() defer ss.rdStdout.Close() - _, err := io.Copy(rec.writer("o", ss), ss.rdStdout) - if err != nil && !errors.Is(err, io.EOF) { - isErrBecauseProcessExited := processDone.Load() && errors.Is(err, syscall.EIO) - if !isErrBecauseProcessExited { - logf("stdout copy: %v", err) - ss.cancelCtx(err) - } - } - if openOutputStreams.Add(-1) == 0 { - ss.CloseWrite() - close(outputDone) + if _, err := io.Copy(rec.writer("o", ss), ss.rdStdout); err != nil && !errors.Is(err, io.EOF) { + logf("stdout copy: %v", err) } + // Send EOF as soon as stdout copying completes. This allows sibling + // processes waiting for EOF to proceed, even if the main process hasn't + // exited yet. The channel remains open for sending exit-status later. + ss.CloseWrite() }() - // rdStderr is nil for ptys. + if ss.rdStderr != nil { + wg.Add(1) go func() { + defer wg.Done() defer ss.rdStderr.Close() - _, err := io.Copy(ss.Stderr(), ss.rdStderr) - if err != nil { + if _, err := io.Copy(ss.Stderr(), ss.rdStderr); err != nil { logf("stderr copy: %v", err) } - if openOutputStreams.Add(-1) == 0 { - ss.CloseWrite() - close(outputDone) - } }() } err = ss.cmd.Wait() - processDone.Store(true) // This will either make the SSH Termination goroutine be a no-op, // or itself will be a no-op because the process was killed by the // aforementioned goroutine. ss.exitOnce.Do(func() {}) - // Close the process-side of all pipes to signal the asynchronous - // io.Copy routines reading/writing from the pipes to terminate. - // Block for the io.Copy to finish before calling ss.Exit below. - closeAll(ss.childPipes...) - select { - case <-outputDone: - case <-ss.ctx.Done(): - } - + var exitCode int if err == nil { ss.logf("Session complete") - ss.Exit(0) - return - } - if ee, ok := err.(*exec.ExitError); ok { - code := ee.ProcessState.ExitCode() - ss.logf("Wait: code=%v", code) - ss.Exit(code) - return + exitCode = 0 + } else if ee, ok := err.(*exec.ExitError); ok { + exitCode = ee.ProcessState.ExitCode() + ss.logf("Wait: code=%v", exitCode) + } else { + ss.logf("Wait: %v", err) + exitCode = 1 } - ss.logf("Wait: %v", err) - ss.Exit(1) - return + // Send exit-status immediately. Per RFC 4254 section 6.10, exit-status + // should be sent before channel close. EOF will be sent by the stdout + // goroutine when it finishes, and that's fine - CloseWrite only closes + // the data stream but keeps the channel open for exit-status. + ss.Exit(exitCode) + + // Close process-side of pipes to signal io.Copy goroutines to finish. + closeAll(ss.childPipes...) + + // Wait for all IO to complete. + wg.Wait() } // recordSSHToLocalDisk is a deprecated dev knob to allow recording SSH sessions |
