Skip to content
Merged
Show file tree
Hide file tree
Changes from 7 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
234 changes: 164 additions & 70 deletions src/multistream_select/dialer_select.rs
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,6 @@ use crate::{
codec::unsigned_varint::UnsignedVarint,
error::{self, Error, ParseError, SubstreamError},
multistream_select::{
drain_trailing_protocols,
protocol::{
webrtc_encode_multistream_message, HeaderLine, Message, MessageIO, Protocol,
ProtocolError, PROTO_MULTISTREAM_1_0,
Expand Down Expand Up @@ -300,6 +299,12 @@ pub enum HandshakeResult {
/// The returned tuple contains the negotiated protocol and response
/// that must be sent to remote peer.
Succeeded(ProtocolName),

/// The proposed protocol was rejected by the remote peer.
///
/// The caller should check if there are remaining fallback protocols to try
/// via [`WebRtcDialerState::propose_next_fallback()`].
Rejected,
}

/// Handshake state.
Expand Down Expand Up @@ -328,21 +333,27 @@ pub struct WebRtcDialerState {
impl WebRtcDialerState {
/// Propose protocol to remote peer.
///
/// `fallback_names` must be in preference order, the first element is the
/// next protocol to try.
///
/// Return [`WebRtcDialerState`] which is used to drive forward the negotiation and an encoded
/// `multistream-select` message that contains the protocol proposal for the substream.
pub fn propose(
protocol: ProtocolName,
fallback_names: Vec<ProtocolName>,
mut fallback_names: Vec<ProtocolName>,
) -> crate::Result<(Self, Vec<u8>)> {
let message = webrtc_encode_multistream_message(
std::iter::once(protocol.clone())
.chain(fallback_names.clone())
.filter_map(|protocol| Protocol::try_from(protocol.as_ref()).ok())
.map(Message::Protocol),
Message::Protocol(
Protocol::try_from(protocol.as_ref()).map_err(|_| Error::InvalidData)?,
),
true,
)?
.freeze()
.to_vec();

// Reverse fallback_names so that we can pop from it.
fallback_names.reverse();

Ok((
Self {
protocol,
Expand All @@ -353,69 +364,86 @@ impl WebRtcDialerState {
))
}

/// Propose the next fallback protocol to the remote peer.
///
/// Returns `None` if there are no more fallback protocols to try.
/// Returns `Some(message)` with the encoded message to send, containing the protocol name.
pub fn propose_next_fallback(&mut self) -> crate::Result<Option<Vec<u8>>> {
if self.fallback_names.is_empty() {
return Ok(None);
}

// UNWRAP: fallback_names has just been checked to not be empty.
let next = self.fallback_names.pop().unwrap();
Comment thread
gab8i marked this conversation as resolved.
Outdated
self.protocol = next;

let message = webrtc_encode_multistream_message(
Message::Protocol(
Protocol::try_from(self.protocol.as_ref()).map_err(|_| Error::InvalidData)?,
),
false,
)?
.freeze()
Comment on lines +383 to +384
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: This and above ? operator alters the state of fallback_names which we have just popped. I think this is ok since we terminate the substream immediately and return negotiation-error?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, it works because if propose_next_fallback fails then also on_outbound_opening_channel_data fails and the data channel is getting closed!

.to_vec();

Ok(Some(message))
}

/// Register response to [`WebRtcDialerState`].
pub fn register_response(
&mut self,
payload: Vec<u8>,
) -> Result<HandshakeResult, crate::error::NegotiationError> {
// All multistream-select messages are length-prefixed. Since this code path is not using
// multistream_select::protocol::MessageIO, we need to decode and remove the length here.
let remaining: &[u8] = &payload;
let (len, tail) = unsigned_varint::decode::usize(remaining).map_err(|error| {
tracing::debug!(
let bytes = Bytes::from(payload);
let mut remaining = bytes.clone();
Comment thread
gab8i marked this conversation as resolved.
Outdated

while !remaining.is_empty() {
let (len, tail) = unsigned_varint::decode::usize(&remaining).map_err(|error| {
tracing::debug!(
target: LOG_TARGET,
?error,
message = ?payload,
"Failed to decode length-prefix in multistream message");
error::NegotiationError::ParseError(ParseError::InvalidData)
})?;
message = ?remaining,
"Failed to decode length-prefix in multistream message",
);
error::NegotiationError::ParseError(ParseError::InvalidData)
})?;

let len_size = remaining.len() - tail.len();
let bytes = Bytes::from(payload);
let payload = bytes.slice(len_size..len_size + len);
let remaining = bytes.slice(len_size + len..);
let message = Message::decode(payload);

tracing::trace!(
target: LOG_TARGET,
?message,
"Decoded message while registering response",
);

let mut protocols = match message {
Ok(Message::Header(HeaderLine::V1)) => {
vec![PROTO_MULTISTREAM_1_0]
let len_size = remaining.len() - tail.len();

if len > tail.len() {
Comment thread
gab8i marked this conversation as resolved.
tracing::debug!(
target: LOG_TARGET,
message = ?tail,
length_prefix = len,
actual_length = tail.len(),
"Truncated multistream message",
);
return Err(error::NegotiationError::ParseError(ParseError::InvalidData));
}
Ok(Message::Protocol(protocol)) => vec![protocol],
Ok(Message::Protocols(protocols)) => protocols,
Ok(Message::NotAvailable) =>
return match &self.state {
HandshakeState::WaitingProtocol => Err(
error::NegotiationError::MultistreamSelectError(NegotiationError::Failed),
),
_ => Err(error::NegotiationError::StateMismatch),
},
Ok(Message::ListProtocols) => return Err(error::NegotiationError::StateMismatch),
Err(_) => return Err(error::NegotiationError::ParseError(ParseError::InvalidData)),
};

protocols.extend(drain_trailing_protocols(remaining)?);
let payload = remaining.slice(len_size..len_size + len);
remaining = remaining.slice(len_size + len..);
let message = Message::decode(payload);
Comment thread
gab8i marked this conversation as resolved.
Outdated

let mut protocol_iter = protocols.into_iter();
loop {
match (&self.state, protocol_iter.next()) {
(HandshakeState::WaitingResponse, None) =>
return Err(crate::error::NegotiationError::StateMismatch),
(HandshakeState::WaitingResponse, Some(protocol)) => {
if protocol == PROTO_MULTISTREAM_1_0 {
self.state = HandshakeState::WaitingProtocol;
} else {
return Err(crate::error::NegotiationError::MultistreamSelectError(
NegotiationError::Failed,
));
}
tracing::trace!(
target: LOG_TARGET,
?message,
"Decoded message while registering response",
);

match (&self.state, message) {
(HandshakeState::WaitingResponse, Ok(Message::Header(HeaderLine::V1))) => {
self.state = HandshakeState::WaitingProtocol;
}
(HandshakeState::WaitingResponse, Ok(Message::Protocol(_))) => {
return Err(crate::error::NegotiationError::MultistreamSelectError(
NegotiationError::Failed,
));
}
(_, Ok(Message::NotAvailable)) => {
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This could be a protocol violation if the header was not received.

The spec says that for the first message the header should be echoed back:

> /multistream/1.0.0\n/ipfs/kad/1.0.0\n
< /multistream/1.0.0\nna\n

However, we currently support missing the header:

> /multistream/1.0.0\n/ipfs/kad/1.0.0\n
< na\n

Which will make us propose the next fallback protocol that will most definetely not be accepted since we are still in the HandshakeState::WaitingResponse state.

In case the our state is waiting protocol this is fine. However, if we are still waiting the header we should return a state missmatch / error

// This is not oki, we ahve not received the headr
(HandshakeState::WaitingResponse, Ok(Message::NotAvailable)) => return Ok(HandshakeResult::StateMismatch),
(HandshakeState::WaitingProtocol, Ok(Message::NotAvailable)) => return Ok(HandshakeResult::Rejected),

Lets also add a bit more trace logs / debugs here especially for the HandshakeResult::Failed cases

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice catch! This needs to be addressed to avoid any protocol issue! Solution should be pretty basic:

                (HandshakeState::WaitingResponse, Ok(msg)) => {
                   tracing::trace!(
                     target: LOG_TARGET,
                     ?msg,
                     "Expected header response from peer, got different message"
                   );
                   return Err(crate::error::NegotiationError::MultistreamSelectError(
                       NegotiationError::Failed,
                   ));
               }

return Ok(HandshakeResult::Rejected);
}
(HandshakeState::WaitingProtocol, Some(protocol)) => {
(HandshakeState::WaitingProtocol, Ok(Message::Protocol(protocol))) => {
if protocol == PROTO_MULTISTREAM_1_0 {
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We could improve a bit the robustness of this function wrt the state machine management.

Currently, we accept only [HEADER ++ protocol].

If the buffer contains an extra message [HEADER ++ protocol ++ first_message/handshake] because smoldot/other impl decides to send us optimistically an extra packet to avoid excessive flushing, then we'd lose the trailing bytes.

We should have at least an warning to easily detect if the remaining bytes are non empty

Copy link
Copy Markdown
Contributor Author

@gab8i gab8i May 14, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Making this function more reliable would imply to change quite a lot the code, but what you describe could totally be implemented by a peer which can send those 3 things out at once. Right now it is only expected to interact with smoldot which seems to not do so.

I've pushed a change which warn if trailing bytes are found after having decoded valid multistream-select messages. Is this ok for now?

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yep warning should be sufficient for now since the fix is quite intrusive. Could you create an issue to not forget about it 🙏 ?

return Err(crate::error::NegotiationError::StateMismatch);
}
Expand All @@ -434,11 +462,16 @@ impl WebRtcDialerState {
NegotiationError::Failed,
));
}
(HandshakeState::WaitingProtocol, None) => {
return Ok(HandshakeResult::NotReady);
_ => {
Comment thread
gab8i marked this conversation as resolved.
return Err(crate::error::NegotiationError::StateMismatch);
}
}
}

match &self.state {
HandshakeState::WaitingProtocol => Ok(HandshakeResult::NotReady),
HandshakeState::WaitingResponse => Err(crate::error::NegotiationError::StateMismatch),
}
}
}

Expand All @@ -447,7 +480,7 @@ mod tests {
use super::*;
use crate::multistream_select::{listener_select_proto, protocol::MSG_MULTISTREAM_1_0};
use bytes::BufMut;
use std::time::Duration;
use std::{sync::OnceState, time::Duration};
Comment thread
gab8i marked this conversation as resolved.
Outdated
#[tokio::test]
async fn select_proto_basic() {
async fn run(version: Version) {
Expand Down Expand Up @@ -813,6 +846,7 @@ mod tests {
)
.unwrap();

// Initial message should only contain the main protocol, not the fallback.
let mut bytes = BytesMut::with_capacity(32);
bytes.put_u8(MSG_MULTISTREAM_1_0.len() as u8);
Message::Header(HeaderLine::V1).encode(&mut bytes).unwrap();
Expand All @@ -821,15 +855,73 @@ mod tests {
bytes.put_u8((proto1.as_ref().len() + 1) as u8); // + 1 for \n
Message::Protocol(proto1).encode(&mut bytes).unwrap();

let proto2 = Protocol::try_from(&b"/sup/proto/1"[..]).expect("valid protocol name");
bytes.put_u8((proto2.as_ref().len() + 1) as u8); // + 1 for \n
Message::Protocol(proto2).encode(&mut bytes).unwrap();

let expected_message = bytes.freeze().to_vec();

assert_eq!(message, expected_message);
}

#[test]
fn propose_next_fallback() {
let (mut dialer_state, _message) = WebRtcDialerState::propose(
ProtocolName::from("/13371338/proto/1"),
vec![ProtocolName::from("/sup/proto/1")],
)
.unwrap();

// Simulate receiving header-only response, transitioning to WaitingProtocol.
let mut header_bytes = BytesMut::with_capacity(32);
header_bytes.put_u8(MSG_MULTISTREAM_1_0.len() as u8);
Message::Header(HeaderLine::V1).encode(&mut header_bytes).unwrap();
// Append "na" to simulate rejection.
let na_bytes = b"na\n";
header_bytes.put_u8(na_bytes.len() as u8);
header_bytes.put_slice(na_bytes);

match dialer_state.register_response(header_bytes.freeze().to_vec()) {
Ok(HandshakeResult::Rejected) => {}
event => panic!("expected Rejected, got: {event:?}"),
}

// After having received the header the state stays always at WaitingProtocol.
assert!(matches!(
dialer_state.state,
HandshakeState::WaitingProtocol
));

// Now propose the next fallback.
let fallback_message = dialer_state
.propose_next_fallback()
.expect("no error")
.expect("should have a fallback");

let mut expected = BytesMut::with_capacity(32);
let proto = Protocol::try_from(&b"/sup/proto/1"[..]).expect("valid protocol name");
expected.put_u8((proto.as_ref().len() + 1) as u8);
Message::Protocol(proto).encode(&mut expected).unwrap();
let protocol_message = expected.freeze().to_vec();
assert_eq!(fallback_message, protocol_message);

// No more fallbacks.
assert!(dialer_state.propose_next_fallback().unwrap().is_none());

match dialer_state.register_response(protocol_message) {
Ok(HandshakeResult::Succeeded(_)) => {}
event => panic!("expected Succeeded, got: {event:?}"),
}

let mut na_response = BytesMut::with_capacity(32);
let na_bytes = b"na\n";
na_response.put_u8(na_bytes.len() as u8);
na_response.put_slice(na_bytes);

dialer_state.state = HandshakeState::WaitingProtocol;

match dialer_state.register_response(na_response.to_vec()) {
Ok(HandshakeResult::Rejected) => {}
event => panic!("expected Rejected, got: {event:?}"),
}
}

#[test]
fn register_response_header_only() {
let mut bytes = BytesMut::with_capacity(32);
Expand Down Expand Up @@ -872,9 +964,10 @@ mod tests {

#[test]
fn negotiate_main_protocol() {
let message = webrtc_encode_multistream_message(vec![Message::Protocol(
Protocol::try_from(&b"/13371338/proto/1"[..]).unwrap(),
)])
let message = webrtc_encode_multistream_message(
Message::Protocol(Protocol::try_from(&b"/13371338/proto/1"[..]).unwrap()),
true,
)
.unwrap()
.freeze();

Expand All @@ -894,9 +987,10 @@ mod tests {

#[test]
fn negotiate_fallback_protocol() {
let message = webrtc_encode_multistream_message(vec![Message::Protocol(
Protocol::try_from(&b"/sup/proto/1"[..]).unwrap(),
)])
let message = webrtc_encode_multistream_message(
Message::Protocol(Protocol::try_from(&b"/sup/proto/1"[..]).unwrap()),
true,
)
.unwrap()
.freeze();

Expand Down
2 changes: 1 addition & 1 deletion src/multistream_select/length_delimited.rs
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ use std::{
};

const MAX_LEN_BYTES: u16 = 2;
const MAX_FRAME_SIZE: u16 = (1 << (MAX_LEN_BYTES * 8 - MAX_LEN_BYTES)) - 1;
pub(super) const MAX_FRAME_SIZE: u16 = (1 << (MAX_LEN_BYTES * 8 - MAX_LEN_BYTES)) - 1;
const DEFAULT_BUFFER_SIZE: usize = 64;
const LOG_TARGET: &str = "litep2p::multistream-select";

Expand Down
Loading
Loading