From 6a6af0d10d6c4d0103f6a4bb5d1b7d95fd12c3fa Mon Sep 17 00:00:00 2001 From: Haiko Schol Date: Wed, 18 Feb 2026 17:26:59 +0700 Subject: [PATCH 01/11] webrtc: Make multistream-select negotiation spec compliant on outbound substreams --- src/multistream_select/dialer_select.rs | 187 ++++++++++++++++-------- src/transport/webrtc/connection.rs | 126 +++++++++++++--- 2 files changed, 230 insertions(+), 83 deletions(-) diff --git a/src/multistream_select/dialer_select.rs b/src/multistream_select/dialer_select.rs index 86793660b..e0131eb6c 100644 --- a/src/multistream_select/dialer_select.rs +++ b/src/multistream_select/dialer_select.rs @@ -24,7 +24,6 @@ use crate::{ codec::unsigned_varint::UnsignedVarint, error::{self, Error, ParseError, SubstreamError}, multistream_select::{ - drain_trailing_protocols, protocol::{ webrtc_encode_multistream_message, HeaderLine, Message, MessageIO, Protocol, ProtocolError, PROTO_MULTISTREAM_1_0, @@ -300,6 +299,12 @@ pub enum HandshakeResult { /// The returned tuple contains the negotiated protocol and response /// that must be sent to remote peer. Succeeded(ProtocolName), + + /// The proposed protocol was rejected by the remote peer. + /// + /// The caller should check if there are remaining fallback protocols to try + /// via [`WebRtcDialerState::propose_next_fallback()`]. + Rejected, } /// Handshake state. @@ -334,12 +339,9 @@ impl WebRtcDialerState { protocol: ProtocolName, fallback_names: Vec, ) -> crate::Result<(Self, Vec)> { - let message = webrtc_encode_multistream_message( - std::iter::once(protocol.clone()) - .chain(fallback_names.clone()) - .filter_map(|protocol| Protocol::try_from(protocol.as_ref()).ok()) - .map(Message::Protocol), - )? + let message = webrtc_encode_multistream_message(std::iter::once(Message::Protocol( + Protocol::try_from(protocol.as_ref()).map_err(|_| Error::InvalidData)?, + )))? .freeze() .to_vec(); @@ -353,69 +355,83 @@ impl WebRtcDialerState { )) } + /// Propose the next fallback protocol to the remote peer. + /// + /// Returns `None` if there are no more fallback protocols to try. + /// Returns `Some(message)` with the encoded message to send, containing the protocol name. + pub fn propose_next_fallback(&mut self) -> crate::Result>> { + if self.fallback_names.is_empty() { + return Ok(None); + } + + let next = self.fallback_names.remove(0); + self.protocol = next; + self.state = HandshakeState::WaitingResponse; + + let message = webrtc_encode_multistream_message(std::iter::once(Message::Protocol( + Protocol::try_from(self.protocol.as_ref()).map_err(|_| Error::InvalidData)?, + )))? + .freeze() + .to_vec(); + + Ok(Some(message)) + } + /// Register response to [`WebRtcDialerState`]. pub fn register_response( &mut self, payload: Vec, ) -> Result { - // All multistream-select messages are length-prefixed. Since this code path is not using - // multistream_select::protocol::MessageIO, we need to decode and remove the length here. - let remaining: &[u8] = &payload; - let (len, tail) = unsigned_varint::decode::usize(remaining).map_err(|error| { - tracing::debug!( + let bytes = Bytes::from(payload); + let mut remaining = bytes.clone(); + + while !remaining.is_empty() { + let (len, tail) = unsigned_varint::decode::usize(&remaining).map_err(|error| { + tracing::debug!( target: LOG_TARGET, ?error, - message = ?payload, - "Failed to decode length-prefix in multistream message"); - error::NegotiationError::ParseError(ParseError::InvalidData) - })?; + message = ?remaining, + "Failed to decode length-prefix in multistream message", + ); + error::NegotiationError::ParseError(ParseError::InvalidData) + })?; - let len_size = remaining.len() - tail.len(); - let bytes = Bytes::from(payload); - let payload = bytes.slice(len_size..len_size + len); - let remaining = bytes.slice(len_size + len..); - let message = Message::decode(payload); - - tracing::trace!( - target: LOG_TARGET, - ?message, - "Decoded message while registering response", - ); - - let mut protocols = match message { - Ok(Message::Header(HeaderLine::V1)) => { - vec![PROTO_MULTISTREAM_1_0] + let len_size = remaining.len() - tail.len(); + + if len > tail.len() { + tracing::debug!( + target: LOG_TARGET, + message = ?tail, + length_prefix = len, + actual_length = tail.len(), + "Truncated multistream message", + ); + return Err(error::NegotiationError::ParseError(ParseError::InvalidData)); } - Ok(Message::Protocol(protocol)) => vec![protocol], - Ok(Message::Protocols(protocols)) => protocols, - Ok(Message::NotAvailable) => - return match &self.state { - HandshakeState::WaitingProtocol => Err( - error::NegotiationError::MultistreamSelectError(NegotiationError::Failed), - ), - _ => Err(error::NegotiationError::StateMismatch), - }, - Ok(Message::ListProtocols) => return Err(error::NegotiationError::StateMismatch), - Err(_) => return Err(error::NegotiationError::ParseError(ParseError::InvalidData)), - }; - protocols.extend(drain_trailing_protocols(remaining)?); + let payload = remaining.slice(len_size..len_size + len); + remaining = remaining.slice(len_size + len..); + let message = Message::decode(payload); - let mut protocol_iter = protocols.into_iter(); - loop { - match (&self.state, protocol_iter.next()) { - (HandshakeState::WaitingResponse, None) => - return Err(crate::error::NegotiationError::StateMismatch), - (HandshakeState::WaitingResponse, Some(protocol)) => { - if protocol == PROTO_MULTISTREAM_1_0 { - self.state = HandshakeState::WaitingProtocol; - } else { - return Err(crate::error::NegotiationError::MultistreamSelectError( - NegotiationError::Failed, - )); - } + tracing::trace!( + target: LOG_TARGET, + ?message, + "Decoded message while registering response", + ); + + match (&self.state, message) { + (HandshakeState::WaitingResponse, Ok(Message::Header(HeaderLine::V1))) => { + self.state = HandshakeState::WaitingProtocol; } - (HandshakeState::WaitingProtocol, Some(protocol)) => { + (HandshakeState::WaitingResponse, Ok(Message::Protocol(_))) => { + return Err(crate::error::NegotiationError::MultistreamSelectError( + NegotiationError::Failed, + )); + } + (_, Ok(Message::NotAvailable)) => { + return Ok(HandshakeResult::Rejected); + } + (HandshakeState::WaitingProtocol, Ok(Message::Protocol(protocol))) => { if protocol == PROTO_MULTISTREAM_1_0 { return Err(crate::error::NegotiationError::StateMismatch); } @@ -434,11 +450,16 @@ impl WebRtcDialerState { NegotiationError::Failed, )); } - (HandshakeState::WaitingProtocol, None) => { - return Ok(HandshakeResult::NotReady); + _ => { + return Err(crate::error::NegotiationError::StateMismatch); } } } + + match &self.state { + HandshakeState::WaitingProtocol => Ok(HandshakeResult::NotReady), + HandshakeState::WaitingResponse => Err(crate::error::NegotiationError::StateMismatch), + } } } @@ -813,6 +834,7 @@ mod tests { ) .unwrap(); + // Initial message should only contain the main protocol, not the fallback. let mut bytes = BytesMut::with_capacity(32); bytes.put_u8(MSG_MULTISTREAM_1_0.len() as u8); Message::Header(HeaderLine::V1).encode(&mut bytes).unwrap(); @@ -821,15 +843,52 @@ mod tests { bytes.put_u8((proto1.as_ref().len() + 1) as u8); // + 1 for \n Message::Protocol(proto1).encode(&mut bytes).unwrap(); - let proto2 = Protocol::try_from(&b"/sup/proto/1"[..]).expect("valid protocol name"); - bytes.put_u8((proto2.as_ref().len() + 1) as u8); // + 1 for \n - Message::Protocol(proto2).encode(&mut bytes).unwrap(); - let expected_message = bytes.freeze().to_vec(); assert_eq!(message, expected_message); } + #[test] + fn propose_next_fallback() { + let (mut dialer_state, _message) = WebRtcDialerState::propose( + ProtocolName::from("/13371338/proto/1"), + vec![ProtocolName::from("/sup/proto/1")], + ) + .unwrap(); + + // Simulate receiving header-only response, transitioning to WaitingProtocol. + let mut header_bytes = BytesMut::with_capacity(32); + header_bytes.put_u8(MSG_MULTISTREAM_1_0.len() as u8); + let _ = Message::Header(HeaderLine::V1).encode(&mut header_bytes).unwrap(); + // Append "na" to simulate rejection. + let na_bytes = b"na\n"; + header_bytes.put_u8(na_bytes.len() as u8); + header_bytes.put_slice(na_bytes); + + match dialer_state.register_response(header_bytes.freeze().to_vec()) { + Ok(HandshakeResult::Rejected) => {} + event => panic!("expected Rejected, got: {event:?}"), + } + + // Now propose the next fallback. + let fallback_message = dialer_state + .propose_next_fallback() + .expect("no error") + .expect("should have a fallback"); + + let mut expected = BytesMut::with_capacity(32); + expected.put_u8(MSG_MULTISTREAM_1_0.len() as u8); + let _ = Message::Header(HeaderLine::V1).encode(&mut expected).unwrap(); + let proto = Protocol::try_from(&b"/sup/proto/1"[..]).expect("valid protocol name"); + expected.put_u8((proto.as_ref().len() + 1) as u8); + let _ = Message::Protocol(proto).encode(&mut expected).unwrap(); + + assert_eq!(fallback_message, expected.freeze().to_vec()); + + // No more fallbacks. + assert!(dialer_state.propose_next_fallback().unwrap().is_none()); + } + #[test] fn register_response_header_only() { let mut bytes = BytesMut::with_capacity(32); diff --git a/src/transport/webrtc/connection.rs b/src/transport/webrtc/connection.rs index f01520160..94608c8db 100644 --- a/src/transport/webrtc/connection.rs +++ b/src/transport/webrtc/connection.rs @@ -21,7 +21,8 @@ use crate::{ error::{Error, ParseError, SubstreamError}, multistream_select::{ - webrtc_listener_negotiate, HandshakeResult, ListenerSelectResult, WebRtcDialerState, + webrtc_listener_negotiate, HandshakeResult, ListenerSelectResult, NegotiationError, + WebRtcDialerState, }, protocol::{Direction, Permit, ProtocolCommand, ProtocolSet, SubstreamKeepAlive}, substream::Substream, @@ -411,23 +412,75 @@ impl WebRtcConnection { ParseError::InvalidData.into(), ))?; - let HandshakeResult::Succeeded(protocol) = dialer_state.register_response(message)? else { - tracing::trace!( - target: LOG_TARGET, - peer = ?self.peer, - ?channel_id, - "multistream-select handshake not ready", - ); + let protocol = match dialer_state.register_response(message)? { + HandshakeResult::Succeeded(protocol) => protocol, + HandshakeResult::NotReady => { + tracing::trace!( + target: LOG_TARGET, + peer = ?self.peer, + ?channel_id, + "multistream-select handshake not ready", + ); - self.channels.insert( - channel_id, - ChannelState::OutboundOpening { - context, - dialer_state, - }, - ); + self.channels.insert( + channel_id, + ChannelState::OutboundOpening { + context, + dialer_state, + }, + ); + + return Ok(None); + } + HandshakeResult::Rejected => match dialer_state.propose_next_fallback() { + Ok(Some(message)) => { + tracing::trace!( + target: LOG_TARGET, + peer = ?self.peer, + ?channel_id, + "protocol rejected, trying next fallback", + ); + + let message = WebRtcMessage::encode(message, None); + self.rtc + .channel(channel_id) + .ok_or(Error::ChannelDoesntExist) + .map_err(|_| { + SubstreamError::NegotiationError(NegotiationError::Failed.into()) + })? + .write(true, message.as_ref()) + .map_err(|_| { + SubstreamError::NegotiationError(NegotiationError::Failed.into()) + })?; - return Ok(None); + self.channels.insert( + channel_id, + ChannelState::OutboundOpening { + context, + dialer_state, + }, + ); + + return Ok(None); + } + Ok(None) => { + tracing::debug!( + target: LOG_TARGET, + peer = ?self.peer, + ?channel_id, + "all protocols rejected by remote peer", + ); + + return Err(SubstreamError::NegotiationError( + NegotiationError::Failed.into(), + )); + } + Err(_) => { + return Err(SubstreamError::NegotiationError( + NegotiationError::Failed.into(), + )); + } + }, }; let ChannelContext { @@ -468,13 +521,13 @@ impl WebRtcConnection { ) -> crate::Result<()> { let message = WebRtcMessage::decode(&data)?; - tracing::trace!( + tracing::debug!( target: LOG_TARGET, peer = ?self.peer, ?channel_id, flag = ?message.flag, data_len = message.payload.as_ref().map_or(0usize, |payload| payload.len()), - "handle inbound message", + "handle inbound message on open channel", ); self.handles @@ -495,6 +548,15 @@ impl WebRtcConnection { /// Handle data received from a channel. async fn on_inbound_data(&mut self, channel_id: ChannelId, data: Vec) -> crate::Result<()> { + tracing::debug!( + target: LOG_TARGET, + peer = ?self.peer, + ?channel_id, + data_len = data.len(), + channel_state = ?self.channels.get(&channel_id), + "received channel data", + ); + let Some(state) = self.channels.remove(&channel_id) else { tracing::warn!( target: LOG_TARGET, @@ -700,7 +762,19 @@ impl WebRtcConnection { pub async fn run_event_loop(mut self) { loop { // poll output until we get a timeout - let timeout = match self.rtc.poll_output().unwrap() { + let output = match self.rtc.poll_output() { + Ok(output) => output, + Err(error) => { + tracing::debug!( + target: LOG_TARGET, + peer = ?self.peer, + ?error, + "poll_output failed, closing connection", + ); + return self.on_connection_closed().await; + } + }; + let timeout = match output { Output::Timeout(v) => v, Output::Transmit(v) => { tracing::trace!( @@ -849,6 +923,20 @@ impl WebRtcConnection { keep_alive, connection_id: _, }) => { + // Check if the connection is still healthy before opening new substreams. + // This prevents panics when trying to open channels on a shutting-down + // SCTP association. + if !self.rtc.is_alive() || !self.rtc.is_connected() { + tracing::debug!( + target: LOG_TARGET, + peer = ?self.peer, + ?protocol, + is_alive = self.rtc.is_alive(), + is_connected = self.rtc.is_connected(), + "rejecting substream open: connection not healthy", + ); + continue; + } self.on_open_substream( protocol, fallback_names, From a293ad20bd7c876e32df8427b5d3ada83cba50d7 Mon Sep 17 00:00:00 2001 From: Haiko Schol Date: Wed, 18 Feb 2026 17:56:11 +0700 Subject: [PATCH 02/11] webrtc: Make multistream-select negotiation spec compliant on inbound substreams --- src/transport/webrtc/connection.rs | 39 +++++++++++++++++++++++------- 1 file changed, 30 insertions(+), 9 deletions(-) diff --git a/src/transport/webrtc/connection.rs b/src/transport/webrtc/connection.rs index 94608c8db..b2d20e048 100644 --- a/src/transport/webrtc/connection.rs +++ b/src/transport/webrtc/connection.rs @@ -309,16 +309,22 @@ impl WebRtcConnection { /// Handle data received to an opening inbound channel. /// /// The first message received over an inbound channel is the `multistream-select` handshake. - /// This handshake contains the protocol (and potentially fallbacks for that protocol) that - /// remote peer wants to use for this channel. Parse the handshake and check if any of the - /// proposed protocols are supported by the local node. If not, send rejection to remote peer - /// and close the channel. If the local node supports one of the protocols, send confirmation - /// for the protocol to remote peer and report an opened substream to the selected protocol. + /// This handshake contains the protocol the remote peer wants to use for this channel. Parse + /// the handshake and check whether the proposed protocol is supported by the local node. + /// If not, send rejection to remote peer and but keep the channel open so that the peer can + /// propose a fallback. If the local node support the protocol, send confirmation for the + /// protocol to remote peer and report an opened substream to the selected protocol. + /// + /// Returns `Ok(Some(...))` if the protocol was accepted and the substream opened, + /// `Ok(None)` if the proposed protocol was rejected (the `na` response has been sent + /// and the channel should remain in [`ChannelState::InboundOpening`] so the dialer can + /// propose another protocol per back-and-forth multistream-select negotiation), + /// or `Err(...)` on a fatal error (channel should be closed). async fn on_inbound_opening_channel_data( &mut self, channel_id: ChannelId, data: Vec, - ) -> crate::Result<(SubstreamId, SubstreamHandle, Option)> { + ) -> crate::Result)>> { tracing::trace!( target: LOG_TARGET, peer = ?self.peer, @@ -344,7 +350,16 @@ impl WebRtcConnection { ) .map_err(Error::WebRtc)?; - let protocol = negotiated.ok_or(Error::SubstreamDoesntExist)?; + let Some(protocol) = negotiated else { + tracing::trace!( + target: LOG_TARGET, + peer = ?self.peer, + ?channel_id, + "inbound protocol rejected, keeping channel open for back-and-forth negotiation", + ); + return Ok(None); + }; + let substream_id = self.protocol_set.next_substream_id(); let codec = self.protocol_set.protocol_codec(&protocol); let opening_permit = self.protocol_set.try_get_permit().ok_or(Error::ConnectionClosed)?; @@ -372,7 +387,7 @@ impl WebRtcConnection { opening_permit, ) .await - .map(|_| (substream_id, handle, lifetime_permit)) + .map(|_| Some((substream_id, handle, lifetime_permit))) .map_err(Into::into) } @@ -571,7 +586,7 @@ impl WebRtcConnection { match state { ChannelState::InboundOpening => { match self.on_inbound_opening_channel_data(channel_id, data).await { - Ok((substream_id, handle, lifetime_permit)) => { + Ok(Some((substream_id, handle, lifetime_permit))) => { self.handles.insert(channel_id, handle); self.channels.insert( channel_id, @@ -582,6 +597,12 @@ impl WebRtcConnection { }, ); } + Ok(None) => { + // Protocol was rejected but `na` response was sent. Keep the + // channel open in `InboundOpening` so the dialer can propose + // another protocol (back-and-forth multistream-select). + self.channels.insert(channel_id, ChannelState::InboundOpening); + } Err(error) => { tracing::debug!( target: LOG_TARGET, From f1e889c7c0caf778aacf1e3b9131e4b9805f9257 Mon Sep 17 00:00:00 2001 From: Haiko Schol Date: Wed, 18 Feb 2026 19:16:09 +0700 Subject: [PATCH 03/11] webrtc: Refactor webrtc_encode_multistream_message to take a single Message --- src/multistream_select/dialer_select.rs | 16 ++++++------- src/multistream_select/listener_select.rs | 24 ++++++++----------- src/multistream_select/protocol.rs | 29 ++++++++++------------- 3 files changed, 31 insertions(+), 38 deletions(-) diff --git a/src/multistream_select/dialer_select.rs b/src/multistream_select/dialer_select.rs index e0131eb6c..34028e491 100644 --- a/src/multistream_select/dialer_select.rs +++ b/src/multistream_select/dialer_select.rs @@ -339,9 +339,9 @@ impl WebRtcDialerState { protocol: ProtocolName, fallback_names: Vec, ) -> crate::Result<(Self, Vec)> { - let message = webrtc_encode_multistream_message(std::iter::once(Message::Protocol( + let message = webrtc_encode_multistream_message(Message::Protocol( Protocol::try_from(protocol.as_ref()).map_err(|_| Error::InvalidData)?, - )))? + ))? .freeze() .to_vec(); @@ -368,9 +368,9 @@ impl WebRtcDialerState { self.protocol = next; self.state = HandshakeState::WaitingResponse; - let message = webrtc_encode_multistream_message(std::iter::once(Message::Protocol( + let message = webrtc_encode_multistream_message(Message::Protocol( Protocol::try_from(self.protocol.as_ref()).map_err(|_| Error::InvalidData)?, - )))? + ))? .freeze() .to_vec(); @@ -931,9 +931,9 @@ mod tests { #[test] fn negotiate_main_protocol() { - let message = webrtc_encode_multistream_message(vec![Message::Protocol( + let message = webrtc_encode_multistream_message(Message::Protocol( Protocol::try_from(&b"/13371338/proto/1"[..]).unwrap(), - )]) + )) .unwrap() .freeze(); @@ -953,9 +953,9 @@ mod tests { #[test] fn negotiate_fallback_protocol() { - let message = webrtc_encode_multistream_message(vec![Message::Protocol( + let message = webrtc_encode_multistream_message(Message::Protocol( Protocol::try_from(&b"/sup/proto/1"[..]).unwrap(), - )]) + )) .unwrap() .freeze(); diff --git a/src/multistream_select/listener_select.rs b/src/multistream_select/listener_select.rs index 6faa2fe07..9d83ab817 100644 --- a/src/multistream_select/listener_select.rs +++ b/src/multistream_select/listener_select.rs @@ -374,9 +374,7 @@ pub fn webrtc_listener_negotiate( if protocol.as_ref() == supported.as_bytes() { return Ok(ListenerSelectResult::Accepted { protocol: supported.clone(), - message: webrtc_encode_multistream_message(std::iter::once( - Message::Protocol(protocol), - ))?, + message: webrtc_encode_multistream_message(Message::Protocol(protocol))?, }); } } @@ -388,7 +386,7 @@ pub fn webrtc_listener_negotiate( ); Ok(ListenerSelectResult::Rejected { - message: webrtc_encode_multistream_message(std::iter::once(Message::NotAvailable))?, + message: webrtc_encode_multistream_message(Message::NotAvailable)?, }) } @@ -407,10 +405,9 @@ mod tests { ProtocolName::from("/13371338/proto/3"), ProtocolName::from("/13371338/proto/4"), ]; - let message = webrtc_encode_multistream_message(vec![ - Message::Protocol(Protocol::try_from(&b"/13371338/proto/1"[..]).unwrap()), - Message::Protocol(Protocol::try_from(&b"/sup/proto/1"[..]).unwrap()), - ]) + let message = webrtc_encode_multistream_message(Message::Protocol( + Protocol::try_from(&b"/13371338/proto/1"[..]).unwrap(), + )) .unwrap() .freeze(); @@ -447,10 +444,10 @@ mod tests { // `webrtc_listener_negotiate()` should reject this invalid message. The error can either be // `InvalidData` because the message is malformed or `StateMismatch` because the message is // not expected at this point in the protocol. - let message = webrtc_encode_multistream_message(std::iter::once(Message::Protocols(vec![ + let message = webrtc_encode_multistream_message(Message::Protocols(vec![ Protocol::try_from(&b"/13371338/proto/1"[..]).unwrap(), Protocol::try_from(&b"/sup/proto/1"[..]).unwrap(), - ]))) + ])) .unwrap() .freeze(); @@ -534,9 +531,9 @@ mod tests { ProtocolName::from("/13371338/proto/3"), ProtocolName::from("/13371338/proto/4"), ]; - let message = webrtc_encode_multistream_message(vec![Message::Protocol( + let message = webrtc_encode_multistream_message(Message::Protocol( Protocol::try_from(&b"/13371339/proto/1"[..]).unwrap(), - )]) + )) .unwrap() .freeze(); @@ -545,8 +542,7 @@ mod tests { Ok(ListenerSelectResult::Rejected { message }) => { assert_eq!( message, - webrtc_encode_multistream_message(std::iter::once(Message::NotAvailable)) - .unwrap() + webrtc_encode_multistream_message(Message::NotAvailable).unwrap() ); } Ok(ListenerSelectResult::Accepted { protocol, message }) => panic!("message accepted"), diff --git a/src/multistream_select/protocol.rs b/src/multistream_select/protocol.rs index 2d327bef1..b5fa8b203 100644 --- a/src/multistream_select/protocol.rs +++ b/src/multistream_select/protocol.rs @@ -234,24 +234,21 @@ impl Message { /// /// This implementation may not be compliant with the multistream-select protocol spec. /// The only purpose of this was to get the `multistream-select` protocol working with smoldot. -pub fn webrtc_encode_multistream_message( - messages: impl IntoIterator, -) -> crate::Result { +pub fn webrtc_encode_multistream_message(message: Message) -> crate::Result { // encode `/multistream-select/1.0.0` header let mut bytes = BytesMut::with_capacity(32); - let message = Message::Header(HeaderLine::V1); - message.encode(&mut bytes).map_err(|_| Litep2pError::InvalidData)?; - let mut header = UnsignedVarint::encode(bytes)?; - - // encode each message - for message in messages { - let mut proto_bytes = BytesMut::with_capacity(256); - message.encode(&mut proto_bytes).map_err(|_| Litep2pError::InvalidData)?; - let mut proto_bytes = UnsignedVarint::encode(proto_bytes)?; - header.append(&mut proto_bytes); - } - - Ok(BytesMut::from(&header[..])) + Message::Header(HeaderLine::V1) + .encode(&mut bytes) + .map_err(|_| Litep2pError::InvalidData)?; + let mut output = UnsignedVarint::encode(bytes)?; + + // encode the message + let mut msg_bytes = BytesMut::with_capacity(256); + message.encode(&mut msg_bytes).map_err(|_| Litep2pError::InvalidData)?; + let mut msg_bytes = UnsignedVarint::encode(msg_bytes)?; + output.append(&mut msg_bytes); + + Ok(BytesMut::from(&output[..])) } /// A `MessageIO` implements a [`Stream`] and [`Sink`] of [`Message`]s. From 752c7dfdeb75b50b79c12847f62fb483f220fab6 Mon Sep 17 00:00:00 2001 From: Haiko Schol Date: Tue, 3 Mar 2026 16:15:19 +0700 Subject: [PATCH 04/11] webrtc: Support multistream-select header and protocol in separate protobuf messages --- src/multistream_select/dialer_select.rs | 32 ++- src/multistream_select/length_delimited.rs | 2 +- src/multistream_select/listener_select.rs | 306 +++++++++++++++------ src/multistream_select/mod.rs | 66 +---- src/multistream_select/protocol.rs | 76 +++-- src/transport/webrtc/connection.rs | 34 ++- 6 files changed, 325 insertions(+), 191 deletions(-) diff --git a/src/multistream_select/dialer_select.rs b/src/multistream_select/dialer_select.rs index 34028e491..721199700 100644 --- a/src/multistream_select/dialer_select.rs +++ b/src/multistream_select/dialer_select.rs @@ -339,9 +339,12 @@ impl WebRtcDialerState { protocol: ProtocolName, fallback_names: Vec, ) -> crate::Result<(Self, Vec)> { - let message = webrtc_encode_multistream_message(Message::Protocol( - Protocol::try_from(protocol.as_ref()).map_err(|_| Error::InvalidData)?, - ))? + let message = webrtc_encode_multistream_message( + Message::Protocol( + Protocol::try_from(protocol.as_ref()).map_err(|_| Error::InvalidData)?, + ), + true, + )? .freeze() .to_vec(); @@ -368,9 +371,12 @@ impl WebRtcDialerState { self.protocol = next; self.state = HandshakeState::WaitingResponse; - let message = webrtc_encode_multistream_message(Message::Protocol( - Protocol::try_from(self.protocol.as_ref()).map_err(|_| Error::InvalidData)?, - ))? + let message = webrtc_encode_multistream_message( + Message::Protocol( + Protocol::try_from(self.protocol.as_ref()).map_err(|_| Error::InvalidData)?, + ), + true, + )? .freeze() .to_vec(); @@ -931,9 +937,10 @@ mod tests { #[test] fn negotiate_main_protocol() { - let message = webrtc_encode_multistream_message(Message::Protocol( - Protocol::try_from(&b"/13371338/proto/1"[..]).unwrap(), - )) + let message = webrtc_encode_multistream_message( + Message::Protocol(Protocol::try_from(&b"/13371338/proto/1"[..]).unwrap()), + true, + ) .unwrap() .freeze(); @@ -953,9 +960,10 @@ mod tests { #[test] fn negotiate_fallback_protocol() { - let message = webrtc_encode_multistream_message(Message::Protocol( - Protocol::try_from(&b"/sup/proto/1"[..]).unwrap(), - )) + let message = webrtc_encode_multistream_message( + Message::Protocol(Protocol::try_from(&b"/sup/proto/1"[..]).unwrap()), + true, + ) .unwrap() .freeze(); diff --git a/src/multistream_select/length_delimited.rs b/src/multistream_select/length_delimited.rs index 7052d6299..9e8693a95 100644 --- a/src/multistream_select/length_delimited.rs +++ b/src/multistream_select/length_delimited.rs @@ -28,7 +28,7 @@ use std::{ }; const MAX_LEN_BYTES: u16 = 2; -const MAX_FRAME_SIZE: u16 = (1 << (MAX_LEN_BYTES * 8 - MAX_LEN_BYTES)) - 1; +pub(super) const MAX_FRAME_SIZE: u16 = (1 << (MAX_LEN_BYTES * 8 - MAX_LEN_BYTES)) - 1; const DEFAULT_BUFFER_SIZE: usize = 64; const LOG_TARGET: &str = "litep2p::multistream-select"; diff --git a/src/multistream_select/listener_select.rs b/src/multistream_select/listener_select.rs index 9d83ab817..9b00c09a1 100644 --- a/src/multistream_select/listener_select.rs +++ b/src/multistream_select/listener_select.rs @@ -25,10 +25,9 @@ use crate::{ codec::unsigned_varint::UnsignedVarint, error::{self, Error}, multistream_select::{ - drain_trailing_protocols, protocol::{ webrtc_encode_multistream_message, HeaderLine, Message, MessageIO, Protocol, - ProtocolError, PROTO_MULTISTREAM_1_0, + ProtocolError, }, Negotiated, NegotiationError, }, @@ -333,50 +332,121 @@ pub enum ListenerSelectResult { protocol: ProtocolName, /// `multistream-select` message. - message: BytesMut, + message: Bytes, }, /// Requested protocol is not available. Rejected { /// `multistream-select` message. - message: BytesMut, + message: Bytes, + }, + + /// The multistream-select header was received but no protocol was proposed yet. + /// The caller should send the `message` (header echo) and wait for the next payload. + PendingProtocol { + /// `multistream-select` message (header echo). + message: Bytes, }, } +/// Decode a single varint-length-prefixed multistream-select message from `data`, +/// advancing past the consumed bytes. +fn decode_multistream_message(data: &mut Bytes) -> Result { + let (len, tail) = unsigned_varint::decode::usize(data).map_err(|error| { + tracing::debug!( + target: LOG_TARGET, + ?error, + message = ?data, + "Failed to decode length-prefix in multistream message", + ); + error::NegotiationError::ParseError(error::ParseError::InvalidData) + })?; + + if len > tail.len() { + tracing::debug!( + target: LOG_TARGET, + length_prefix = len, + actual_length = tail.len(), + "Truncated multistream message", + ); + return Err(error::NegotiationError::ParseError( + error::ParseError::InvalidData, + )); + } + + let len_size = data.len() - tail.len(); + let payload = data.slice(len_size..len_size + len); + *data = data.slice(len_size + len..); + + Message::decode(payload).map_err(|error| { + tracing::debug!(target: LOG_TARGET, ?error, "Failed to decode multistream message"); + error::NegotiationError::ParseError(error::ParseError::InvalidData) + }) +} + /// Negotiate protocols for listener. /// -/// Parse protocols offered by the remote peer and check if any of the offered protocols match -/// locally available protocols. If a match is found, return an encoded multistream-select -/// response and the negotiated protocol. If parsing fails or no match is found, return an error. +/// Parse the protocol offered by the remote peer and check if it matches any locally available +/// protocol. The `header_received` parameter indicates whether the multistream-select header +/// has already been exchanged in a previous round. pub fn webrtc_listener_negotiate( supported_protocols: Vec, mut payload: Bytes, + header_received: bool, ) -> crate::Result { - let protocols = drain_trailing_protocols(payload)?; - let mut protocol_iter = protocols.into_iter(); + // Save for zero-copy header echo (Bytes::clone is O(1)). + let raw_payload = payload.clone(); + + let first_msg = decode_multistream_message(&mut payload)?; + + let (protocol, header_in_this_payload) = match first_msg { + Message::Header(HeaderLine::V1) => { + if payload.is_empty() { + // Header only — echo the exact received bytes back (zero alloc). + return Ok(ListenerSelectResult::PendingProtocol { + message: raw_payload, + }); + } + // Header + protocol in same payload. + match decode_multistream_message(&mut payload)? { + Message::Protocol(protocol) => (protocol, true), + _ => + return Err(Error::NegotiationError( + error::NegotiationError::ParseError(error::ParseError::InvalidData), + )), + } + } + // Protocol without header is only valid if the header was already exchanged. + Message::Protocol(protocol) if header_received => (protocol, false), + _ => + return Err(Error::NegotiationError( + error::NegotiationError::MultistreamSelectError(NegotiationError::Failed), + )), + }; - // skip the multistream-select header because it's not part of user protocols but verify it's - // present - if protocol_iter.next() != Some(PROTO_MULTISTREAM_1_0) { + // Reject messages with unexpected trailing data. + if !payload.is_empty() { return Err(Error::NegotiationError( - error::NegotiationError::MultistreamSelectError(NegotiationError::Failed), + error::NegotiationError::ParseError(error::ParseError::InvalidData), )); } - for protocol in protocol_iter { - tracing::trace!( - target: LOG_TARGET, - protocol = ?std::str::from_utf8(protocol.as_ref()), - "listener: checking protocol", - ); + tracing::trace!( + target: LOG_TARGET, + protocol = ?std::str::from_utf8(protocol.as_ref()), + "listener: checking protocol", + ); - for supported in supported_protocols.iter() { - if protocol.as_ref() == supported.as_bytes() { - return Ok(ListenerSelectResult::Accepted { - protocol: supported.clone(), - message: webrtc_encode_multistream_message(Message::Protocol(protocol))?, - }); - } + for supported in supported_protocols.iter() { + if protocol.as_ref() == supported.as_bytes() { + return Ok(ListenerSelectResult::Accepted { + protocol: supported.clone(), + message: webrtc_encode_multistream_message( + Message::Protocol(protocol), + header_in_this_payload, + )? + .freeze(), + }); } } @@ -386,7 +456,8 @@ pub fn webrtc_listener_negotiate( ); Ok(ListenerSelectResult::Rejected { - message: webrtc_encode_multistream_message(Message::NotAvailable)?, + message: webrtc_encode_multistream_message(Message::NotAvailable, header_in_this_payload)? + .freeze(), }) } @@ -405,16 +476,18 @@ mod tests { ProtocolName::from("/13371338/proto/3"), ProtocolName::from("/13371338/proto/4"), ]; - let message = webrtc_encode_multistream_message(Message::Protocol( - Protocol::try_from(&b"/13371338/proto/1"[..]).unwrap(), - )) + let message = webrtc_encode_multistream_message( + Message::Protocol(Protocol::try_from(&b"/13371338/proto/1"[..]).unwrap()), + true, + ) .unwrap() .freeze(); - match webrtc_listener_negotiate(local_protocols, message) { + match webrtc_listener_negotiate(local_protocols, message, false) { Err(error) => panic!("error received: {error:?}"), Ok(ListenerSelectResult::Rejected { .. }) => panic!("message rejected"), - Ok(ListenerSelectResult::Accepted { protocol, message }) => { + Ok(ListenerSelectResult::PendingProtocol { .. }) => panic!("unexpected pending"), + Ok(ListenerSelectResult::Accepted { protocol, .. }) => { assert_eq!(protocol, ProtocolName::from("/13371338/proto/1")); } } @@ -429,32 +502,19 @@ mod tests { ProtocolName::from("/13371338/proto/3"), ProtocolName::from("/13371338/proto/4"), ]; - // The invalid message is really two multistream-select messages inside one `WebRtcMessage`: - // 1. the multistream-select header - // 2. an "ls response" message (that does not contain another header) - // - // This is invalid for two reasons: - // 1. It is malformed. Either the header is followed by one or more `Message::Protocol` - // instances or the header is part of the "ls response". - // 2. This sequence of messages is not spec compliant. A listener receives one of the - // following on an inbound substream: - // - a multistream-select header followed by a `Message::Protocol` instance - // - a multistream-select header followed by an "ls" message (<\n>) - // - // `webrtc_listener_negotiate()` should reject this invalid message. The error can either be - // `InvalidData` because the message is malformed or `StateMismatch` because the message is - // not expected at this point in the protocol. - let message = webrtc_encode_multistream_message(Message::Protocols(vec![ - Protocol::try_from(&b"/13371338/proto/1"[..]).unwrap(), - Protocol::try_from(&b"/sup/proto/1"[..]).unwrap(), - ])) + let message = webrtc_encode_multistream_message( + Message::Protocols(vec![ + Protocol::try_from(&b"/13371338/proto/1"[..]).unwrap(), + Protocol::try_from(&b"/sup/proto/1"[..]).unwrap(), + ]), + true, + ) .unwrap() .freeze(); - match webrtc_listener_negotiate(local_protocols, message) { + match webrtc_listener_negotiate(local_protocols, message, false) { Err(error) => assert!(std::matches!( error, - // something has gone off the rails here... Error::NegotiationError(error::NegotiationError::ParseError( error::ParseError::InvalidData )), @@ -473,18 +533,15 @@ mod tests { ProtocolName::from("/13371338/proto/4"), ]; - // send only header line + // Send only header line with varint length prefix. let mut bytes = BytesMut::with_capacity(32); - let message = Message::Header(HeaderLine::V1); - message.encode(&mut bytes).map_err(|_| Error::InvalidData).unwrap(); + Message::Header(HeaderLine::V1).encode(&mut bytes).unwrap(); + let payload = Bytes::from(UnsignedVarint::encode(bytes).unwrap()); - match webrtc_listener_negotiate(local_protocols, bytes.freeze()) { - Err(error) => assert!(std::matches!( - error, - Error::NegotiationError(error::NegotiationError::ParseError( - error::ParseError::InvalidData - )), - )), + match webrtc_listener_negotiate(local_protocols, payload.clone(), false) { + Ok(ListenerSelectResult::PendingProtocol { message }) => { + assert_eq!(message, payload); + } event => panic!("invalid event: {event:?}"), } } @@ -499,19 +556,14 @@ mod tests { ProtocolName::from("/13371338/proto/4"), ]; - // header line missing - let mut bytes = BytesMut::with_capacity(256); - vec![&b"/13371338/proto/1"[..], &b"/sup/proto/1"[..]] - .into_iter() - .for_each(|proto| { - bytes.put_u8((proto.len() + 1) as u8); - - Message::Protocol(Protocol::try_from(proto).unwrap()) - .encode(&mut bytes) - .unwrap(); - }); + // Single protocol, no header. + let mut bytes = BytesMut::with_capacity(64); + Message::Protocol(Protocol::try_from(&b"/13371338/proto/1"[..]).unwrap()) + .encode(&mut bytes) + .unwrap(); + let payload = Bytes::from(UnsignedVarint::encode(bytes).unwrap()); - match webrtc_listener_negotiate(local_protocols, bytes.freeze()) { + match webrtc_listener_negotiate(local_protocols, payload, false) { Err(error) => assert!(std::matches!( error, Error::NegotiationError(error::NegotiationError::MultistreamSelectError( @@ -524,28 +576,118 @@ mod tests { #[test] fn protocol_not_supported() { - let mut local_protocols = vec![ + let local_protocols = vec![ ProtocolName::from("/13371338/proto/1"), ProtocolName::from("/sup/proto/1"), ProtocolName::from("/13371338/proto/2"), ProtocolName::from("/13371338/proto/3"), ProtocolName::from("/13371338/proto/4"), ]; - let message = webrtc_encode_multistream_message(Message::Protocol( - Protocol::try_from(&b"/13371339/proto/1"[..]).unwrap(), - )) + let message = webrtc_encode_multistream_message( + Message::Protocol(Protocol::try_from(&b"/13371339/proto/1"[..]).unwrap()), + true, + ) .unwrap() .freeze(); - match webrtc_listener_negotiate(local_protocols, message) { + match webrtc_listener_negotiate(local_protocols, message, false) { Err(error) => panic!("error received: {error:?}"), Ok(ListenerSelectResult::Rejected { message }) => { assert_eq!( message, - webrtc_encode_multistream_message(Message::NotAvailable).unwrap() + webrtc_encode_multistream_message(Message::NotAvailable, true) + .unwrap() + .freeze() + ); + } + Ok(ListenerSelectResult::Accepted { .. }) => panic!("message accepted"), + Ok(ListenerSelectResult::PendingProtocol { .. }) => panic!("unexpected pending"), + } + } + + #[test] + fn protocols_not_supported() { + let local_protocols = vec![ProtocolName::from("/13371338/proto/1")]; + + // Round 1: send header only → PendingProtocol (header echo). + let mut bytes = BytesMut::with_capacity(32); + Message::Header(HeaderLine::V1).encode(&mut bytes).unwrap(); + let header_payload = Bytes::from(UnsignedVarint::encode(bytes).unwrap()); + + match webrtc_listener_negotiate(local_protocols.clone(), header_payload.clone(), false) { + Ok(ListenerSelectResult::PendingProtocol { message }) => { + assert_eq!(message, header_payload); + } + event => panic!("expected PendingProtocol, got {event:?}"), + } + + // Round 2: send first protocol (not supported) → Rejected (na, no header). + let mut bytes = BytesMut::with_capacity(64); + Message::Protocol(Protocol::try_from(&b"/unsupported/proto/1"[..]).unwrap()) + .encode(&mut bytes) + .unwrap(); + let proto1_payload = Bytes::from(UnsignedVarint::encode(bytes).unwrap()); + + match webrtc_listener_negotiate(local_protocols.clone(), proto1_payload, true) { + Ok(ListenerSelectResult::Rejected { message }) => { + assert_eq!( + message, + webrtc_encode_multistream_message(Message::NotAvailable, false) + .unwrap() + .freeze() + ); + } + event => panic!("expected Rejected, got {event:?}"), + } + + // Round 3: send second protocol (also not supported) → Rejected (na, no header). + let mut bytes = BytesMut::with_capacity(64); + Message::Protocol(Protocol::try_from(&b"/unsupported/proto/2"[..]).unwrap()) + .encode(&mut bytes) + .unwrap(); + let proto2_payload = Bytes::from(UnsignedVarint::encode(bytes).unwrap()); + + match webrtc_listener_negotiate(local_protocols, proto2_payload, true) { + Ok(ListenerSelectResult::Rejected { message }) => { + assert_eq!( + message, + webrtc_encode_multistream_message(Message::NotAvailable, false) + .unwrap() + .freeze() ); } - Ok(ListenerSelectResult::Accepted { protocol, message }) => panic!("message accepted"), + event => panic!("expected Rejected, got {event:?}"), + } + } + + #[test] + fn header_only_then_protocol() { + let local_protocols = vec![ProtocolName::from("/13371338/proto/1")]; + + // Call 1: header only → PendingProtocol. + let mut bytes = BytesMut::with_capacity(32); + Message::Header(HeaderLine::V1).encode(&mut bytes).unwrap(); + let header_payload = Bytes::from(UnsignedVarint::encode(bytes).unwrap()); + + match webrtc_listener_negotiate(local_protocols.clone(), header_payload.clone(), false) { + Ok(ListenerSelectResult::PendingProtocol { message }) => { + assert_eq!(message, header_payload); + } + event => panic!("expected PendingProtocol, got {event:?}"), + } + + // Call 2: protocol only (header_received=true) → Accepted. + let mut bytes = BytesMut::with_capacity(64); + Message::Protocol(Protocol::try_from(&b"/13371338/proto/1"[..]).unwrap()) + .encode(&mut bytes) + .unwrap(); + let proto_payload = Bytes::from(UnsignedVarint::encode(bytes).unwrap()); + + match webrtc_listener_negotiate(local_protocols, proto_payload, true) { + Ok(ListenerSelectResult::Accepted { protocol, .. }) => { + assert_eq!(protocol, ProtocolName::from("/13371338/proto/1")); + } + event => panic!("expected Accepted, got {event:?}"), } } } diff --git a/src/multistream_select/mod.rs b/src/multistream_select/mod.rs index f195b1f3d..762ba3022 100644 --- a/src/multistream_select/mod.rs +++ b/src/multistream_select/mod.rs @@ -75,7 +75,7 @@ mod listener_select; mod negotiated; mod protocol; -use crate::error::{self, ParseError}; +use crate::error; pub use crate::multistream_select::{ dialer_select::{dialer_select_proto, DialerSelectFuture, HandshakeResult, WebRtcDialerState}, listener_select::{ @@ -86,10 +86,6 @@ pub use crate::multistream_select::{ protocol::{HeaderLine, Message, Protocol, ProtocolError, PROTO_MULTISTREAM_1_0}, }; -use bytes::Bytes; - -const LOG_TARGET: &str = "litep2p::multistream-select"; - /// Supported multistream-select versions. #[derive(Clone, Copy, Debug, PartialEq, Eq)] pub enum Version { @@ -137,63 +133,3 @@ impl Default for Version { Version::V1 } } - -// This function is only used in the WebRTC transport. It expects one or more multistream-select -// messages in `remaining` and returns a list of protocols that were decoded from them. -fn drain_trailing_protocols( - mut remaining: Bytes, -) -> Result, error::NegotiationError> { - let mut protocols = vec![]; - - loop { - if remaining.is_empty() { - break; - } - - let (len, tail) = unsigned_varint::decode::usize(&remaining).map_err(|error| { - tracing::debug!( - target: LOG_TARGET, - ?error, - message = ?remaining, - "Failed to decode length-prefix in multistream message"); - error::NegotiationError::ParseError(ParseError::InvalidData) - })?; - - if len > tail.len() { - tracing::debug!( - target: LOG_TARGET, - message = ?tail, - length_prefix = len, - actual_length = tail.len(), - "Truncated multistream message", - ); - - return Err(error::NegotiationError::ParseError(ParseError::InvalidData)); - } - - let len_size = remaining.len() - tail.len(); - let payload = remaining.slice(len_size..len_size + len); - let res = Message::decode(payload); - - match res { - Ok(Message::Header(HeaderLine::V1)) => protocols.push(PROTO_MULTISTREAM_1_0), - Ok(Message::Protocol(protocol)) => protocols.push(protocol), - Ok(Message::Protocols(_)) => - return Err(error::NegotiationError::ParseError(ParseError::InvalidData)), - Err(error) => { - tracing::debug!( - target: LOG_TARGET, - ?error, - message = ?tail[..len], - "Failed to decode multistream message", - ); - return Err(error::NegotiationError::ParseError(ParseError::InvalidData)); - } - _ => return Err(error::NegotiationError::StateMismatch), - } - - remaining = remaining.slice(len_size + len..); - } - - Ok(protocols) -} diff --git a/src/multistream_select/protocol.rs b/src/multistream_select/protocol.rs index b5fa8b203..a73d3156d 100644 --- a/src/multistream_select/protocol.rs +++ b/src/multistream_select/protocol.rs @@ -26,7 +26,6 @@ //! `MessageReader`. use crate::{ - codec::unsigned_varint::UnsignedVarint, error::Error as Litep2pError, multistream_select::{ length_delimited::{LengthDelimited, LengthDelimitedReader}, @@ -132,6 +131,25 @@ pub enum Message { } impl Message { + /// Returns the exact encoded byte length of this message, without allocating. + pub fn encoded_len(&self) -> usize { + match self { + Message::Header(HeaderLine::V1) => MSG_MULTISTREAM_1_0.len(), + Message::Protocol(p) => p.0.as_ref().len() + 1, + Message::ListProtocols => MSG_LS.len(), + Message::NotAvailable => MSG_PROTOCOL_NA.len(), + Message::Protocols(ps) => { + let mut len = 1usize; // trailing \n + let mut buf = unsigned_varint::encode::usize_buffer(); + for p in ps { + let proto_len = p.0.as_ref().len() + 1; + len += unsigned_varint::encode::usize(proto_len, &mut buf).len() + proto_len; + } + len + } + } + } + /// Encodes a `Message` into its byte representation. pub fn encode(&self, dest: &mut BytesMut) -> Result<(), ProtocolError> { match self { @@ -228,27 +246,43 @@ impl Message { } } -/// Create `multistream-select` message from an iterator of `Message`s. +/// Encode a single `multistream-select` message, optionally preceded by the protocol header. /// -/// # Note -/// -/// This implementation may not be compliant with the multistream-select protocol spec. -/// The only purpose of this was to get the `multistream-select` protocol working with smoldot. -pub fn webrtc_encode_multistream_message(message: Message) -> crate::Result { - // encode `/multistream-select/1.0.0` header - let mut bytes = BytesMut::with_capacity(32); - Message::Header(HeaderLine::V1) - .encode(&mut bytes) - .map_err(|_| Litep2pError::InvalidData)?; - let mut output = UnsignedVarint::encode(bytes)?; - - // encode the message - let mut msg_bytes = BytesMut::with_capacity(256); - message.encode(&mut msg_bytes).map_err(|_| Litep2pError::InvalidData)?; - let mut msg_bytes = UnsignedVarint::encode(msg_bytes)?; - output.append(&mut msg_bytes); - - Ok(BytesMut::from(&output[..])) +/// When `prepend_header` is `true` the `/multistream/1.0.0` header line is written before the +/// message. Everything is written into a single `BytesMut` allocation. +pub fn webrtc_encode_multistream_message( + message: Message, + prepend_header: bool, +) -> crate::Result { + let msg_len = message.encoded_len(); + let header_len = MSG_MULTISTREAM_1_0.len(); + let mut varint_buf = unsigned_varint::encode::usize_buffer(); + + let capacity = { + let msg_varint_len = unsigned_varint::encode::usize(msg_len, &mut varint_buf).len(); + let total = if prepend_header { + let header_varint_len = + unsigned_varint::encode::usize(header_len, &mut varint_buf).len(); + header_varint_len + header_len + msg_varint_len + msg_len + } else { + msg_varint_len + msg_len + }; + total.min(super::length_delimited::MAX_FRAME_SIZE as usize) + }; + + let mut output = BytesMut::with_capacity(capacity); + + if prepend_header { + output.extend_from_slice(unsigned_varint::encode::usize(header_len, &mut varint_buf)); + Message::Header(HeaderLine::V1) + .encode(&mut output) + .map_err(|_| Litep2pError::InvalidData)?; + } + + output.extend_from_slice(unsigned_varint::encode::usize(msg_len, &mut varint_buf)); + message.encode(&mut output).map_err(|_| Litep2pError::InvalidData)?; + + Ok(output) } /// A `MessageIO` implements a [`Stream`] and [`Sink`] of [`Message`]s. diff --git a/src/transport/webrtc/connection.rs b/src/transport/webrtc/connection.rs index b2d20e048..ffef8a7ae 100644 --- a/src/transport/webrtc/connection.rs +++ b/src/transport/webrtc/connection.rs @@ -149,7 +149,10 @@ enum ChannelState { Closing, /// Inbound channel is opening. - InboundOpening, + InboundOpening { + /// Whether the multistream-select header has already been received/sent. + header_received: bool, + }, /// Outbound channel is opening. OutboundOpening { @@ -264,7 +267,12 @@ impl WebRtcConnection { "inbound channel opened, wait for `multistream-select` message", ); - self.channels.insert(channel_id, ChannelState::InboundOpening); + self.channels.insert( + channel_id, + ChannelState::InboundOpening { + header_received: false, + }, + ); return Ok(()); }; @@ -324,6 +332,7 @@ impl WebRtcConnection { &mut self, channel_id: ChannelId, data: Vec, + header_received: bool, ) -> crate::Result)>> { tracing::trace!( target: LOG_TARGET, @@ -336,9 +345,10 @@ impl WebRtcConnection { let protocols = self.protocol_set.protocols_with_keep_alives(); let protocol_names = protocols.keys().cloned().collect(); let (response, negotiated) = - match webrtc_listener_negotiate(protocol_names, payload.into())? { + match webrtc_listener_negotiate(protocol_names, payload.into(), header_received)? { ListenerSelectResult::Accepted { protocol, message } => (message, Some(protocol)), - ListenerSelectResult::Rejected { message } => (message, None), + ListenerSelectResult::Rejected { message } + | ListenerSelectResult::PendingProtocol { message } => (message, None), }; self.rtc @@ -584,8 +594,9 @@ impl WebRtcConnection { }; match state { - ChannelState::InboundOpening => { - match self.on_inbound_opening_channel_data(channel_id, data).await { + ChannelState::InboundOpening { header_received } => { + match self.on_inbound_opening_channel_data(channel_id, data, header_received).await + { Ok(Some((substream_id, handle, lifetime_permit))) => { self.handles.insert(channel_id, handle); self.channels.insert( @@ -598,10 +609,13 @@ impl WebRtcConnection { ); } Ok(None) => { - // Protocol was rejected but `na` response was sent. Keep the - // channel open in `InboundOpening` so the dialer can propose - // another protocol (back-and-forth multistream-select). - self.channels.insert(channel_id, ChannelState::InboundOpening); + // Header has been exchanged after any successful round. + self.channels.insert( + channel_id, + ChannelState::InboundOpening { + header_received: true, + }, + ); } Err(error) => { tracing::debug!( From 2b7738fed413ae8b46522f92b639523fba557add Mon Sep 17 00:00:00 2001 From: gab Date: Tue, 5 May 2026 11:14:28 +0200 Subject: [PATCH 05/11] fix(webrtc): don't re-send multistream header After the multistream-select header has been exchanged, only protocol and na messages should flow between peers. propose_next_fallback was incorrectly transitioning state back to WaitingResponse (expecting another header) and wrapping the outgoing protocol message with the header. Keep the state at WaitingProtocol and emit the protocol message without the header prefix. Also the receiving side has been fixed, no header is expected if it was already received. --- src/multistream_select/dialer_select.rs | 38 +++++++++++++++++------ src/multistream_select/listener_select.rs | 3 +- 2 files changed, 31 insertions(+), 10 deletions(-) diff --git a/src/multistream_select/dialer_select.rs b/src/multistream_select/dialer_select.rs index 721199700..7a8c2f3bc 100644 --- a/src/multistream_select/dialer_select.rs +++ b/src/multistream_select/dialer_select.rs @@ -369,13 +369,12 @@ impl WebRtcDialerState { let next = self.fallback_names.remove(0); self.protocol = next; - self.state = HandshakeState::WaitingResponse; let message = webrtc_encode_multistream_message( Message::Protocol( Protocol::try_from(self.protocol.as_ref()).map_err(|_| Error::InvalidData)?, ), - true, + false, )? .freeze() .to_vec(); @@ -474,7 +473,7 @@ mod tests { use super::*; use crate::multistream_select::{listener_select_proto, protocol::MSG_MULTISTREAM_1_0}; use bytes::BufMut; - use std::time::Duration; + use std::{sync::OnceState, time::Duration}; #[tokio::test] async fn select_proto_basic() { async fn run(version: Version) { @@ -865,7 +864,7 @@ mod tests { // Simulate receiving header-only response, transitioning to WaitingProtocol. let mut header_bytes = BytesMut::with_capacity(32); header_bytes.put_u8(MSG_MULTISTREAM_1_0.len() as u8); - let _ = Message::Header(HeaderLine::V1).encode(&mut header_bytes).unwrap(); + Message::Header(HeaderLine::V1).encode(&mut header_bytes).unwrap(); // Append "na" to simulate rejection. let na_bytes = b"na\n"; header_bytes.put_u8(na_bytes.len() as u8); @@ -876,6 +875,12 @@ mod tests { event => panic!("expected Rejected, got: {event:?}"), } + // After having received the header the state stays always at WaitingProtocol. + assert!(matches!( + dialer_state.state, + HandshakeState::WaitingProtocol + )); + // Now propose the next fallback. let fallback_message = dialer_state .propose_next_fallback() @@ -883,16 +888,31 @@ mod tests { .expect("should have a fallback"); let mut expected = BytesMut::with_capacity(32); - expected.put_u8(MSG_MULTISTREAM_1_0.len() as u8); - let _ = Message::Header(HeaderLine::V1).encode(&mut expected).unwrap(); let proto = Protocol::try_from(&b"/sup/proto/1"[..]).expect("valid protocol name"); expected.put_u8((proto.as_ref().len() + 1) as u8); - let _ = Message::Protocol(proto).encode(&mut expected).unwrap(); - - assert_eq!(fallback_message, expected.freeze().to_vec()); + Message::Protocol(proto).encode(&mut expected).unwrap(); + let protocol_message = expected.freeze().to_vec(); + assert_eq!(fallback_message, protocol_message); // No more fallbacks. assert!(dialer_state.propose_next_fallback().unwrap().is_none()); + + match dialer_state.register_response(protocol_message) { + Ok(HandshakeResult::Succeeded(_)) => {} + event => panic!("expected Succeeded, got: {event:?}"), + } + + let mut na_response = BytesMut::with_capacity(32); + let na_bytes = b"na\n"; + na_response.put_u8(na_bytes.len() as u8); + na_response.put_slice(na_bytes); + + dialer_state.state = HandshakeState::WaitingProtocol; + + match dialer_state.register_response(na_response.to_vec()) { + Ok(HandshakeResult::Rejected) => {} + event => panic!("expected Rejected, got: {event:?}"), + } } #[test] diff --git a/src/multistream_select/listener_select.rs b/src/multistream_select/listener_select.rs index 9b00c09a1..580897d71 100644 --- a/src/multistream_select/listener_select.rs +++ b/src/multistream_select/listener_select.rs @@ -400,7 +400,8 @@ pub fn webrtc_listener_negotiate( let first_msg = decode_multistream_message(&mut payload)?; let (protocol, header_in_this_payload) = match first_msg { - Message::Header(HeaderLine::V1) => { + // Header is expected only if not already received. + Message::Header(HeaderLine::V1) if !header_received => { if payload.is_empty() { // Header only — echo the exact received bytes back (zero alloc). return Ok(ListenerSelectResult::PendingProtocol { From 5a12aa08361dc78a4394219df4e0514ac7ace34b Mon Sep 17 00:00:00 2001 From: gab Date: Thu, 14 May 2026 11:18:51 +0200 Subject: [PATCH 06/11] fix(webrtc): fail encoding if message exceeds MAX_FRAME_SIZE --- src/multistream_select/protocol.rs | 16 +++++++++++++--- 1 file changed, 13 insertions(+), 3 deletions(-) diff --git a/src/multistream_select/protocol.rs b/src/multistream_select/protocol.rs index a73d3156d..a16a64449 100644 --- a/src/multistream_select/protocol.rs +++ b/src/multistream_select/protocol.rs @@ -260,16 +260,26 @@ pub fn webrtc_encode_multistream_message( let capacity = { let msg_varint_len = unsigned_varint::encode::usize(msg_len, &mut varint_buf).len(); - let total = if prepend_header { + if prepend_header { let header_varint_len = unsigned_varint::encode::usize(header_len, &mut varint_buf).len(); header_varint_len + header_len + msg_varint_len + msg_len } else { msg_varint_len + msg_len - }; - total.min(super::length_delimited::MAX_FRAME_SIZE as usize) + } }; + if capacity > super::length_delimited::MAX_FRAME_SIZE as usize { + tracing::debug!( + target: LOG_TARGET, + capacity, + max = super::length_delimited::MAX_FRAME_SIZE, + ?message, + "encoded multistream message exceeds MAX_FRAME_SIZE", + ); + return Err(Litep2pError::InvalidData); + } + let mut output = BytesMut::with_capacity(capacity); if prepend_header { From bf1204ee080c5b0448d49c68b74f63bc8c9fca6f Mon Sep 17 00:00:00 2001 From: gab Date: Thu, 14 May 2026 11:20:04 +0200 Subject: [PATCH 07/11] opt(webrtc): reverse fallback_names and pop each iteration --- src/multistream_select/dialer_select.rs | 11 +++++++++-- 1 file changed, 9 insertions(+), 2 deletions(-) diff --git a/src/multistream_select/dialer_select.rs b/src/multistream_select/dialer_select.rs index 7a8c2f3bc..b6069ec8b 100644 --- a/src/multistream_select/dialer_select.rs +++ b/src/multistream_select/dialer_select.rs @@ -333,11 +333,14 @@ pub struct WebRtcDialerState { impl WebRtcDialerState { /// Propose protocol to remote peer. /// + /// `fallback_names` must be in preference order, the first element is the + /// next protocol to try. + /// /// Return [`WebRtcDialerState`] which is used to drive forward the negotiation and an encoded /// `multistream-select` message that contains the protocol proposal for the substream. pub fn propose( protocol: ProtocolName, - fallback_names: Vec, + mut fallback_names: Vec, ) -> crate::Result<(Self, Vec)> { let message = webrtc_encode_multistream_message( Message::Protocol( @@ -348,6 +351,9 @@ impl WebRtcDialerState { .freeze() .to_vec(); + // Reverse fallback_names so that we can pop from it. + fallback_names.reverse(); + Ok(( Self { protocol, @@ -367,7 +373,8 @@ impl WebRtcDialerState { return Ok(None); } - let next = self.fallback_names.remove(0); + // UNWRAP: fallback_names has just been checked to not be empty. + let next = self.fallback_names.pop().unwrap(); self.protocol = next; let message = webrtc_encode_multistream_message( From d98ed84ec71b21b6d815138b851df85994b975f5 Mon Sep 17 00:00:00 2001 From: gab Date: Thu, 14 May 2026 16:20:53 +0200 Subject: [PATCH 08/11] chore(webrtc): cleanup multistream-select (tracing, ListProtocols guard) --- src/multistream_select/dialer_select.rs | 25 +++++++------ src/multistream_select/listener_select.rs | 14 ++++++- src/multistream_select/protocol.rs | 2 +- src/transport/webrtc/connection.rs | 45 +++++++++++++++++------ 4 files changed, 60 insertions(+), 26 deletions(-) diff --git a/src/multistream_select/dialer_select.rs b/src/multistream_select/dialer_select.rs index b6069ec8b..533eda138 100644 --- a/src/multistream_select/dialer_select.rs +++ b/src/multistream_select/dialer_select.rs @@ -33,7 +33,7 @@ use crate::{ types::protocol::ProtocolName, }; -use bytes::{Bytes, BytesMut}; +use bytes::{Buf, Bytes, BytesMut}; use futures::prelude::*; use std::{ convert::TryFrom as _, @@ -369,12 +369,10 @@ impl WebRtcDialerState { /// Returns `None` if there are no more fallback protocols to try. /// Returns `Some(message)` with the encoded message to send, containing the protocol name. pub fn propose_next_fallback(&mut self) -> crate::Result>> { - if self.fallback_names.is_empty() { + let Some(next) = self.fallback_names.pop() else { return Ok(None); - } + }; - // UNWRAP: fallback_names has just been checked to not be empty. - let next = self.fallback_names.pop().unwrap(); self.protocol = next; let message = webrtc_encode_multistream_message( @@ -394,8 +392,7 @@ impl WebRtcDialerState { &mut self, payload: Vec, ) -> Result { - let bytes = Bytes::from(payload); - let mut remaining = bytes.clone(); + let mut remaining = Bytes::from(payload); while !remaining.is_empty() { let (len, tail) = unsigned_varint::decode::usize(&remaining).map_err(|error| { @@ -408,8 +405,6 @@ impl WebRtcDialerState { error::NegotiationError::ParseError(ParseError::InvalidData) })?; - let len_size = remaining.len() - tail.len(); - if len > tail.len() { tracing::debug!( target: LOG_TARGET, @@ -421,8 +416,9 @@ impl WebRtcDialerState { return Err(error::NegotiationError::ParseError(ParseError::InvalidData)); } - let payload = remaining.slice(len_size..len_size + len); - remaining = remaining.slice(len_size + len..); + let len_size = remaining.len() - tail.len(); + remaining.advance(len_size); + let payload = remaining.split_to(len); let message = Message::decode(payload); tracing::trace!( @@ -462,6 +458,11 @@ impl WebRtcDialerState { NegotiationError::Failed, )); } + (HandshakeState::WaitingProtocol, Ok(Message::ListProtocols)) => { + return Err(crate::error::NegotiationError::MultistreamSelectError( + NegotiationError::ProtocolError(ProtocolError::InvalidMessage), + )); + } _ => { return Err(crate::error::NegotiationError::StateMismatch); } @@ -480,7 +481,7 @@ mod tests { use super::*; use crate::multistream_select::{listener_select_proto, protocol::MSG_MULTISTREAM_1_0}; use bytes::BufMut; - use std::{sync::OnceState, time::Duration}; + use std::time::Duration; #[tokio::test] async fn select_proto_basic() { async fn run(version: Version) { diff --git a/src/multistream_select/listener_select.rs b/src/multistream_select/listener_select.rs index 580897d71..672e1edd0 100644 --- a/src/multistream_select/listener_select.rs +++ b/src/multistream_select/listener_select.rs @@ -411,10 +411,15 @@ pub fn webrtc_listener_negotiate( // Header + protocol in same payload. match decode_multistream_message(&mut payload)? { Message::Protocol(protocol) => (protocol, true), - _ => + _ => { + tracing::trace!( + target: LOG_TARGET, + "failed to decode multistream message", + ); return Err(Error::NegotiationError( error::NegotiationError::ParseError(error::ParseError::InvalidData), - )), + )); + } } } // Protocol without header is only valid if the header was already exchanged. @@ -427,6 +432,11 @@ pub fn webrtc_listener_negotiate( // Reject messages with unexpected trailing data. if !payload.is_empty() { + tracing::trace!( + target: LOG_TARGET, + ?payload, + "rejecting message with unexpected trailing data", + ); return Err(Error::NegotiationError( error::NegotiationError::ParseError(error::ParseError::InvalidData), )); diff --git a/src/multistream_select/protocol.rs b/src/multistream_select/protocol.rs index a16a64449..29b25e630 100644 --- a/src/multistream_select/protocol.rs +++ b/src/multistream_select/protocol.rs @@ -270,7 +270,7 @@ pub fn webrtc_encode_multistream_message( }; if capacity > super::length_delimited::MAX_FRAME_SIZE as usize { - tracing::debug!( + tracing::warn!( target: LOG_TARGET, capacity, max = super::length_delimited::MAX_FRAME_SIZE, diff --git a/src/transport/webrtc/connection.rs b/src/transport/webrtc/connection.rs index ffef8a7ae..8d9dab2ef 100644 --- a/src/transport/webrtc/connection.rs +++ b/src/transport/webrtc/connection.rs @@ -467,16 +467,31 @@ impl WebRtcConnection { ); let message = WebRtcMessage::encode(message, None); - self.rtc - .channel(channel_id) - .ok_or(Error::ChannelDoesntExist) - .map_err(|_| { - SubstreamError::NegotiationError(NegotiationError::Failed.into()) - })? - .write(true, message.as_ref()) - .map_err(|_| { - SubstreamError::NegotiationError(NegotiationError::Failed.into()) - })?; + + let Some(mut channel) = self.rtc.channel(channel_id) else { + tracing::trace!( + target: LOG_TARGET, + peer = ?self.peer, + ?channel_id, + "protocol rejected received for non-existing channel", + ); + return Err(SubstreamError::NegotiationError( + NegotiationError::Failed.into(), + )); + }; + + if let Err(err) = channel.write(true, message.as_ref()) { + tracing::trace!( + target: LOG_TARGET, + peer = ?self.peer, + ?channel_id, + ?err, + "failed to write multistream-select fallback proposal", + ); + return Err(SubstreamError::NegotiationError( + NegotiationError::Failed.into(), + )); + }; self.channels.insert( channel_id, @@ -500,7 +515,15 @@ impl WebRtcConnection { NegotiationError::Failed.into(), )); } - Err(_) => { + Err(e) => { + tracing::trace!( + target: LOG_TARGET, + peer = ?self.peer, + ?channel_id, + ?e, + "dialer failed proposing next fallback", + ); + return Err(SubstreamError::NegotiationError( NegotiationError::Failed.into(), )); From c6bdbae915748193903e3d695209e83536fe2552 Mon Sep 17 00:00:00 2001 From: gab Date: Thu, 14 May 2026 16:55:17 +0200 Subject: [PATCH 09/11] fix(webrtc): mutlistream-select header expectation and trailing bytes handling Only expect the header as first response of the multistream-select protocol. Warn if the remote peer has pipelined something after the protocol. --- src/multistream_select/dialer_select.rs | 22 ++++++++++++++++++++-- 1 file changed, 20 insertions(+), 2 deletions(-) diff --git a/src/multistream_select/dialer_select.rs b/src/multistream_select/dialer_select.rs index 533eda138..687bc9286 100644 --- a/src/multistream_select/dialer_select.rs +++ b/src/multistream_select/dialer_select.rs @@ -427,16 +427,32 @@ impl WebRtcDialerState { "Decoded message while registering response", ); + let check_trailing_bytes = |bytes: &Bytes| { + // remote may be optimistically pipelining the first application frame + if !bytes.is_empty() { + tracing::warn!( + bytes_len = bytes.len(), + "trailing bytes after multistream-select negotiation were discarded" + ); + } + }; + match (&self.state, message) { (HandshakeState::WaitingResponse, Ok(Message::Header(HeaderLine::V1))) => { self.state = HandshakeState::WaitingProtocol; } - (HandshakeState::WaitingResponse, Ok(Message::Protocol(_))) => { + (HandshakeState::WaitingResponse, Ok(msg)) => { + tracing::trace!( + target: LOG_TARGET, + ?msg, + "Expected header response from peer, got different message" + ); return Err(crate::error::NegotiationError::MultistreamSelectError( NegotiationError::Failed, )); } - (_, Ok(Message::NotAvailable)) => { + (HandshakeState::WaitingProtocol, Ok(Message::NotAvailable)) => { + check_trailing_bytes(&remaining); return Ok(HandshakeResult::Rejected); } (HandshakeState::WaitingProtocol, Ok(Message::Protocol(protocol))) => { @@ -445,11 +461,13 @@ impl WebRtcDialerState { } if self.protocol.as_bytes() == protocol.as_ref() { + check_trailing_bytes(&remaining); return Ok(HandshakeResult::Succeeded(self.protocol.clone())); } for fallback in &self.fallback_names { if fallback.as_bytes() == protocol.as_ref() { + check_trailing_bytes(&remaining); return Ok(HandshakeResult::Succeeded(fallback.clone())); } } From f1a3be5c23292dd1314e56a5214c808b989c1c23 Mon Sep 17 00:00:00 2001 From: gab Date: Mon, 18 May 2026 18:42:58 +0200 Subject: [PATCH 10/11] fix(webrtc): multistream-select, only accept latest proposed protocol --- src/multistream_select/dialer_select.rs | 31 +++++++++++++++++++------ 1 file changed, 24 insertions(+), 7 deletions(-) diff --git a/src/multistream_select/dialer_select.rs b/src/multistream_select/dialer_select.rs index 687bc9286..dc74ed460 100644 --- a/src/multistream_select/dialer_select.rs +++ b/src/multistream_select/dialer_select.rs @@ -465,13 +465,6 @@ impl WebRtcDialerState { return Ok(HandshakeResult::Succeeded(self.protocol.clone())); } - for fallback in &self.fallback_names { - if fallback.as_bytes() == protocol.as_ref() { - check_trailing_bytes(&remaining); - return Ok(HandshakeResult::Succeeded(fallback.clone())); - } - } - return Err(crate::error::NegotiationError::MultistreamSelectError( NegotiationError::Failed, )); @@ -1026,4 +1019,28 @@ mod tests { _ => panic!("invalid event"), } } + + #[test] + fn reject_unproposed_fallback_confirmation() { + let (mut dialer_state, _message) = WebRtcDialerState::propose( + ProtocolName::from("/13371338/proto/1"), + vec![ProtocolName::from("/sup/proto/1")], + ) + .unwrap(); + + // The dialer has only proposed the main protocol. The fallback is stored for a + // later round and must not be accepted until `propose_next_fallback()` sends it. + let mut response = BytesMut::with_capacity(64); + response.put_u8(MSG_MULTISTREAM_1_0.len() as u8); + Message::Header(HeaderLine::V1).encode(&mut response).unwrap(); + + let fallback = Protocol::try_from(&b"/sup/proto/1"[..]).expect("valid protocol name"); + response.put_u8((fallback.as_ref().len() + 1) as u8); + Message::Protocol(fallback).encode(&mut response).unwrap(); + + match dialer_state.register_response(response.freeze().to_vec()) { + Err(error::NegotiationError::MultistreamSelectError(NegotiationError::Failed)) => {} + event => panic!("expected unproposed fallback to be rejected, got: {event:?}"), + } + } } From 6c9043483fe7413758d9f7f5327186f36dacc33c Mon Sep 17 00:00:00 2001 From: gab Date: Tue, 19 May 2026 10:29:11 +0200 Subject: [PATCH 11/11] test(webrtc): fix `WebRtcDialerState` `negotiate_fallback_protocol` test --- src/multistream_select/dialer_select.rs | 2 ++ 1 file changed, 2 insertions(+) diff --git a/src/multistream_select/dialer_select.rs b/src/multistream_select/dialer_select.rs index dc74ed460..e95e129af 100644 --- a/src/multistream_select/dialer_select.rs +++ b/src/multistream_select/dialer_select.rs @@ -1012,6 +1012,8 @@ mod tests { ) .unwrap(); + dialer_state.propose_next_fallback(); + match dialer_state.register_response(message.to_vec()) { Ok(HandshakeResult::Succeeded(negotiated)) => { assert_eq!(negotiated, ProtocolName::from("/sup/proto/1"))