diff --git a/src/transport/webrtc/connection.rs b/src/transport/webrtc/connection.rs index d7e31a07..e7d53b7f 100644 --- a/src/transport/webrtc/connection.rs +++ b/src/transport/webrtc/connection.rs @@ -30,7 +30,7 @@ use crate::{ webrtc::{ schema::webrtc::message::Flag, substream::{Event as SubstreamEvent, Substream as WebRtcSubstream, SubstreamHandle}, - util::WebRtcMessage, + util::{extract_framed_message, WebRtcMessage}, }, Endpoint, }, @@ -38,6 +38,7 @@ use crate::{ PeerId, }; +use bytes::{Bytes, BytesMut}; use futures::{task::AtomicWaker, Stream, StreamExt}; use indexmap::IndexMap; use str0m::{ @@ -251,6 +252,18 @@ pub struct WebRtcConnection { /// Substream handles. handles: SubstreamHandleSet, + + /// Inbound data channel byte buffer for reassembling full protobuf frames. + /// + /// The libp2p-go msgio implementation issues two separate `Write` calls: + /// - variant length + /// - protobuf body + /// + /// These will become two distinct SCTP messages on the data channel. + /// + /// Accumulate raw bytes here and only attempt protobuf decode once a + /// full `varint length ++ body` frame is available. + recv_buffers: HashMap, } impl WebRtcConnection { @@ -278,6 +291,7 @@ impl WebRtcConnection { pending_messages: HashMap::new(), channels: HashMap::new(), handles: SubstreamHandleSet::new(), + recv_buffers: HashMap::new(), } } @@ -455,6 +469,7 @@ impl WebRtcConnection { self.channels.remove(&channel_id); self.pending_messages.remove(&channel_id); self.handles.remove(&channel_id); + self.recv_buffers.remove(&channel_id); Ok(()) } @@ -476,7 +491,7 @@ impl WebRtcConnection { async fn on_inbound_opening_channel_data( &mut self, channel_id: ChannelId, - data: Vec, + data: Bytes, header_received: bool, ) -> crate::Result)>> { tracing::trace!( @@ -558,7 +573,7 @@ impl WebRtcConnection { async fn on_outbound_opening_channel_data( &mut self, channel_id: ChannelId, - data: Vec, + data: Bytes, mut dialer_state: WebRtcDialerState, context: ChannelContext, ) -> Result, SubstreamError> { @@ -683,8 +698,9 @@ impl WebRtcConnection { async fn on_open_channel_data( &mut self, channel_id: ChannelId, - data: Vec, + data: Bytes, ) -> crate::Result<()> { + // Decode errors are not recoverable. let message = WebRtcMessage::decode(&data)?; tracing::debug!( @@ -713,6 +729,13 @@ impl WebRtcConnection { } /// Handle data received from a channel. + /// + /// Bytes are accumulated in a per-channel buffer and only handed to the per-state + /// dispatcher once a complete `varint length ++ protobuf body` frame is available. + /// + /// This handles peers (go-libp2p's pbio writer) that split varint and body + /// across two SCTP messages, while remaining a no-op for peers that send the whole + /// frame in one message (smoldot). async fn on_inbound_data(&mut self, channel_id: ChannelId, data: Vec) -> crate::Result<()> { tracing::debug!( target: LOG_TARGET, @@ -723,6 +746,31 @@ impl WebRtcConnection { "received channel data", ); + self.recv_buffers.entry(channel_id).or_default().extend_from_slice(&data); + + loop { + let Some(buffer) = self.recv_buffers.get_mut(&channel_id) else { + return Ok(()); + }; + + let Some(body) = extract_framed_message(buffer)? else { + return Ok(()); + }; + + self.dispatch_framed_message(channel_id, body).await?; + // If the channel was closed/removed during dispatch, stop draining its buffer. + if !self.channels.contains_key(&channel_id) { + return Ok(()); + } + } + } + + /// Dispatch a single reassembled protobuf body to the per-channel-state handler. + async fn dispatch_framed_message( + &mut self, + channel_id: ChannelId, + data: Bytes, + ) -> crate::Result<()> { let Some(state) = self.channels.remove(&channel_id) else { tracing::warn!( target: LOG_TARGET, diff --git a/src/transport/webrtc/opening.rs b/src/transport/webrtc/opening.rs index 937dd0a6..ec35dad9 100644 --- a/src/transport/webrtc/opening.rs +++ b/src/transport/webrtc/opening.rs @@ -23,11 +23,15 @@ use crate::{ config::Role, crypto::{ed25519::Keypair, noise::NoiseContext}, - transport::{webrtc::util::WebRtcMessage, Endpoint}, + transport::{ + webrtc::util::{extract_framed_message, WebRtcMessage}, + Endpoint, + }, types::ConnectionId, Error, PeerId, }; +use bytes::BytesMut; use multiaddr::{Multiaddr, Protocol}; use multihash_codetable::Code; use str0m::{ @@ -111,6 +115,18 @@ pub struct OpeningWebRtcConnection { /// Local address. local_address: SocketAddr, + + /// Inbound noise-channel byte buffer for reassembling protobuf frames. + /// + /// The libp2p-go msgio implementation issues two separate `Write` calls: + /// - variant length + /// - protobuf body + /// + /// These will become two distinct SCTP messages on the data channel. + /// + /// Accumulate raw bytes here and only attempt protobuf decode once a + /// full `varint length ++ body` frame is available. + noise_recv_buffer: BytesMut, } /// Connection state. @@ -168,6 +184,7 @@ impl OpeningWebRtcConnection { id_keypair, peer_address, local_address, + noise_recv_buffer: BytesMut::new(), } } @@ -199,7 +216,12 @@ impl OpeningWebRtcConnection { /// /// Create the first Noise handshake message and send it to remote peer. fn on_noise_channel_open(&mut self) -> crate::Result<()> { - tracing::trace!(target: LOG_TARGET, "send initial noise handshake"); + tracing::trace!( + target: LOG_TARGET, + connection_id = ?self.connection_id, + peer = ?self.peer_address, + "send initial noise handshake", + ); let State::Opened { mut context } = std::mem::replace(&mut self.state, State::Poisoned) else { @@ -234,6 +256,8 @@ impl OpeningWebRtcConnection { if let Err(error) = self.rtc.handle_input(Input::Timeout(Instant::now())) { tracing::error!( target: LOG_TARGET, + connection_id = ?self.connection_id, + peer = ?self.peer_address, ?error, "failed to handle timeout for `Rtc`" ); @@ -254,8 +278,38 @@ impl OpeningWebRtcConnection { /// /// If the peer is accepted, [`OpeningWebRtcConnection::on_accept()`] is called which creates /// the final Noise message and sends it to the remote peer, concluding the handshake. - fn on_noise_channel_data(&mut self, data: Vec) -> crate::Result { - tracing::trace!(target: LOG_TARGET, "handle noise handshake reply"); + fn on_noise_channel_data(&mut self, data: Vec) -> crate::Result> { + tracing::trace!( + target: LOG_TARGET, + connection_id = ?self.connection_id, + peer = ?self.peer_address, + len = data.len(), + buffered = self.noise_recv_buffer.len(), + "noise channel data received", + ); + + self.noise_recv_buffer.extend_from_slice(&data); + + let body = match extract_framed_message(&mut self.noise_recv_buffer)? { + Some(body) => body, + None => { + tracing::trace!( + target: LOG_TARGET, + connection_id = ?self.connection_id, + peer = ?self.peer_address, + buffered = self.noise_recv_buffer.len(), + "incomplete noise frame, waiting for more bytes", + ); + return Ok(None); + } + }; + + tracing::trace!( + target: LOG_TARGET, + connection_id = ?self.connection_id, + peer = ?self.peer_address, + "handle noise handshake reply", + ); let State::HandshakeSent { mut context } = std::mem::replace(&mut self.state, State::Poisoned) @@ -263,11 +317,13 @@ impl OpeningWebRtcConnection { return Err(Error::InvalidState); }; - let message = WebRtcMessage::decode(&data)?.payload.ok_or(Error::InvalidData)?; + let message = WebRtcMessage::decode(&body)?.payload.ok_or(Error::InvalidData)?; let remote_peer_id = context.get_remote_peer_id(&message)?; tracing::trace!( target: LOG_TARGET, + connection_id = ?self.connection_id, + peer = ?self.peer_address, ?remote_peer_id, "remote reply parsed successfully", ); @@ -293,16 +349,21 @@ impl OpeningWebRtcConnection { .with(Protocol::Certhash(certificate)) .with(Protocol::P2p(remote_peer_id.into())); - Ok(WebRtcEvent::ConnectionOpened { + Ok(Some(WebRtcEvent::ConnectionOpened { peer: remote_peer_id, endpoint: Endpoint::listener(address, self.connection_id), - }) + })) } /// Accept connection by sending the final Noise handshake message /// and return the `Rtc` object for further use. pub fn on_accept(mut self) -> crate::Result { - tracing::trace!(target: LOG_TARGET, "accept webrtc connection"); + tracing::trace!( + target: LOG_TARGET, + connection_id = ?self.connection_id, + peer = ?self.peer_address, + "accept webrtc connection", + ); let State::Validating { mut context } = std::mem::replace(&mut self.state, State::Poisoned) else { @@ -350,7 +411,11 @@ impl OpeningWebRtcConnection { match self.rtc.accepts(&message) { true => self.rtc.handle_input(message).map_err(|error| { - tracing::debug!(target: LOG_TARGET, source = ?self.peer_address, ?error, "failed to handle data"); + tracing::debug!( + target: LOG_TARGET, + peer = ?self.peer_address, + ?error, "failed to handle data" + ); Error::InputRejected }), false => { @@ -369,6 +434,8 @@ impl OpeningWebRtcConnection { if !self.rtc.is_alive() { tracing::debug!( target: LOG_TARGET, + connection_id = ?self.connection_id, + peer = ?self.peer_address, "`Rtc` is not alive, closing `WebRtcConnection`" ); @@ -382,6 +449,7 @@ impl OpeningWebRtcConnection { tracing::debug!( target: LOG_TARGET, connection_id = ?self.connection_id, + peer = ?self.peer_address, ?error, "`WebRtcConnection::poll_process()` failed", ); @@ -394,6 +462,8 @@ impl OpeningWebRtcConnection { Output::Transmit(transmit) => { tracing::trace!( target: LOG_TARGET, + connection_id = ?self.connection_id, + peer = ?self.peer_address, "transmit data", ); @@ -406,13 +476,19 @@ impl OpeningWebRtcConnection { Output::Event(e) => match e { Event::IceConnectionStateChange(v) => if v == IceConnectionState::Disconnected { - tracing::trace!(target: LOG_TARGET, "ice connection closed"); + tracing::trace!( + target: LOG_TARGET, + connection_id = ?self.connection_id, + peer = ?self.peer_address, + "ice connection closed", + ); return WebRtcEvent::ConnectionClosed; }, Event::ChannelOpen(channel_id, name) => { tracing::trace!( target: LOG_TARGET, connection_id = ?self.connection_id, + peer = ?self.peer_address, ?channel_id, ?name, "channel opened", @@ -422,6 +498,7 @@ impl OpeningWebRtcConnection { tracing::warn!( target: LOG_TARGET, connection_id = ?self.connection_id, + peer = ?self.peer_address, ?channel_id, "ignoring opened channel", ); @@ -432,6 +509,7 @@ impl OpeningWebRtcConnection { tracing::debug!( target: LOG_TARGET, connection_id = ?self.connection_id, + peer = ?self.peer_address, ?error, "noise channel open failed", ); @@ -441,6 +519,8 @@ impl OpeningWebRtcConnection { Event::ChannelData(data) => { tracing::trace!( target: LOG_TARGET, + connection_id = ?self.connection_id, + peer = ?self.peer_address, "data received over channel", ); @@ -449,17 +529,20 @@ impl OpeningWebRtcConnection { target: LOG_TARGET, channel_id = ?data.id, connection_id = ?self.connection_id, + peer = ?self.peer_address, "ignoring data from channel", ); continue; } match self.on_noise_channel_data(data.data) { - Ok(event) => return event, + Ok(Some(event)) => return event, + Ok(None) => continue, Err(error) => { tracing::debug!( target: LOG_TARGET, connection_id = ?self.connection_id, + peer = ?self.peer_address, ?error, "noise channel data handling failed", ); @@ -468,7 +551,13 @@ impl OpeningWebRtcConnection { } } Event::ChannelClose(channel_id) => { - tracing::debug!(target: LOG_TARGET, ?channel_id, "channel closed"); + tracing::debug!( + target: LOG_TARGET, + connection_id = ?self.connection_id, + peer = ?self.peer_address, + ?channel_id, + "channel closed", + ); } Event::Connected => match std::mem::replace(&mut self.state, State::Poisoned) { State::Closed => { @@ -483,6 +572,7 @@ impl OpeningWebRtcConnection { Err(err) => { tracing::error!( target: LOG_TARGET, + connection_id = ?self.connection_id, peer = ?self.peer_address, "NoiseContext failed with error {err}", ); @@ -493,6 +583,7 @@ impl OpeningWebRtcConnection { tracing::debug!( target: LOG_TARGET, + connection_id = ?self.connection_id, peer = ?self.peer_address, "connection opened", ); @@ -502,6 +593,7 @@ impl OpeningWebRtcConnection { state => { tracing::debug!( target: LOG_TARGET, + connection_id = ?self.connection_id, peer = ?self.peer_address, ?state, "invalid state for connection" diff --git a/src/transport/webrtc/substream.rs b/src/transport/webrtc/substream.rs index 6599a1ef..9acfbe44 100644 --- a/src/transport/webrtc/substream.rs +++ b/src/transport/webrtc/substream.rs @@ -19,7 +19,10 @@ // DEALINGS IN THE SOFTWARE. use crate::{ - transport::webrtc::{schema::webrtc::message::Flag, util::WebRtcMessage}, + transport::webrtc::{ + schema::webrtc::message::Flag, + util::{WebRtcMessage, MAX_FRAME_SIZE}, + }, Error, }; @@ -36,9 +39,6 @@ use std::{ time::Duration, }; -/// Maximum frame size. -const MAX_FRAME_SIZE: usize = 16384; - /// Timeout for waiting on FIN_ACK after sending FIN. /// Matches go-libp2p's 5 second stream close timeout. const FIN_ACK_TIMEOUT: Duration = Duration::from_secs(5); diff --git a/src/transport/webrtc/util.rs b/src/transport/webrtc/util.rs index ae050d50..3ce786ab 100644 --- a/src/transport/webrtc/util.rs +++ b/src/transport/webrtc/util.rs @@ -23,8 +23,15 @@ use crate::{ transport::webrtc::schema::{self, webrtc::message::Flag}, }; +use bytes::{Bytes, BytesMut}; use prost::Message; +/// Logging target for the file. +const LOG_TARGET: &str = "litep2p::webrtc"; + +/// Maximum size of a single framed WebRTC body in bytes. +pub const MAX_FRAME_SIZE: usize = 16 * 1024; + /// WebRTC message. #[derive(Debug)] pub struct WebRtcMessage { @@ -72,34 +79,23 @@ impl WebRtcMessage { out_buf } - /// Decode payload into [`WebRtcMessage`]. - /// - /// Decodes the varint length prefix directly from the slice without allocations, - /// then decodes the protobuf message from the remaining bytes. + /// Decode a protobuf-encoded [`schema::webrtc::Message`] body with no varint length prefix. /// /// # Flag handling /// /// Unknown flag values (e.g., from a newer protocol version) are logged as warnings /// and treated as `None` for forward compatibility. This allows the message payload /// to still be processed even if the flag is not recognized. - pub fn decode(payload: &[u8]) -> Result { - // Decode varint length prefix directly from slice (no allocation) - // Returns (decoded_length, remaining_bytes_after_varint) - let (len, remaining) = - unsigned_varint::decode::usize(payload).map_err(|_| ParseError::InvalidData)?; - - // Get exactly `len` bytes of protobuf data (no allocation) - let protobuf_data = remaining.get(..len).ok_or(ParseError::InvalidData)?; - + pub fn decode(protobuf_data: &[u8]) -> Result { match schema::webrtc::Message::decode(protobuf_data) { Ok(message) => { let flag = message.flag.and_then(|f| match Flag::try_from(f) { Ok(flag) => Some(flag), Err(_) => { tracing::warn!( - target: "litep2p::webrtc", + target: LOG_TARGET, ?f, - "received message with unknown flag value, ignoring flag" + "Received message with unknown flag value, ignoring flag" ); None } @@ -114,14 +110,86 @@ impl WebRtcMessage { } } +/// Try to extract one complete `varint length ++ body` frame from the front of `buffer`. +pub fn extract_framed_message(buffer: &mut BytesMut) -> Result, ParseError> { + let (len, remaining) = match unsigned_varint::decode::usize(buffer) { + Ok(decoded) => decoded, + // More bytes may arrive and complete the varint. + Err(unsigned_varint::decode::Error::Insufficient) => { + tracing::trace!( + target: LOG_TARGET, + buffer_len = buffer.len(), + "Received incomplete SCTP varint header, waiting for more data" + ); + return Ok(None); + } + // Permanent failures. + Err(err) => { + tracing::debug!( + target: LOG_TARGET, + ?err, + buffer_len = buffer.len(), + "Permanent error encountered during SCTP varint framing" + ); + return Err(ParseError::InvalidData); + } + }; + + // Reject oversized frames before waiting for the body. + if len > MAX_FRAME_SIZE { + tracing::debug!( + target: LOG_TARGET, + declared_len = len, + max = MAX_FRAME_SIZE, + "Rejecting oversized SCTP frame" + ); + return Err(ParseError::InvalidData); + } + + if remaining.len() < len { + tracing::trace!( + target: LOG_TARGET, + expected_body_len = len, + available_body_len = remaining.len(), + "Received incomplete SCTP payload, waiting for more data" + ); + return Ok(None); + } + + let varint_len = buffer.len() - remaining.len(); + // Slice off the whole frame, then drop the varint header and freeze the body. + let mut frame = buffer.split_to(varint_len + len); + let _ = frame.split_to(varint_len); + + tracing::trace!( + target: LOG_TARGET, + message_len = len, + varint_len, + "Successfully extracted SCTP framed message" + ); + + Ok(Some(frame.freeze())) +} + #[cfg(test)] mod tests { use super::*; + /// Strip the unsigned-varint length prefix that [`WebRtcMessage::encode`] prepends, + /// returning the bare protobuf body that [`WebRtcMessage::decode`] expects. + fn protobuf_body(encoded: &[u8]) -> &[u8] { + let (len, rest) = unsigned_varint::decode::usize(encoded).unwrap(); + &rest[..len] + } + + fn buf(bytes: &[u8]) -> BytesMut { + BytesMut::from(bytes) + } + #[test] fn with_payload_no_flag() { let message = WebRtcMessage::encode("Hello, world!".as_bytes().to_vec(), None); - let decoded = WebRtcMessage::decode(&message).unwrap(); + let decoded = WebRtcMessage::decode(protobuf_body(&message)).unwrap(); assert_eq!(decoded.payload, Some("Hello, world!".as_bytes().to_vec())); assert_eq!(decoded.flag, None); @@ -131,7 +199,7 @@ mod tests { fn with_payload_and_flag() { let message = WebRtcMessage::encode("Hello, world!".as_bytes().to_vec(), Some(Flag::StopSending)); - let decoded = WebRtcMessage::decode(&message).unwrap(); + let decoded = WebRtcMessage::decode(protobuf_body(&message)).unwrap(); assert_eq!(decoded.payload, Some("Hello, world!".as_bytes().to_vec())); assert_eq!(decoded.flag, Some(Flag::StopSending)); @@ -140,9 +208,302 @@ mod tests { #[test] fn no_payload_with_flag() { let message = WebRtcMessage::encode(vec![], Some(Flag::ResetStream)); - let decoded = WebRtcMessage::decode(&message).unwrap(); + let decoded = WebRtcMessage::decode(protobuf_body(&message)).unwrap(); assert_eq!(decoded.payload, None); assert_eq!(decoded.flag, Some(Flag::ResetStream)); } + + #[test] + fn extract_single_frame_one_chunk() { + // The common case: a peer (e.g. smoldot) sends the whole `varint ++ body` + // in a single SCTP message. Extraction should succeed and drain the buffer. + let frame = WebRtcMessage::encode(b"hello".to_vec(), None); + let mut buffer = buf(&frame); + + let body = extract_framed_message(&mut buffer).unwrap().expect("complete frame"); + assert_eq!(&body[..], protobuf_body(&frame)); + assert!(buffer.is_empty(), "buffer fully drained"); + } + + #[test] + fn extract_single_frame_empty_body() { + // A zero-length body is a legal frame: `0x00` varint, no body bytes. + let mut buffer = buf(&[0x00]); + + let body = extract_framed_message(&mut buffer).unwrap().expect("zero-length frame"); + assert!(body.is_empty()); + assert!(buffer.is_empty()); + } + + #[test] + fn extract_frame_split_varint_then_body() { + // go-libp2p's pbio writer issues two `Write` calls (varint, then body) + // which surface as two SCTP messages. + let frame = WebRtcMessage::encode(b"split-across-sctp".to_vec(), None); + let (len, rest) = unsigned_varint::decode::usize(&frame).unwrap(); + let varint_bytes = &frame[..frame.len() - rest.len()]; + let body_bytes = &rest[..len]; + + let mut buffer = BytesMut::new(); + + // SCTP message #1: just the varint. No complete frame yet. + buffer.extend_from_slice(varint_bytes); + assert!(extract_framed_message(&mut buffer).unwrap().is_none()); + assert_eq!( + &buffer[..], + varint_bytes, + "varint preserved for next attempt" + ); + + // SCTP message #2: the body arrives. Extraction now succeeds. + buffer.extend_from_slice(body_bytes); + let body = extract_framed_message(&mut buffer).unwrap().expect("frame now complete"); + assert_eq!(&body[..], body_bytes); + assert!(buffer.is_empty()); + } + + #[test] + fn extract_frame_split_with_multi_byte_varint() { + // Real noise frames are bigger than 128 bytes, so the varint itself is multi-byte. + // Verify the split still works when the varint takes 2 bytes. + let payload = vec![0xab; 300]; + let frame = WebRtcMessage::encode(payload, None); + let (len, rest) = unsigned_varint::decode::usize(&frame).unwrap(); + let varint_bytes = &frame[..frame.len() - rest.len()]; + let body_bytes = &rest[..len]; + assert!(varint_bytes.len() >= 2, "expected multi-byte varint"); + + let mut buffer = BytesMut::new(); + buffer.extend_from_slice(varint_bytes); + assert!(extract_framed_message(&mut buffer).unwrap().is_none()); + + buffer.extend_from_slice(body_bytes); + let body = extract_framed_message(&mut buffer).unwrap().expect("complete frame"); + assert_eq!(body.len(), len); + assert_eq!(&body[..], body_bytes); + assert!(buffer.is_empty()); + } + + #[test] + fn extract_frame_with_partial_varint() { + // Even more adversarial: the varint itself is split across SCTP messages. + // First byte alone has the high bit set, so the varint isn't decodable yet. + // This must be classified as "incomplete" (Ok(None)), not "malformed" (Err) — + // more bytes will fix it. + let payload = vec![0xcd; 300]; + let frame = WebRtcMessage::encode(payload, None); + let (len, rest) = unsigned_varint::decode::usize(&frame).unwrap(); + let varint_bytes = &frame[..frame.len() - rest.len()]; + assert!(varint_bytes.len() >= 2); + + let mut buffer = BytesMut::new(); + + // First byte of the varint only — undecodable. + buffer.extend_from_slice(&varint_bytes[..1]); + assert!(extract_framed_message(&mut buffer).unwrap().is_none()); + + // Remainder of varint arrives, body still missing. + buffer.extend_from_slice(&varint_bytes[1..]); + assert!(extract_framed_message(&mut buffer).unwrap().is_none()); + + // Body arrives — frame now complete. + buffer.extend_from_slice(&rest[..len]); + let body = extract_framed_message(&mut buffer).unwrap().expect("complete frame"); + assert_eq!(body.len(), len); + assert!(buffer.is_empty()); + } + + #[test] + fn extract_from_empty_buffer() { + let mut buffer = BytesMut::new(); + assert!(extract_framed_message(&mut buffer).unwrap().is_none()); + assert!(buffer.is_empty()); + } + + #[test] + fn extract_two_frames_concatenated() { + // Inbound path drains in a loop: if two frames are coalesced into one SCTP + // message, two consecutive extractions must each yield one body. + let frame_a = WebRtcMessage::encode(b"first".to_vec(), None); + let frame_b = WebRtcMessage::encode(b"second".to_vec(), None); + let body_a = protobuf_body(&frame_a).to_vec(); + let body_b = protobuf_body(&frame_b).to_vec(); + + let mut buffer = BytesMut::new(); + buffer.extend_from_slice(&frame_a); + buffer.extend_from_slice(&frame_b); + + let extracted_a = extract_framed_message(&mut buffer).unwrap().expect("first frame"); + assert_eq!(&extracted_a[..], &body_a[..]); + + let extracted_b = extract_framed_message(&mut buffer).unwrap().expect("second frame"); + assert_eq!(&extracted_b[..], &body_b[..]); + + assert!(buffer.is_empty()); + assert!(extract_framed_message(&mut buffer).unwrap().is_none()); + } + + #[test] + fn extract_frame_then_partial_next_frame() { + // One complete frame followed by the start of a second frame: the first + // frame is returned and the partial bytes of frame #2 remain buffered. + let frame_a = WebRtcMessage::encode(b"complete".to_vec(), None); + let frame_b = WebRtcMessage::encode(b"incoming".to_vec(), None); + let body_a = protobuf_body(&frame_a).to_vec(); + + let mut buffer = BytesMut::new(); + buffer.extend_from_slice(&frame_a); + buffer.extend_from_slice(&frame_b[..2]); // partial second frame + + let extracted = extract_framed_message(&mut buffer).unwrap().expect("first frame"); + assert_eq!(&extracted[..], &body_a[..]); + assert_eq!(&buffer[..], &frame_b[..2], "partial second frame preserved"); + + // Second extraction is a no-op until the rest of frame_b arrives. + assert!(extract_framed_message(&mut buffer).unwrap().is_none()); + buffer.extend_from_slice(&frame_b[2..]); + + let extracted_b = + extract_framed_message(&mut buffer).unwrap().expect("second frame complete"); + assert_eq!(&extracted_b[..], protobuf_body(&frame_b)); + assert!(buffer.is_empty()); + } + + #[test] + fn extract_body_arrives_byte_by_byte() { + // Worst-case fragmentation: every body byte arrives in its own SCTP message. + let payload: Vec = (0..50u8).collect(); + let frame = WebRtcMessage::encode(payload.clone(), None); + let (len, rest) = unsigned_varint::decode::usize(&frame).unwrap(); + let varint_bytes = &frame[..frame.len() - rest.len()]; + let body_bytes = rest[..len].to_vec(); + + let mut buffer = BytesMut::new(); + buffer.extend_from_slice(varint_bytes); + assert!(extract_framed_message(&mut buffer).unwrap().is_none()); + + for (i, byte) in body_bytes.iter().enumerate() { + buffer.extend_from_slice(&[*byte]); + if i + 1 < body_bytes.len() { + assert!( + extract_framed_message(&mut buffer).unwrap().is_none(), + "should still be waiting at byte {i}", + ); + } + } + + let extracted = extract_framed_message(&mut buffer).unwrap().expect("complete frame"); + assert_eq!(&extracted[..], &body_bytes[..]); + assert!(buffer.is_empty()); + } + + #[test] + fn extract_does_not_consume_on_failure() { + // On `Ok(None)`, the buffer must be left exactly as-is so the caller can + // append more bytes and retry. Verify both for the partial-varint and + // partial-body cases. + let frame = WebRtcMessage::encode(vec![0u8; 200], None); + let (len, rest) = unsigned_varint::decode::usize(&frame).unwrap(); + let varint_bytes = &frame[..frame.len() - rest.len()]; + + // Partial varint. + let mut buffer = buf(&varint_bytes[..1]); + let snapshot = buffer.clone(); + assert!(extract_framed_message(&mut buffer).unwrap().is_none()); + assert_eq!(&buffer[..], &snapshot[..]); + + // Complete varint, partial body. + let mut buffer = buf(varint_bytes); + buffer.extend_from_slice(&rest[..len / 2]); + let snapshot = buffer.clone(); + assert!(extract_framed_message(&mut buffer).unwrap().is_none()); + assert_eq!(&buffer[..], &snapshot[..]); + } + + #[test] + fn extract_rejects_overlong_varint() { + // A malicious peer sends a varint whose length prefix exceeds usize. With + // a string of all-`0x80` bytes (every continuation byte present, value zero), + // the accumulated value eventually overflows usize. This must surface as + // `Err(InvalidData)` — not `Ok(None)`, otherwise the buffer would grow + // unboundedly while waiting for "more bytes" that can never help. + let mut buffer = buf(&[0x80u8; 11]); + let err = extract_framed_message(&mut buffer).expect_err("overlong varint must error"); + assert!(matches!(err, ParseError::InvalidData)); + } + + #[test] + fn extract_rejects_oversized_frame() { + // A malicious peer declares a body just over `MAX_FRAME_SIZE` and then dribbles + // bytes in. Without this cap, the buffer would grow without bound waiting for + // the body to complete. With the cap, the oversized varint is rejected the + // moment it decodes — regardless of how many body bytes have actually arrived. + let oversized = MAX_FRAME_SIZE + 1; + let mut varint_buf = unsigned_varint::encode::usize_buffer(); + let varint = unsigned_varint::encode::usize(oversized, &mut varint_buf); + + // Just the varint, no body — would otherwise be `Ok(None)` (incomplete body). + let mut buffer = buf(varint); + let err = extract_framed_message(&mut buffer).expect_err("oversized frame must error"); + assert!(matches!(err, ParseError::InvalidData)); + + // Even with a partial body, the result is still `Err` — the check happens + // before the body-length check. + let mut buffer = buf(varint); + buffer.extend_from_slice(&[0u8; 100]); + let err = extract_framed_message(&mut buffer).expect_err("oversized frame must error"); + assert!(matches!(err, ParseError::InvalidData)); + } + + #[test] + fn extract_accepts_max_frame_size() { + // A body of exactly `MAX_FRAME_SIZE` bytes is the largest legal frame and must + // still extract — the cap is "≤ MAX_FRAME_SIZE", not "< MAX_FRAME_SIZE". + let body = vec![0xa5u8; MAX_FRAME_SIZE]; + let mut buffer = BytesMut::new(); + let mut varint_buf = unsigned_varint::encode::usize_buffer(); + buffer.extend_from_slice(unsigned_varint::encode::usize(body.len(), &mut varint_buf)); + buffer.extend_from_slice(&body); + + let extracted = + extract_framed_message(&mut buffer).unwrap().expect("max-size frame extracts"); + assert_eq!(extracted.len(), MAX_FRAME_SIZE); + assert_eq!(&extracted[..], &body[..]); + assert!(buffer.is_empty()); + } + + #[test] + fn extract_rejects_non_minimal_varint() { + // `0x80 0x00` decodes to value 0 but is non-minimal — a single `0x00` byte + // is the canonical encoding. The decoder rejects this with `NotMinimal`, + // which we propagate as `Err` to avoid wedging the inbound buffer. + let mut buffer = buf(&[0x80u8, 0x00]); + let err = extract_framed_message(&mut buffer).expect_err("non-minimal varint must error"); + assert!(matches!(err, ParseError::InvalidData)); + } + + #[test] + fn extract_returns_zero_copy_body() { + // The returned `Bytes` should be a view over the same allocation as the + // input buffer — no fresh allocation, no copy. Verify by checking that the + // returned `Bytes` shares the same pointer as the slice that lives in the + // buffer before extraction. + let frame = WebRtcMessage::encode(vec![0u8; 256], None); + let mut buffer = buf(&frame); + + let (len, rest) = unsigned_varint::decode::usize(&buffer).unwrap(); + let varint_len = buffer.len() - rest.len(); + // SAFETY: we just decoded the varint, so this is the body's start address. + let expected_ptr = buffer.as_ptr().wrapping_add(varint_len); + let expected_len = len; + + let body = extract_framed_message(&mut buffer).unwrap().expect("complete frame"); + assert_eq!(body.len(), expected_len); + assert_eq!( + body.as_ptr(), + expected_ptr, + "Bytes must be a zero-copy view" + ); + } }