summaryrefslogtreecommitdiffhomepage
path: root/talpid-routing/src/debounce.rs
blob: ff767555b810aba0cc7f6832c594aed1b4c53420 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
#![allow(dead_code)]

use std::{
    sync::mpsc::{RecvTimeoutError, Sender, channel},
    time::{Duration, Instant},
};

/// BurstGuard is a wrapper for a function that protects that function from being called too many
/// times in a short amount of time. To call the function use `burst_guard.trigger()`, at that point
/// `BurstGuard` will wait for `buffer_period` and if no more calls to `trigger` are made then it
/// will call the wrapped function. If another call to `trigger` is made during this wait then it
/// will wait another `buffer_period`, this happens over and over until either
/// `longest_buffer_period` time has elapsed or until no call to `trigger` has been made in
/// `buffer_period`. At which point the wrapped function will be called.
pub struct BurstGuard {
    sender: Sender<BurstGuardEvent>,
    /// This is the period of time the `BurstGuard` will wait for a new trigger to be sent
    /// before it calls the callback.
    buffer_period: Duration,
    /// This is the longest period that the `BurstGuard` will wait from the first trigger till
    /// it calls the callback.
    longest_buffer_period: Duration,
}

enum BurstGuardEvent {
    Trigger(Duration),
    Shutdown(Sender<()>),
}

impl BurstGuard {
    /// Create a new burst guard
    pub fn new<F: Fn() + Send + 'static>(
        buffer_period: Duration,
        longest_buffer_period: Duration,
        callback: F,
    ) -> Self {
        let (sender, listener) = channel();
        std::thread::spawn(move || {
            // The `stop` implementation assumes that this thread will not call `callback` again
            // if the listener has been dropped.
            while let Ok(message) = listener.recv() {
                match message {
                    BurstGuardEvent::Trigger(mut period) => {
                        let start = Instant::now();
                        loop {
                            match listener.recv_timeout(period) {
                                Ok(BurstGuardEvent::Trigger(new_period)) => {
                                    period = new_period;
                                    let max_period = std::cmp::max(longest_buffer_period, period);
                                    if start.elapsed() >= max_period {
                                        callback();
                                        break;
                                    }
                                }
                                Ok(BurstGuardEvent::Shutdown(tx)) => {
                                    let _ = tx.send(());
                                    return;
                                }
                                Err(RecvTimeoutError::Timeout) => {
                                    callback();
                                    break;
                                }
                                Err(RecvTimeoutError::Disconnected) => {
                                    break;
                                }
                            }
                        }
                    }
                    BurstGuardEvent::Shutdown(tx) => {
                        let _ = tx.send(());
                        return;
                    }
                }
            }
        });
        Self {
            sender,
            buffer_period,
            longest_buffer_period,
        }
    }

    /// When `stop` returns an then the `BurstGuard` thread is guaranteed to not make any further
    /// calls to `callback`.
    pub fn stop(self) {
        let (sender, listener) = channel();
        // If we could not send then it means the thread has already shut down and we can return
        if self.sender.send(BurstGuardEvent::Shutdown(sender)).is_ok() {
            // We do not care what the result is, if it is OK it means the thread shut down, if
            // it is Err it also means it shut down.
            let _ = listener.recv();
        }
    }

    /// Stop without waiting for in-flight events to complete.
    pub fn stop_nonblocking(self) {
        let (sender, _listener) = channel();
        let _ = self.sender.send(BurstGuardEvent::Shutdown(sender));
    }

    /// Asynchronously trigger burst
    pub fn trigger(&self) {
        talpid_types::detect_flood!();
        self.trigger_with_period(self.buffer_period)
    }

    /// Asynchronously trigger burst
    pub fn trigger_with_period(&self, buffer_period: Duration) {
        self.sender
            .send(BurstGuardEvent::Trigger(buffer_period))
            .unwrap();
    }
}