summaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
-rw-r--r--ssh/tailssh/tailssh.go54
1 files changed, 54 insertions, 0 deletions
diff --git a/ssh/tailssh/tailssh.go b/ssh/tailssh/tailssh.go
index 7d12ab45f..40f376da9 100644
--- a/ssh/tailssh/tailssh.go
+++ b/ssh/tailssh/tailssh.go
@@ -32,6 +32,7 @@ import (
gossh "golang.org/x/crypto/ssh"
"tailscale.com/envknob"
"tailscale.com/ipn/ipnlocal"
+ "tailscale.com/ipn/ipnstate"
"tailscale.com/net/tsaddr"
"tailscale.com/net/tsdial"
"tailscale.com/sessionrecording"
@@ -76,6 +77,7 @@ type ipnLocalBackend interface {
Dialer() *tsdial.Dialer
TailscaleVarRoot() string
NodeKey() key.NodePublic
+ Ping(ctx context.Context, ip netip.Addr, pingType tailcfg.PingType, size int) (*ipnstate.PingResult, error)
}
type server struct {
@@ -834,6 +836,7 @@ func (c *conn) detachSession(ss *sshSession) {
}
var errSessionDone = errors.New("session is done")
+var errClientUnreachable = errors.New("client is unreachable")
// handleSSHAgentForwarding starts a Unix socket listener and in the background
// forwards agent connections between the listener and the ssh.Session.
@@ -954,6 +957,57 @@ func (ss *sshSession) run() {
ss.logf("startNewRecording: <nil>")
if rec != nil {
defer rec.Close()
+
+ ping := func() bool {
+ clientIP := ss.conn.info.src.Addr()
+
+ ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
+ defer cancel()
+
+ _, err := ss.conn.srv.lb.Ping(ctx, clientIP, tailcfg.PingICMP, 0)
+ if err != nil {
+ ss.logf("pinging SSH client %s failed: %v", clientIP, err)
+ return false
+ }
+
+ ss.logf("pinging SSH client %s successful", clientIP)
+ return true
+ }
+
+ go func() {
+ ss.logf("starting connection monitor for session %s", ss.sharedID)
+ ticker := time.NewTicker(15 * time.Second)
+ defer ticker.Stop()
+
+ consecutiveFailures := 0
+ const maxFailures = 3
+
+ for {
+ select {
+ case <-ss.ctx.Done():
+ ss.logf("session terminated, closing recording: %v", context.Cause(ss.ctx))
+ rec.Close()
+ return
+
+ case <-ticker.C:
+ pong := ping()
+ if pong {
+ consecutiveFailures = 0
+ ss.logf("connection test passed for session %s", ss.sharedID)
+ } else {
+ consecutiveFailures++
+ ss.logf("connection test failed (%d/%d) for session %s", consecutiveFailures, maxFailures, ss.sharedID)
+
+ if consecutiveFailures >= maxFailures {
+ ss.logf("connection lost (connection test failed %d times), closing recording", maxFailures)
+ ss.cancelCtx(errClientUnreachable)
+ rec.Close()
+ return
+ }
+ }
+ }
+ }
+ }()
}
}
}