summaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorJoakim Hulthe <joakim.hulthe@mullvad.net>2025-04-10 18:06:45 +0200
committerJoakim Hulthe <joakim.hulthe@mullvad.net>2025-04-10 18:06:45 +0200
commitc870d0df0dc26c3ddf72de1fe69a4c4a1afe28b2 (patch)
tree0ff05d0f8d245cf8a8be63368f52019c937f8429
parent2142bb3c1cb8c7ecb52751fdc4c03622870d7733 (diff)
parent2d8215081909c6fd1ceb434d151a3ed6a10f636a (diff)
downloadmullvadvpn-c870d0df0dc26c3ddf72de1fe69a4c4a1afe28b2.tar.xz
mullvadvpn-c870d0df0dc26c3ddf72de1fe69a4c4a1afe28b2.zip
Merge branch 'use-ring-buffer-instead-of-clearing-fragments-arbitrarily-des-1966'
-rw-r--r--mullvad-masque-proxy/src/fragment.rs185
1 files changed, 149 insertions, 36 deletions
diff --git a/mullvad-masque-proxy/src/fragment.rs b/mullvad-masque-proxy/src/fragment.rs
index c699ed2d78..3d9fc94273 100644
--- a/mullvad-masque-proxy/src/fragment.rs
+++ b/mullvad-masque-proxy/src/fragment.rs
@@ -1,5 +1,5 @@
use std::{
- collections::BTreeMap,
+ collections::{BTreeMap, VecDeque},
time::{Duration, Instant},
};
@@ -8,8 +8,24 @@ use h3::proto::varint::VarInt;
use crate::FRAGMENT_HEADER_SIZE_FRAGMENTED;
-#[derive(Default)]
+/// The index of the first fragment of a packet.
+const FRAGMENT_INDEX_START: u8 = 1;
+
+/// The maximum number of unassembled fragments that we buffer.
+// 255 is the theoretical maximum number of fragments for a single packet.
+const FRAGMENT_BUFFER_CAP: usize = 255;
+
pub struct Fragments {
+ /// FIFO queue of fragment indices. Used to mitigate floods of unordered packet fragments.
+ ///
+ /// When receiving a fragment, push it to the back of the queue. When the queue length exceeds
+ /// [FRAGMENT_BUFFER_CAP], pop the first element and remove it from [Self::fragment_map].
+ fragment_index_fifo: VecDeque<u16>,
+
+ /// Map of fragmented packets.
+ ///
+ /// If fragments are arriving in order, this should never hold more than one set of fragments.
+ // TODO: would a hashmap be faster?
fragment_map: BTreeMap<u16, Vec<Fragment>>,
}
@@ -29,6 +45,15 @@ pub enum DefragError {
#[error("Packet is too large to fragment")]
pub struct PacketTooLarge(pub usize);
+impl Default for Fragments {
+ fn default() -> Self {
+ Self {
+ fragment_index_fifo: VecDeque::with_capacity(FRAGMENT_BUFFER_CAP),
+ fragment_map: Default::default(),
+ }
+ }
+}
+
impl Fragments {
// TODO: Let caller provide output buffer.
pub fn handle_incoming_packet(
@@ -60,33 +85,60 @@ impl Fragments {
time_received: Instant::now(),
};
+ // ensure that the fifo has capacity before pushing the new fragment id
+ if self.fragment_index_fifo.len() >= FRAGMENT_BUFFER_CAP {
+ let id = self.fragment_index_fifo.pop_front().expect("fifo is full");
+ if self.fragment_map.remove(&id).is_some() && cfg!(debug_assertions) {
+ println!("Fragment was discarded before reassembly");
+ };
+ }
+ self.fragment_index_fifo.push_back(id);
+
+ debug_assert_eq!(
+ self.fragment_index_fifo.capacity(),
+ FRAGMENT_BUFFER_CAP,
+ "fragment_index_fifo must never grow",
+ );
+
let fragments = self.fragment_map.entry(id).or_default();
fragments.push(fragment);
- Ok(self.try_fetch(id, fragment_count))
+ Ok(self.try_reassemble(id, fragment_count))
}
// TODO: Let caller provide output buffer.
- fn try_fetch(&mut self, id: u16, fragment_count: u8) -> Option<Bytes> {
+ fn try_reassemble(&mut self, id: u16, fragment_count: u8) -> Option<Bytes> {
// establish that there are enough fragments to reconstruct the whole packet
- let payload = {
- let fragments = self.fragment_map.get_mut(&id)?;
+ let fragments = &self.fragment_map[&id];
+ if fragments.len() != fragment_count.into() {
+ return None;
+ }
- if fragments.len() != fragment_count.into() {
- return None;
- }
+ // looks like a valid fragment set. pop it from the map.
+ let mut fragments = self.fragment_map.remove(&id).expect("fragment must exist");
+
+ fragments.sort_unstable_by_key(|f| f.index);
- fragments.sort_by_key(|f| f.index);
- let mut payload =
- BytesMut::with_capacity(fragments.iter().map(|f| f.payload.len()).sum());
- for fragment in fragments {
- payload.extend_from_slice(&fragment.payload);
+ // assert that fragments are in the correct order
+ // TODO: is this excessively paranoid?
+ let fragments_missing = (FRAGMENT_INDEX_START..)
+ .zip(&fragments)
+ .any(|(expected_index, fragment)| fragment.index != expected_index);
+ if fragments_missing {
+ if cfg!(debug_assertions) {
+ println!("Discarding unordered fragment set");
}
- payload
- };
+ return None;
+ }
- self.fragment_map.remove(&id);
- Some(payload.into())
+ // smush the fragments together
+ let mut payload = BytesMut::with_capacity(fragments.iter().map(|f| f.payload.len()).sum());
+ for fragment in fragments {
+ payload.extend_from_slice(&fragment.payload);
+ }
+ let payload = payload.freeze();
+
+ Some(payload)
}
pub fn clear_old_fragments(&mut self, max_age: Duration) {
@@ -130,8 +182,7 @@ pub fn fragment_packet(
crate::HTTP_MASQUE_FRAGMENTED_DATAGRAM_CONTEXT_ID.encode(&mut fragment);
fragment.put_u16(packet_id);
fragment.put_u8(
- // fragment indexes start at 1
- u8::try_from(fragment_index + 1)
+ u8::try_from(fragment_index + usize::from(FRAGMENT_INDEX_START))
.expect("fragment index must fit in an u8, since num_fragments fits is an u8"),
);
fragment.put_u8(fragment_count);
@@ -151,28 +202,90 @@ mod test {
#[test]
fn test_fragment_reconstruction() {
use rand::{seq::SliceRandom, thread_rng};
-
- let payload = (0..255).collect::<Vec<u8>>();
- let max_payload_size = 50;
- let packet_id = 76;
-
let mut fragments = Fragments::default();
- let mut payload_clone = Bytes::from(payload.clone());
- let mut fragment_buf = fragment_packet(max_payload_size, &mut payload_clone, packet_id)
- .unwrap()
- .collect::<Vec<_>>();
+ 'outer: for packet_id in 1..255u16 {
+ let payload = (0..packet_id as u8).collect::<Vec<u8>>();
+ let max_payload_size = 50;
- fragment_buf.shuffle(&mut thread_rng());
+ let mut payload_clone = Bytes::from(payload.clone());
+ let mut fragment_buf = fragment_packet(max_payload_size, &mut payload_clone, packet_id)
+ .unwrap()
+ .collect::<Vec<_>>();
- for fragment in fragment_buf {
- if let Some(reconstructed_packet) = fragments.handle_incoming_packet(fragment).unwrap()
- {
- assert_eq!(payload.as_slice(), reconstructed_packet.as_ref());
- return;
+ fragment_buf.shuffle(&mut thread_rng());
+
+ for fragment in fragment_buf {
+ if let Some(reconstructed_packet) =
+ fragments.handle_incoming_packet(fragment).unwrap()
+ {
+ assert_eq!(payload.as_slice(), reconstructed_packet.as_ref());
+ continue 'outer;
+ }
}
+
+ panic!("Failed to reconstruct packet");
}
+ }
+
+ #[test]
+ fn test_fragment_cap() {
+ use rand::{seq::SliceRandom, thread_rng};
+
+ // test whether we can reassemble a fragmented packet when we receive a flood of bad fragments
+ // interspersed with our good fragments. returns true if reassembly was successful.
+ let fragment_survives_flood = |number_of_bad_fragments| {
+ let mut fragments = Fragments::default();
+
+ let packet_id = 123;
+ let bad_packet_id = 321;
+
+ let payload = (0..255).collect::<Vec<u8>>();
+ let max_payload_size = 50;
+
+ let mut payload_clone = Bytes::from(payload.clone());
+ let mut fragment_buf = fragment_packet(max_payload_size, &mut payload_clone, packet_id)
+ .unwrap()
+ .collect::<Vec<_>>();
+
+ fragment_buf.shuffle(&mut thread_rng());
+
+ // send one fragment
+ let packet = fragments
+ .handle_incoming_packet(fragment_buf.pop().unwrap())
+ .unwrap();
+ assert!(packet.is_none(), "haven't sent all fragments yet");
+
+ // then send a bunch of fragments to fill the queue
+ let mut bad_payload = Bytes::from([0u8; 2].to_vec());
+ let incomplete_fragment = fragment_packet(
+ 1 + FRAGMENT_HEADER_SIZE_FRAGMENTED,
+ &mut bad_payload,
+ bad_packet_id,
+ )
+ .unwrap()
+ .next()
+ .unwrap();
+ for _ in fragment_buf.len()..number_of_bad_fragments {
+ let packet = fragments
+ .handle_incoming_packet(incomplete_fragment.clone())
+ .unwrap();
+ assert!(packet.is_none());
+ }
+
+ for fragment in fragment_buf {
+ if let Some(reconstructed_packet) =
+ fragments.handle_incoming_packet(fragment).unwrap()
+ {
+ assert_eq!(payload.as_slice(), reconstructed_packet.as_ref());
+ return true;
+ }
+ }
+
+ false
+ };
- panic!("Failed to reconstruct packet");
+ assert!(fragment_survives_flood(FRAGMENT_BUFFER_CAP - 1));
+ assert!(!fragment_survives_flood(FRAGMENT_BUFFER_CAP));
}
}