summaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorIrbe Krumina <irbe@tailscale.com>2024-07-26 20:05:49 +0300
committerIrbe Krumina <irbe@tailscale.com>2024-07-26 21:32:37 +0300
commit69c27b23cb8ae46e6f0845817e879d636f26e70a (patch)
tree62d00d75cc340334c636c4b58e85ab1aba6dd507
parent8d7b78f3f795e781d939750893610639b224d81a (diff)
downloadtailscale-irbekrm/websocket.tar.xz
tailscale-irbekrm/websocket.zip
cmd/k8s-operator,k8s-operator/session-recording: implement support for WebSocket protocolirbekrm/websocket
Kubernetes currently supports two streaming protocols- SPDY and WebSockets. WebSockets are replacing SPDY, see https://github.com/kubernetes/enhancements/issues/4006 Our 'kubectl exec' session recording was only supporting SPDY. This change: - adds functionality to parse streaming sessions over WebSockets - for sessions that the API server proxy has determined need to be recorded, determines if the session is over SPDY or WebSockets and invoke the relevant parser accordingly - refactors the session recording logic into its own package Updates tailscale/corp#19821 Signed-off-by: Irbe Krumina <irbe@tailscale.com>
-rw-r--r--cmd/k8s-operator/depaware.txt7
-rw-r--r--cmd/k8s-operator/proxy.go81
-rw-r--r--k8s-operator/session-recording/fakes/fakes.go117
-rw-r--r--k8s-operator/session-recording/hijacker.go (renamed from cmd/k8s-operator/spdy-hijacker.go)130
-rw-r--r--k8s-operator/session-recording/hijacker_test.go (renamed from cmd/k8s-operator/spdy-hijacker_test.go)32
-rw-r--r--k8s-operator/session-recording/spdy/conn.go (renamed from cmd/k8s-operator/spdy-remote-conn-recorder.go)37
-rw-r--r--k8s-operator/session-recording/spdy/conn_test.go (renamed from cmd/k8s-operator/spdy-remote-conn-recorder_test.go)123
-rw-r--r--k8s-operator/session-recording/spdy/frame.go (renamed from cmd/k8s-operator/spdy-frame.go)2
-rw-r--r--k8s-operator/session-recording/spdy/frame_test.go (renamed from cmd/k8s-operator/spdy-frame_test.go)2
-rw-r--r--k8s-operator/session-recording/spdy/zlib-reader.go (renamed from cmd/k8s-operator/zlib-reader.go)2
-rw-r--r--k8s-operator/session-recording/tsrecorder/header.go54
-rw-r--r--k8s-operator/session-recording/tsrecorder/tsrecorder.go (renamed from cmd/k8s-operator/recorder.go)56
-rw-r--r--k8s-operator/session-recording/ws/conn.go244
-rw-r--r--k8s-operator/session-recording/ws/conn_test.go171
-rw-r--r--k8s-operator/session-recording/ws/message.go253
-rw-r--r--k8s-operator/session-recording/ws/message_test.go125
16 files changed, 1192 insertions, 244 deletions
diff --git a/cmd/k8s-operator/depaware.txt b/cmd/k8s-operator/depaware.txt
index b5c0ed517..c12fd89b7 100644
--- a/cmd/k8s-operator/depaware.txt
+++ b/cmd/k8s-operator/depaware.txt
@@ -423,6 +423,7 @@ tailscale.com/cmd/k8s-operator dependencies: (generated by github.com/tailscale/
k8s.io/apimachinery/pkg/util/naming from k8s.io/apimachinery/pkg/runtime+
k8s.io/apimachinery/pkg/util/net from k8s.io/apimachinery/pkg/watch+
k8s.io/apimachinery/pkg/util/rand from k8s.io/apiserver/pkg/storage/names
+ k8s.io/apimachinery/pkg/util/remotecommand from tailscale.com/k8s-operator/session-recording/ws
k8s.io/apimachinery/pkg/util/runtime from k8s.io/apimachinery/pkg/apis/meta/internalversion/scheme+
k8s.io/apimachinery/pkg/util/sets from k8s.io/apimachinery/pkg/api/meta+
k8s.io/apimachinery/pkg/util/strategicpatch from k8s.io/client-go/tools/record+
@@ -692,6 +693,10 @@ tailscale.com/cmd/k8s-operator dependencies: (generated by github.com/tailscale/
tailscale.com/k8s-operator from tailscale.com/cmd/k8s-operator
tailscale.com/k8s-operator/apis from tailscale.com/k8s-operator/apis/v1alpha1
tailscale.com/k8s-operator/apis/v1alpha1 from tailscale.com/cmd/k8s-operator+
+ tailscale.com/k8s-operator/session-recording from tailscale.com/cmd/k8s-operator
+ tailscale.com/k8s-operator/session-recording/spdy from tailscale.com/k8s-operator/session-recording
+ tailscale.com/k8s-operator/session-recording/tsrecorder from tailscale.com/k8s-operator/session-recording+
+ tailscale.com/k8s-operator/session-recording/ws from tailscale.com/k8s-operator/session-recording
tailscale.com/kube from tailscale.com/cmd/k8s-operator+
tailscale.com/licenses from tailscale.com/client/web
tailscale.com/log/filelogger from tailscale.com/logpolicy
@@ -752,7 +757,7 @@ tailscale.com/cmd/k8s-operator dependencies: (generated by github.com/tailscale/
tailscale.com/tka from tailscale.com/client/tailscale+
W tailscale.com/tsconst from tailscale.com/net/netmon+
tailscale.com/tsd from tailscale.com/ipn/ipnlocal+
- tailscale.com/tsnet from tailscale.com/cmd/k8s-operator
+ tailscale.com/tsnet from tailscale.com/cmd/k8s-operator+
tailscale.com/tstime from tailscale.com/cmd/k8s-operator+
tailscale.com/tstime/mono from tailscale.com/net/tstun+
tailscale.com/tstime/rate from tailscale.com/derp+
diff --git a/cmd/k8s-operator/proxy.go b/cmd/k8s-operator/proxy.go
index 258a958fa..45b048f6f 100644
--- a/cmd/k8s-operator/proxy.go
+++ b/cmd/k8s-operator/proxy.go
@@ -22,6 +22,7 @@ import (
"k8s.io/client-go/transport"
"tailscale.com/client/tailscale"
"tailscale.com/client/tailscale/apitype"
+ sessionrecording "tailscale.com/k8s-operator/session-recording"
tskube "tailscale.com/kube"
"tailscale.com/ssh/tailssh"
"tailscale.com/tailcfg"
@@ -36,12 +37,6 @@ var whoIsKey = ctxkey.New("", (*apitype.WhoIsResponse)(nil))
var (
// counterNumRequestsproxies counts the number of API server requests proxied via this proxy.
counterNumRequestsProxied = clientmetric.NewCounter("k8s_auth_proxy_requests_proxied")
-
- // counterSessionRecordingsAttempted counts the number of session recording attempts.
- counterSessionRecordingsAttempted = clientmetric.NewCounter("k8s_auth_proxy__session_recordings_attempted")
-
- // counterSessionRecordingsUploaded counts the number of successfully uploaded session recordings.
- counterSessionRecordingsUploaded = clientmetric.NewCounter("k8s_auth_proxy_session_recordings_uploaded")
)
type apiServerProxyMode int
@@ -173,7 +168,9 @@ func runAPIServerProxy(ts *tsnet.Server, rt http.RoundTripper, log *zap.SugaredL
mux := http.NewServeMux()
mux.HandleFunc("/", ap.serveDefault)
- mux.HandleFunc("/api/v1/namespaces/{namespace}/pods/{pod}/exec", ap.serveExec)
+ mux.HandleFunc("POST /api/v1/namespaces/{namespace}/pods/{pod}/exec", ap.serveExecSPDY)
+
+ mux.HandleFunc("GET /api/v1/namespaces/{namespace}/pods/{pod}/exec", ap.serveExecWS)
hs := &http.Server{
// Kubernetes uses SPDY for exec and port-forward, however SPDY is
@@ -214,9 +211,10 @@ func (ap *apiserverProxy) serveDefault(w http.ResponseWriter, r *http.Request) {
ap.rp.ServeHTTP(w, r.WithContext(whoIsKey.WithValue(r.Context(), who)))
}
-// serveExec serves 'kubectl exec' requests, optionally configuring the kubectl
-// exec sessions to be recorded.
-func (ap *apiserverProxy) serveExec(w http.ResponseWriter, r *http.Request) {
+// serveExecWS serves 'kubectl exec' requests, optionally configuring the
+// kubectl exec sessions to be recorded. It should only be called for requests
+// for sessions that use WebSockets protocol for streaming.
+func (ap *apiserverProxy) serveExecWS(w http.ResponseWriter, r *http.Request) {
who, err := ap.whoIs(r)
if err != nil {
ap.authError(w, err)
@@ -232,14 +230,59 @@ func (ap *apiserverProxy) serveExec(w http.ResponseWriter, r *http.Request) {
ap.rp.ServeHTTP(w, r.WithContext(whoIsKey.WithValue(r.Context(), who)))
return
}
- counterSessionRecordingsAttempted.Add(1) // at this point we know that users intended for this session to be recorded
+ sessionrecording.CounterSessionRecordingsAttempted.Add(1) // at this point we know that users intended for this session to be recorded
if !failOpen && len(addrs) == 0 {
msg := "forbidden: 'kubectl exec' session must be recorded, but no recorders are available."
ap.log.Error(msg)
http.Error(w, msg, http.StatusForbidden)
return
}
- if r.Method != "POST" || r.Header.Get("Upgrade") != "SPDY/3.1" {
+ if h := r.Header.Get("Upgrade"); h != "websocket" {
+ msg := fmt.Sprintf("[unexpected] 'kubectl exec' session was initiated for WebSocket protocol, but the request does not contain expected upgrade header, wants: 'websocket', got: %q", h)
+ if failOpen {
+ msg = msg + "; failure mode is 'fail open'; continuing session without recording."
+ ap.log.Warn(msg)
+ ap.rp.ServeHTTP(w, r.WithContext(whoIsKey.WithValue(r.Context(), who)))
+ return
+ }
+ ap.log.Error(msg)
+ msg += "; failure mode is 'fail closed'; closing connection."
+ http.Error(w, msg, 403)
+ return
+ } else {
+ ap.log.Debugf("detected 'kubectl exec' session streaming protocol is WebSockets")
+ }
+ wsH := sessionrecording.New(ap.ts, r, who, w, r.PathValue("pod"), r.PathValue("namespace"), sessionrecording.WebSocketsProtocol, addrs, failOpen, tailssh.ConnectToRecorder, ap.log)
+
+ ap.rp.ServeHTTP(wsH, r.WithContext(whoIsKey.WithValue(r.Context(), who)))
+}
+
+// serveExecSPDY serves 'kubectl exec' requests, optionally configuring the
+// kubectl exec sessions to be recorded. It should only be called for requests
+// that initate 'kubectl exec' sessions using the SPDY protocol for streaming.
+func (ap *apiserverProxy) serveExecSPDY(w http.ResponseWriter, r *http.Request) {
+ who, err := ap.whoIs(r)
+ if err != nil {
+ ap.authError(w, err)
+ return
+ }
+ counterNumRequestsProxied.Add(1)
+ failOpen, addrs, err := determineRecorderConfig(who)
+ if err != nil {
+ ap.log.Errorf("error trying to determine whether the 'kubectl exec' session needs to be recorded: %v", err)
+ return
+ }
+ if failOpen && len(addrs) == 0 { // will not record
+ ap.rp.ServeHTTP(w, r.WithContext(whoIsKey.WithValue(r.Context(), who)))
+ return
+ }
+ if !failOpen && len(addrs) == 0 {
+ msg := "forbidden: 'kubectl exec' session must be recorded, but no recorders are available."
+ ap.log.Error(msg)
+ http.Error(w, msg, 403)
+ return
+ }
+ if r.Header.Get("Upgrade") != "SPDY/3.1" {
msg := "'kubectl exec' session recording is configured, but the request is not over SPDY. Session recording is currently only supported for SPDY based clients"
if failOpen {
msg = msg + "; failure mode is 'fail open'; continuing session without recording."
@@ -252,19 +295,7 @@ func (ap *apiserverProxy) serveExec(w http.ResponseWriter, r *http.Request) {
http.Error(w, msg, http.StatusForbidden)
return
}
- spdyH := &spdyHijacker{
- ts: ap.ts,
- req: r,
- who: who,
- ResponseWriter: w,
- log: ap.log,
- pod: r.PathValue("pod"),
- ns: r.PathValue("namespace"),
- addrs: addrs,
- failOpen: failOpen,
- connectToRecorder: tailssh.ConnectToRecorder,
- }
-
+ spdyH := sessionrecording.New(ap.ts, r, who, w, r.PathValue("pod"), r.PathValue("namespace"), sessionrecording.SPDYProtocol, addrs, failOpen, tailssh.ConnectToRecorder, ap.log)
ap.rp.ServeHTTP(spdyH, r.WithContext(whoIsKey.WithValue(r.Context(), who)))
}
diff --git a/k8s-operator/session-recording/fakes/fakes.go b/k8s-operator/session-recording/fakes/fakes.go
new file mode 100644
index 000000000..9f5c349d4
--- /dev/null
+++ b/k8s-operator/session-recording/fakes/fakes.go
@@ -0,0 +1,117 @@
+// Copyright (c) Tailscale Inc & AUTHORS
+// SPDX-License-Identifier: BSD-3-Clause
+
+//go:build !plan9
+
+// package fakes contains utils for testing session recording behaviour.
+package fakes
+
+import (
+ "bytes"
+ "encoding/json"
+ "net"
+ "sync"
+ "testing"
+
+ "tailscale.com/k8s-operator/session-recording/tsrecorder"
+ "tailscale.com/tstime"
+)
+
+func New(conn net.Conn, wb bytes.Buffer, rb bytes.Buffer, closed bool) net.Conn {
+ return &TestConn{
+ Conn: conn,
+ writeBuf: wb,
+ readBuf: rb,
+ closed: closed,
+ }
+}
+
+type TestConn struct {
+ net.Conn
+ // writeBuf contains whatever was send to the conn via Write.
+ writeBuf bytes.Buffer
+ // readBuf contains whatever was sent to the conn via Read.
+ readBuf bytes.Buffer
+ sync.RWMutex // protects the following
+ closed bool
+}
+
+var _ net.Conn = &TestConn{}
+
+func (tc *TestConn) Read(b []byte) (int, error) {
+ return tc.readBuf.Read(b)
+}
+
+func (tc *TestConn) Write(b []byte) (int, error) {
+ return tc.writeBuf.Write(b)
+}
+
+func (tc *TestConn) Close() error {
+ tc.Lock()
+ defer tc.Unlock()
+ tc.closed = true
+ return nil
+}
+
+func (tc *TestConn) IsClosed() bool {
+ tc.Lock()
+ defer tc.Unlock()
+ return tc.closed
+}
+
+func (tc *TestConn) WriteBufBytes() []byte {
+ return tc.writeBuf.Bytes()
+}
+
+func (tc *TestConn) ResetReadBuf() {
+ tc.readBuf.Reset()
+}
+
+func (tc *TestConn) WriteReadBufBytes(b []byte) error {
+ _, err := tc.readBuf.Write(b)
+ return err
+}
+
+type TestSessionRecorder struct {
+ // buf holds data that was sent to the session recorder.
+ buf bytes.Buffer
+}
+
+func (t *TestSessionRecorder) Write(b []byte) (int, error) {
+ return t.buf.Write(b)
+}
+
+func (t *TestSessionRecorder) Close() error {
+ t.buf.Reset()
+ return nil
+}
+
+func (t *TestSessionRecorder) Bytes() []byte {
+ return t.buf.Bytes()
+}
+
+func CastLine(t *testing.T, p []byte, clock tstime.Clock) []byte {
+ t.Helper()
+ j, err := json.Marshal([]any{
+ clock.Now().Sub(clock.Now()).Seconds(),
+ "o",
+ string(p),
+ })
+ if err != nil {
+ t.Fatalf("error marshalling cast line: %v", err)
+ }
+ return append(j, '\n')
+}
+
+func AsciinemaResizeMsg(t *testing.T, width, height int) []byte {
+ t.Helper()
+ ch := tsrecorder.CastHeader{
+ Width: width,
+ Height: height,
+ }
+ bs, err := json.Marshal(ch)
+ if err != nil {
+ t.Fatalf("error marshalling CastHeader: %v", err)
+ }
+ return append(bs, '\n')
+}
diff --git a/cmd/k8s-operator/spdy-hijacker.go b/k8s-operator/session-recording/hijacker.go
index f74771e42..bbaee3ba7 100644
--- a/cmd/k8s-operator/spdy-hijacker.go
+++ b/k8s-operator/session-recording/hijacker.go
@@ -3,12 +3,15 @@
//go:build !plan9
-package main
+// Package sessionrecording has functionality for recording 'kubectl exec'
+// sessions and sending to a tsrecorder.
+package sessionrecording
import (
"bufio"
"bytes"
"context"
+ "errors"
"fmt"
"io"
"net"
@@ -16,20 +19,52 @@ import (
"net/netip"
"strings"
- "github.com/pkg/errors"
"go.uber.org/zap"
"tailscale.com/client/tailscale/apitype"
+ "tailscale.com/k8s-operator/session-recording/spdy"
+ "tailscale.com/k8s-operator/session-recording/tsrecorder"
+ "tailscale.com/k8s-operator/session-recording/ws"
"tailscale.com/tailcfg"
"tailscale.com/tsnet"
"tailscale.com/tstime"
+ "tailscale.com/util/clientmetric"
"tailscale.com/util/multierr"
)
+const (
+ SPDYProtocol = "SPDY"
+ WebSocketsProtocol = "WebSockets"
+)
+
+var (
+ // counterSessionRecordingsAttempted counts the number of session recording attempts.
+ CounterSessionRecordingsAttempted = clientmetric.NewCounter("k8s_auth_proxy__session_recordings_attempted")
+
+ // counterSessionRecordingsUploaded counts the number of successfully uploaded session recordings.
+ CounterSessionRecordingsUploaded = clientmetric.NewCounter("k8s_auth_proxy_session_recordings_uploaded")
+)
+
+func New(ts *tsnet.Server, req *http.Request, who *apitype.WhoIsResponse, w http.ResponseWriter, pod, ns string, proto protocol, addrs []netip.AddrPort, failOpen bool, connFunc RecorderDialFn, log *zap.SugaredLogger) *SpdyHijacker {
+ return &SpdyHijacker{
+ ts: ts,
+ req: req,
+ who: who,
+ ResponseWriter: w,
+ pod: pod,
+ ns: ns,
+ addrs: addrs,
+ failOpen: failOpen,
+ connectToRecorder: connFunc,
+ proto: proto,
+ log: log,
+ }
+}
+
// spdyHijacker implements [net/http.Hijacker] interface.
// It must be configured with an http request for a 'kubectl exec' session that
// needs to be recorded. It knows how to hijack the connection and configure for
// the session contents to be sent to a tsrecorder instance.
-type spdyHijacker struct {
+type SpdyHijacker struct {
http.ResponseWriter
ts *tsnet.Server
req *http.Request
@@ -40,8 +75,13 @@ type spdyHijacker struct {
addrs []netip.AddrPort // tsrecorder addresses
failOpen bool // whether to fail open if recording fails
connectToRecorder RecorderDialFn
+ proto protocol
}
+// protocol is the streaming protocol of the hijacked session. Supported
+// protocols are SPDY and WebSockets.
+type protocol string
+
// RecorderDialFn dials the specified netip.AddrPorts that should be tsrecorder
// addresses. It tries to connect to recorder endpoints one by one, till one
// connection succeeds. In case of success, returns a list with a single
@@ -51,7 +91,7 @@ type RecorderDialFn func(context.Context, []netip.AddrPort, func(context.Context
// Hijack hijacks a 'kubectl exec' session and configures for the session
// contents to be sent to a recorder.
-func (h *spdyHijacker) Hijack() (net.Conn, *bufio.ReadWriter, error) {
+func (h *SpdyHijacker) Hijack() (net.Conn, *bufio.ReadWriter, error) {
h.log.Infof("recorder addrs: %v, failOpen: %v", h.addrs, h.failOpen)
reqConn, brw, err := h.ResponseWriter.(http.Hijacker).Hijack()
if err != nil {
@@ -69,7 +109,7 @@ func (h *spdyHijacker) Hijack() (net.Conn, *bufio.ReadWriter, error) {
// spdyHijacker.addrs. Returns conn from provided opts, wrapped in recording
// logic. If connecting to the recorder fails or an error is received during the
// session and spdyHijacker.failOpen is false, connection will be closed.
-func (h *spdyHijacker) setUpRecording(ctx context.Context, conn net.Conn) (net.Conn, error) {
+func (h *SpdyHijacker) setUpRecording(ctx context.Context, conn net.Conn) (net.Conn, error) {
const (
// https://docs.asciinema.org/manual/asciicast/v2/
asciicastv2 = 2
@@ -91,30 +131,20 @@ func (h *spdyHijacker) setUpRecording(ctx context.Context, conn net.Conn) (net.C
}
return nil, errors.New(msg)
}
-
// TODO (irbekrm): log which recorder
h.log.Info("successfully connected to a session recorder")
wc = rw
cl := tstime.DefaultClock{}
- lc := &spdyRemoteConnRecorder{
- log: h.log,
- Conn: conn,
- rec: &recorder{
- start: cl.Now(),
- clock: cl,
- failOpen: h.failOpen,
- conn: wc,
- },
- }
+ rec := tsrecorder.New(wc, cl, cl.Now(), h.failOpen)
qp := h.req.URL.Query()
- ch := CastHeader{
+ ch := tsrecorder.CastHeader{
Version: asciicastv2,
- Timestamp: lc.rec.start.Unix(),
+ Timestamp: cl.Now().Unix(),
Command: strings.Join(qp["command"], " "),
SrcNode: strings.TrimSuffix(h.who.Node.Name, "."),
SrcNodeID: h.who.Node.StableID,
- Kubernetes: &Kubernetes{
+ Kubernetes: &tsrecorder.Kubernetes{
PodName: h.pod,
Namespace: h.ns,
Container: strings.Join(qp["container"], " "),
@@ -126,7 +156,16 @@ func (h *spdyHijacker) setUpRecording(ctx context.Context, conn net.Conn) (net.C
} else {
ch.SrcNodeTags = h.who.Node.Tags
}
- lc.ch = ch
+ var lc net.Conn
+ switch h.proto {
+ case SPDYProtocol:
+ lc = spdy.New(conn, rec, ch, h.log)
+ case WebSocketsProtocol:
+ lc = ws.New(conn, rec, ch, h.log)
+ default:
+ return nil, fmt.Errorf("unknown protocol: %s", h.proto)
+ }
+
go func() {
var err error
select {
@@ -135,7 +174,7 @@ func (h *spdyHijacker) setUpRecording(ctx context.Context, conn net.Conn) (net.C
case err = <-errChan:
}
if err == nil {
- counterSessionRecordingsUploaded.Add(1)
+ CounterSessionRecordingsUploaded.Add(1)
h.log.Info("finished uploading the recording")
return
}
@@ -147,60 +186,13 @@ func (h *spdyHijacker) setUpRecording(ctx context.Context, conn net.Conn) (net.C
}
msg += "; failure mode set to 'fail closed'; closing connection"
h.log.Error(msg)
- lc.failed = true
- // TODO (irbekrm): write a message to the client
if err := lc.Close(); err != nil {
h.log.Infof("error closing recorder connections: %v", err)
}
return
}()
- return lc, nil
-}
-
-// CastHeader is the asciicast header to be sent to the recorder at the start of
-// the recording of a session.
-// https://docs.asciinema.org/manual/asciicast/v2/#header
-type CastHeader struct {
- // Version is the asciinema file format version.
- Version int `json:"version"`
-
- // Width is the terminal width in characters.
- Width int `json:"width"`
-
- // Height is the terminal height in characters.
- Height int `json:"height"`
-
- // Timestamp is the unix timestamp of when the recording started.
- Timestamp int64 `json:"timestamp"`
-
- // Tailscale-specific fields: SrcNode is the full MagicDNS name of the
- // tailnet node originating the connection, without the trailing dot.
- SrcNode string `json:"srcNode"`
-
- // SrcNodeID is the node ID of the tailnet node originating the connection.
- SrcNodeID tailcfg.StableNodeID `json:"srcNodeID"`
- // SrcNodeTags is the list of tags on the node originating the connection (if any).
- SrcNodeTags []string `json:"srcNodeTags,omitempty"`
-
- // SrcNodeUserID is the user ID of the node originating the connection (if not tagged).
- SrcNodeUserID tailcfg.UserID `json:"srcNodeUserID,omitempty"` // if not tagged
-
- // SrcNodeUser is the LoginName of the node originating the connection (if not tagged).
- SrcNodeUser string `json:"srcNodeUser,omitempty"`
-
- Command string
-
- // Kubernetes-specific fields:
- Kubernetes *Kubernetes `json:"kubernetes,omitempty"`
-}
-
-// Kubernetes contains 'kubectl exec' session specific information for
-// tsrecorder.
-type Kubernetes struct {
- PodName string
- Namespace string
- Container string
+ return lc, nil
}
func closeConnWithWarning(conn net.Conn, msg string) error {
diff --git a/cmd/k8s-operator/spdy-hijacker_test.go b/k8s-operator/session-recording/hijacker_test.go
index 7ac79d7f0..cfc694d26 100644
--- a/cmd/k8s-operator/spdy-hijacker_test.go
+++ b/k8s-operator/session-recording/hijacker_test.go
@@ -3,7 +3,7 @@
//go:build !plan9
-package main
+package sessionrecording
import (
"context"
@@ -19,6 +19,7 @@ import (
"go.uber.org/zap"
"tailscale.com/client/tailscale/apitype"
+ "tailscale.com/k8s-operator/session-recording/fakes"
"tailscale.com/tailcfg"
"tailscale.com/tsnet"
"tailscale.com/tstest"
@@ -34,39 +35,49 @@ func Test_SPDYHijacker(t *testing.T) {
failOpen bool
failRecorderConnect bool // fail initial connect to the recorder
failRecorderConnPostConnect bool // send error down the error channel
+ proto protocol
wantsConnClosed bool
wantsSetupErr bool
}{
{
- name: "setup succeeds, conn stays open",
+ name: "spdy_setup_succeeds_conn_stays_open",
+ proto: SPDYProtocol,
},
{
- name: "setup fails, policy is to fail open, conn stays open",
+ name: "ws_setup_succeeds_conn_stays_open",
+ proto: WebSocketsProtocol,
+ },
+ {
+ name: "setup_fails_policy_is_to_fail_open_conn_stays_open",
failOpen: true,
failRecorderConnect: true,
+ proto: SPDYProtocol,
},
{
- name: "setup fails, policy is to fail closed, conn is closed",
+ name: "setup_fails_policy_is_to_fail_closed_conn_is_closed",
failRecorderConnect: true,
wantsSetupErr: true,
wantsConnClosed: true,
+ proto: SPDYProtocol,
},
{
- name: "connection fails post-initial connect, policy is to fail open, conn stays open",
+ name: "connection_fails_post-initial_connect_policy_is_to_fail_open_conn_stays_open",
failRecorderConnPostConnect: true,
failOpen: true,
+ proto: SPDYProtocol,
},
{
- name: "connection fails post-initial connect, policy is to fail closed, conn is closed",
+ name: "connection_fails_post-initial_connect_policy_is_to_fail_closed_conn_is_closed",
failRecorderConnPostConnect: true,
wantsConnClosed: true,
+ proto: SPDYProtocol,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
- tc := &testConn{}
+ tc := &fakes.TestConn{}
ch := make(chan error)
- h := &spdyHijacker{
+ h := &SpdyHijacker{
connectToRecorder: func(context.Context, []netip.AddrPort, func(context.Context, string, string) (net.Conn, error)) (wc io.WriteCloser, rec []*tailcfg.SSHRecordingAttempt, _ <-chan error, err error) {
if tt.failRecorderConnect {
err = errors.New("test")
@@ -78,6 +89,7 @@ func Test_SPDYHijacker(t *testing.T) {
log: zl.Sugar(),
ts: &tsnet.Server{},
req: &http.Request{URL: &url.URL{}},
+ proto: tt.proto,
}
ctx := context.Background()
_, err := h.setUpRecording(ctx, tc)
@@ -98,8 +110,8 @@ func Test_SPDYHijacker(t *testing.T) {
// (test that connection remains open over some period
// of time).
if err := tstest.WaitFor(timeout, func() (err error) {
- if tt.wantsConnClosed != tc.isClosed() {
- return fmt.Errorf("got connection state: %t, wants connection state: %t", tc.isClosed(), tt.wantsConnClosed)
+ if tt.wantsConnClosed != tc.IsClosed() {
+ return fmt.Errorf("got conIection state: %t, wants connection state: %t", tc.IsClosed(), tt.wantsConnClosed)
}
return nil
}); err != nil {
diff --git a/cmd/k8s-operator/spdy-remote-conn-recorder.go b/k8s-operator/session-recording/spdy/conn.go
index 563b2a241..af27f27e6 100644
--- a/cmd/k8s-operator/spdy-remote-conn-recorder.go
+++ b/k8s-operator/session-recording/spdy/conn.go
@@ -3,7 +3,9 @@
//go:build !plan9
-package main
+// Package spdy has functionality to parse 'kubectl exec' sessions streamed over
+// SPDY.
+package spdy
import (
"bytes"
@@ -15,18 +17,30 @@ import (
"sync"
"sync/atomic"
+ "tailscale.com/k8s-operator/session-recording/tsrecorder"
+
"go.uber.org/zap"
corev1 "k8s.io/api/core/v1"
)
+func New(conn net.Conn, rec *tsrecorder.Client, ch tsrecorder.CastHeader, log *zap.SugaredLogger) net.Conn {
+ return &spdyRemoteConnRecorder{
+ Conn: conn,
+ rec: rec,
+ ch: ch,
+ log: log,
+ }
+
+}
+
// spdyRemoteConnRecorder is a wrapper around net.Conn. It reads the bytestream
// for a 'kubectl exec' session, sends session recording data to the configured
// recorder and forwards the raw bytes to the original destination.
type spdyRemoteConnRecorder struct {
net.Conn
// rec knows how to send data written to it to a tsrecorder instance.
- rec *recorder
- ch CastHeader
+ rec *tsrecorder.Client
+ ch tsrecorder.CastHeader
stdoutStreamID atomic.Uint32
stderrStreamID atomic.Uint32
@@ -34,7 +48,6 @@ type spdyRemoteConnRecorder struct {
wmu sync.Mutex // sequences writes
closed bool
- failed bool
rmu sync.Mutex // sequences reads
writeCastHeaderOnce sync.Once
@@ -78,9 +91,9 @@ func (c *spdyRemoteConnRecorder) Read(b []byte) (int, error) {
switch sf.StreamID {
case c.resizeStreamID.Load():
var err error
- var msg spdyResizeMsg
+ var msg tsrecorder.ResizeMsg
if err = json.Unmarshal(sf.Payload, &msg); err != nil {
- return 0, fmt.Errorf("error umarshalling resize msg: %w", err)
+ return 0, err
}
c.ch.Width = msg.Width
c.ch.Height = msg.Height
@@ -127,13 +140,14 @@ func (c *spdyRemoteConnRecorder) Write(b []byte) (int, error) {
case c.stdoutStreamID.Load(), c.stderrStreamID.Load():
var err error
c.writeCastHeaderOnce.Do(func() {
+
var j []byte
j, err = json.Marshal(c.ch)
if err != nil {
return
}
j = append(j, '\n')
- err = c.rec.writeCastLine(j)
+ err = c.rec.WriteCastLine(j)
if err != nil {
c.log.Errorf("received error from recorder: %v", err)
}
@@ -157,7 +171,9 @@ func (c *spdyRemoteConnRecorder) Close() error {
if c.closed {
return nil
}
- if !c.failed && c.writeBuf.Len() > 0 {
+ // TODO: only do this if this is a normal closure rather than the
+ // reocrding has failed.
+ if c.writeBuf.Len() > 0 {
c.Conn.Write(c.writeBuf.Bytes())
}
c.writeBuf.Reset()
@@ -187,8 +203,3 @@ func (c *spdyRemoteConnRecorder) storeStreamID(sf spdyFrame, header http.Header)
c.resizeStreamID.Store(id)
}
}
-
-type spdyResizeMsg struct {
- Width int `json:"width"`
- Height int `json:"height"`
-}
diff --git a/cmd/k8s-operator/spdy-remote-conn-recorder_test.go b/k8s-operator/session-recording/spdy/conn_test.go
index 95f5a8bfc..ce8c9ae49 100644
--- a/cmd/k8s-operator/spdy-remote-conn-recorder_test.go
+++ b/k8s-operator/session-recording/spdy/conn_test.go
@@ -3,19 +3,17 @@
//go:build !plan9
-package main
+package spdy
import (
- "bytes"
"encoding/json"
- "net"
"reflect"
- "sync"
"testing"
"go.uber.org/zap"
+ "tailscale.com/k8s-operator/session-recording/fakes"
+ "tailscale.com/k8s-operator/session-recording/tsrecorder"
"tailscale.com/tstest"
- "tailscale.com/tstime"
)
// Test_Writes tests that 1 or more Write calls to spdyRemoteConnRecorder
@@ -56,13 +54,13 @@ func Test_Writes(t *testing.T) {
name: "single_write_stdout_data_frame_with_payload",
inputs: [][]byte{{0x0, 0x0, 0x0, 0x1, 0x0, 0x0, 0x0, 0x5, 0x1, 0x2, 0x3, 0x4, 0x5}},
wantForwarded: []byte{0x0, 0x0, 0x0, 0x1, 0x0, 0x0, 0x0, 0x5, 0x1, 0x2, 0x3, 0x4, 0x5},
- wantRecorded: castLine(t, []byte{0x1, 0x2, 0x3, 0x4, 0x5}, cl),
+ wantRecorded: fakes.CastLine(t, []byte{0x1, 0x2, 0x3, 0x4, 0x5}, cl),
},
{
name: "single_write_stderr_data_frame_with_payload",
inputs: [][]byte{{0x0, 0x0, 0x0, 0x2, 0x0, 0x0, 0x0, 0x5, 0x1, 0x2, 0x3, 0x4, 0x5}},
wantForwarded: []byte{0x0, 0x0, 0x0, 0x2, 0x0, 0x0, 0x0, 0x5, 0x1, 0x2, 0x3, 0x4, 0x5},
- wantRecorded: castLine(t, []byte{0x1, 0x2, 0x3, 0x4, 0x5}, cl),
+ wantRecorded: fakes.CastLine(t, []byte{0x1, 0x2, 0x3, 0x4, 0x5}, cl),
},
{
name: "single_data_frame_unknow_stream_with_payload",
@@ -73,13 +71,13 @@ func Test_Writes(t *testing.T) {
name: "control_frame_and_data_frame_split_across_two_writes",
inputs: [][]byte{{0x80, 0x3, 0x0, 0x1, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x1}, {0x0, 0x0, 0x0, 0x5, 0x1, 0x2, 0x3, 0x4, 0x5}},
wantForwarded: []byte{0x80, 0x3, 0x0, 0x1, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x1, 0x0, 0x0, 0x0, 0x5, 0x1, 0x2, 0x3, 0x4, 0x5},
- wantRecorded: castLine(t, []byte{0x1, 0x2, 0x3, 0x4, 0x5}, cl),
+ wantRecorded: fakes.CastLine(t, []byte{0x1, 0x2, 0x3, 0x4, 0x5}, cl),
},
{
name: "single_first_write_stdout_data_frame_with_payload",
inputs: [][]byte{{0x0, 0x0, 0x0, 0x1, 0x0, 0x0, 0x0, 0x5, 0x1, 0x2, 0x3, 0x4, 0x5}},
wantForwarded: []byte{0x0, 0x0, 0x0, 0x1, 0x0, 0x0, 0x0, 0x5, 0x1, 0x2, 0x3, 0x4, 0x5},
- wantRecorded: append(asciinemaResizeMsg(t, 10, 20), castLine(t, []byte{0x1, 0x2, 0x3, 0x4, 0x5}, cl)...),
+ wantRecorded: append(fakes.AsciinemaResizeMsg(t, 10, 20), fakes.CastLine(t, []byte{0x1, 0x2, 0x3, 0x4, 0x5}, cl)...),
width: 10,
height: 20,
firstWrite: true,
@@ -87,25 +85,21 @@ func Test_Writes(t *testing.T) {
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
- tc := &testConn{}
- sr := &testSessionRecorder{}
- rec := &recorder{
- conn: sr,
- clock: cl,
- start: cl.Now(),
- }
+ tc := &fakes.TestConn{}
+ sr := &fakes.TestSessionRecorder{}
+ rec := tsrecorder.New(sr, cl, cl.Now(), true)
c := &spdyRemoteConnRecorder{
Conn: tc,
log: zl.Sugar(),
rec: rec,
- ch: CastHeader{
+ ch: tsrecorder.CastHeader{
Width: tt.width,
Height: tt.height,
},
}
if !tt.firstWrite {
- // this test case does not intend to test that cast header gets written once
+ // This test case does not intend to test that cast header gets written once.
c.writeCastHeaderOnce.Do(func() {})
}
@@ -118,13 +112,13 @@ func Test_Writes(t *testing.T) {
}
// Assert that the expected bytes have been forwarded to the original destination.
- gotForwarded := tc.writeBuf.Bytes()
+ gotForwarded := tc.WriteBufBytes()
if !reflect.DeepEqual(gotForwarded, tt.wantForwarded) {
t.Errorf("expected bytes not forwarded, wants\n%v\ngot\n%v", tt.wantForwarded, gotForwarded)
}
// Assert that the expected bytes have been forwarded to the session recorder.
- gotRecorded := sr.buf.Bytes()
+ gotRecorded := sr.Bytes()
if !reflect.DeepEqual(gotRecorded, tt.wantRecorded) {
t.Errorf("expected bytes not recorded, wants\n%v\ngot\n%v", tt.wantRecorded, gotRecorded)
}
@@ -197,13 +191,9 @@ func Test_Reads(t *testing.T) {
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
- tc := &testConn{}
- sr := &testSessionRecorder{}
- rec := &recorder{
- conn: sr,
- clock: cl,
- start: cl.Now(),
- }
+ tc := &fakes.TestConn{}
+ sr := &fakes.TestSessionRecorder{}
+ rec := tsrecorder.New(sr, cl, cl.Now(), true)
c := &spdyRemoteConnRecorder{
Conn: tc,
log: zl.Sugar(),
@@ -213,9 +203,8 @@ func Test_Reads(t *testing.T) {
for i, input := range tt.inputs {
c.zlibReqReader = reader
- tc.readBuf.Reset()
- _, err := tc.readBuf.Write(input)
- if err != nil {
+ tc.ResetReadBuf()
+ if err := tc.WriteReadBufBytes(input); err != nil {
t.Fatalf("writing bytes to test conn: %v", err)
}
_, err = c.Read(make([]byte, len(input)))
@@ -244,83 +233,11 @@ func Test_Reads(t *testing.T) {
}
}
-func castLine(t *testing.T, p []byte, clock tstime.Clock) []byte {
- t.Helper()
- j, err := json.Marshal([]any{
- clock.Now().Sub(clock.Now()).Seconds(),
- "o",
- string(p),
- })
- if err != nil {
- t.Fatalf("error marshalling cast line: %v", err)
- }
- return append(j, '\n')
-}
-
func resizeMsgBytes(t *testing.T, width, height int) []byte {
t.Helper()
- bs, err := json.Marshal(spdyResizeMsg{Width: width, Height: height})
+ bs, err := json.Marshal(tsrecorder.ResizeMsg{Width: width, Height: height})
if err != nil {
t.Fatalf("error marshalling resizeMsg: %v", err)
}
return bs
}
-
-func asciinemaResizeMsg(t *testing.T, width, height int) []byte {
- t.Helper()
- ch := CastHeader{
- Width: width,
- Height: height,
- }
- bs, err := json.Marshal(ch)
- if err != nil {
- t.Fatalf("error marshalling CastHeader: %v", err)
- }
- return append(bs, '\n')
-}
-
-type testConn struct {
- net.Conn
- // writeBuf contains whatever was send to the conn via Write.
- writeBuf bytes.Buffer
- // readBuf contains whatever was sent to the conn via Read.
- readBuf bytes.Buffer
- sync.RWMutex // protects the following
- closed bool
-}
-
-var _ net.Conn = &testConn{}
-
-func (tc *testConn) Read(b []byte) (int, error) {
- return tc.readBuf.Read(b)
-}
-
-func (tc *testConn) Write(b []byte) (int, error) {
- return tc.writeBuf.Write(b)
-}
-
-func (tc *testConn) Close() error {
- tc.Lock()
- defer tc.Unlock()
- tc.closed = true
- return nil
-}
-func (tc *testConn) isClosed() bool {
- tc.Lock()
- defer tc.Unlock()
- return tc.closed
-}
-
-type testSessionRecorder struct {
- // buf holds data that was sent to the session recorder.
- buf bytes.Buffer
-}
-
-func (t *testSessionRecorder) Write(b []byte) (int, error) {
- return t.buf.Write(b)
-}
-
-func (t *testSessionRecorder) Close() error {
- t.buf.Reset()
- return nil
-}
diff --git a/cmd/k8s-operator/spdy-frame.go b/k8s-operator/session-recording/spdy/frame.go
index 0ddefdfa1..54b29d33a 100644
--- a/cmd/k8s-operator/spdy-frame.go
+++ b/k8s-operator/session-recording/spdy/frame.go
@@ -3,7 +3,7 @@
//go:build !plan9
-package main
+package spdy
import (
"bytes"
diff --git a/cmd/k8s-operator/spdy-frame_test.go b/k8s-operator/session-recording/spdy/frame_test.go
index 416ddfc8b..c6aa4cf01 100644
--- a/cmd/k8s-operator/spdy-frame_test.go
+++ b/k8s-operator/session-recording/spdy/frame_test.go
@@ -3,7 +3,7 @@
//go:build !plan9
-package main
+package spdy
import (
"bytes"
diff --git a/cmd/k8s-operator/zlib-reader.go b/k8s-operator/session-recording/spdy/zlib-reader.go
index b29772be3..1eb654be3 100644
--- a/cmd/k8s-operator/zlib-reader.go
+++ b/k8s-operator/session-recording/spdy/zlib-reader.go
@@ -3,7 +3,7 @@
//go:build !plan9
-package main
+package spdy
import (
"bytes"
diff --git a/k8s-operator/session-recording/tsrecorder/header.go b/k8s-operator/session-recording/tsrecorder/header.go
new file mode 100644
index 000000000..45c50ca1e
--- /dev/null
+++ b/k8s-operator/session-recording/tsrecorder/header.go
@@ -0,0 +1,54 @@
+// Copyright (c) Tailscale Inc & AUTHORS
+// SPDX-License-Identifier: BSD-3-Clause
+
+//go:build !plan9
+
+package tsrecorder
+
+import "tailscale.com/tailcfg"
+
+// CastHeader is the asciicast header to be sent to the recorder at the start of
+// the recording of a session.
+// https://docs.asciinema.org/manual/asciicast/v2/#header
+type CastHeader struct {
+ // Version is the asciinema file format version.
+ Version int `json:"version"`
+
+ // Width is the terminal width in characters.
+ Width int `json:"width"`
+
+ // Height is the terminal height in characters.
+ Height int `json:"height"`
+
+ // Timestamp is the unix timestamp of when the recording started.
+ Timestamp int64 `json:"timestamp"`
+
+ // Tailscale-specific fields: SrcNode is the full MagicDNS name of the
+ // tailnet node originating the connection, without the trailing dot.
+ SrcNode string `json:"srcNode"`
+
+ // SrcNodeID is the node ID of the tailnet node originating the connection.
+ SrcNodeID tailcfg.StableNodeID `json:"srcNodeID"`
+
+ // SrcNodeTags is the list of tags on the node originating the connection (if any).
+ SrcNodeTags []string `json:"srcNodeTags,omitempty"`
+
+ // SrcNodeUserID is the user ID of the node originating the connection (if not tagged).
+ SrcNodeUserID tailcfg.UserID `json:"srcNodeUserID,omitempty"` // if not tagged
+
+ // SrcNodeUser is the LoginName of the node originating the connection (if not tagged).
+ SrcNodeUser string `json:"srcNodeUser,omitempty"`
+
+ Command string
+
+ // Kubernetes-specific fields:
+ Kubernetes *Kubernetes `json:"kubernetes,omitempty"`
+}
+
+// Kubernetes contains 'kubectl exec' session specific information for
+// tsrecorder.
+type Kubernetes struct {
+ PodName string
+ Namespace string
+ Container string
+}
diff --git a/cmd/k8s-operator/recorder.go b/k8s-operator/session-recording/tsrecorder/tsrecorder.go
index ae17f3820..4ce78a882 100644
--- a/cmd/k8s-operator/recorder.go
+++ b/k8s-operator/session-recording/tsrecorder/tsrecorder.go
@@ -3,7 +3,9 @@
//go:build !plan9
-package main
+// Package tsrecorder contains functionality to send recorded kubectl-exec
+// sessions to tsrecorder.
+package tsrecorder
import (
"encoding/json"
@@ -16,9 +18,18 @@ import (
"tailscale.com/tstime"
)
-// recorder knows how to send the provided bytes to the configured tsrecorder
+func New(conn io.WriteCloser, clock tstime.Clock, start time.Time, failOpen bool) *Client {
+ return &Client{
+ start: start,
+ clock: clock,
+ conn: conn,
+ failOpen: failOpen,
+ }
+}
+
+// Client knows how to send the provided bytes to the configured tsrecorder
// instance in asciinema format.
-type recorder struct {
+type Client struct {
start time.Time
clock tstime.Clock
@@ -36,15 +47,15 @@ type recorder struct {
// Write appends timestamp to the provided bytes and sends them to the
// configured tsrecorder.
-func (rec *recorder) Write(p []byte) (err error) {
+func (c *Client) Write(p []byte) (err error) {
if len(p) == 0 {
return nil
}
- if rec.backOff {
+ if c.backOff {
return nil
}
j, err := json.Marshal([]any{
- rec.clock.Now().Sub(rec.start).Seconds(),
+ c.clock.Now().Sub(c.start).Seconds(),
"o",
string(p),
})
@@ -52,37 +63,42 @@ func (rec *recorder) Write(p []byte) (err error) {
return fmt.Errorf("error marhalling payload: %w", err)
}
j = append(j, '\n')
- if err := rec.writeCastLine(j); err != nil {
- if !rec.failOpen {
+ if err := c.WriteCastLine(j); err != nil {
+ if !c.failOpen {
return fmt.Errorf("error writing payload to recorder: %w", err)
}
- rec.backOff = true
+ c.backOff = true
}
return nil
}
-func (rec *recorder) Close() error {
- rec.mu.Lock()
- defer rec.mu.Unlock()
- if rec.conn == nil {
+func (c *Client) Close() error {
+ c.mu.Lock()
+ defer c.mu.Unlock()
+ if c.conn == nil {
return nil
}
- err := rec.conn.Close()
- rec.conn = nil
+ err := c.conn.Close()
+ c.conn = nil
return err
}
// writeCastLine sends bytes to the tsrecorder. The bytes should be in
// asciinema format.
-func (rec *recorder) writeCastLine(j []byte) error {
- rec.mu.Lock()
- defer rec.mu.Unlock()
- if rec.conn == nil {
+func (c *Client) WriteCastLine(j []byte) error {
+ c.mu.Lock()
+ defer c.mu.Unlock()
+ if c.conn == nil {
return errors.New("recorder closed")
}
- _, err := rec.conn.Write(j)
+ _, err := c.conn.Write(j)
if err != nil {
return fmt.Errorf("recorder write error: %w", err)
}
return nil
}
+
+type ResizeMsg struct {
+ Width int `json:"width"`
+ Height int `json:"height"`
+}
diff --git a/k8s-operator/session-recording/ws/conn.go b/k8s-operator/session-recording/ws/conn.go
new file mode 100644
index 000000000..88bbc2a7f
--- /dev/null
+++ b/k8s-operator/session-recording/ws/conn.go
@@ -0,0 +1,244 @@
+// Copyright (c) Tailscale Inc & AUTHORS
+// SPDX-License-Identifier: BSD-3-Clause
+
+//go:build !plan9
+
+// package ws has functionality to parse 'kubectl exec' sessions streamed using
+// WebSockets protocol.
+package ws
+
+import (
+ "bytes"
+ "encoding/json"
+ "errors"
+ "fmt"
+ "io"
+ "net"
+ "sync"
+
+ "go.uber.org/zap"
+ "k8s.io/apimachinery/pkg/util/remotecommand"
+ "tailscale.com/k8s-operator/session-recording/tsrecorder"
+ "tailscale.com/util/multierr"
+)
+
+// New returns a wrapper around net.Conn that intercepts reads and writes for a
+// websocket streaming session over the provided net.Conn, parses the data as
+// websocket messages and sends message payloads for STDIN/STDOUT streams to a
+// tsrecorder instance using the provided client. Caller must ensure that the
+// session is streamed using WebSockets protocol.
+func New(c net.Conn, rec *tsrecorder.Client, ch tsrecorder.CastHeader, log *zap.SugaredLogger) net.Conn {
+ return &conn{
+ Conn: c,
+ rec: rec,
+ ch: ch,
+ log: log,
+ }
+}
+
+// conn is a wrapper around net.Conn. It reads the bytestream
+// for a 'kubectl exec' session, sends session recording data to the configured
+// recorder and forwards the raw bytes to the original destination.
+// A new conn is created per session.
+// conn only knows to how to read a 'kubectl exec' session that is streamed using WebSocket protocol.
+// https://www.rfc-editor.org/rfc/rfc6455
+type conn struct {
+ net.Conn
+ // rec knows how to send data to a tsrecorder instance.
+ rec *tsrecorder.Client
+ // ch is the asiinema CastHeader for a session.
+ ch tsrecorder.CastHeader
+ log *zap.SugaredLogger
+
+ rmu sync.Mutex // sequences reads
+ // currentReadMsg contains parsed contents of a websocket binary data message that
+ // is currently being read from the underlying net.Conn.
+ currentReadMsg *message
+ // readBuf contains bytes for a currently parsed binary data message
+ // read from the underlying conn. If the message is masked, it is
+ // unmasked in place, so having this buffer allows us to avoid modifying
+ // the original byte array.
+ readBuf bytes.Buffer
+
+ wmu sync.Mutex // sequences writes
+ writeCastHeaderOnce sync.Once
+ closed bool
+ // writeBuf contains bytes for a currently parsed binary data message
+ // being written to the underlying conn. If the message is masked, it is
+ // unmasked in place, so having this buffer allows us to avoid modifying
+ // the original byte array.
+ writeBuf bytes.Buffer
+ // currentWriteMsg contains parsed contents of a websocket binary data message that
+ // is currently being written to the underlying net.Conn.
+ currentWriteMsg *message
+}
+
+// Read reads bytes from the original connection and parses them as websocket
+// message fragments. If the message is for the resize stream, sets the width
+// and height of the CastHeader for this connection.
+// The fragment can be incomplete.
+func (c *conn) Read(b []byte) (int, error) {
+ c.rmu.Lock()
+ defer c.rmu.Unlock()
+ n, err := c.Conn.Read(b)
+ if err != nil {
+ // It seems that we sometimes get a wrapped io.EOF, but the
+ // caller checks for io.EOF with ==.
+ if errors.Is(err, io.EOF) {
+ err = io.EOF
+ }
+ return 0, err
+ }
+
+ typ := messageType(opcode(b))
+ if typ == noOpcode && c.currentReadMsg != nil && !c.currentReadMsg.isFinalized { // subsequent fragment
+ typ = c.currentReadMsg.typ
+ }
+
+ // A control message can not be fragmented and we are not interested in
+ // these messages. Just return.
+ if isControlMessage(typ) {
+ return n, nil
+ }
+
+ // The only data message type that Kubernetes supports is binary message.
+ // If we received another message type, return and let the API server close the connection.
+ // https://github.com/kubernetes/client-go/blob/release-1.30/tools/remotecommand/websocket.go#L281
+ if typ != binaryMessage {
+ c.log.Info("[unexpected] received a data message with a type that is not binary message type %d", typ)
+ return n, nil
+ }
+ if _, err := c.readBuf.Write(b[:n]); err != nil {
+ return 0, fmt.Errorf("[unexpected] error writing message contents to read buffer: %w", err)
+ }
+
+ readMsg := &message{typ: typ} // start a new message...
+ // ... or pick up an already started one if the previous fragment was not final.
+ if c.currentReadMsg != nil && !c.currentReadMsg.isFinalized {
+ readMsg = c.currentReadMsg
+ }
+
+ ok, err := readMsg.Parse(c.readBuf.Bytes(), c.log)
+ if err != nil {
+ return 0, fmt.Errorf("error parsing message: %v", err)
+ }
+ if !ok { // incomplete fragment
+ return n, nil
+ }
+ c.readBuf.Next(len(readMsg.raw))
+
+ if readMsg.isFinalized {
+ // Stream IDs for websocket streams are static.
+ // https://github.com/kubernetes/client-go/blob/v0.30.0-rc.1/tools/remotecommand/websocket.go#L218
+ if readMsg.streamID.Load() == remotecommand.StreamResize {
+ var err error
+ var msg tsrecorder.ResizeMsg
+ if err = json.Unmarshal(readMsg.payload, &msg); err != nil {
+ return 0, fmt.Errorf("error umarshalling resize message: %w", err)
+ }
+ c.ch.Width = msg.Width
+ c.ch.Height = msg.Height
+ }
+ }
+ c.currentReadMsg = readMsg
+ return n, err
+}
+
+// Write parses the written bytes as WebSocket message fragment. If the message
+// is for stdout or stderr streams, it is written to the configured tsrecorder.
+// A message fragment can be incomplete.
+func (c *conn) Write(b []byte) (int, error) {
+ c.wmu.Lock()
+ defer c.wmu.Unlock()
+
+ typ := messageType(opcode(b))
+ // If we are in process of parsing a message fragment, the received
+ // bytes are not structured as a message fragment and can not be used to
+ // determine a message fragment.
+ if len(c.writeBuf.Bytes()) > 0 { // buffer contains previous incomplete fragment
+ typ = c.currentWriteMsg.typ
+ }
+
+ if isControlMessage(typ) {
+ n, err := c.Conn.Write(b)
+ return n, err
+ }
+
+ if _, err := c.writeBuf.Write(b); err != nil {
+ c.log.Errorf("write: error writing to write buf: %v", err)
+ return 0, fmt.Errorf("[unexpected] error writing to internal write buffer: %w", err)
+ }
+
+ writeMsg := &message{typ: typ} // start a new message...
+ // ... or continue the existing one if it has not been finalized.
+ if c.currentWriteMsg != nil && !c.currentWriteMsg.isFinalized {
+ writeMsg = c.currentWriteMsg
+ }
+
+ ok, err := writeMsg.Parse(c.writeBuf.Bytes(), c.log)
+ if err != nil {
+ c.log.Errorf("write: parsing a message errored: %v", err)
+ return 0, fmt.Errorf("write: error parsing message: %v", err)
+ }
+ c.currentWriteMsg = writeMsg
+ if !ok { // incomplete fragment
+ return len(b), nil
+ }
+ c.writeBuf.Next(len(writeMsg.raw)) // advance frame
+
+ if len(writeMsg.payload) != 0 && writeMsg.isFinalized {
+ if writeMsg.streamID.Load() == remotecommand.StreamStdOut || writeMsg.streamID.Load() == remotecommand.StreamStdErr {
+ var err error
+ c.writeCastHeaderOnce.Do(func() {
+ var j []byte
+ j, err = json.Marshal(c.ch)
+ if err != nil {
+ c.log.Infof("error marhsalling conn: %v", err)
+ return
+ }
+ j = append(j, '\n')
+ err = c.rec.WriteCastLine(j)
+ if err != nil {
+ c.log.Errorf("received error from recorder: %v", err)
+ }
+ })
+ if err != nil {
+ return 0, fmt.Errorf("error writing CastHeader: %w", err)
+ }
+ if err := c.rec.Write(writeMsg.payload); err != nil {
+ return 0, fmt.Errorf("error writing message to recorder: %v", err)
+ }
+ }
+ }
+ _, err = c.Conn.Write(c.currentWriteMsg.raw)
+ if err != nil {
+ c.log.Errorf("write: error writing to conn: %v", err)
+ }
+ return len(b), err
+}
+
+func (c *conn) Close() error {
+ c.wmu.Lock()
+ defer c.wmu.Unlock()
+ if c.closed {
+ return nil
+ }
+ // TODO: only do this if this is a normal closure rather than the
+ // reocrding has failed.
+ if c.writeBuf.Len() > 0 {
+ c.Conn.Write(c.writeBuf.Bytes())
+ }
+ c.closed = true
+ connCloseErr := c.Conn.Close()
+ recCloseErr := c.rec.Close()
+ return multierr.New(connCloseErr, recCloseErr)
+}
+
+// opcode reads the websocket message opcode that denotes the message type.
+// opcode is contained in bits [4-8] of the message.
+// https://www.rfc-editor.org/rfc/rfc6455#section-5.2
+func opcode(b []byte) int {
+ // 0xf = 00001111; b & 00001111 zeroes out bits [0 - 3] of b
+ var mask byte = 0xf
+ return int(b[0] & mask)
+}
diff --git a/k8s-operator/session-recording/ws/conn_test.go b/k8s-operator/session-recording/ws/conn_test.go
new file mode 100644
index 000000000..a64b89c56
--- /dev/null
+++ b/k8s-operator/session-recording/ws/conn_test.go
@@ -0,0 +1,171 @@
+// Copyright (c) Tailscale Inc & AUTHORS
+// SPDX-License-Identifier: BSD-3-Clause
+
+//go:build !plan9
+
+package ws
+
+import (
+ "reflect"
+ "testing"
+
+ "go.uber.org/zap"
+ "k8s.io/apimachinery/pkg/util/remotecommand"
+ "tailscale.com/k8s-operator/session-recording/fakes"
+ "tailscale.com/k8s-operator/session-recording/tsrecorder"
+ "tailscale.com/tstest"
+)
+
+func Test_conn_Read(t *testing.T) {
+ zl, err := zap.NewDevelopment()
+ if err != nil {
+ t.Fatal(err)
+ }
+ // Resize stream ID + {"width": 10, "height": 20}
+ testResizeMsg := []byte{byte(remotecommand.StreamResize), 0x7b, 0x22, 0x77, 0x69, 0x64, 0x74, 0x68, 0x22, 0x3a, 0x31, 0x30, 0x2c, 0x22, 0x68, 0x65, 0x69, 0x67, 0x68, 0x74, 0x22, 0x3a, 0x32, 0x30, 0x7d}
+ lenResizeMsgPayload := byte(len(testResizeMsg))
+
+ tests := []struct {
+ name string
+ inputs [][]byte
+ wantWidth int
+ wantHeight int
+ }{
+ {
+ name: "single_read_control_message",
+ inputs: [][]byte{{0x88, 0x0}},
+ },
+ {
+ name: "single_read_resize_message",
+ inputs: [][]byte{append([]byte{0x82, lenResizeMsgPayload}, testResizeMsg...)},
+ wantWidth: 10,
+ wantHeight: 20,
+ },
+ {
+ name: "two_reads_resize_message",
+ inputs: [][]byte{{0x2, 0x9, 0x4, 0x7b, 0x22, 0x77, 0x69, 0x64, 0x74, 0x68, 0x22}, {0x80, 0x11, 0x4, 0x3a, 0x31, 0x30, 0x2c, 0x22, 0x68, 0x65, 0x69, 0x67, 0x68, 0x74, 0x22, 0x3a, 0x32, 0x30, 0x7d}},
+ wantWidth: 10,
+ wantHeight: 20,
+ },
+ }
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ tc := &fakes.TestConn{}
+ tc.ResetReadBuf()
+ c := &conn{
+ Conn: tc,
+ log: zl.Sugar(),
+ }
+ for i, input := range tt.inputs {
+ if err := tc.WriteReadBufBytes(input); err != nil {
+ t.Fatalf("writing bytes to test conn: %v", err)
+ }
+ _, err := c.Read(make([]byte, len(input)))
+ if err != nil {
+ t.Errorf("[%d] conn.Read() errored %v", i, err)
+ return
+ }
+ }
+ if tt.wantHeight != 0 || tt.wantWidth != 0 {
+ if tt.wantWidth != c.ch.Width {
+ t.Errorf("wants width: %v, got %v", tt.wantWidth, c.ch.Width)
+ }
+ if tt.wantHeight != c.ch.Height {
+ t.Errorf("want height: %v, got %v", tt.wantHeight, c.ch.Height)
+ }
+ }
+ })
+ }
+}
+
+func Test_conn_Write(t *testing.T) {
+ zl, err := zap.NewDevelopment()
+ if err != nil {
+ t.Fatal(err)
+ }
+ cl := tstest.NewClock(tstest.ClockOpts{})
+ tests := []struct {
+ name string
+ inputs [][]byte
+ wantForwarded []byte
+ wantRecorded []byte
+ firstWrite bool
+ width int
+ height int
+ }{
+ {
+ name: "single_write_control_frame",
+ inputs: [][]byte{{0x88, 0x0}},
+ wantForwarded: []byte{0x88, 0x0},
+ },
+ {
+ name: "single_write_stdout_data_message",
+ inputs: [][]byte{{0x82, 0x3, 0x1, 0x7, 0x8}},
+ wantForwarded: []byte{0x82, 0x3, 0x1, 0x7, 0x8},
+ wantRecorded: fakes.CastLine(t, []byte{0x7, 0x8}, cl),
+ },
+ {
+ name: "single_write_stderr_data_message",
+ inputs: [][]byte{{0x82, 0x3, 0x2, 0x7, 0x8}},
+ wantForwarded: []byte{0x82, 0x3, 0x2, 0x7, 0x8},
+ wantRecorded: fakes.CastLine(t, []byte{0x7, 0x8}, cl),
+ },
+ {
+ name: "single_write_stdin_data_message",
+ inputs: [][]byte{{0x82, 0x3, 0x0, 0x7, 0x8}},
+ wantForwarded: []byte{0x82, 0x3, 0x0, 0x7, 0x8},
+ },
+ {
+ name: "single_write_stdout_data_message_with_cast_header",
+ inputs: [][]byte{{0x82, 0x3, 0x1, 0x7, 0x8}},
+ wantForwarded: []byte{0x82, 0x3, 0x1, 0x7, 0x8},
+ wantRecorded: append(fakes.AsciinemaResizeMsg(t, 10, 20), fakes.CastLine(t, []byte{0x7, 0x8}, cl)...),
+ width: 10,
+ height: 20,
+ firstWrite: true,
+ },
+ {
+ name: "two_writes_stdout_data_message",
+ inputs: [][]byte{{0x2, 0x3, 0x1, 0x7, 0x8}, {0x80, 0x6, 0x1, 0x1, 0x2, 0x3, 0x4, 0x5}},
+ wantForwarded: []byte{0x2, 0x3, 0x1, 0x7, 0x8, 0x80, 0x6, 0x1, 0x1, 0x2, 0x3, 0x4, 0x5},
+ wantRecorded: fakes.CastLine(t, []byte{0x7, 0x8, 0x1, 0x2, 0x3, 0x4, 0x5}, cl),
+ },
+ }
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ tc := &fakes.TestConn{}
+ sr := &fakes.TestSessionRecorder{}
+ rec := tsrecorder.New(sr, cl, cl.Now(), true)
+ c := &conn{
+ Conn: tc,
+ log: zl.Sugar(),
+ ch: tsrecorder.CastHeader{
+ Width: tt.width,
+ Height: tt.height,
+ },
+ rec: rec,
+ }
+ if !tt.firstWrite {
+ // This test case does not intend to test that cast header gets written once.
+ c.writeCastHeaderOnce.Do(func() {})
+ }
+ for i, input := range tt.inputs {
+ _, err := c.Write(input)
+ if err != nil {
+ t.Fatalf("[%d] conn.Write() errored: %v", i, err)
+ }
+ }
+ // Assert that the expected bytes have been forwarded to the original destination.
+ gotForwarded := tc.WriteBufBytes()
+ if !reflect.DeepEqual(gotForwarded, tt.wantForwarded) {
+ t.Errorf("expected bytes not forwarded, wants\n%v\ngot\n%v", tt.wantForwarded, gotForwarded)
+ }
+
+ // Assert that the expected bytes have been forwarded to the session recorder.
+ gotRecorded := sr.Bytes()
+ if !reflect.DeepEqual(gotRecorded, tt.wantRecorded) {
+ t.Errorf("expected bytes not recorded, wants\n%v\ngot\n%v", tt.wantRecorded, gotRecorded)
+ }
+ })
+ }
+}
diff --git a/k8s-operator/session-recording/ws/message.go b/k8s-operator/session-recording/ws/message.go
new file mode 100644
index 000000000..bf33e6bb2
--- /dev/null
+++ b/k8s-operator/session-recording/ws/message.go
@@ -0,0 +1,253 @@
+// Copyright (c) Tailscale Inc & AUTHORS
+// SPDX-License-Identifier: BSD-3-Clause
+
+//go:build !plan9
+
+package ws
+
+import (
+ "encoding/binary"
+ "fmt"
+ "sync/atomic"
+
+ "github.com/pkg/errors"
+ "go.uber.org/zap"
+)
+
+const (
+ noOpcode messageType = 0 // continuation frame for fragmented messages
+ binaryMessage messageType = 2
+)
+
+// messageType is the type of a websocket data or control message as defined by opcode.
+// https://www.rfc-editor.org/rfc/rfc6455#section-5.2
+// Known types of control messages are close, ping and pong.
+// https://www.rfc-editor.org/rfc/rfc6455#section-5.5
+// The only data message type supported by Kubernetes is binary message
+// https://github.com/kubernetes/client-go/blob/v0.30.0-rc.1/tools/remotecommand/websocket.go#L281
+type messageType int
+
+// message is a parsed Websocket Message.
+type message struct {
+ // payload is the contents of the so far parsed Websocket
+ // data Message payload, potentially from multiple fragments written by
+ // multiple invocations of Parse. As per RFC 6455 We can assume that the
+ // fragments will always arrive in order and data messages will not be
+ // interleaved.
+ payload []byte
+
+ // isFinalized is set to true if msgPayload contains full contents of
+ // the message (the final fragment has been received).
+ isFinalized bool
+
+ // streamID is the stream to which the message belongs, i.e stdin, stout
+ // etc. It is one of the stream IDs defined in
+ // https://github.com/kubernetes/apimachinery/commit/73d12d09c5be8703587b5127416eb83dc3b7e182#diff-291f96e8632d04d2d20f5fb00f6b323492670570d65434e8eac90c7a442d13bdR23-R36
+ streamID atomic.Uint32
+
+ // typ is the type of a WebsocketMessage as defined by its opcode
+ // https://www.rfc-editor.org/rfc/rfc6455#section-5.2
+ typ messageType
+ raw []byte
+}
+
+// Parse accepts a websocket message fragment as a byte slice and parses its contents.
+// The fragment can be:
+// - a fragment that consists of a whole message
+// - an initial fragment for a message for which we expect more fragments
+// - a subsequent fragment for a message that we are currently parsing and whose so-far parsed contents are stored in msg.
+// It is not expected that the byte slice would contain an incomplete fragment or fragment for a different message than the one currently being parsed (if any).
+// Message fragment structure:
+// 0 1 2 3
+// 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1
+// +-+-+-+-+-------+-+-------------+-------------------------------+
+// |F|R|R|R| opcode|M| Payload len | Extended payload length |
+// |I|S|S|S| (4) |A| (7) | (16/64) |
+// |N|V|V|V| |S| | (if payload len==126/127) |
+// | |1|2|3| |K| | |
+// +-+-+-+-+-------+-+-------------+ - - - - - - - - - - - - - - - +
+// | Extended payload length continued, if payload len == 127 |
+// + - - - - - - - - - - - - - - - +-------------------------------+
+// | |Masking-key, if MASK set to 1 |
+// +-------------------------------+-------------------------------+
+// | Masking-key (continued) | Payload Data |
+// +-------------------------------- - - - - - - - - - - - - - - - +
+// : Payload Data continued ... :
+// + - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - +
+// | Payload Data continued ... |
+// +---------------------------------------------------------------+
+// https://www.rfc-editor.org/rfc/rfc6455#section-5.2
+//
+// Fragmentation rules:
+// An unfragmented message consists of a single frame with the FIN
+// bit set (Section 5.2) and an opcode other than 0.
+// A fragmented message consists of a single frame with the FIN bit
+// clear and an opcode other than 0, followed by zero or more frames
+// with the FIN bit clear and the opcode set to 0, and terminated by
+// a single frame with the FIN bit set and an opcode of 0.
+// https://www.rfc-editor.org/rfc/rfc6455#section-5.4
+func (msg *message) Parse(b []byte, log *zap.SugaredLogger) (bool, error) {
+ if msg.typ != binaryMessage {
+ return false, fmt.Errorf("[unexpected] internal error: attempted to parse a message with type %d", msg.typ)
+ }
+
+ msg.isFinalized = isFinalFragment(b)
+
+ maskSet := isMasked(b)
+
+ payloadLength, payloadOffset, maskOffset, err := fragmentDimensions(b, maskSet)
+ if err != nil {
+ return false, fmt.Errorf("error determining payload length: %w", err)
+ }
+ log.Debugf("parse: parsing a message with payload length: %d payload offset: %d maskOffset: %d mask set: %t, is finalized: %t", payloadLength, payloadOffset, maskOffset, maskSet, msg.isFinalized)
+
+ if len(b) < int(payloadOffset)+int(payloadLength) { // incomplete fragment
+ return false, nil
+ }
+ msg.raw = make([]byte, int(payloadOffset)+int(payloadLength))
+ copy(msg.raw, b[:payloadOffset+payloadLength])
+
+ // Extract the payload.
+ msgPayload := b[payloadOffset : payloadOffset+payloadLength]
+
+ // Unmask the payload if needed.
+ if maskSet {
+ m := b[maskOffset:payloadOffset]
+ var mask [4]byte
+ copy(mask[:], m)
+ maskBytes(mask, msgPayload)
+ }
+
+ // Determine what stream the message is for. Stream ID of a Kubernetes
+ // streaming session is a 32bit integer, stored in the first byte of the
+ // message payload.
+ // https://github.com/kubernetes/apimachinery/commit/73d12d09c5be8703587b5127416eb83dc3b7e182#diff-291f96e8632d04d2d20f5fb00f6b323492670570d65434e8eac90c7a442d13bdR23-R36
+ if len(msgPayload) == 0 {
+ return false, errors.New("[unexpected] received a message fragment with no stream ID")
+ }
+
+ streamId := uint32(msgPayload[0])
+ if msg.streamID.Load() != 0 && msg.streamID.Load() != streamId {
+ return false, fmt.Errorf("[unexpected] received message fragments with mismatched streamIDs %d and %d", msg.streamID.Load(), streamId)
+ }
+ msg.streamID.Store(streamId)
+
+ // This is normal, Kubernetes seem to send a couple data messages with
+ // no payloads at the start.
+ if len(msgPayload) < 2 {
+ return true, nil
+ }
+ msgPayload = msgPayload[1:] // remove the stream ID byte
+ msg.payload = append(msg.payload, msgPayload...)
+ return true, nil
+}
+
+// maskBytes applies mask to bytes in place.
+// https://www.rfc-editor.org/rfc/rfc6455#section-5.3
+func maskBytes(key [4]byte, b []byte) {
+ for i := range b {
+ b[i] = b[i] ^ key[i%4]
+ }
+}
+
+// isControlMessage returns true if the message type is one of the know control
+// frame message types.
+// https://www.rfc-editor.org/rfc/rfc6455#section-5.5
+func isControlMessage(t messageType) bool {
+ const (
+ closeMessage messageType = 8
+ pingMessage messageType = 9
+ pongMessage messageType = 10
+ )
+ return t == closeMessage || t == pingMessage || t == pongMessage
+}
+
+// isFinalFragment can be called with websocket message fragment and returns true if
+// the fragment is the final fragment of a websocket message.
+func isFinalFragment(b []byte) bool {
+ // Extract FIN bit. FIN bit is the first bit of a message fragment.
+ const finBitMask byte = 1 << 7
+ finBit := b[0] & finBitMask
+ return finBit != 0
+}
+
+// isMasked can be called with a websocket message fragment and returns true if
+// the payload of the message is masked. It uses the mask bit to determine if
+// the payload is masked.
+// https://www.rfc-editor.org/rfc/rfc6455#section-5.3
+func isMasked(b []byte) bool {
+ return extractFirstBit(b[1]) != 0
+}
+
+// extractFirstBit extracts first bit of a byte by zeroing out all the other
+// bits.
+func extractFirstBit(b byte) byte {
+ const mask byte = 1 << 7
+ return b & mask
+}
+
+// zeroFirstBit returns the provided byte with the first bit set to 0.
+func zeroFirstBit(b byte) byte {
+ const revMask byte = 1 << 7
+ return b & (^revMask)
+}
+
+// fragmentDimensions returns payload length as well as payload offset and mask offset.
+func fragmentDimensions(b []byte, maskSet bool) (payloadLength, payloadOffset, maskOffset int64, _ error) {
+
+ // payload length can be stored either in bits [9-15] or in bytes 2, 3
+ // or in bytes 2, 3, 4, 5, 6, 7.
+ // https://www.rfc-editor.org/rfc/rfc6455#section-5.2
+ // 0 1 2 3
+ // 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1
+ // +-+-+-+-+-------+-+-------------+-------------------------------+
+ // |F|R|R|R| opcode|M| Payload len | Extended payload length |
+ // |I|S|S|S| (4) |A| (7) | (16/64) |
+ // |N|V|V|V| |S| | (if payload len==126/127) |
+ // | |1|2|3| |K| | |
+ // +-+-+-+-+-------+-+-------------+ - - - - - - - - - - - - - - - +
+ // | Extended payload length continued, if payload len == 127 |
+ // + - - - - - - - - - - - - - - - +-------------------------------+
+ payloadLengthIndicator := zeroFirstBit(b[1])
+ var lengthOffset int64
+ switch {
+ case payloadLengthIndicator < 126:
+ lengthOffset = 1
+ maskOffset = 2
+ payloadLength = int64(payloadLengthIndicator)
+ case payloadLengthIndicator == 126:
+ maskOffset = 4
+ lengthOffset = 2
+ payloadLength = extractInt64(b, lengthOffset, 2)
+ case payloadLengthIndicator == 127:
+ maskOffset = 10
+ lengthOffset = 2
+ payloadLength = extractInt64(b, lengthOffset, 6)
+ default:
+ return -1, -1, -1, fmt.Errorf("unexpected payload length indicator value: %v", payloadLengthIndicator)
+ }
+
+ // Masking key can take up 0 or 4 bytes- we need to take that into
+ // account when determining payload offset.
+ // 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1
+ // ....
+ // + - - - - - - - - - - - - - - - +-------------------------------+
+ // | |Masking-key, if MASK set to 1 |
+ // +-------------------------------+-------------------------------+
+ // | Masking-key (continued) | Payload Data |
+ // + - - - - - - - - - - - - - - - +-------------------------------+
+ // ...
+ if maskSet {
+ payloadOffset = maskOffset + 4
+ } else {
+ payloadOffset = maskOffset
+ }
+ return
+}
+
+func extractInt64(b []byte, offset, length int64) int64 {
+ payloadLengthBytes := b[offset : offset+length]
+ payloadLengthBytesPadded := append(make([]byte, 8-len(payloadLengthBytes)), payloadLengthBytes...)
+
+ return int64(binary.BigEndian.Uint64(payloadLengthBytesPadded))
+}
diff --git a/k8s-operator/session-recording/ws/message_test.go b/k8s-operator/session-recording/ws/message_test.go
new file mode 100644
index 000000000..63a80ade9
--- /dev/null
+++ b/k8s-operator/session-recording/ws/message_test.go
@@ -0,0 +1,125 @@
+// Copyright (c) Tailscale Inc & AUTHORS
+// SPDX-License-Identifier: BSD-3-Clause
+
+//go:build !plan9
+
+package ws
+
+import (
+ "reflect"
+ "testing"
+
+ "go.uber.org/zap"
+)
+
+func Test_msg_Parse(t *testing.T) {
+ zl, err := zap.NewDevelopment()
+ if err != nil {
+ t.Fatalf("error creating a test logger: %v", err)
+ }
+ testMask := [4]byte{1, 2, 3, 4}
+ tests := []struct {
+ name string
+ b []byte
+ initialPayload []byte
+ wantPayload []byte
+ wantIsFinalized bool
+ wantStreamID uint32
+ }{
+ {
+ name: "single_fragment_stdout_stream_no_payload_no_mask",
+ b: []byte{0x82, 0x1, 0x1},
+ wantPayload: nil,
+ wantIsFinalized: true,
+ wantStreamID: 1,
+ },
+ {
+ name: "single_fragment_stderr_steam_no_payload_has_mask",
+ b: append([]byte{0x82, 0x81, 0x1, 0x2, 0x3, 0x4}, maskedBytes(testMask, []byte{0x2})...),
+ wantPayload: nil,
+ wantIsFinalized: true,
+ wantStreamID: 2,
+ },
+ {
+ name: "single_fragment_stdout_stream_no_mask_has_payload",
+ b: []byte{0x82, 0x3, 0x1, 0x7, 0x8},
+ wantPayload: []byte{0x7, 0x8},
+ wantIsFinalized: true,
+ wantStreamID: 1,
+ },
+ {
+ name: "single_fragment_stdout_stream_has_mask_has_payload",
+ b: append([]byte{0x82, 0x83, 0x1, 0x2, 0x3, 0x4}, maskedBytes(testMask, []byte{0x1, 0x7, 0x8})...),
+ wantPayload: []byte{0x7, 0x8},
+ wantIsFinalized: true,
+ wantStreamID: 1,
+ },
+ {
+ name: "initial_fragment_stdout_stream_no_mask_has_payload",
+ b: []byte{0x2, 0x3, 0x1, 0x7, 0x8},
+ wantPayload: []byte{0x7, 0x8},
+ wantStreamID: 1,
+ },
+ {
+ name: "initial_fragment_stdout_stream_has_mask_has_payload",
+ b: append([]byte{0x2, 0x83, 0x1, 0x2, 0x3, 0x4}, maskedBytes(testMask, []byte{0x1, 0x7, 0x8})...),
+ wantPayload: []byte{0x7, 0x8},
+ wantStreamID: 1,
+ },
+ {
+ name: "subsequent_fragment_stdout_stream_no_mask_has_payload",
+ b: []byte{0x0, 0x3, 0x1, 0x7, 0x8},
+ initialPayload: []byte{0x1, 0x2, 0x3},
+ wantPayload: []byte{0x1, 0x2, 0x3, 0x7, 0x8},
+ wantStreamID: 1,
+ },
+ {
+ name: "subsequent_fragment_stdout_stream_has_mask_has_payload",
+ b: append([]byte{0x0, 0x83, 0x1, 0x2, 0x3, 0x4}, maskedBytes(testMask, []byte{0x1, 0x7, 0x8})...),
+ initialPayload: []byte{0x1, 0x2, 0x3},
+ wantPayload: []byte{0x1, 0x2, 0x3, 0x7, 0x8},
+ wantStreamID: 1,
+ },
+ {
+ name: "final_fragment_stdout_stream_no_mask_has_payload",
+ b: []byte{0x80, 0x3, 0x1, 0x7, 0x8},
+ initialPayload: []byte{0x1, 0x2, 0x3},
+ wantIsFinalized: true,
+ wantPayload: []byte{0x1, 0x2, 0x3, 0x7, 0x8},
+ wantStreamID: 1,
+ },
+ {
+ name: "final_fragment_stdout_stream_has_mask_has_payload",
+ b: append([]byte{0x80, 0x83, 0x1, 0x2, 0x3, 0x4}, maskedBytes(testMask, []byte{0x1, 0x7, 0x8})...),
+ initialPayload: []byte{0x1, 0x2, 0x3},
+ wantIsFinalized: true,
+ wantPayload: []byte{0x1, 0x2, 0x3, 0x7, 0x8},
+ wantStreamID: 1,
+ },
+ }
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ msg := &message{
+ typ: binaryMessage,
+ payload: tt.initialPayload,
+ }
+ if _, err := msg.Parse(tt.b, zl.Sugar()); err != nil {
+ t.Errorf("msg.Parse() errored %v", err)
+ }
+ if msg.isFinalized != tt.wantIsFinalized {
+ t.Errorf("wants message to be finalized: %t, got: %t", tt.wantIsFinalized, msg.isFinalized)
+ }
+ if msg.streamID.Load() != tt.wantStreamID {
+ t.Errorf("wants stream ID: %d, got: %d", tt.wantStreamID, msg.streamID.Load())
+ }
+ if !reflect.DeepEqual(msg.payload, tt.wantPayload) {
+ t.Errorf("unexpected message payload after Parse, wants %b, got %b", tt.wantPayload, msg.payload)
+ }
+ })
+ }
+}
+
+func maskedBytes(mask [4]byte, b []byte) []byte {
+ maskBytes(mask, b)
+ return b
+}