diff options
Diffstat (limited to 'ssh')
| -rw-r--r-- | ssh/tailssh/listen.go | 136 | ||||
| -rw-r--r-- | ssh/tailssh/session.go | 210 |
2 files changed, 346 insertions, 0 deletions
diff --git a/ssh/tailssh/listen.go b/ssh/tailssh/listen.go new file mode 100644 index 000000000..04fa14d87 --- /dev/null +++ b/ssh/tailssh/listen.go @@ -0,0 +1,136 @@ +// Copyright (c) Tailscale Inc & contributors +// SPDX-License-Identifier: BSD-3-Clause + +//go:build (linux && !android) || (darwin && !ios) || freebsd || openbsd || plan9 + +package tailssh + +import ( + "errors" + "net" + "net/netip" + "sync" + + "tailscale.com/ipn/ipnlocal" + "tailscale.com/tempfork/gliderlabs/ssh" + "tailscale.com/types/logger" +) + +func init() { + ipnlocal.RegisterListenSSH(listenSSH) +} + +// listenSSH wraps rawLn with an SSH server that resolves Tailscale peer +// identity for each connection. The returned listener's Accept yields +// *Session values (as net.Conn). +func listenSSH(rawLn net.Listener, lb *ipnlocal.LocalBackend, logf logger.Logf) (net.Listener, error) { + hostKeys, err := getHostKeys(lb.TailscaleVarRoot(), logf) + if err != nil { + return nil, err + } + signers := make([]ssh.Signer, len(hostKeys)) + for i, k := range hostKeys { + signers[i] = k + } + + sl := &sshListener{ + rawLn: rawLn, + sessions: make(chan net.Conn, 16), + done: make(chan struct{}), + } + + sshSrv := &ssh.Server{ + HostSigners: signers, + Handler: func(sess ssh.Session) { + srcAddr := sess.RemoteAddr().String() + ipp, err := netip.ParseAddrPort(srcAddr) + if err != nil { + logf("listenSSH: bad remote addr %q: %v", srcAddr, err) + sess.Exit(1) + return + } + node, userProfile, ok := lb.WhoIs("tcp", ipp) + if !ok { + logf("listenSSH: WhoIs failed for %v", srcAddr) + sess.Exit(1) + return + } + + done := make(chan struct{}) + s := newSession(sess, PeerIdentity{ + Node: node, + UserProfile: userProfile, + }, done) + + // Send the session to the listener. If the listener is + // closed, drop the session. + select { + case sl.sessions <- s: + case <-sl.done: + sess.Exit(1) + return + } + + // Block until the consumer is done with the session. + select { + case <-done: + case <-sess.Context().Done(): + case <-sl.done: + } + }, + } + + go func() { + if err := sshSrv.Serve(rawLn); err != nil { + // Serve returns when the listener is closed. Only log + // unexpected errors. + select { + case <-sl.done: + default: + logf("listenSSH: Serve error: %v", err) + } + } + sl.Close() + }() + + return sl, nil +} + +// sshListener is a net.Listener that yields *Session values from its Accept +// method. It wraps a raw TCP listener with an SSH server. +type sshListener struct { + rawLn net.Listener + sessions chan net.Conn + done chan struct{} + closeOnce sync.Once +} + +// Accept returns the next SSH session as a net.Conn. The returned value can +// be type-asserted to *Session. +func (l *sshListener) Accept() (net.Conn, error) { + select { + case s, ok := <-l.sessions: + if !ok { + return nil, errors.New("listener closed") + } + return s, nil + case <-l.done: + return nil, errors.New("listener closed") + } +} + +// Close closes the underlying raw listener and signals all pending sessions +// to terminate. +func (l *sshListener) Close() error { + var err error + l.closeOnce.Do(func() { + close(l.done) + err = l.rawLn.Close() + }) + return err +} + +// Addr returns the address of the underlying raw listener. +func (l *sshListener) Addr() net.Addr { + return l.rawLn.Addr() +} diff --git a/ssh/tailssh/session.go b/ssh/tailssh/session.go new file mode 100644 index 000000000..d88df1455 --- /dev/null +++ b/ssh/tailssh/session.go @@ -0,0 +1,210 @@ +// Copyright (c) Tailscale Inc & contributors +// SPDX-License-Identifier: BSD-3-Clause + +//go:build (linux && !android) || (darwin && !ios) || freebsd || openbsd || plan9 + +package tailssh + +import ( + "context" + "errors" + "io" + "net" + "time" + + "tailscale.com/tailcfg" + "tailscale.com/tempfork/gliderlabs/ssh" +) + +var errNoDeadline = errors.New("tailssh.Session: deadlines not supported") + +// Signal represents an SSH signal (e.g. "INT", "TERM"). +type Signal = ssh.Signal + +// Pty represents a PTY request and configuration. +type Pty struct { + // Term is the TERM environment variable value. + Term string + + // Window is the initial window size. + Window Window + + // Modes are the RFC 4254 terminal modes as opcode/value pairs. + Modes map[uint8]uint32 +} + +// Window represents the size of a PTY window. +type Window struct { + Width int + Height int + WidthPixels int + HeightPixels int +} + +// PeerIdentity contains the Tailscale identity of the connecting SSH peer. +type PeerIdentity struct { + Node tailcfg.NodeView + UserProfile tailcfg.UserProfile +} + +// Session wraps a gliderlabs ssh.Session with Tailscale peer identity +// information. It implements net.Conn so callers that only need Read/Write/Close +// can use it directly. Callers that need SSH-specific functionality can +// type-assert from the net.Conn returned by the listener's Accept. +type Session struct { + // sess is the underlying gliderlabs SSH session. + sess ssh.Session + + // peer is the Tailscale identity of the remote peer. + peer PeerIdentity + + // done is closed when the session handler should return, + // unblocking the gliderlabs handler goroutine. + done chan struct{} +} + +// newSession creates a new Session wrapping the given gliderlabs session and +// peer identity. The done channel is closed by the session consumer to signal +// that the handler goroutine may return. +func newSession(sess ssh.Session, peer PeerIdentity, done chan struct{}) *Session { + return &Session{ + sess: sess, + peer: peer, + done: done, + } +} + +// Read reads from the SSH channel (stdin from the client). +func (s *Session) Read(p []byte) (int, error) { + return s.sess.Read(p) +} + +// Write writes to the SSH channel (stdout to the client). +func (s *Session) Write(p []byte) (int, error) { + return s.sess.Write(p) +} + +// Close signals the session handler to return and closes the underlying channel. +func (s *Session) Close() error { + select { + case <-s.done: + default: + close(s.done) + } + return nil +} + +// RemoteAddr returns the net.Addr of the client side of the connection. +func (s *Session) RemoteAddr() net.Addr { + return s.sess.RemoteAddr() +} + +// LocalAddr returns the net.Addr of the server side of the connection. +func (s *Session) LocalAddr() net.Addr { + return s.sess.LocalAddr() +} + +// SetDeadline is not supported and returns an error. +func (s *Session) SetDeadline(t time.Time) error { + return errNoDeadline +} + +// SetReadDeadline is not supported and returns an error. +func (s *Session) SetReadDeadline(t time.Time) error { + return errNoDeadline +} + +// SetWriteDeadline is not supported and returns an error. +func (s *Session) SetWriteDeadline(t time.Time) error { + return errNoDeadline +} + +// User returns the SSH username. +func (s *Session) User() string { + return s.sess.User() +} + +// PeerIdentity returns the Tailscale identity of the remote peer. +func (s *Session) PeerIdentity() PeerIdentity { + return s.peer +} + +// Environ returns a copy of the environment variables set by the client. +func (s *Session) Environ() []string { + return s.sess.Environ() +} + +// RawCommand returns the exact command string provided by the client. +func (s *Session) RawCommand() string { + return s.sess.RawCommand() +} + +// Subsystem returns the subsystem requested by the client. +func (s *Session) Subsystem() string { + return s.sess.Subsystem() +} + +// Pty returns PTY information, a channel of window size changes, and whether +// a PTY was requested. The returned types use this package's Pty and Window +// types rather than the internal gliderlabs types. +func (s *Session) Pty() (Pty, <-chan Window, bool) { + gPty, gWinCh, ok := s.sess.Pty() + if !ok { + return Pty{}, nil, false + } + p := Pty{ + Term: gPty.Term, + Window: Window{ + Width: gPty.Window.Width, + Height: gPty.Window.Height, + WidthPixels: gPty.Window.WidthPixels, + HeightPixels: gPty.Window.HeightPixels, + }, + } + if gPty.Modes != nil { + p.Modes = make(map[uint8]uint32, len(gPty.Modes)) + for k, v := range gPty.Modes { + p.Modes[k] = v + } + } + + // Convert the gliderlabs Window channel to our Window type. + winCh := make(chan Window, 1) + go func() { + defer close(winCh) + for gw := range gWinCh { + winCh <- Window{ + Width: gw.Width, + Height: gw.Height, + WidthPixels: gw.WidthPixels, + HeightPixels: gw.HeightPixels, + } + } + }() + + return p, winCh, true +} + +// Signals registers a channel to receive signals from the client. +// Pass nil to unregister. +func (s *Session) Signals(c chan<- Signal) { + s.sess.Signals(c) +} + +// Exit sends an exit status to the client and closes the session. +func (s *Session) Exit(code int) error { + err := s.sess.Exit(code) + s.Close() + return err +} + +// Stderr returns an io.Writer for the SSH stderr channel. +func (s *Session) Stderr() io.Writer { + return s.sess.Stderr() +} + +// Context returns the session's context, which is canceled when the client +// disconnects. +func (s *Session) Context() context.Context { + return s.sess.Context() +} |
