summaryrefslogtreecommitdiffhomepage
path: root/talpid-wireguard/src/connectivity/mock.rs
blob: 103c691c9a58027229fe241ecbb27eed2c705e73 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
use std::future::Future;
use std::pin::Pin;
use talpid_tunnel_config_client::DaitaSettings;
use tokio::time::Instant;

use super::Check;
use super::check::{CancelToken, ConnState, PingState};
use super::pinger;

use crate::{Config, Tunnel, TunnelError};
use pinger::Pinger;

// Convenient re-exports
pub use crate::stats::{Stats, StatsMap};

#[derive(Default)]
pub(crate) struct MockPinger {
    on_send_ping: Option<Box<dyn FnMut() + Send + Sync>>,
}

pub(crate) struct MockTunnel {
    on_get_stats: Box<dyn Fn() -> Result<StatsMap, TunnelError> + Send + Sync>,
}

pub fn mock_checker(now: Instant, pinger: Box<dyn Pinger>) -> (Check, CancelToken) {
    let conn_state = ConnState::new(now, Default::default());
    let ping_state = PingState::new_with(pinger);
    Check::mock(conn_state, ping_state)
}

pub fn connected_state(timestamp: Instant) -> ConnState {
    const PEER: [u8; 32] = [0u8; 32];
    let mut stats = StatsMap::new();
    stats.insert(PEER, Stats::default());
    ConnState::Connected {
        rx_timestamp: timestamp,
        tx_timestamp: timestamp,
        stats,
    }
}

impl MockTunnel {
    const PEER: [u8; 32] = [0u8; 32];

    pub fn new<F: Fn() -> Result<StatsMap, TunnelError> + Send + Sync + 'static>(f: F) -> Self {
        Self {
            on_get_stats: Box::new(f),
        }
    }

    /// Convert self to the more general [TunnelType].
    pub fn boxed(self) -> Box<dyn Tunnel> {
        Box::new(self)
    }

    pub fn always_incrementing() -> Self {
        let mut map = StatsMap::new();
        map.insert(Self::PEER, Stats::default());
        let peers = std::sync::Mutex::new(map);
        Self {
            on_get_stats: Box::new(move || {
                let mut peers = peers.lock().unwrap();
                for traffic in peers.values_mut() {
                    traffic.tx_bytes += 1;
                    traffic.rx_bytes += 1;
                }
                Ok(peers.clone())
            }),
        }
    }

    pub fn never_incrementing() -> Self {
        Self {
            on_get_stats: Box::new(|| {
                let mut map = StatsMap::new();
                map.insert(Self::PEER, Stats::default());
                Ok(map)
            }),
        }
    }
}

#[async_trait::async_trait]
impl Tunnel for MockTunnel {
    fn get_interface_name(&self) -> String {
        "mock-tunnel".to_string()
    }

    fn stop(self: Box<Self>) -> Result<(), TunnelError> {
        Ok(())
    }

    async fn get_tunnel_stats(&self) -> Result<StatsMap, TunnelError> {
        (self.on_get_stats)()
    }

    fn set_config(
        &mut self,
        _config: Config,
        _daita: Option<DaitaSettings>,
    ) -> Pin<Box<dyn Future<Output = std::result::Result<(), TunnelError>> + Send>> {
        Box::pin(async { Ok(()) })
    }
}

#[async_trait::async_trait]
impl Pinger for MockPinger {
    async fn send_icmp(&mut self) -> Result<(), pinger::Error> {
        if let Some(callback) = self.on_send_ping.as_mut() {
            (callback)();
        }
        Ok(())
    }
}