diff options
| author | Joakim Hulthe <joakim.hulthe@mullvad.net> | 2025-04-10 18:06:45 +0200 |
|---|---|---|
| committer | Joakim Hulthe <joakim.hulthe@mullvad.net> | 2025-04-10 18:06:45 +0200 |
| commit | c870d0df0dc26c3ddf72de1fe69a4c4a1afe28b2 (patch) | |
| tree | 0ff05d0f8d245cf8a8be63368f52019c937f8429 | |
| parent | 2142bb3c1cb8c7ecb52751fdc4c03622870d7733 (diff) | |
| parent | 2d8215081909c6fd1ceb434d151a3ed6a10f636a (diff) | |
| download | mullvadvpn-c870d0df0dc26c3ddf72de1fe69a4c4a1afe28b2.tar.xz mullvadvpn-c870d0df0dc26c3ddf72de1fe69a4c4a1afe28b2.zip | |
Merge branch 'use-ring-buffer-instead-of-clearing-fragments-arbitrarily-des-1966'
| -rw-r--r-- | mullvad-masque-proxy/src/fragment.rs | 185 |
1 files changed, 149 insertions, 36 deletions
diff --git a/mullvad-masque-proxy/src/fragment.rs b/mullvad-masque-proxy/src/fragment.rs index c699ed2d78..3d9fc94273 100644 --- a/mullvad-masque-proxy/src/fragment.rs +++ b/mullvad-masque-proxy/src/fragment.rs @@ -1,5 +1,5 @@ use std::{ - collections::BTreeMap, + collections::{BTreeMap, VecDeque}, time::{Duration, Instant}, }; @@ -8,8 +8,24 @@ use h3::proto::varint::VarInt; use crate::FRAGMENT_HEADER_SIZE_FRAGMENTED; -#[derive(Default)] +/// The index of the first fragment of a packet. +const FRAGMENT_INDEX_START: u8 = 1; + +/// The maximum number of unassembled fragments that we buffer. +// 255 is the theoretical maximum number of fragments for a single packet. +const FRAGMENT_BUFFER_CAP: usize = 255; + pub struct Fragments { + /// FIFO queue of fragment indices. Used to mitigate floods of unordered packet fragments. + /// + /// When receiving a fragment, push it to the back of the queue. When the queue length exceeds + /// [FRAGMENT_BUFFER_CAP], pop the first element and remove it from [Self::fragment_map]. + fragment_index_fifo: VecDeque<u16>, + + /// Map of fragmented packets. + /// + /// If fragments are arriving in order, this should never hold more than one set of fragments. + // TODO: would a hashmap be faster? fragment_map: BTreeMap<u16, Vec<Fragment>>, } @@ -29,6 +45,15 @@ pub enum DefragError { #[error("Packet is too large to fragment")] pub struct PacketTooLarge(pub usize); +impl Default for Fragments { + fn default() -> Self { + Self { + fragment_index_fifo: VecDeque::with_capacity(FRAGMENT_BUFFER_CAP), + fragment_map: Default::default(), + } + } +} + impl Fragments { // TODO: Let caller provide output buffer. pub fn handle_incoming_packet( @@ -60,33 +85,60 @@ impl Fragments { time_received: Instant::now(), }; + // ensure that the fifo has capacity before pushing the new fragment id + if self.fragment_index_fifo.len() >= FRAGMENT_BUFFER_CAP { + let id = self.fragment_index_fifo.pop_front().expect("fifo is full"); + if self.fragment_map.remove(&id).is_some() && cfg!(debug_assertions) { + println!("Fragment was discarded before reassembly"); + }; + } + self.fragment_index_fifo.push_back(id); + + debug_assert_eq!( + self.fragment_index_fifo.capacity(), + FRAGMENT_BUFFER_CAP, + "fragment_index_fifo must never grow", + ); + let fragments = self.fragment_map.entry(id).or_default(); fragments.push(fragment); - Ok(self.try_fetch(id, fragment_count)) + Ok(self.try_reassemble(id, fragment_count)) } // TODO: Let caller provide output buffer. - fn try_fetch(&mut self, id: u16, fragment_count: u8) -> Option<Bytes> { + fn try_reassemble(&mut self, id: u16, fragment_count: u8) -> Option<Bytes> { // establish that there are enough fragments to reconstruct the whole packet - let payload = { - let fragments = self.fragment_map.get_mut(&id)?; + let fragments = &self.fragment_map[&id]; + if fragments.len() != fragment_count.into() { + return None; + } - if fragments.len() != fragment_count.into() { - return None; - } + // looks like a valid fragment set. pop it from the map. + let mut fragments = self.fragment_map.remove(&id).expect("fragment must exist"); + + fragments.sort_unstable_by_key(|f| f.index); - fragments.sort_by_key(|f| f.index); - let mut payload = - BytesMut::with_capacity(fragments.iter().map(|f| f.payload.len()).sum()); - for fragment in fragments { - payload.extend_from_slice(&fragment.payload); + // assert that fragments are in the correct order + // TODO: is this excessively paranoid? + let fragments_missing = (FRAGMENT_INDEX_START..) + .zip(&fragments) + .any(|(expected_index, fragment)| fragment.index != expected_index); + if fragments_missing { + if cfg!(debug_assertions) { + println!("Discarding unordered fragment set"); } - payload - }; + return None; + } - self.fragment_map.remove(&id); - Some(payload.into()) + // smush the fragments together + let mut payload = BytesMut::with_capacity(fragments.iter().map(|f| f.payload.len()).sum()); + for fragment in fragments { + payload.extend_from_slice(&fragment.payload); + } + let payload = payload.freeze(); + + Some(payload) } pub fn clear_old_fragments(&mut self, max_age: Duration) { @@ -130,8 +182,7 @@ pub fn fragment_packet( crate::HTTP_MASQUE_FRAGMENTED_DATAGRAM_CONTEXT_ID.encode(&mut fragment); fragment.put_u16(packet_id); fragment.put_u8( - // fragment indexes start at 1 - u8::try_from(fragment_index + 1) + u8::try_from(fragment_index + usize::from(FRAGMENT_INDEX_START)) .expect("fragment index must fit in an u8, since num_fragments fits is an u8"), ); fragment.put_u8(fragment_count); @@ -151,28 +202,90 @@ mod test { #[test] fn test_fragment_reconstruction() { use rand::{seq::SliceRandom, thread_rng}; - - let payload = (0..255).collect::<Vec<u8>>(); - let max_payload_size = 50; - let packet_id = 76; - let mut fragments = Fragments::default(); - let mut payload_clone = Bytes::from(payload.clone()); - let mut fragment_buf = fragment_packet(max_payload_size, &mut payload_clone, packet_id) - .unwrap() - .collect::<Vec<_>>(); + 'outer: for packet_id in 1..255u16 { + let payload = (0..packet_id as u8).collect::<Vec<u8>>(); + let max_payload_size = 50; - fragment_buf.shuffle(&mut thread_rng()); + let mut payload_clone = Bytes::from(payload.clone()); + let mut fragment_buf = fragment_packet(max_payload_size, &mut payload_clone, packet_id) + .unwrap() + .collect::<Vec<_>>(); - for fragment in fragment_buf { - if let Some(reconstructed_packet) = fragments.handle_incoming_packet(fragment).unwrap() - { - assert_eq!(payload.as_slice(), reconstructed_packet.as_ref()); - return; + fragment_buf.shuffle(&mut thread_rng()); + + for fragment in fragment_buf { + if let Some(reconstructed_packet) = + fragments.handle_incoming_packet(fragment).unwrap() + { + assert_eq!(payload.as_slice(), reconstructed_packet.as_ref()); + continue 'outer; + } } + + panic!("Failed to reconstruct packet"); } + } + + #[test] + fn test_fragment_cap() { + use rand::{seq::SliceRandom, thread_rng}; + + // test whether we can reassemble a fragmented packet when we receive a flood of bad fragments + // interspersed with our good fragments. returns true if reassembly was successful. + let fragment_survives_flood = |number_of_bad_fragments| { + let mut fragments = Fragments::default(); + + let packet_id = 123; + let bad_packet_id = 321; + + let payload = (0..255).collect::<Vec<u8>>(); + let max_payload_size = 50; + + let mut payload_clone = Bytes::from(payload.clone()); + let mut fragment_buf = fragment_packet(max_payload_size, &mut payload_clone, packet_id) + .unwrap() + .collect::<Vec<_>>(); + + fragment_buf.shuffle(&mut thread_rng()); + + // send one fragment + let packet = fragments + .handle_incoming_packet(fragment_buf.pop().unwrap()) + .unwrap(); + assert!(packet.is_none(), "haven't sent all fragments yet"); + + // then send a bunch of fragments to fill the queue + let mut bad_payload = Bytes::from([0u8; 2].to_vec()); + let incomplete_fragment = fragment_packet( + 1 + FRAGMENT_HEADER_SIZE_FRAGMENTED, + &mut bad_payload, + bad_packet_id, + ) + .unwrap() + .next() + .unwrap(); + for _ in fragment_buf.len()..number_of_bad_fragments { + let packet = fragments + .handle_incoming_packet(incomplete_fragment.clone()) + .unwrap(); + assert!(packet.is_none()); + } + + for fragment in fragment_buf { + if let Some(reconstructed_packet) = + fragments.handle_incoming_packet(fragment).unwrap() + { + assert_eq!(payload.as_slice(), reconstructed_packet.as_ref()); + return true; + } + } + + false + }; - panic!("Failed to reconstruct packet"); + assert!(fragment_survives_flood(FRAGMENT_BUFFER_CAP - 1)); + assert!(!fragment_survives_flood(FRAGMENT_BUFFER_CAP)); } } |
