diff options
| author | Joakim Hulthe <joakim.hulthe@mullvad.net> | 2025-09-02 12:46:48 +0200 |
|---|---|---|
| committer | Sebastian Holmin <sebastian.holmin@mullvad.net> | 2025-09-02 15:22:04 +0200 |
| commit | e2fb69137df7c565f84e09b8b8038d1c00546ad1 (patch) | |
| tree | f87fa5b508152cf5b3cfc15da36fcc0e97e288c7 | |
| parent | 7bba5411e7b2f6f8586d441d0fc04dda3197608e (diff) | |
| download | mullvadvpn-e2fb69137df7c565f84e09b8b8038d1c00546ad1.tar.xz mullvadvpn-e2fb69137df7c565f84e09b8b8038d1c00546ad1.zip | |
Make masque fragment reassembly slightly more efficient
| -rw-r--r-- | mullvad-masque-proxy/src/fragment.rs | 81 |
1 files changed, 40 insertions, 41 deletions
diff --git a/mullvad-masque-proxy/src/fragment.rs b/mullvad-masque-proxy/src/fragment.rs index 6f4a84b0d9..d4cfca5496 100644 --- a/mullvad-masque-proxy/src/fragment.rs +++ b/mullvad-masque-proxy/src/fragment.rs @@ -1,4 +1,4 @@ -use std::collections::{BTreeMap, VecDeque}; +use std::collections::{BTreeMap, VecDeque, btree_map}; use bytes::{Buf, BufMut, Bytes, BytesMut}; use h3::proto::varint::VarInt; @@ -22,6 +22,8 @@ pub struct Fragments { /// Map of fragmented packets. /// /// If fragments are arriving in order, this should never hold more than one set of fragments. + /// + /// INVARIANT: The `Vec` is sorted by `Fragment::index` // TODO: would a hashmap be faster? fragment_map: BTreeMap<u16, Vec<Fragment>>, } @@ -37,6 +39,9 @@ pub enum DefragError { #[error("Too few fragments in fragmented packet")] TooFewFragments, + + #[error("Received a fragment twice")] + DuplicateFragment, } // When a packet is larger than u16::MAX, it can't be fragmented. @@ -108,49 +113,42 @@ impl Fragments { "fragment_index_fifo must never grow", ); - let fragments = self.fragment_map.entry(id).or_default(); - fragments.push(fragment); + let entry = self.fragment_map.entry(id); - let reassembled = self.try_reassemble(id, fragment_count) - .map(DefragReceived::Reassembled) - // TODO: This may also occur if a packet is discarded - .unwrap_or(DefragReceived::Fragment); - Ok(reassembled) - } + let mut entry = match entry { + btree_map::Entry::Occupied(occupied) => occupied, - // TODO: Let caller provide output buffer. - fn try_reassemble(&mut self, id: u16, fragment_count: u8) -> Option<Bytes> { - // establish that there are enough fragments to reconstruct the whole packet - let fragments = &self.fragment_map[&id]; - if fragments.len() != fragment_count.into() { - return None; - } + // if this is the first received fragment, don't bother trying to reassemble + btree_map::Entry::Vacant(vacant) => { + let mut fragment_list = Vec::with_capacity(2); // two fragments should be the norm + fragment_list.push(fragment); + vacant.insert(fragment_list); + return Ok(DefragReceived::Fragment); + } + }; - // looks like a valid fragment set. pop it from the map. - let mut fragments = self.fragment_map.remove(&id).expect("fragment must exist"); + let fragments = entry.get_mut(); - fragments.sort_unstable_by_key(|f| f.index); + // insert the fragment such that the list is sorted + match fragments.binary_search_by_key(&fragment.index, |f| f.index) { + Err(insert_here) => fragments.insert(insert_here, fragment), + Ok(_) => return Err(DefragError::DuplicateFragment), + }; - // assert that fragments are in the correct order - // TODO: is this excessively paranoid? - let fragments_missing = (FRAGMENT_INDEX_START..) - .zip(&fragments) - .any(|(expected_index, fragment)| fragment.index != expected_index); - if fragments_missing { - if cfg!(debug_assertions) { - log::debug!("Discarding unordered fragment set"); - } - return None; + // establish that there are enough fragments to reconstruct the whole packet + if fragments.len() != fragment_count.into() { + return Ok(DefragReceived::Fragment); } + let fragments = entry.remove(); + // smush the fragments together let mut payload = BytesMut::with_capacity(fragments.iter().map(|f| f.payload.len()).sum()); for fragment in fragments { payload.extend_from_slice(&fragment.payload); } - let payload = payload.freeze(); - Some(payload) + Ok(DefragReceived::Reassembled(payload.freeze())) } } @@ -275,8 +273,8 @@ mod test { let fragment_survives_flood = |number_of_bad_fragments| { let mut fragments = Fragments::default(); - let packet_id = 123; - let bad_packet_id = 321; + let packet_id = 1; + let mut bad_packet_ids = 2..0xffff; let payload = (0..255).collect::<Vec<u8>>(); let max_payload_size = 50; @@ -299,15 +297,16 @@ mod test { // then send a bunch of fragments to fill the queue let mut bad_payload = Bytes::from([0u8; 2].to_vec()); - let incomplete_fragment = fragment_packet( - 1 + FRAGMENT_HEADER_SIZE_FRAGMENTED, - &mut bad_payload, - bad_packet_id, - ) - .unwrap() - .next() - .unwrap(); for _ in fragment_buf.len()..number_of_bad_fragments { + let incomplete_fragment = fragment_packet( + 1 + FRAGMENT_HEADER_SIZE_FRAGMENTED, + &mut bad_payload, + bad_packet_ids.next().unwrap(), + ) + .unwrap() + .next() + .unwrap(); + let packet = fragments .handle_incoming_packet(incomplete_fragment.clone()) .unwrap(); |
