summaryrefslogtreecommitdiffhomepage
path: root/wireguard/libwg/interfacewatcher/interfacewatcher_windows.go
blob: fcb53b858434db19845c0a33e7b89fdb56056eee (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
/* SPDX-License-Identifier: MIT
 *
 * Copyright (C) 2019 WireGuard LLC. All Rights Reserved.
 * Copyright (C) 2021 Mullvad VPN AB. All Rights Reserved.
 */

package interfacewatcher

import (
	"sync"
	"time"
	
	"golang.zx2c4.com/wireguard/windows/tunnel/winipcfg"
)

type Event struct {
	Luid   winipcfg.LUID
	Family winipcfg.AddressFamily
}

type interfaceWatcher struct {
	ready                   chan bool
	processingMutex         sync.Mutex
	interfaceChangeCallback *winipcfg.InterfaceChangeCallback
	seenEvents              []Event
	wantedEvents			[]Event
	expired                 bool
}

func NewWatcher() (*interfaceWatcher, error) {
	iw := &interfaceWatcher{
		ready: make(chan bool, 1),
		expired: false,
	}
	var err error
	iw.interfaceChangeCallback, err = winipcfg.RegisterInterfaceChangeCallback(func(notificationType winipcfg.MibNotificationType, iface *winipcfg.MibIPInterfaceRow) {
		if notificationType != winipcfg.MibAddInstance {
			return
		}

		iw.processingMutex.Lock()
		defer iw.processingMutex.Unlock()

		if iw.expired {
			return
		}

		iw.seenEvents = append(iw.seenEvents, Event{iface.InterfaceLUID, iface.Family})

		if len(iw.wantedEvents) != 0 {
			iw.evaluateEvents()
		}
	})
	if err != nil {
		return nil, err
	}
	return iw, nil
}

func (iw *interfaceWatcher) evaluateEvents() {
	matched := 0

	// This is n*n, but typically very few items in both slices :-)
	for _, wanted := range iw.wantedEvents {
		for _, seen := range iw.seenEvents {
			if seen == wanted {
				matched += 1
				break
			}
		}
	}

	if len(iw.wantedEvents) != matched {
		return
	}

	iw.expired = true
	iw.ready <- true
}

// You can only join() once after which the watcher becomes expired.
func (iw *interfaceWatcher) Join(wantedEvents []Event, timeoutSeconds int) bool {
	{
		iw.processingMutex.Lock()

		if iw.expired || len(wantedEvents) == 0 {
			iw.processingMutex.Unlock()
			return false
		}

		iw.wantedEvents = wantedEvents
		iw.evaluateEvents()

		iw.processingMutex.Unlock()
	}

	result := false

	select {
    case <- iw.ready:
        result = true
    case <- time.After(time.Duration(timeoutSeconds) * time.Second):
        result = false
	}

	{
		iw.processingMutex.Lock()

		iw.wantedEvents = nil
		iw.expired = true

		iw.processingMutex.Unlock()
	}

	return result
}

func (iw *interfaceWatcher) Destroy() {
	if iw.interfaceChangeCallback != nil {
		iw.interfaceChangeCallback.Unregister()
		iw.interfaceChangeCallback = nil
	}
}