summaryrefslogtreecommitdiffhomepage
path: root/tool/gocross/env.go
blob: 9d8a4f1b390b40f4bda35d6b78e23904277e0315 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
// Copyright (c) Tailscale Inc & AUTHORS
// SPDX-License-Identifier: BSD-3-Clause

package main

import (
	"fmt"
	"os"
	"sort"
	"strings"
)

// Environment starts from an initial set of environment variables, and tracks
// mutations to the environment. It can then apply those mutations to the
// environment, or produce debugging output that illustrates the changes it
// would make.
type Environment struct {
	init  map[string]string
	set   map[string]string
	unset map[string]bool

	setenv   func(string, string) error
	unsetenv func(string) error
}

// NewEnvironment returns an Environment initialized from os.Environ.
func NewEnvironment() *Environment {
	init := map[string]string{}
	for _, env := range os.Environ() {
		fs := strings.SplitN(env, "=", 2)
		if len(fs) != 2 {
			panic("bad environ provided")
		}
		init[fs[0]] = fs[1]
	}

	return newEnvironmentForTest(init, os.Setenv, os.Unsetenv)
}

func newEnvironmentForTest(init map[string]string, setenv func(string, string) error, unsetenv func(string) error) *Environment {
	return &Environment{
		init:     init,
		set:      map[string]string{},
		unset:    map[string]bool{},
		setenv:   setenv,
		unsetenv: unsetenv,
	}
}

// Set sets the environment variable k to v.
func (e *Environment) Set(k, v string) {
	e.set[k] = v
	delete(e.unset, k)
}

// Unset removes the environment variable k.
func (e *Environment) Unset(k string) {
	delete(e.set, k)
	e.unset[k] = true
}

// IsSet reports whether the environment variable k is set.
func (e *Environment) IsSet(k string) bool {
	if e.unset[k] {
		return false
	}
	if _, ok := e.init[k]; ok {
		return true
	}
	if _, ok := e.set[k]; ok {
		return true
	}
	return false
}

// Get returns the value of the environment variable k, or defaultVal if it is
// not set.
func (e *Environment) Get(k, defaultVal string) string {
	if e.unset[k] {
		return defaultVal
	}
	if v, ok := e.set[k]; ok {
		return v
	}
	if v, ok := e.init[k]; ok {
		return v
	}
	return defaultVal
}

// Apply applies all pending mutations to the environment.
func (e *Environment) Apply() error {
	for k, v := range e.set {
		if err := e.setenv(k, v); err != nil {
			return fmt.Errorf("setting %q: %v", k, err)
		}
		e.init[k] = v
		delete(e.set, k)
	}
	for k := range e.unset {
		if err := e.unsetenv(k); err != nil {
			return fmt.Errorf("unsetting %q: %v", k, err)
		}
		delete(e.init, k)
		delete(e.unset, k)
	}
	return nil
}

// Diff returns a string describing the pending mutations to the environment.
func (e *Environment) Diff() string {
	lines := make([]string, 0, len(e.set)+len(e.unset))
	for k, v := range e.set {
		old, ok := e.init[k]
		if ok {
			lines = append(lines, fmt.Sprintf("%s=%s (was %s)", k, v, old))
		} else {
			lines = append(lines, fmt.Sprintf("%s=%s (was <nil>)", k, v))
		}
	}
	for k := range e.unset {
		old, ok := e.init[k]
		if ok {
			lines = append(lines, fmt.Sprintf("%s=<nil> (was %s)", k, old))
		} else {
			lines = append(lines, fmt.Sprintf("%s=<nil> (was <nil>)", k))
		}
	}
	sort.Strings(lines)
	return strings.Join(lines, "\n")
}