summaryrefslogtreecommitdiffhomepage
path: root/util/linuxfw/fake.go
blob: 1886e25429537103cb4d5725bd0ddd135cd2ee18 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
// Copyright (c) Tailscale Inc & contributors
// SPDX-License-Identifier: BSD-3-Clause

//go:build linux

package linuxfw

import (
	"errors"
	"fmt"
	"os"
	"strconv"
	"strings"
)

type fakeIPTables struct {
	n map[string][]string
}

type fakeRule struct {
	table, chain string
	args         []string
}

func newFakeIPTables() *fakeIPTables {
	return &fakeIPTables{
		n: map[string][]string{
			"filter/INPUT":    nil,
			"filter/OUTPUT":   nil,
			"filter/FORWARD":  nil,
			"nat/PREROUTING":  nil,
			"nat/OUTPUT":      nil,
			"nat/POSTROUTING": nil,
			"mangle/FORWARD":  nil,
		},
	}
}

func (n *fakeIPTables) Insert(table, chain string, pos int, args ...string) error {
	k := table + "/" + chain
	if rules, ok := n.n[k]; ok {
		if pos > len(rules)+1 {
			return fmt.Errorf("bad position %d in %s", pos, k)
		}
		rules = append(rules, "")
		copy(rules[pos:], rules[pos-1:])
		rules[pos-1] = strings.Join(args, " ")
		n.n[k] = rules
	} else {
		return fmt.Errorf("unknown table/chain %s", k)
	}
	return nil
}

func (n *fakeIPTables) Append(table, chain string, args ...string) error {
	k := table + "/" + chain
	return n.Insert(table, chain, len(n.n[k])+1, args...)
}

func (n *fakeIPTables) Exists(table, chain string, args ...string) (bool, error) {
	k := table + "/" + chain
	if rules, ok := n.n[k]; ok {
		for _, rule := range rules {
			if rule == strings.Join(args, " ") {
				return true, nil
			}
		}
		return false, nil
	} else {
		return false, fmt.Errorf("unknown table/chain %s", k)
	}
}

func (n *fakeIPTables) Delete(table, chain string, args ...string) error {
	k := table + "/" + chain
	if rules, ok := n.n[k]; ok {
		for i, rule := range rules {
			if rule == strings.Join(args, " ") {
				rules = append(rules[:i], rules[i+1:]...)
				n.n[k] = rules
				return nil
			}
		}
		return fmt.Errorf("delete of unknown rule %q from %s", strings.Join(args, " "), k)
	} else {
		return fmt.Errorf("unknown table/chain %s", k)
	}
}

func (n *fakeIPTables) List(table, chain string) ([]string, error) {
	k := table + "/" + chain
	if rules, ok := n.n[k]; ok {
		return rules, nil
	} else {
		return nil, fmt.Errorf("unknown table/chain %s", k)
	}
}

func (n *fakeIPTables) ClearChain(table, chain string) error {
	k := table + "/" + chain
	if _, ok := n.n[k]; ok {
		n.n[k] = nil
		return nil
	} else {
		return errors.New("exitcode:1")
	}
}

func (n *fakeIPTables) NewChain(table, chain string) error {
	k := table + "/" + chain
	if _, ok := n.n[k]; ok {
		return fmt.Errorf("table/chain %s already exists", k)
	}
	n.n[k] = nil
	return nil
}

func (n *fakeIPTables) DeleteChain(table, chain string) error {
	k := table + "/" + chain
	if rules, ok := n.n[k]; ok {
		if len(rules) != 0 {
			return fmt.Errorf("table/chain %s is not empty", k)
		}
		delete(n.n, k)
		return nil
	} else {
		return fmt.Errorf("unknown table/chain %s", k)
	}
}

func NewFakeIPTablesRunner() NetfilterRunner {
	ipt4 := newFakeIPTables()
	v6Available := false
	var ipt6 iptablesInterface
	if use6, err := strconv.ParseBool(os.Getenv("TS_TEST_FAKE_NETFILTER_6")); use6 || err != nil {
		ipt6 = newFakeIPTables()
		v6Available = true
	}

	iptr := &iptablesRunner{ipt4, ipt6, v6Available, v6Available, v6Available}
	return iptr
}