summaryrefslogtreecommitdiffhomepage
path: root/talpid-wireguard/src/wireguard_kernel/netlink_tunnel.rs
blob: 407b4cdd25fb1512569b15612bf9087bc3304937 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
use std::pin::Pin;

use futures::Future;
use talpid_tunnel_config_client::DaitaSettings;

use crate::config::MULLVAD_INTERFACE_NAME;

use super::{
    super::stats::{Stats, StatsMap},
    Config, Error, Handle, Tunnel, TunnelError,
    wg_message::DeviceNla,
};

pub struct NetlinkTunnel {
    interface_index: u32,
    netlink_connections: Handle,
    tokio_handle: tokio::runtime::Handle,
}

impl NetlinkTunnel {
    pub fn new(tokio_handle: tokio::runtime::Handle, config: &Config) -> Result<Self, Error> {
        tokio_handle.clone().block_on(async {
            let mut netlink_connections = Handle::connect().await?;
            let interface_index = netlink_connections
                .create_device(MULLVAD_INTERFACE_NAME.to_string(), config.mtu as u32)
                .await?;

            let mut tunnel = Self {
                interface_index,
                netlink_connections,
                tokio_handle,
            };

            if let Err(err) = tunnel.setup(config).await {
                if let Err(teardown_err) = tunnel
                    .netlink_connections
                    .delete_device(interface_index)
                    .await
                {
                    log::error!(
                        "Failed to tear down WireGuard interface after failing to apply config: {}",
                        teardown_err
                    );
                }
                return Err(err);
            }

            Ok(tunnel)
        })
    }

    async fn setup(&mut self, config: &Config) -> Result<(), Error> {
        self.netlink_connections
            .wg_handle
            .set_config(self.interface_index, config)
            .await?;

        for tunnel_ip in config.tunnel.addresses.iter() {
            self.netlink_connections
                .set_ip_address(self.interface_index, *tunnel_ip)
                .await?;
        }

        Ok(())
    }
}

#[async_trait::async_trait]
impl Tunnel for NetlinkTunnel {
    fn get_interface_name(&self) -> String {
        let mut wg = self.netlink_connections.wg_handle.clone();
        let result = self.tokio_handle.block_on(async move {
            let device = wg.get_by_index(self.interface_index).await?;
            for nla in device.nlas {
                if let DeviceNla::IfName(name) = nla {
                    return Ok(name);
                }
            }
            Err(Error::Truncated)
        });

        match result {
            Ok(name) => name.to_string_lossy().to_string(),
            Err(err) => {
                log::error!(
                    "Failed to deduce interface name at runtime, will attempt to use the default name. {}",
                    err
                );
                MULLVAD_INTERFACE_NAME.to_string()
            }
        }
    }

    fn stop(self: Box<Self>) -> std::result::Result<(), TunnelError> {
        let Self {
            mut netlink_connections,
            interface_index,
            tokio_handle,
        } = *self;
        tokio_handle.block_on(async move {
            if let Err(err) = netlink_connections.delete_device(interface_index).await {
                log::error!("Failed to remove WireGuard device: {}", err);
                Err(TunnelError::FatalStartWireguardError(Box::new(err)))
            } else {
                Ok(())
            }
        })
    }

    async fn get_tunnel_stats(&self) -> std::result::Result<StatsMap, TunnelError> {
        let interface_index = self.interface_index;
        let mut wg = self.netlink_connections.wg_handle.clone();
        let device = wg.get_by_index(interface_index).await.map_err(|err| {
            log::error!("Failed to fetch WireGuard device config: {}", err);
            TunnelError::GetConfigError
        })?;
        Ok(Stats::parse_device_message(&device))
    }

    fn set_config(
        &mut self,
        config: Config,
        daita: Option<DaitaSettings>,
    ) -> Pin<Box<dyn Future<Output = std::result::Result<(), TunnelError>> + Send + 'static>> {
        let mut wg = self.netlink_connections.wg_handle.clone();
        let interface_index = self.interface_index;
        Box::pin(async move {
            if daita.is_some() {
                // Outright fail to start - this tunnel type does not support DAITA.
                return Err(TunnelError::DaitaNotSupported);
            }

            wg.set_config(interface_index, &config)
                .await
                .map_err(|err| {
                    log::error!("Failed to set WireGuard device config: {}", err);
                    TunnelError::SetConfigError
                })
        })
    }
}