-
Notifications
You must be signed in to change notification settings - Fork 35
feat(webrtc): multistream-select protocol implementation #573
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 7 commits
6a6af0d
a293ad2
f1e889c
752c7df
2b7738f
5a12aa0
bf1204e
d98ed84
c6bdbae
f1a3be5
6c90434
ea5d8f9
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -24,7 +24,6 @@ use crate::{ | |
| codec::unsigned_varint::UnsignedVarint, | ||
| error::{self, Error, ParseError, SubstreamError}, | ||
| multistream_select::{ | ||
| drain_trailing_protocols, | ||
| protocol::{ | ||
| webrtc_encode_multistream_message, HeaderLine, Message, MessageIO, Protocol, | ||
| ProtocolError, PROTO_MULTISTREAM_1_0, | ||
|
|
@@ -300,6 +299,12 @@ pub enum HandshakeResult { | |
| /// The returned tuple contains the negotiated protocol and response | ||
| /// that must be sent to remote peer. | ||
| Succeeded(ProtocolName), | ||
|
|
||
| /// The proposed protocol was rejected by the remote peer. | ||
| /// | ||
| /// The caller should check if there are remaining fallback protocols to try | ||
| /// via [`WebRtcDialerState::propose_next_fallback()`]. | ||
| Rejected, | ||
| } | ||
|
|
||
| /// Handshake state. | ||
|
|
@@ -328,21 +333,27 @@ pub struct WebRtcDialerState { | |
| impl WebRtcDialerState { | ||
| /// Propose protocol to remote peer. | ||
| /// | ||
| /// `fallback_names` must be in preference order, the first element is the | ||
| /// next protocol to try. | ||
| /// | ||
| /// Return [`WebRtcDialerState`] which is used to drive forward the negotiation and an encoded | ||
| /// `multistream-select` message that contains the protocol proposal for the substream. | ||
| pub fn propose( | ||
| protocol: ProtocolName, | ||
| fallback_names: Vec<ProtocolName>, | ||
| mut fallback_names: Vec<ProtocolName>, | ||
| ) -> crate::Result<(Self, Vec<u8>)> { | ||
| let message = webrtc_encode_multistream_message( | ||
| std::iter::once(protocol.clone()) | ||
| .chain(fallback_names.clone()) | ||
| .filter_map(|protocol| Protocol::try_from(protocol.as_ref()).ok()) | ||
| .map(Message::Protocol), | ||
| Message::Protocol( | ||
| Protocol::try_from(protocol.as_ref()).map_err(|_| Error::InvalidData)?, | ||
| ), | ||
| true, | ||
| )? | ||
| .freeze() | ||
| .to_vec(); | ||
|
|
||
| // Reverse fallback_names so that we can pop from it. | ||
| fallback_names.reverse(); | ||
|
|
||
| Ok(( | ||
| Self { | ||
| protocol, | ||
|
|
@@ -353,69 +364,86 @@ impl WebRtcDialerState { | |
| )) | ||
| } | ||
|
|
||
| /// Propose the next fallback protocol to the remote peer. | ||
| /// | ||
| /// Returns `None` if there are no more fallback protocols to try. | ||
| /// Returns `Some(message)` with the encoded message to send, containing the protocol name. | ||
| pub fn propose_next_fallback(&mut self) -> crate::Result<Option<Vec<u8>>> { | ||
| if self.fallback_names.is_empty() { | ||
| return Ok(None); | ||
| } | ||
|
|
||
| // UNWRAP: fallback_names has just been checked to not be empty. | ||
| let next = self.fallback_names.pop().unwrap(); | ||
| self.protocol = next; | ||
|
|
||
| let message = webrtc_encode_multistream_message( | ||
| Message::Protocol( | ||
| Protocol::try_from(self.protocol.as_ref()).map_err(|_| Error::InvalidData)?, | ||
| ), | ||
| false, | ||
| )? | ||
| .freeze() | ||
|
Comment on lines
+383
to
+384
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. nit: This and above
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yes, it works because if |
||
| .to_vec(); | ||
|
|
||
| Ok(Some(message)) | ||
| } | ||
|
|
||
| /// Register response to [`WebRtcDialerState`]. | ||
| pub fn register_response( | ||
| &mut self, | ||
| payload: Vec<u8>, | ||
| ) -> Result<HandshakeResult, crate::error::NegotiationError> { | ||
| // All multistream-select messages are length-prefixed. Since this code path is not using | ||
| // multistream_select::protocol::MessageIO, we need to decode and remove the length here. | ||
| let remaining: &[u8] = &payload; | ||
| let (len, tail) = unsigned_varint::decode::usize(remaining).map_err(|error| { | ||
| tracing::debug!( | ||
| let bytes = Bytes::from(payload); | ||
| let mut remaining = bytes.clone(); | ||
|
gab8i marked this conversation as resolved.
Outdated
|
||
|
|
||
| while !remaining.is_empty() { | ||
| let (len, tail) = unsigned_varint::decode::usize(&remaining).map_err(|error| { | ||
| tracing::debug!( | ||
| target: LOG_TARGET, | ||
| ?error, | ||
| message = ?payload, | ||
| "Failed to decode length-prefix in multistream message"); | ||
| error::NegotiationError::ParseError(ParseError::InvalidData) | ||
| })?; | ||
| message = ?remaining, | ||
| "Failed to decode length-prefix in multistream message", | ||
| ); | ||
| error::NegotiationError::ParseError(ParseError::InvalidData) | ||
| })?; | ||
|
|
||
| let len_size = remaining.len() - tail.len(); | ||
| let bytes = Bytes::from(payload); | ||
| let payload = bytes.slice(len_size..len_size + len); | ||
| let remaining = bytes.slice(len_size + len..); | ||
| let message = Message::decode(payload); | ||
|
|
||
| tracing::trace!( | ||
| target: LOG_TARGET, | ||
| ?message, | ||
| "Decoded message while registering response", | ||
| ); | ||
|
|
||
| let mut protocols = match message { | ||
| Ok(Message::Header(HeaderLine::V1)) => { | ||
| vec![PROTO_MULTISTREAM_1_0] | ||
| let len_size = remaining.len() - tail.len(); | ||
|
|
||
| if len > tail.len() { | ||
|
gab8i marked this conversation as resolved.
|
||
| tracing::debug!( | ||
| target: LOG_TARGET, | ||
| message = ?tail, | ||
| length_prefix = len, | ||
| actual_length = tail.len(), | ||
| "Truncated multistream message", | ||
| ); | ||
| return Err(error::NegotiationError::ParseError(ParseError::InvalidData)); | ||
| } | ||
| Ok(Message::Protocol(protocol)) => vec![protocol], | ||
| Ok(Message::Protocols(protocols)) => protocols, | ||
| Ok(Message::NotAvailable) => | ||
| return match &self.state { | ||
| HandshakeState::WaitingProtocol => Err( | ||
| error::NegotiationError::MultistreamSelectError(NegotiationError::Failed), | ||
| ), | ||
| _ => Err(error::NegotiationError::StateMismatch), | ||
| }, | ||
| Ok(Message::ListProtocols) => return Err(error::NegotiationError::StateMismatch), | ||
| Err(_) => return Err(error::NegotiationError::ParseError(ParseError::InvalidData)), | ||
| }; | ||
|
|
||
| protocols.extend(drain_trailing_protocols(remaining)?); | ||
| let payload = remaining.slice(len_size..len_size + len); | ||
| remaining = remaining.slice(len_size + len..); | ||
| let message = Message::decode(payload); | ||
|
gab8i marked this conversation as resolved.
Outdated
|
||
|
|
||
| let mut protocol_iter = protocols.into_iter(); | ||
| loop { | ||
| match (&self.state, protocol_iter.next()) { | ||
| (HandshakeState::WaitingResponse, None) => | ||
| return Err(crate::error::NegotiationError::StateMismatch), | ||
| (HandshakeState::WaitingResponse, Some(protocol)) => { | ||
| if protocol == PROTO_MULTISTREAM_1_0 { | ||
| self.state = HandshakeState::WaitingProtocol; | ||
| } else { | ||
| return Err(crate::error::NegotiationError::MultistreamSelectError( | ||
| NegotiationError::Failed, | ||
| )); | ||
| } | ||
| tracing::trace!( | ||
| target: LOG_TARGET, | ||
| ?message, | ||
| "Decoded message while registering response", | ||
| ); | ||
|
|
||
| match (&self.state, message) { | ||
| (HandshakeState::WaitingResponse, Ok(Message::Header(HeaderLine::V1))) => { | ||
| self.state = HandshakeState::WaitingProtocol; | ||
| } | ||
| (HandshakeState::WaitingResponse, Ok(Message::Protocol(_))) => { | ||
| return Err(crate::error::NegotiationError::MultistreamSelectError( | ||
| NegotiationError::Failed, | ||
| )); | ||
| } | ||
| (_, Ok(Message::NotAvailable)) => { | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This could be a protocol violation if the header was not received. The spec says that for the first message the header should be echoed back: However, we currently support missing the header: Which will make us propose the next fallback protocol that will most definetely not be accepted since we are still in the In case the our state is waiting protocol this is fine. However, if we are still waiting the header we should return a state missmatch / error // This is not oki, we ahve not received the headr
(HandshakeState::WaitingResponse, Ok(Message::NotAvailable)) => return Ok(HandshakeResult::StateMismatch),
(HandshakeState::WaitingProtocol, Ok(Message::NotAvailable)) => return Ok(HandshakeResult::Rejected),Lets also add a bit more trace logs / debugs here especially for the
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Nice catch! This needs to be addressed to avoid any protocol issue! Solution should be pretty basic: (HandshakeState::WaitingResponse, Ok(msg)) => {
tracing::trace!(
target: LOG_TARGET,
?msg,
"Expected header response from peer, got different message"
);
return Err(crate::error::NegotiationError::MultistreamSelectError(
NegotiationError::Failed,
));
} |
||
| return Ok(HandshakeResult::Rejected); | ||
| } | ||
| (HandshakeState::WaitingProtocol, Some(protocol)) => { | ||
| (HandshakeState::WaitingProtocol, Ok(Message::Protocol(protocol))) => { | ||
| if protocol == PROTO_MULTISTREAM_1_0 { | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We could improve a bit the robustness of this function wrt the state machine management. Currently, we accept only If the buffer contains an extra message We should have at least an warning to easily detect if the remaining bytes are non empty
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Making this function more reliable would imply to change quite a lot the code, but what you describe could totally be implemented by a peer which can send those 3 things out at once. Right now it is only expected to interact with smoldot which seems to not do so. I've pushed a change which warn if trailing bytes are found after having decoded valid multistream-select messages. Is this ok for now?
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yep warning should be sufficient for now since the fix is quite intrusive. Could you create an issue to not forget about it 🙏 ? |
||
| return Err(crate::error::NegotiationError::StateMismatch); | ||
| } | ||
|
|
@@ -434,11 +462,16 @@ impl WebRtcDialerState { | |
| NegotiationError::Failed, | ||
| )); | ||
| } | ||
| (HandshakeState::WaitingProtocol, None) => { | ||
| return Ok(HandshakeResult::NotReady); | ||
| _ => { | ||
|
gab8i marked this conversation as resolved.
|
||
| return Err(crate::error::NegotiationError::StateMismatch); | ||
| } | ||
| } | ||
| } | ||
|
|
||
| match &self.state { | ||
| HandshakeState::WaitingProtocol => Ok(HandshakeResult::NotReady), | ||
| HandshakeState::WaitingResponse => Err(crate::error::NegotiationError::StateMismatch), | ||
| } | ||
| } | ||
| } | ||
|
|
||
|
|
@@ -447,7 +480,7 @@ mod tests { | |
| use super::*; | ||
| use crate::multistream_select::{listener_select_proto, protocol::MSG_MULTISTREAM_1_0}; | ||
| use bytes::BufMut; | ||
| use std::time::Duration; | ||
| use std::{sync::OnceState, time::Duration}; | ||
|
gab8i marked this conversation as resolved.
Outdated
|
||
| #[tokio::test] | ||
| async fn select_proto_basic() { | ||
| async fn run(version: Version) { | ||
|
|
@@ -813,6 +846,7 @@ mod tests { | |
| ) | ||
| .unwrap(); | ||
|
|
||
| // Initial message should only contain the main protocol, not the fallback. | ||
| let mut bytes = BytesMut::with_capacity(32); | ||
| bytes.put_u8(MSG_MULTISTREAM_1_0.len() as u8); | ||
| Message::Header(HeaderLine::V1).encode(&mut bytes).unwrap(); | ||
|
|
@@ -821,15 +855,73 @@ mod tests { | |
| bytes.put_u8((proto1.as_ref().len() + 1) as u8); // + 1 for \n | ||
| Message::Protocol(proto1).encode(&mut bytes).unwrap(); | ||
|
|
||
| let proto2 = Protocol::try_from(&b"/sup/proto/1"[..]).expect("valid protocol name"); | ||
| bytes.put_u8((proto2.as_ref().len() + 1) as u8); // + 1 for \n | ||
| Message::Protocol(proto2).encode(&mut bytes).unwrap(); | ||
|
|
||
| let expected_message = bytes.freeze().to_vec(); | ||
|
|
||
| assert_eq!(message, expected_message); | ||
| } | ||
|
|
||
| #[test] | ||
| fn propose_next_fallback() { | ||
| let (mut dialer_state, _message) = WebRtcDialerState::propose( | ||
| ProtocolName::from("/13371338/proto/1"), | ||
| vec![ProtocolName::from("/sup/proto/1")], | ||
| ) | ||
| .unwrap(); | ||
|
|
||
| // Simulate receiving header-only response, transitioning to WaitingProtocol. | ||
| let mut header_bytes = BytesMut::with_capacity(32); | ||
| header_bytes.put_u8(MSG_MULTISTREAM_1_0.len() as u8); | ||
| Message::Header(HeaderLine::V1).encode(&mut header_bytes).unwrap(); | ||
| // Append "na" to simulate rejection. | ||
| let na_bytes = b"na\n"; | ||
| header_bytes.put_u8(na_bytes.len() as u8); | ||
| header_bytes.put_slice(na_bytes); | ||
|
|
||
| match dialer_state.register_response(header_bytes.freeze().to_vec()) { | ||
| Ok(HandshakeResult::Rejected) => {} | ||
| event => panic!("expected Rejected, got: {event:?}"), | ||
| } | ||
|
|
||
| // After having received the header the state stays always at WaitingProtocol. | ||
| assert!(matches!( | ||
| dialer_state.state, | ||
| HandshakeState::WaitingProtocol | ||
| )); | ||
|
|
||
| // Now propose the next fallback. | ||
| let fallback_message = dialer_state | ||
| .propose_next_fallback() | ||
| .expect("no error") | ||
| .expect("should have a fallback"); | ||
|
|
||
| let mut expected = BytesMut::with_capacity(32); | ||
| let proto = Protocol::try_from(&b"/sup/proto/1"[..]).expect("valid protocol name"); | ||
| expected.put_u8((proto.as_ref().len() + 1) as u8); | ||
| Message::Protocol(proto).encode(&mut expected).unwrap(); | ||
| let protocol_message = expected.freeze().to_vec(); | ||
| assert_eq!(fallback_message, protocol_message); | ||
|
|
||
| // No more fallbacks. | ||
| assert!(dialer_state.propose_next_fallback().unwrap().is_none()); | ||
|
|
||
| match dialer_state.register_response(protocol_message) { | ||
| Ok(HandshakeResult::Succeeded(_)) => {} | ||
| event => panic!("expected Succeeded, got: {event:?}"), | ||
| } | ||
|
|
||
| let mut na_response = BytesMut::with_capacity(32); | ||
| let na_bytes = b"na\n"; | ||
| na_response.put_u8(na_bytes.len() as u8); | ||
| na_response.put_slice(na_bytes); | ||
|
|
||
| dialer_state.state = HandshakeState::WaitingProtocol; | ||
|
|
||
| match dialer_state.register_response(na_response.to_vec()) { | ||
| Ok(HandshakeResult::Rejected) => {} | ||
| event => panic!("expected Rejected, got: {event:?}"), | ||
| } | ||
| } | ||
|
|
||
| #[test] | ||
| fn register_response_header_only() { | ||
| let mut bytes = BytesMut::with_capacity(32); | ||
|
|
@@ -872,9 +964,10 @@ mod tests { | |
|
|
||
| #[test] | ||
| fn negotiate_main_protocol() { | ||
| let message = webrtc_encode_multistream_message(vec![Message::Protocol( | ||
| Protocol::try_from(&b"/13371338/proto/1"[..]).unwrap(), | ||
| )]) | ||
| let message = webrtc_encode_multistream_message( | ||
| Message::Protocol(Protocol::try_from(&b"/13371338/proto/1"[..]).unwrap()), | ||
| true, | ||
| ) | ||
| .unwrap() | ||
| .freeze(); | ||
|
|
||
|
|
@@ -894,9 +987,10 @@ mod tests { | |
|
|
||
| #[test] | ||
| fn negotiate_fallback_protocol() { | ||
| let message = webrtc_encode_multistream_message(vec![Message::Protocol( | ||
| Protocol::try_from(&b"/sup/proto/1"[..]).unwrap(), | ||
| )]) | ||
| let message = webrtc_encode_multistream_message( | ||
| Message::Protocol(Protocol::try_from(&b"/sup/proto/1"[..]).unwrap()), | ||
| true, | ||
| ) | ||
| .unwrap() | ||
| .freeze(); | ||
|
|
||
|
|
||
Uh oh!
There was an error while loading. Please reload this page.