summaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorDavid Anderson <dave@tailscale.com>2025-07-29 10:41:18 -0700
committerDavid Anderson <dave@tailscale.com>2025-07-29 10:59:40 -0700
commit2907b24fb0d5d0fafb2b7d04be59df9f7e009e04 (patch)
treea66dcdd0efd4cc98e86459c6d6aa6175294e87dc
parente37432afb7acb012576b8df483d31492317b790b (diff)
downloadtailscale-push-tyyxlsmpmlvz.tar.xz
tailscale-push-tyyxlsmpmlvz.zip
WIP: arena-based packet buffer abstractionpush-tyyxlsmpmlvz
Signed-off-by: David Anderson <dave@tailscale.com>
-rw-r--r--net/pktbuf/arena.go47
-rw-r--r--net/pktbuf/buffer.go344
-rw-r--r--net/pktbuf/packet.go268
3 files changed, 659 insertions, 0 deletions
diff --git a/net/pktbuf/arena.go b/net/pktbuf/arena.go
new file mode 100644
index 000000000..748f3718f
--- /dev/null
+++ b/net/pktbuf/arena.go
@@ -0,0 +1,47 @@
+package pktbuf
+
+import "slices"
+
+// Arena is an arena-based memory allocator for byte slices.
+type Arena struct {
+ mem []byte
+ high int // high water mark for previous arena cycles
+ avg float32
+}
+
+const initialArenaChunkSize = 4096
+
+// Get allocates and returns a byte slice of the given size.
+//
+// The allocation remains valid until the next call to UnsafelyReset.
+func (a *Arena) Get(sz int) []byte {
+ a.mem = slices.Grow(a.mem, sz)
+ ln := len(a.mem)
+ a.mem = a.mem[:ln+sz]
+ ret := a.mem[ln : ln+sz : ln+sz]
+ // compiler should turn this into an efficient memset.
+ for i := range ret {
+ ret[i] = 0
+ }
+ return ret
+}
+
+const shrinkHysteresis = 1024
+
+// Reset clears the arena for reuse. Past allocations are unaffected.
+func (a *Arena) Reset() {
+ a.mem = nil
+}
+
+// UnsafelyReset clears the arena for reuse. Unlike Reset,
+// UnsafelyReset reuses the arena's existing storage for future
+// allocations, so callers MUST cease all use of previously allocated
+// slices prior to reset.
+func (a *Arena) UnsafelyReset() {
+ a.high = max(a.high, len(a.mem))
+ a.avg = 0.9*a.avg + 0.1*float32(len(a.mem))
+ if avgInt := int(a.avg); avgInt < a.high-shrinkHysteresis {
+ a.mem = make([]byte, 0, avgInt)
+ }
+ a.mem = a.mem[:0]
+}
diff --git a/net/pktbuf/buffer.go b/net/pktbuf/buffer.go
new file mode 100644
index 000000000..c2454c4d9
--- /dev/null
+++ b/net/pktbuf/buffer.go
@@ -0,0 +1,344 @@
+package pktbuf
+
+import (
+ "iter"
+ "slices"
+)
+
+// A chunkBuffer is like a byte slice, but internally the bytes are
+// stored as a list of chunks ([][]byte), with spare nil slices on
+// either side to allow for efficient insertion and deletion of
+// chunks.
+//
+// Most chunkBuffer operations require a linear traversal of the chunk
+// list. As such, it's intended for uses where the number of chunks is
+// low enough that this linear traversal is very fast. Using a
+// chunkBuffer with up to 100 chunks is probably fine, but beyond that
+// you probably want to use something like a rope instead, which
+// scales up gracefully but has poor spatial locality and memory
+// access patterns at smaller scale.
+type chunkBuffer struct {
+ chunks [][]byte
+ // start and end are indices in chunks of the chunks currently
+ // being used. That is, chunks[start:end] is the range of non-nil
+ // slices.
+ start, end int
+ length int
+}
+
+// len reports the number of bytes in the buffer.
+func (c *chunkBuffer) len() int {
+ return c.length
+}
+
+// startGap reports the number of unused chunk slots at the start of
+// the buffer.
+func (c *chunkBuffer) startGap() int {
+ return c.start
+}
+
+// endGap reports the number of unused chunk slots at the end of the
+// buffer.
+func (c *chunkBuffer) endGap() int {
+ return len(c.chunks) - c.end
+}
+
+// grow increases the buffer's chunk capacity to have at least minGap
+// unused chunk slots at both the start and end of the buffer.
+func (c *chunkBuffer) grow(minGap int) {
+ used := c.end - c.start
+ minLen := used + 2*minGap
+
+ // Depending on the operations that took place in the past, the
+ // position of the in-use chunks might be lopsided (e.g. only 1
+ // slot available at the start but 32 at the end).
+ //
+ // In that case, as long as the minimum gap requirement is met,
+ // this logic will avoid taking the hit of a reallocation. The
+ // rest of the code below will boil down to just re-centering the
+ // chunks within the slice.
+ tgt := min(len(c.chunks), 16)
+ for tgt < minLen {
+ tgt *= 2
+ }
+
+ c.chunks = slices.Grow(c.chunks, tgt-len(c.chunks))
+ c.chunks = c.chunks[:cap(c.chunks)]
+
+ gap := (tgt - used) / 2
+ copy(c.chunks[gap:], c.chunks[c.start:c.end])
+ c.start = gap
+ c.end = gap + used
+}
+
+// ensureStartGap ensures that at least minGap unused chunk slots are
+// available at the start of the buffer.
+func (c *chunkBuffer) ensureStartGap(minGap int) {
+ if c.startGap() < minGap {
+ c.grow(minGap)
+ }
+}
+
+// ensureEndGap ensures that at least minGap unused chunk slots are
+// available at the end of the buffer.
+func (c *chunkBuffer) ensureEndGap(minGap int) {
+ if c.endGap() < minGap {
+ c.grow(minGap)
+ }
+}
+
+// append adds bs to the end of the buffer.
+//
+// The caller must not mutate bs after appending it.
+func (c *chunkBuffer) append(bss ...[]byte) {
+ c.ensureEndGap(len(bss))
+ for _, bs := range bss {
+ c.chunks[c.end] = slices.Clip(bs)
+ c.end++
+ c.length += len(bs)
+ }
+}
+
+// prepend adds bs to the start of the buffer.
+//
+// The caller must not mutate bs after prepending it.
+func (c *chunkBuffer) prepend(bss ...[]byte) {
+ c.ensureStartGap(len(bss))
+ for _, bs := range bss {
+ c.start--
+ c.chunks[c.start] = slices.Clip(bs)
+ c.length += len(bs)
+ }
+}
+
+// insert inserts bs at the given offset in the buffer.
+func (c *chunkBuffer) insert(bs []byte, off int) {
+ idx := c.mkGap(off, 1)
+ c.chunks[idx] = slices.Clip(bs)
+ c.length += len(bs)
+}
+
+// splice splices the chunks of other into the buffer at the given
+// offset.
+//
+// After calling splice, other is empty and can be reused.
+func (c *chunkBuffer) splice(other *chunkBuffer, off int) {
+ sz := other.end - other.start
+ if sz == 0 {
+ return
+ }
+ idx := c.mkGap(off, sz)
+ copy(c.chunks[idx:idx+sz], other.chunks[c.start:c.end])
+ c.length += other.length
+ other.chunks = deleteCompact(other.chunks, 0, len(other.chunks))
+ other.start = len(other.chunks) / 2
+ other.end = len(other.chunks) / 2
+ other.length = 0
+}
+
+// deletePrefix removes sz bytes from the start of the buffer.
+func (c *chunkBuffer) deletePrefix(sz int) {
+ origSz := sz
+ for c.start != c.end {
+ if len(c.chunks[c.start]) >= sz {
+ c.chunks[c.start] = nil
+ c.start++
+ continue
+ }
+ if sz > 0 {
+ c.chunks[c.start] = slices.Clip(c.chunks[c.start][sz:])
+ }
+ break
+ }
+ c.length = max(0, c.length-origSz)
+}
+
+// deleteSuffix removes sz bytes from the end of the buffer.
+func (c *chunkBuffer) deleteSuffix(sz int) {
+ origSz := sz
+ for c.start != c.end {
+ if len(c.chunks[c.end-1]) >= sz {
+ c.chunks[c.end-1] = nil
+ c.end--
+ continue
+ }
+ if sz > 0 {
+ c.chunks[c.end-1] = c.chunks[c.end-1][sz:]
+ }
+ break
+ }
+ c.length -= max(0, c.length-origSz)
+}
+
+// delete removes the byte range [off:off+sz] from the buffer.
+func (c *chunkBuffer) delete(off, sz int) {
+ deleteStart := -1
+ for i, chunk := range c.chunks[c.start:c.end] {
+ if len(chunk) > off {
+ deleteStart = i
+ break
+ }
+ off -= len(chunk)
+ }
+ if off > 0 {
+ c.chunks[deleteStart] = slices.Clip(c.chunks[deleteStart][:off])
+ sz -= off
+ off = 0
+ deleteStart++
+ }
+ deleteEnd := -1
+ for i, chunk := range c.chunks[deleteStart:c.end] {
+ if len(chunk) > sz {
+ deleteEnd = i
+ break
+ }
+ sz -= len(chunk)
+ }
+ if sz > 0 {
+ c.chunks[deleteEnd] = c.chunks[deleteEnd][sz:]
+ }
+ c.chunks = deleteCompact(c.chunks, deleteStart, deleteEnd)
+}
+
+// extract removes the byte range [off:off+sz] from the buffer, and
+// returns it as a new buffer.
+func (c *chunkBuffer) extract(off, sz int) chunkBuffer {
+ startIdx := c.mkGap(off, 0)
+ endIdx := c.mkGap(off+sz, 0)
+ retSz := endIdx - startIdx
+ var ret chunkBuffer
+ ret.ensureEndGap(retSz)
+ copy(ret.chunks[c.start:], c.chunks[startIdx:endIdx])
+ ret.length = sz
+ c.chunks = deleteCompact(c.chunks, startIdx, endIdx)
+ c.length -= sz
+ return ret
+}
+
+// mkGap creates a gap of sz nil chunks at the given byte offset.
+//
+// Returns the index in c.chunks of the start of the gap. To fill the
+// gap, copy into c.chunks[returnedIdx:returnedIdx+sz].
+func (c *chunkBuffer) mkGap(off int, sz int) int {
+ switch {
+ case off == 0:
+ c.ensureStartGap(sz)
+ c.start -= sz
+ return c.start
+ case off == c.len():
+ c.ensureEndGap(sz)
+ ret := c.end
+ c.end += sz
+ return ret
+ default:
+ at := 0
+ for i, chunk := range c.chunks[c.start:c.end] {
+ switch {
+ case at == off:
+ // The right chunk boundary already exists, just need
+ // to make room.
+ if sz > 0 {
+ c.ensureEndGap(sz)
+ copy(c.chunks[i+sz:], c.chunks[i:c.end])
+ c.end += sz
+ }
+ return i
+ case at+len(chunk) < off:
+ at += len(chunk)
+ off -= len(chunk)
+ continue
+ default:
+ // Need to split the chunk to create the correct boundary.
+ c.ensureEndGap(sz + 1)
+ copy(c.chunks[i+sz+1:], c.chunks[i+1:c.end])
+ c.chunks[i+sz] = c.chunks[i][off-at:]
+ c.chunks[i] = c.chunks[i][:off-at]
+ c.end += sz + 1
+ return i + 1
+ }
+ }
+ panic("requested offset outside of slice range")
+ }
+}
+
+// allChunks returns the currently in-use chunks.
+//
+// The returned chunks are only valid until the next mutation of the
+// chunkBuffer.
+func (c *chunkBuffer) allChunks() [][]byte {
+ return c.chunks[c.start:c.end]
+}
+
+// slices iterates over the currently in-use chunks.
+//
+// The chunkBuffer must not be mutated while the iterator is active.
+func (c *chunkBuffer) slices(off, sz int) iter.Seq[[]byte] {
+ return func(yield func([]byte) bool) {
+ next, stop := iter.Pull(slices.Values(c.chunks[c.start:c.end]))
+ defer stop()
+ var (
+ chunk []byte
+ ok bool
+ )
+ for off > 0 {
+ chunk, ok = next()
+ if !ok {
+ panic("requested slices offset is out of bounds")
+ }
+ if len(chunk) > off {
+ break
+ }
+ off -= len(chunk)
+ }
+
+ // First chunk to output needs extra calculations to account
+ // for an offset within the chunk. The loop after that can
+ // skip that extra math.
+ end := min(off+sz, len(chunk))
+ if !yield(chunk[off:end]) {
+ return
+ }
+ sz -= end - off
+
+ for sz > 0 {
+ chunk, ok = next()
+ if !ok {
+ panic("requested slice endpoint is out of bounds")
+ }
+ end := min(sz, len(chunk))
+ if !yield(chunk[:end]) {
+ return
+ }
+ sz -= end
+ }
+ }
+}
+
+// readAt reads exactly len(bs) bytes into bs from the given offset in
+// the chunkBuffer.
+//
+// Panics if the range to read is out of bounds.
+func (c *chunkBuffer) readAt(bs []byte, off int) {
+ for chunk := range c.slices(off, len(bs)) {
+ copy(bs, chunk)
+ bs = bs[len(chunk):]
+ }
+}
+
+// writeAt writes bs to the given offset in the chunkBuffer.
+//
+// Panics if the range to write is out of bounds.
+func (c *chunkBuffer) writeAt(bs []byte, off int) {
+ for chunk := range c.slices(off, len(bs)) {
+ copy(chunk, bs)
+ bs = bs[len(chunk):]
+ }
+}
+
+// deleteCompact is similar to slices.Delete, but doesn't shrink the
+// length of bs. Instead, elements past the deletion point are shifted
+// backwards, and leftover trailing elements are nil'd.
+func deleteCompact(bs [][]byte, start, end int) [][]byte {
+ ln := len(bs)
+ return slices.Delete(bs, start, end)[:ln:ln]
+}
diff --git a/net/pktbuf/packet.go b/net/pktbuf/packet.go
new file mode 100644
index 000000000..3397decb8
--- /dev/null
+++ b/net/pktbuf/packet.go
@@ -0,0 +1,268 @@
+package pktbuf
+
+import (
+ "encoding/binary"
+
+ "golang.org/x/net/ipv4"
+ "golang.org/x/net/ipv6"
+)
+
+// A Segment is a chunk of bytes extracted from a Packet.
+//
+// The bytes are not accessible directly through the Segment. The only
+// valid operation on Segments is to reattach them to a Packet.
+type Segment struct {
+ arena *Arena
+ buf chunkBuffer
+}
+
+// A Packet is a bunch of bytes with attached metadata.
+type Packet[Meta any] struct {
+ arena *Arena
+ buf chunkBuffer
+ Meta Meta
+}
+
+// NewPacket allocates a new packet from the given arena, containing
+// sz zero bytes and with the given metadata attached.
+func NewPacket[Meta any](arena *Arena, sz int, meta Meta) *Packet[Meta] {
+ ret := &Packet[Meta]{
+ arena: arena,
+ Meta: meta,
+ }
+ ret.Grow(sz)
+ return ret
+}
+
+// Extract removes the slice [off:off+sz] from the packet, and returns
+// it as a Segment.
+func (p *Packet[Meta]) Extract(off, sz int) Segment {
+ return Segment{
+ arena: p.arena,
+ buf: p.buf.extract(off, sz),
+ }
+}
+
+// Append appends the given Segments to the end of the packet.
+func (p *Packet[Meta]) Append(segs ...Segment) {
+ for _, seg := range segs {
+ if seg.arena != p.arena {
+ panic("cannot append segment from different arena")
+ }
+ p.buf.append(seg.buf.allChunks()...)
+ }
+}
+
+// AppendBytes appends bs to the end of the packet.
+//
+// bs is copied into a fresh allocation from the packet's Arena.
+func (p *Packet[Meta]) AppendBytes(bs []byte) {
+ b := p.arena.Get(len(bs))
+ copy(b, bs)
+ p.buf.append(b)
+}
+
+// Prepend prepends the given Segments to the start of the packet.
+func (p *Packet[Meta]) Prepend(segs ...Segment) {
+ for _, seg := range segs {
+ if seg.arena != p.arena {
+ panic("cannot prepend segment from different arena")
+ }
+ p.buf.prepend(seg.buf.allChunks()...)
+ }
+}
+
+// PrependBytes prepends the given bytes to the start of the packet.
+//
+// bs is copied into a fresh allocation from the packet's Arena.
+func (p *Packet[Meta]) PrependBytes(bs []byte) {
+ b := p.arena.Get(len(bs))
+ copy(b, bs)
+ p.buf.prepend(b)
+}
+
+// Insert inserts seg into the packet at the given offset.
+func (p *Packet[Meta]) Insert(off int, seg Segment) {
+ p.buf.splice(&seg.buf, off)
+}
+
+// Grow adds sz zero bytes to the end of the packet.
+func (p *Packet[Meta]) Grow(sz int) {
+ if sz == 0 {
+ return
+ }
+ p.buf.append(p.arena.Get(sz))
+}
+
+// GrowFront adds sz zero bytes to the start of the packet.
+func (p *Packet[Meta]) GrowFront(sz int) {
+ if sz == 0 {
+ return
+ }
+ p.buf.prepend(p.arena.Get(sz))
+}
+
+// WriteAt writes bs to the given offset in the packet.
+//
+// Panics if the range [off:off+len(bs)] is out of bounds.
+func (p *Packet[Meta]) WriteAt(bs []byte, off int64) {
+ p.buf.writeAt(bs, int(off))
+}
+
+// ReadAt reads len(bs) bytes from the given offset in the packet.
+//
+// Panics if the range [off:off+len(bs)] is out of bounds.
+func (p *Packet[Meta]) ReadAt(bs []byte, off int64) {
+ p.buf.readAt(bs, int(off))
+}
+
+// Uint8 returns the value of the byte at off in the packet.
+//
+// Panics if off is out of bounds.
+func (p *Packet[Meta]) Uint8(off int64) byte {
+ var bs [1]byte
+ p.ReadAt(bs[:], off)
+ return bs[0]
+}
+
+// Uint16BE returns the big-endian 16-bit value at off in the packet.
+//
+// Panics if off is out of bounds.
+func (p *Packet[Meta]) Uint16BE(off int64) uint16 {
+ var bs [2]byte
+ p.ReadAt(bs[:], off)
+ return binary.BigEndian.Uint16(bs[:])
+}
+
+// Uint16LE returns the little-endian 16-bit value at off in the
+// packet.
+//
+// Panics if off is out of bounds.
+func (p *Packet[Meta]) Uint16LE(off int64) uint16 {
+ var bs [2]byte
+ p.ReadAt(bs[:], off)
+ return binary.LittleEndian.Uint16(bs[:])
+}
+
+// Uint32BE returns the big-endian 32-bit value at off in the
+// packet.
+//
+// Panics if off is out of bounds.
+func (p *Packet[Meta]) Uint32BE(off int64) uint32 {
+ var bs [4]byte
+ p.ReadAt(bs[:], off)
+ return binary.BigEndian.Uint32(bs[:])
+}
+
+// Uint32LE returns the little-endian 32-bit value at off in the
+// packet.
+//
+// Panics if off is out of bounds.
+func (p *Packet[Meta]) Uint32LE(off int64) uint32 {
+ var bs [4]byte
+ p.ReadAt(bs[:], off)
+ return binary.LittleEndian.Uint32(bs[:])
+}
+
+// Uint64BE returns the big-endian 64-bit value at off in the
+// packet.
+//
+// Panics if off is out of bounds.
+func (p *Packet[Meta]) Uint64BE(off int64) uint64 {
+ var bs [8]byte
+ p.ReadAt(bs[:], off)
+ return binary.BigEndian.Uint64(bs[:])
+}
+
+// Uint64LE returns the little-endian 64-bit value at off in the
+// packet.
+//
+// Panics if off is out of bounds.
+func (p *Packet[Meta]) Uint64LE(off int64) uint64 {
+ var bs [8]byte
+ p.ReadAt(bs[:], off)
+ return binary.LittleEndian.Uint64(bs[:])
+}
+
+// PutUint8 writes v at the given offset.
+//
+// Panics if off is out of bounds.
+func (p *Packet[Meta]) PutUint8(v byte, off int64) {
+ var bs [1]byte
+ bs[0] = v
+ p.buf.writeAt(bs[:], int(off))
+}
+
+// PutUint16BE writes v in big-endian order at the given offset.
+//
+// Panics if off is out of bounds.
+func (p *Packet[Meta]) PutUint16BE(v uint16, off int64) {
+ var bs [2]byte
+ binary.BigEndian.PutUint16(bs[:], v)
+ p.WriteAt(bs[:], off)
+}
+
+// PutUint16LE writes v in little-endian order at the given offset.
+//
+// Panics if off is out of bounds.
+func (p *Packet[Meta]) PutUint16LE(v uint16, off int64) {
+ var bs [2]byte
+ binary.LittleEndian.PutUint16(bs[:], v)
+ p.WriteAt(bs[:], off)
+}
+
+// PutUint32BE writes v in big-endian order at the given offset.
+//
+// Panics if off is out of bounds.
+func (p *Packet[Meta]) PutUint32BE(v uint32, off int64) {
+ var bs [4]byte
+ binary.BigEndian.PutUint32(bs[:], v)
+ p.WriteAt(bs[:], off)
+}
+
+// PutUint32LE writes v in little-endian order at the given offset.
+//
+// Panics if off is out of bounds.
+func (p *Packet[Meta]) PutUint32LE(v uint32, off int64) {
+ var bs [4]byte
+ binary.LittleEndian.PutUint32(bs[:], v)
+ p.WriteAt(bs[:], off)
+}
+
+// PutUint64BE writes v in big-endian order at the given offset.
+//
+// Panics if off is out of bounds.
+func (p *Packet[Meta]) PutUint64BE(v uint64, off int64) {
+ var bs [8]byte
+ binary.BigEndian.PutUint64(bs[:], v)
+ p.WriteAt(bs[:], off)
+}
+
+// PutUint64LE writes v in little-endian order at the given offset.
+//
+// Panics if off is out of bounds.
+func (p *Packet[Meta]) PutUint64LE(v uint64, off int64) {
+ var bs [8]byte
+ binary.LittleEndian.PutUint64(bs[:], v)
+ p.WriteAt(bs[:], off)
+}
+
+// Message4 constructs an ipv4.Message from the packet.
+//
+// The ipv4.Message is only valid until the next mutation of the
+// packet.
+func (p *Packet[Meta]) Message4() ipv4.Message {
+ return ipv4.Message{
+ Buffers: p.buf.allChunks(),
+ }
+}
+
+// Message6 constructs an ipv6.Message from the packet.
+//
+// The ipv6.Message is only valid until the next mutation of the
+// packet.
+func (p *Packet[Meta]) Message6() ipv6.Message {
+ return ipv6.Message{
+ Buffers: p.buf.allChunks(),
+ }
+}