summaryrefslogtreecommitdiffhomepage
path: root/types/opt/bool.go
blob: 0a3ee67ad2a6ecf0cc135faf2e7d83a5bf734eda (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
// Copyright (c) Tailscale Inc & AUTHORS
// SPDX-License-Identifier: BSD-3-Clause

// Package opt defines optional types.
package opt

import (
	"fmt"
	"strconv"
)

// Bool represents an optional boolean to be JSON-encoded.  The string
// is either "true", "false", or the empty string to mean unset.
//
// As a special case, the underlying string may also be the string
// "unset" as as a synonym for the empty string. This lets the
// explicit unset value be exchanged over an encoding/json "omitempty"
// field without it being dropped.
type Bool string

// NewBool constructs a new Bool value equal to b. The returned Bool is set,
// unless Set("") or Clear() methods are called.
func NewBool(b bool) Bool {
	return Bool(strconv.FormatBool(b))
}

func (b *Bool) Set(v bool) {
	*b = Bool(strconv.FormatBool(v))
}

func (b *Bool) Clear() { *b = "" }

func (b Bool) Get() (v bool, ok bool) {
	switch b {
	case "true":
		return true, true
	case "false":
		return false, true
	default:
		return false, false
	}
}

// Scan implements database/sql.Scanner.
func (b *Bool) Scan(src any) error {
	if src == nil {
		*b = ""
		return nil
	}
	switch src := src.(type) {
	case bool:
		if src {
			*b = "true"
		} else {
			*b = "false"
		}
		return nil
	case int64:
		if src == 0 {
			*b = "false"
		} else {
			*b = "true"
		}
		return nil
	default:
		return fmt.Errorf("opt.Bool.Scan: invalid type %T: %v", src, src)
	}
}

// EqualBool reports whether b is equal to v.
// If b is empty or not a valid bool, it reports false.
func (b Bool) EqualBool(v bool) bool {
	p, ok := b.Get()
	return ok && p == v
}

var (
	trueBytes  = []byte("true")
	falseBytes = []byte("false")
	nullBytes  = []byte("null")
)

func (b Bool) MarshalJSON() ([]byte, error) {
	switch b {
	case "true":
		return trueBytes, nil
	case "false":
		return falseBytes, nil
	case "", "unset":
		return nullBytes, nil
	}
	return nil, fmt.Errorf("invalid opt.Bool value %q", string(b))
}

func (b *Bool) UnmarshalJSON(j []byte) error {
	switch string(j) {
	case "true":
		*b = "true"
	case "false":
		*b = "false"
	case "null":
		*b = "unset"
	default:
		return fmt.Errorf("invalid opt.Bool value %q", j)
	}
	return nil
}

// BoolFlag is a wrapper for Bool that implements [flag.Value].
type BoolFlag struct {
	*Bool
}

// Set the value of b, using any value supported by [strconv.ParseBool].
func (b *BoolFlag) Set(s string) error {
	v, err := strconv.ParseBool(s)
	if err != nil {
		return err
	}
	b.Bool.Set(v)
	return nil
}

// String returns "true" or "false" if the value is set, or an empty string otherwise.
func (b *BoolFlag) String() string {
	if b == nil || b.Bool == nil {
		return ""
	}
	if v, ok := b.Bool.Get(); ok {
		return strconv.FormatBool(v)
	}
	return ""
}