diff options
Diffstat (limited to 'ssh/tailssh/tailssh.go')
| -rw-r--r-- | ssh/tailssh/tailssh.go | 82 |
1 files changed, 80 insertions, 2 deletions
diff --git a/ssh/tailssh/tailssh.go b/ssh/tailssh/tailssh.go index f4167ffbe..dcbc62334 100644 --- a/ssh/tailssh/tailssh.go +++ b/ssh/tailssh/tailssh.go @@ -15,6 +15,7 @@ import ( "errors" "fmt" "io" + "log" "net" "net/http" "net/netip" @@ -35,6 +36,8 @@ import ( "tailscale.com/ipn/ipnlocal" "tailscale.com/logtail/backoff" "tailscale.com/net/tsaddr" + "tailscale.com/net/tsdial" + "tailscale.com/ssh/haulproto" "tailscale.com/tailcfg" "tailscale.com/tempfork/gliderlabs/ssh" "tailscale.com/types/logger" @@ -63,6 +66,8 @@ type ipnLocalBackend interface { WhoIs(ipp netip.AddrPort) (n *tailcfg.Node, u tailcfg.UserProfile, ok bool) DoNoiseRequest(req *http.Request) (*http.Response, error) TailscaleVarRoot() string + Dialer() *tsdial.Dialer + PeerAPIBase(tailcfg.StableNodeID) (string, error) } type server struct { @@ -1348,6 +1353,21 @@ func (ss *sshSession) startNewRecording() (_ *recording, err error) { return nil, err } rec.out = f + if ss.conn.finalAction.SessionHaulTargetNode != "" { + rec.peerBase, err = ss.conn.srv.lb.PeerAPIBase(ss.conn.finalAction.SessionHaulTargetNode) + if err != nil { + return nil, err + } + src, err := os.Open(f.Name()) + if err != nil { + return nil, err + } + // TODO(skriptble): This is for debugging, switch back to a regular loger + // after. + lggr := logger.WithPrefix(log.Printf, "ssh-session("+ss.sharedID+"): ") + rec.lpc = haulproto.NewClient(lggr, src) + go rec.startHauling() + } // {"version": 2, "width": 221, "height": 84, "timestamp": 1647146075, "env": {"SHELL": "/bin/bash", "TERM": "screen"}} type CastHeader struct { @@ -1389,8 +1409,10 @@ func (ss *sshSession) startNewRecording() (_ *recording, err error) { // recording is the state for an SSH session recording. type recording struct { - ss *sshSession - start time.Time + ss *sshSession + lpc *haulproto.Client + peerBase string + start time.Time mu sync.Mutex // guards writes to, close of out out *os.File // nil if closed @@ -1404,9 +1426,64 @@ func (r *recording) Close() error { } err := r.out.Close() r.out = nil + r.lpc.Notify() // attempt to clear out any remaining log lines before closing + r.lpc.Close() return err } +func (r *recording) startHauling() { + for { + // TODO(skriptble): We need finish hauling the logs to the remote end + // before we exit. Should add another mode for this goroutine where we + // haul until we get an + r.mu.Lock() + closed := r.out == nil + r.mu.Unlock() + if closed { + return + } + // Dial destination + hc := r.ss.conn.srv.lb.Dialer().PeerAPIHTTPClient() + req, err := http.NewRequest(http.MethodPost, r.peerBase+"/v0/ssh-log-haul", nil) + if err != nil { + r.ss.logf("ssh-hauling couldn't create request: %v", err) + return // Should we panic here instead? Something is very broken. + } + req.Header.Add("Connection", "upgrade") + req.Header.Add("Upgrade", "ts-ssh-haul") + req.Header.Add("SSH-Session-Name", filepath.Base(r.out.Name())) + + resp, err := hc.Do(req) + if err != nil { + r.ss.logf("ssh-hauling couldn't establish connection: %v", err) + time.Sleep(2 * time.Second) // TODO(skriptble): Replace this with a better backoff mechanism. + continue + } + if resp.StatusCode != http.StatusSwitchingProtocols { + body, _ := io.ReadAll(resp.Body) + resp.Body.Close() + r.ss.logf("ssh-hauling unexpected HTTP response: %s, %s", resp.Status, body) + time.Sleep(2 * time.Second) + continue + } + rwc, ok := resp.Body.(io.ReadWriteCloser) + if !ok { + resp.Body.Close() + r.ss.logf("ssh-hauling, http Transport did not provide a writeable body") + time.Sleep(2 * time.Second) + continue + } + // Run hauler + err = r.lpc.Run(context.Background(), rwc) + rwc.Close() + if err == haulproto.ErrClosed { + break + } + r.ss.logf("ssh-hauling encountered error: %v", err) + time.Sleep(time.Second) + } +} + // writer returns an io.Writer around w that first records the write. // // The dir should be "i" for input or "o" for output. @@ -1453,6 +1530,7 @@ func (w loggingWriter) writeCastLine(j []byte) error { if err != nil { return fmt.Errorf("logger Write: %w", err) } + w.r.lpc.Notify() return nil } |
