summaryrefslogtreecommitdiffhomepage
path: root/util/zstdframe/options.go
blob: b4b0f2b85304c09927c7dedd0bf431ed3ad63e69 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
// Copyright (c) Tailscale Inc & AUTHORS
// SPDX-License-Identifier: BSD-3-Clause

package zstdframe

import (
	"math/bits"
	"strconv"
	"sync"

	"github.com/klauspost/compress/zstd"
	"tailscale.com/util/must"
)

// Option is an option that can be passed to [AppendEncode] or [AppendDecode].
type Option interface{ isOption() }

type encoderLevel int

// Constants that implement [Option] and can be passed to [AppendEncode].
const (
	FastestCompression = encoderLevel(zstd.SpeedFastest)
	DefaultCompression = encoderLevel(zstd.SpeedDefault)
	BetterCompression  = encoderLevel(zstd.SpeedBetterCompression)
	BestCompression    = encoderLevel(zstd.SpeedBestCompression)
)

func (encoderLevel) isOption() {}

// EncoderLevel specifies the compression level when encoding.
//
// This exists for compatibility with [zstd.EncoderLevel] values.
// Most usages should directly use one of the following constants:
//   - [FastestCompression]
//   - [DefaultCompression]
//   - [BetterCompression]
//   - [BestCompression]
//
// By default, [DefaultCompression] is chosen.
// This option is ignored when decoding.
func EncoderLevel(level zstd.EncoderLevel) Option { return encoderLevel(level) }

type withChecksum bool

func (withChecksum) isOption() {}

// WithChecksum specifies whether to produce a checksum when encoding,
// or whether to verify the checksum when decoding.
// By default, checksums are produced and verified.
func WithChecksum(check bool) Option { return withChecksum(check) }

type maxDecodedSize uint64

func (maxDecodedSize) isOption() {}

type maxDecodedSizeLog2 uint8 // uint8 avoids allocation when storing into interface

func (maxDecodedSizeLog2) isOption() {}

// MaxDecodedSize specifies the maximum decoded size and
// is used to protect against hostile content.
// By default, there is no limit.
// This option is ignored when encoding.
func MaxDecodedSize(maxSize uint64) Option {
	if bits.OnesCount64(maxSize) == 1 {
		return maxDecodedSizeLog2(log2(maxSize))
	}
	return maxDecodedSize(maxSize)
}

type maxWindowSizeLog2 uint8 // uint8 avoids allocation when storing into interface

func (maxWindowSizeLog2) isOption() {}

// MaxWindowSize specifies the maximum window size, which must be a power-of-two
// and be in the range of [[zstd.MinWindowSize], [zstd.MaxWindowSize]].
//
// The compression or decompression algorithm will use a LZ77 rolling window
// no larger than the specified size. The compression ratio will be
// adversely affected, but memory requirements will be lower.
// When decompressing, an error is reported if a LZ77 back reference exceeds
// the specified maximum window size.
//
// For decompression, [MaxDecodedSize] is generally more useful.
func MaxWindowSize(maxSize uint64) Option {
	switch {
	case maxSize < zstd.MinWindowSize:
		panic("maximum window size cannot be less than " + strconv.FormatUint(zstd.MinWindowSize, 10))
	case bits.OnesCount64(maxSize) != 1:
		panic("maximum window size must be a power-of-two")
	case maxSize > zstd.MaxWindowSize:
		panic("maximum window size cannot be greater than " + strconv.FormatUint(zstd.MaxWindowSize, 10))
	default:
		return maxWindowSizeLog2(log2(maxSize))
	}
}

type lowMemory bool

func (lowMemory) isOption() {}

// LowMemory specifies that the encoder and decoder should aim to use
// lower amounts of memory at the cost of speed.
// By default, more memory used for better speed.
func LowMemory(low bool) Option { return lowMemory(low) }

var encoderPools sync.Map // map[encoderOptions]*sync.Pool -> *zstd.Encoder

type encoderOptions struct {
	level         zstd.EncoderLevel
	maxWindowLog2 uint8
	checksum      bool
	lowMemory     bool
}

type encoder struct {
	pool *sync.Pool
	*zstd.Encoder
}

func getEncoder(opts ...Option) encoder {
	eopts := encoderOptions{level: zstd.SpeedDefault, checksum: true}
	for _, opt := range opts {
		switch opt := opt.(type) {
		case encoderLevel:
			eopts.level = zstd.EncoderLevel(opt)
		case maxWindowSizeLog2:
			eopts.maxWindowLog2 = uint8(opt)
		case withChecksum:
			eopts.checksum = bool(opt)
		case lowMemory:
			eopts.lowMemory = bool(opt)
		}
	}

	vpool, ok := encoderPools.Load(eopts)
	if !ok {
		vpool, _ = encoderPools.LoadOrStore(eopts, new(sync.Pool))
	}
	pool := vpool.(*sync.Pool)
	enc, _ := pool.Get().(*zstd.Encoder)
	if enc == nil {
		var noopts int
		zopts := [...]zstd.EOption{
			// Set concurrency=1 to ensure synchronous operation.
			zstd.WithEncoderConcurrency(1),
			// In stateless compression, the data is already in a single buffer,
			// so we might as well encode it as a single segment,
			// which ensures that the Frame_Content_Size is always populated,
			// informing decoders up-front the expected decompressed size.
			zstd.WithSingleSegment(true),
			// Ensure strict compliance with RFC 8878, section 3.1.,
			// where zstandard "is made up of one or more frames".
			zstd.WithZeroFrames(true),
			zstd.WithEncoderLevel(eopts.level),
			zstd.WithEncoderCRC(eopts.checksum),
			zstd.WithLowerEncoderMem(eopts.lowMemory),
			nil, // reserved for zstd.WithWindowSize
		}
		if eopts.maxWindowLog2 > 0 {
			zopts[len(zopts)-noopts-1] = zstd.WithWindowSize(1 << eopts.maxWindowLog2)
		} else {
			noopts++
		}
		enc = must.Get(zstd.NewWriter(nil, zopts[:len(zopts)-noopts]...))
	}
	return encoder{pool, enc}
}

func putEncoder(e encoder) { e.pool.Put(e.Encoder) }

var decoderPools sync.Map // map[decoderOptions]*sync.Pool -> *zstd.Decoder

type decoderOptions struct {
	maxSizeLog2   uint8
	maxWindowLog2 uint8
	checksum      bool
	lowMemory     bool
}

type decoder struct {
	pool *sync.Pool
	*zstd.Decoder

	maxSize uint64
}

func getDecoder(opts ...Option) decoder {
	maxSize := uint64(1 << 63)
	dopts := decoderOptions{maxSizeLog2: 63, checksum: true}
	for _, opt := range opts {
		switch opt := opt.(type) {
		case maxDecodedSizeLog2:
			maxSize = 1 << uint8(opt)
			dopts.maxSizeLog2 = uint8(opt)
		case maxDecodedSize:
			maxSize = uint64(opt)
			dopts.maxSizeLog2 = uint8(log2(maxSize))
		case maxWindowSizeLog2:
			dopts.maxWindowLog2 = uint8(opt)
		case withChecksum:
			dopts.checksum = bool(opt)
		case lowMemory:
			dopts.lowMemory = bool(opt)
		}
	}

	vpool, ok := decoderPools.Load(dopts)
	if !ok {
		vpool, _ = decoderPools.LoadOrStore(dopts, new(sync.Pool))
	}
	pool := vpool.(*sync.Pool)
	dec, _ := pool.Get().(*zstd.Decoder)
	if dec == nil {
		var noopts int
		zopts := [...]zstd.DOption{
			// Set concurrency=1 to ensure synchronous operation.
			zstd.WithDecoderConcurrency(1),
			zstd.WithDecoderMaxMemory(1 << min(max(10, dopts.maxSizeLog2), 63)),
			zstd.IgnoreChecksum(!dopts.checksum),
			zstd.WithDecoderLowmem(dopts.lowMemory),
			nil, // reserved for zstd.WithDecoderMaxWindow
		}
		if dopts.maxWindowLog2 > 0 {
			zopts[len(zopts)-noopts-1] = zstd.WithDecoderMaxWindow(1 << dopts.maxWindowLog2)
		} else {
			noopts++
		}
		dec = must.Get(zstd.NewReader(nil, zopts[:len(zopts)-noopts]...))
	}
	return decoder{pool, dec, maxSize}
}

func putDecoder(d decoder) { d.pool.Put(d.Decoder) }

func (d decoder) DecodeAll(src, dst []byte) ([]byte, error) {
	// We only configure DecodeAll to enforce MaxDecodedSize by powers-of-two.
	// Perform a more fine grain check based on the exact value.
	dst2, err := d.Decoder.DecodeAll(src, dst)
	if err == nil && uint64(len(dst2)-len(dst)) > d.maxSize {
		err = zstd.ErrDecoderSizeExceeded
	}
	return dst2, err
}

// log2 computes log2 of x rounded up to the nearest integer.
func log2(x uint64) int { return 64 - bits.LeadingZeros64(x-1) }