diff --git a/src/protocol/notification/negotiation.rs b/src/protocol/notification/negotiation.rs index 9c53c7606..6509d8274 100644 --- a/src/protocol/notification/negotiation.rs +++ b/src/protocol/notification/negotiation.rs @@ -79,6 +79,7 @@ pub enum HandshakeEvent { } /// Outbound substream's handshake state +#[derive(Debug)] enum HandshakeState { /// Send handshake to remote peer. SendHandshake, @@ -218,6 +219,13 @@ impl Stream for HandshakeService { inner.substreams.iter_mut() { if let Poll::Ready(()) = timer.poll_unpin(cx) { + tracing::trace!( + target: LOG_TARGET, + ?peer, + ?direction, + ?state, + "handshake negotiation timed out", + ); return Poll::Ready(Some(( *peer, HandshakeEvent::NegotiationError { @@ -285,10 +293,36 @@ impl Stream for HandshakeService { }, HandshakeState::ReadHandshake => match pinned.poll_next(cx) { Poll::Ready(Some(Ok(handshake))) => { + tracing::trace!( + target: LOG_TARGET, + ?peer, + handshake_len = handshake.len(), + "successfully read handshake from substream", + ); inner.ready.push_back((*peer, *direction, handshake.freeze().into())); continue 'outer; } - Poll::Ready(Some(Err(_))) | Poll::Ready(None) => { + Poll::Ready(Some(Err(error))) => { + tracing::trace!( + target: LOG_TARGET, + ?peer, + ?error, + "error reading handshake from substream", + ); + return Poll::Ready(Some(( + *peer, + HandshakeEvent::NegotiationError { + peer: *peer, + direction: *direction, + }, + ))); + } + Poll::Ready(None) => { + tracing::trace!( + target: LOG_TARGET, + ?peer, + "substream closed while reading handshake", + ); return Poll::Ready(Some(( *peer, HandshakeEvent::NegotiationError { diff --git a/src/transport/webrtc/connection.rs b/src/transport/webrtc/connection.rs index e7d53b7fa..b1322b94a 100644 --- a/src/transport/webrtc/connection.rs +++ b/src/transport/webrtc/connection.rs @@ -29,7 +29,7 @@ use crate::{ transport::{ webrtc::{ schema::webrtc::message::Flag, - substream::{Event as SubstreamEvent, Substream as WebRtcSubstream, SubstreamHandle}, + substream::{Message, Substream as WebRtcSubstream, SubstreamHandle}, util::{extract_framed_message, WebRtcMessage}, }, Endpoint, @@ -145,7 +145,7 @@ impl SubstreamHandleSet { } impl Stream for SubstreamHandleSet { - type Item = (ChannelId, Option); + type Item = (ChannelId, Option); fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { let len = match self.handles.len() { @@ -465,8 +465,50 @@ impl WebRtcConnection { "channel closed", ); - self.pending_outbound.remove(&channel_id); - self.channels.remove(&channel_id); + // If this was a pending outbound channel (waiting for DCEP ACK from remote), + // report the failure so the protocol handler can retry. + if let Some(context) = self.pending_outbound.remove(&channel_id) { + tracing::debug!( + target: LOG_TARGET, + peer = ?self.peer, + ?channel_id, + protocol = %context.protocol, + substream_id = ?context.substream_id, + "outbound channel closed before opening, reporting failure", + ); + + let _ = self + .protocol_set + .report_substream_open_failure( + context.protocol, + context.substream_id, + SubstreamError::ConnectionClosed, + ) + .await; + } + + if let Some(ChannelState::OutboundOpening { context, .. }) = + self.channels.remove(&channel_id) + { + tracing::debug!( + target: LOG_TARGET, + peer = ?self.peer, + ?channel_id, + protocol = %context.protocol, + substream_id = ?context.substream_id, + "outbound channel closed during negotiation, reporting failure", + ); + + let _ = self + .protocol_set + .report_substream_open_failure( + context.protocol, + context.substream_id, + SubstreamError::ConnectionClosed, + ) + .await; + } + self.pending_messages.remove(&channel_id); self.handles.remove(&channel_id); self.recv_buffers.remove(&channel_id); @@ -501,6 +543,7 @@ impl WebRtcConnection { "handle opening inbound substream", ); + // Decode errors are not recoverable. let payload = WebRtcMessage::decode(&data)?.payload.ok_or(Error::InvalidData)?; let protocols = self.protocol_set.protocols_with_keep_alives(); let protocol_names = protocols.keys().cloned().collect(); @@ -889,6 +932,7 @@ impl WebRtcConnection { self.rtc.direct_api().close_data_channel(channel_id); self.channels.insert(channel_id, ChannelState::Closing); + self.handles.remove(&channel_id); } }, ChannelState::Closing => { @@ -972,6 +1016,29 @@ impl WebRtcConnection { "connection closed", ); + let mut report_failure = async |context: &ChannelContext| { + let _ = self + .protocol_set + .report_substream_open_failure( + context.protocol.clone(), + context.substream_id, + SubstreamError::ConnectionClosed, + ) + .await; + }; + + // Drain pending outbound opens (data channel not yet acked). + for (_, context) in self.pending_outbound.drain() { + report_failure(&context).await; + } + + // Drain channels still in OutboundOpening (multistream-select in flight). + for (_, state) in self.channels.drain() { + if let ChannelState::OutboundOpening { context, .. } = state { + report_failure(&context).await; + } + } + let _ = self .protocol_set .report_connection_closed(self.peer, self.endpoint.connection_id()) @@ -1004,7 +1071,26 @@ impl WebRtcConnection { "transmit data", ); - self.socket.try_send_to(&v.contents, v.destination).unwrap(); + if let Err(error) = self.socket.try_send_to(&v.contents, v.destination) { + if error.kind() == std::io::ErrorKind::WouldBlock { + tracing::trace!( + target: LOG_TARGET, + peer = ?self.peer, + destination = ?v.destination, + "UDP send buffer full, dropping datagram (str0m will retransmit)", + ); + } else { + tracing::debug!( + target: LOG_TARGET, + peer = ?self.peer, + destination = ?v.destination, + ?error, + "failed to send datagram, closing connection", + ); + return self.on_connection_closed().await; + } + } + continue; } Output::Event(v) => match v { @@ -1076,27 +1162,44 @@ impl WebRtcConnection { }, }; - let duration = timeout - Instant::now(); - if duration.is_zero() { - self.rtc.handle_input(Input::Timeout(Instant::now())).unwrap(); - continue; - } - tokio::select! { biased; datagram = self.dgram_rx.recv() => match datagram { Some(datagram) => { + let contents = match datagram.as_slice().try_into() { + Ok(contents) => contents, + Err(error) => { + tracing::debug!( + target: LOG_TARGET, + peer = ?self.peer, + ?error, + datagram_len = datagram.len(), + "failed to parse inbound datagram, closing connection", + ); + + return self.on_connection_closed().await; + } + }; + let input = Input::Receive( Instant::now(), Receive { proto: Str0mProtocol::Udp, source: self.peer_address, destination: self.local_address, - contents: datagram.as_slice().try_into().unwrap(), + contents, }, ); - self.rtc.handle_input(input).unwrap(); + if let Err(error) = self.rtc.handle_input(input) { + tracing::debug!( + target: LOG_TARGET, + peer = ?self.peer, + ?error, + "str0m rejected inbound datagram, closing connection", + ); + return self.on_connection_closed().await; + } } None => { tracing::trace!( @@ -1121,7 +1224,7 @@ impl WebRtcConnection { self.channels.insert(channel_id, ChannelState::Closing); self.handles.remove(&channel_id); } - Some((channel_id, Some(SubstreamEvent::Message { payload, flag }))) => { + Some((channel_id, Some(Message { payload, flag }))) => { if let Err(error) = self.on_outbound_data(channel_id, payload, flag) { tracing::debug!( target: LOG_TARGET, @@ -1131,11 +1234,11 @@ impl WebRtcConnection { "failed to send data to remote peer", ); - self.channels.insert(channel_id, ChannelState::Closing); self.rtc.direct_api().close_data_channel(channel_id); + self.channels.insert(channel_id, ChannelState::Closing); + self.handles.remove(&channel_id); } } - Some((_, Some(SubstreamEvent::RecvClosed))) => {} }, command = self.protocol_set.next() => match command { None | Some(ProtocolCommand::ForceClose) => { @@ -1178,8 +1281,17 @@ impl WebRtcConnection { ); } }, - _ = tokio::time::sleep(duration) => { - self.rtc.handle_input(Input::Timeout(Instant::now())).unwrap(); + _ = tokio::time::sleep(timeout.saturating_duration_since(Instant::now())) => { + if let Err(error) = self.rtc.handle_input(Input::Timeout(Instant::now())) { + tracing::debug!( + target: LOG_TARGET, + peer = ?self.peer, + ?error, + "str0m rejected timeout input, closing connection", + ); + + return self.on_connection_closed().await; + } } } } diff --git a/src/transport/webrtc/mod.rs b/src/transport/webrtc/mod.rs index 49844d137..31d2f896a 100644 --- a/src/transport/webrtc/mod.rs +++ b/src/transport/webrtc/mod.rs @@ -224,14 +224,18 @@ impl WebRtcTransport { pass: &str, source: SocketAddr, destination: SocketAddr, - ) -> (Rtc, ChannelId) { + ) -> crate::Result<(Rtc, ChannelId)> { let mut rtc = Rtc::builder() .set_ice_lite(true) .set_dtls_cert(self.dtls_cert.clone()) .set_fingerprint_verification(false) .build(std::time::Instant::now()); - rtc.add_local_candidate(Candidate::host(destination, Str0mProtocol::Udp).unwrap()); - rtc.add_remote_candidate(Candidate::host(source, Str0mProtocol::Udp).unwrap()); + rtc.add_local_candidate( + Candidate::host(destination, Str0mProtocol::Udp).map_err(RtcError::Ice)?, + ); + rtc.add_remote_candidate( + Candidate::host(source, Str0mProtocol::Udp).map_err(RtcError::Ice)?, + ); rtc.direct_api() .set_remote_fingerprint(REMOTE_FINGERPRINT.parse().expect("parse() to succeed")); rtc.direct_api().set_remote_ice_credentials(IceCreds { @@ -243,18 +247,18 @@ impl WebRtcTransport { pass: pass.to_owned(), }); rtc.direct_api().set_ice_controlling(false); - rtc.direct_api().start_dtls(false).unwrap(); + rtc.direct_api().start_dtls(false)?; rtc.direct_api().start_sctp(false); let noise_channel_id = rtc.direct_api().create_data_channel(ChannelConfig { label: "noise".to_string(), - ordered: false, - reliability: Default::default(), + ordered: true, + reliability: str0m::channel::Reliability::Reliable, negotiated: Some(0), protocol: "".to_string(), }); - (rtc, noise_channel_id) + Ok((rtc, noise_channel_id)) } /// Poll opening connection. @@ -365,37 +369,36 @@ impl WebRtcTransport { let contents: DatagramRecv = buffer.as_slice().try_into().map_err(|_| Error::InvalidData)?; - // Handle non stun packets. - if !is_stun_packet(&buffer) { - tracing::debug!( + // If an opening connection already exists for this source, route all packets to it + if let Some(opening_conn) = self.opening.get_mut(&source) { + tracing::trace!( target: LOG_TARGET, ?source, - "received non-stun message" + is_stun = is_stun_packet(&buffer), + "routing packet to existing opening connection" ); - match self.opening.get_mut(&source) { - Some(connection) => - if let Err(error) = connection.on_input(contents) { - tracing::error!( - target: LOG_TARGET, - ?error, - ?source, - "failed to handle inbound datagram" - ); - }, - None => { - tracing::warn!( - target: LOG_TARGET, - ?source, - "received non-stun message from unknown peer", - ); - return Err(Error::InvalidData); - } - }; - + if let Err(error) = opening_conn.on_input(contents) { + tracing::error!( + target: LOG_TARGET, + ?error, + ?source, + "failed to handle inbound datagram" + ); + } return Ok(true); } + // No existing connection - this should be a STUN packet to create a new connection + if !is_stun_packet(&buffer) { + tracing::warn!( + target: LOG_TARGET, + ?source, + "received non-stun packet without existing connection, ignoring" + ); + return Ok(false); + } + let stun_message = str0m::ice::StunMessage::parse(&buffer).map_err(|_| Error::InvalidData)?; let Some((ufrag, pass)) = stun_message.split_username() else { @@ -411,24 +414,22 @@ impl WebRtcTransport { target: LOG_TARGET, ?source, ?ufrag, - ?pass, "received stun message" ); // create new `Rtc` object for the peer and give it the received STUN message - let (mut rtc, noise_channel_id) = - self.make_rtc_client(ufrag, pass, source, self.socket.local_addr().unwrap()); + let local_addr = self.socket.local_addr()?; + let (mut rtc, noise_channel_id) = self.make_rtc_client(ufrag, pass, source, local_addr)?; rtc.handle_input(Input::Receive( Instant::now(), Receive { source, proto: Str0mProtocol::Udp, - destination: self.socket.local_addr().unwrap(), + destination: local_addr, contents, }, - )) - .expect("client to handle input successfully"); + ))?; let connection_id = self.context.next_connection_id(); let connection = OpeningWebRtcConnection::new( diff --git a/src/transport/webrtc/opening.rs b/src/transport/webrtc/opening.rs index ec35dad9c..2d515ea01 100644 --- a/src/transport/webrtc/opening.rs +++ b/src/transport/webrtc/opening.rs @@ -409,24 +409,17 @@ impl OpeningWebRtcConnection { }, ); - match self.rtc.accepts(&message) { - true => self.rtc.handle_input(message).map_err(|error| { - tracing::debug!( - target: LOG_TARGET, - peer = ?self.peer_address, - ?error, "failed to handle data" - ); - Error::InputRejected - }), - false => { - tracing::warn!( - target: LOG_TARGET, - peer = ?self.peer_address, - "input rejected", - ); - Err(Error::InputRejected) - } - } + // Let str0m handle input validation internally, similar to how the initial STUN packet is + // handled + self.rtc.handle_input(message).map_err(|error| { + tracing::debug!( + target: LOG_TARGET, + source = ?self.peer_address, + ?error, + "failed to handle data" + ); + Error::InputRejected + }) } /// Progress the state of [`OpeningWebRtcConnection`]. diff --git a/src/transport/webrtc/substream.rs b/src/transport/webrtc/substream.rs index 9acfbe444..10e8bc68f 100644 --- a/src/transport/webrtc/substream.rs +++ b/src/transport/webrtc/substream.rs @@ -22,113 +22,219 @@ use crate::{ transport::webrtc::{ schema::webrtc::message::Flag, util::{WebRtcMessage, MAX_FRAME_SIZE}, + LOG_TARGET, }, Error, }; use bytes::{Buf, BufMut, BytesMut}; use futures::{task::AtomicWaker, Future, Stream}; -use parking_lot::Mutex; use tokio::sync::mpsc::{channel, Receiver, Sender}; use tokio_util::sync::PollSender; use std::{ + marker::PhantomData, pin::Pin, - sync::Arc, + sync::{ + atomic::{AtomicU8, Ordering}, + Arc, + }, task::{Context, Poll}, time::Duration, }; /// Timeout for waiting on FIN_ACK after sending FIN. -/// Matches go-libp2p's 5 second stream close timeout. -const FIN_ACK_TIMEOUT: Duration = Duration::from_secs(5); - -/// Substream event. -#[derive(Debug, PartialEq, Eq)] -pub enum Event { - /// Receiver closed. - RecvClosed, - - /// Send/receive message with optional flag. - Message { - payload: Vec, - flag: Option, - }, +/// Matches go-libp2p and js-libp2p's 10-second stream close timeout. +const FIN_ACK_TIMEOUT: Duration = Duration::from_secs(10); + +/// Substream Message. +#[derive(PartialEq, Eq, Debug)] +pub struct Message { + pub payload: Vec, + pub flag: Option, +} + +trait AtomicState { + fn from_u8(raw_state: u8) -> Self; + fn into_u8(self) -> u8; +} + +/// Shared state used to sync `Substream` and `SubstreamHandle`. +#[derive(Clone)] +struct SharedState { + inner: Arc, + _phantom: PhantomData, +} + +struct Inner { + state: AtomicU8, + waker: AtomicWaker, +} + +impl SharedState { + fn new(val: T) -> Self { + Self { + inner: Arc::new(Inner { + state: AtomicU8::new(val.into_u8()), + waker: AtomicWaker::new(), + }), + _phantom: Default::default(), + } + } + + fn set(&self, new_value: T) { + self.inner.state.store(new_value.into_u8(), Ordering::Release); + self.inner.waker.wake(); + } + + fn get(&self) -> T { + T::from_u8(self.inner.state.load(Ordering::Acquire)) + } + + // NOTE: this method should only be called from a single place, + // otherwise the waker is re-registered for another context + // and thus steals the previously registered waker. + fn register_and_get(&self, cx: &mut Context<'_>) -> T { + self.inner.waker.register(cx.waker()); + self.get() + } } /// Substream stream. -#[derive(Debug, Clone, Copy)] -enum State { +#[derive(Debug, Clone)] +#[repr(u8)] +enum ChannelState { /// Substream is fully open. - Open, + Open = 0, + /// ResetStream has been received. + /// + /// This preempts both writer and reader state. + Reset = 1, +} + +impl AtomicState for ChannelState { + fn from_u8(raw_state: u8) -> Self { + match raw_state { + 0 => ChannelState::Open, + 1 => ChannelState::Reset, + // Unreachable in practice: `into_u8` only ever stores a valid variant and the + // `AtomicU8` only returns a previously-stored byte. + // Return Reset defensively rather than a panic. + _ => ChannelState::Reset, + } + } + + fn into_u8(self) -> u8 { + self as u8 + } +} + +/// State of the reading side of the stream. +#[derive(Debug, Clone)] +#[repr(u8)] +enum ReaderState { + /// The reading stream is open. + Open = 0, + /// A Fin flag was received. + Fin = 1, + /// FinAck was sent back. + FinAck = 2, +} - /// Remote is no longer interested in receiving anything. - SendClosed, +impl AtomicState for ReaderState { + fn from_u8(raw_state: u8) -> Self { + match raw_state { + 0 => ReaderState::Open, + 1 => ReaderState::Fin, + 2 => ReaderState::FinAck, + // Unreachable in practice: `into_u8` only ever stores a valid variant and the + // `AtomicU8` only returns a previously-stored byte. + // Return FinAck defensively rather than a panic. + _ => ReaderState::FinAck, + } + } + + fn into_u8(self) -> u8 { + self as u8 + } +} - /// Shutdown initiated, flushing pending data before sending FIN. - Closing, +/// State of the writing side of the stream. +#[derive(Debug, Clone)] +#[repr(u8)] +enum WriterState { + /// The writing stream is open. + Open = 0, + /// A Fin flag was sent. + Fin = 1, + /// FinAck was received. + FinAck = 2, + /// StopSending was received. + StopSending = 3, +} - /// We sent FIN, waiting for FIN_ACK. - FinSent, +impl AtomicState for WriterState { + fn from_u8(raw_state: u8) -> Self { + match raw_state { + 0 => WriterState::Open, + 1 => WriterState::Fin, + 2 => WriterState::FinAck, + 3 => WriterState::StopSending, + // Unreachable in practice: `into_u8` only ever stores a valid variant and the + // `AtomicU8` only returns a previously-stored byte. + // Return FinAck defensively rather than a panic. + _ => WriterState::FinAck, + } + } - /// We received FIN_ACK, write half is closed. - FinAcked, + fn into_u8(self) -> u8 { + self as u8 + } } /// Channel-backed substream. Must be owned and polled by exactly one task at a time. pub struct Substream { - /// Substream state. - state: Arc>, - /// Read buffer. read_buffer: BytesMut, - + /// RX channel for receiving messages from `peer`. + rx: Receiver, /// TX channel for sending messages to `peer`, wrapped in a [`PollSender`] /// so that backpressure is driven by the caller's waker. - tx: PollSender, - - /// RX channel for receiving messages from `peer`. - rx: Receiver, - - /// Waker to notify when shutdown completes (FIN_ACK received). - shutdown_waker: Arc, - - /// Waker to notify when write state changes (e.g., STOP_SENDING received). - write_waker: Arc, - - /// Timeout for waiting on FIN_ACK after sending FIN. - /// Boxed to maintain Unpin for Substream while allowing the Sleep to be polled. - fin_ack_timeout: Option>>, + tx: Option>, + /// State of the channel. + channel_state: SharedState, + /// State of the writing half. + writer_state: SharedState, } impl Substream { /// Create new [`Substream`]. pub fn new() -> (Self, SubstreamHandle) { - let (outbound_tx, outbound_rx) = channel(256); - let (inbound_tx, inbound_rx) = channel(256); - let state = Arc::new(Mutex::new(State::Open)); - let shutdown_waker = Arc::new(AtomicWaker::new()); - let write_waker = Arc::new(AtomicWaker::new()); + // Tokio channels implement their own backpressure, + // which solves the Substream <-> SubstreamHandle backpressure problem. + let (outbound_message_tx, outbound_message_rx) = channel(256); + let (inbound_message_tx, inbound_message_rx) = channel(256); + + let channel_state = SharedState::new(ChannelState::Open); + let writer_state = SharedState::new(WriterState::Open); + let reader_state = SharedState::new(ReaderState::Open); let handle = SubstreamHandle { - inbound_tx, - outbound_tx: outbound_tx.clone(), - rx: outbound_rx, - state: Arc::clone(&state), - shutdown_waker: Arc::clone(&shutdown_waker), - write_waker: Arc::clone(&write_waker), - read_closed: std::sync::atomic::AtomicBool::new(false), + channel_state: channel_state.clone(), + writer_state: writer_state.clone(), + reader_state: reader_state.clone(), + message_tx: Some(inbound_message_tx), + message_rx: outbound_message_rx, + fin_ack_timeout: None, }; ( Self { - state, - tx: PollSender::new(outbound_tx), - rx: inbound_rx, read_buffer: BytesMut::new(), - shutdown_waker, - write_waker, - fin_ack_timeout: None, + tx: Some(PollSender::new(outbound_message_tx)), + rx: inbound_message_rx, + channel_state, + writer_state, }, handle, ) @@ -137,26 +243,23 @@ impl Substream { /// Substream handle that is given to the WebRTC transport backend. pub struct SubstreamHandle { - state: Arc>, - + /// State of the channel. + channel_state: SharedState, + /// State of the writing half. + writer_state: SharedState, + /// State of the reading half. + reader_state: SharedState, /// TX channel for sending inbound messages from `peer` to the associated `Substream`. - inbound_tx: Sender, - - /// TX channel for sending outbound messages to `peer` (e.g., FIN_ACK responses). - outbound_tx: Sender, - + /// + /// The sender is taken (dropped) when a FIN flag is received from the remote: + /// closing it signals to the `Substream` reader that no further inbound + /// payloads will arrive. + message_tx: Option>, /// RX channel for receiving outbound messages to `peer` from the associated `Substream`. - rx: Receiver, - - /// Waker to notify when shutdown completes (FIN_ACK received). - shutdown_waker: Arc, - - /// Waker to notify when write state changes (e.g., STOP_SENDING received). - write_waker: Arc, - - /// Whether we've already sent RecvClosed to the inbound channel. - /// Prevents duplicate RecvClosed events if multiple FIN messages are received. - read_closed: std::sync::atomic::AtomicBool, + message_rx: Receiver, + /// Timeout for waiting on FIN_ACK after sending FIN + /// Boxed to maintain Unpin for Substream while allowing the Sleep to be polled. + fin_ack_timeout: Option>>, } impl SubstreamHandle { @@ -166,18 +269,46 @@ impl SubstreamHandle { /// /// Payload is processed first (if present), then flags are handled. This ensures that /// a FIN message containing final data will deliver that data before signaling closure. - pub async fn on_message(&self, message: WebRtcMessage) -> crate::Result<()> { + pub async fn on_message(&mut self, message: WebRtcMessage) -> crate::Result<()> { + // If Reset was received then early return discarding messages. + // In practice this should never happen because SCTP guarantees the order + // of messages, thus no other message is expected after a Reset. + if matches!(self.channel_state.get(), ChannelState::Reset) { + return Ok(()); + } + // Process payload first, before handling flags. - // This ensures that if a FIN message contains data, we deliver it before closing. - if let Some(payload) = message.payload { - if !payload.is_empty() { - self.inbound_tx - .send(Event::Message { + match (self.message_tx.as_ref(), message.payload) { + (None, Some(payload)) if !payload.is_empty() => { + tracing::debug!( + target: LOG_TARGET, + payload_len = payload.len(), + "peer sent payload after FIN flag, spec violation" + ); + } + (Some(message_tx), Some(payload)) if !payload.is_empty() => { + // TODO: awaiting here makes the entire connection + // rely on the readers to be fast enough, a slow reader + // could cause this method to wait and thus stall the entire webrtc + // connection. Solution would be to implement reading + // backpressure, keeping track of pending incoming messages. + // https://github.com/paritytech/litep2p/issues/604 + let send_result = message_tx + .send(Message { payload, flag: None, }) - .await?; + .await; + + if let Err(err) = send_result { + tracing::debug!( + target: LOG_TARGET, + ?err, + "failed to propagate message to Substream" + ); + } } + _ => (), } // Now handle flags @@ -185,65 +316,60 @@ impl SubstreamHandle { match flag { Flag::Fin => { // Guard against duplicate FIN messages - only send RecvClosed once - if self.read_closed.swap(true, std::sync::atomic::Ordering::SeqCst) { + if matches!( + self.reader_state.get(), + ReaderState::Fin | ReaderState::FinAck + ) { // Already processed FIN, ignore duplicate - tracing::debug!( - target: "litep2p::webrtc::substream", - "received duplicate FIN, ignoring" - ); + tracing::debug!(target: LOG_TARGET, "received duplicate FIN, ignoring"); return Ok(()); } - - // Received FIN from remote, close our read half - self.inbound_tx.send(Event::RecvClosed).await?; - - // Send FIN_ACK back to remote using try_send to avoid blocking. - // If the channel is full, the remote will timeout waiting for FIN_ACK - // and handle it gracefully. This prevents deadlock if the outbound - // channel is blocked due to backpressure. - if let Err(e) = self.outbound_tx.try_send(Event::Message { - payload: vec![], - flag: Some(Flag::FinAck), - }) { - tracing::warn!( - target: "litep2p::webrtc::substream", - ?e, - "failed to send FIN_ACK, remote will timeout" - ); + self.reader_state.set(ReaderState::Fin); + if self.message_tx.take().is_none() { + tracing::warn!(target: LOG_TARGET, "message channel was already dropped"); } return Ok(()); } Flag::FinAck => { // Received FIN_ACK, we can now fully close our write half - let mut state = self.state.lock(); - if matches!(*state, State::FinSent) { - *state = State::FinAcked; - // Wake up any task waiting on shutdown - self.shutdown_waker.wake(); + let writer_state = self.writer_state.get(); + if matches!(writer_state, WriterState::Fin) { + self.writer_state.set(WriterState::FinAck); } else { - tracing::warn!( - target: "litep2p::webrtc::substream", - ?state, - "received FIN_ACK in unexpected state, ignoring" + // If FIN_ACK is received upon an unexpected writer_state + // tear down the connection. + tracing::debug!( + target: LOG_TARGET, + ?writer_state, + "received FIN_ACK in unexpected writer state, tearing down channel" ); + self.channel_state.set(ChannelState::Reset); + self.message_rx.close(); + let _ = self.message_tx.take(); + return Err(Error::ConnectionClosed); } return Ok(()); } Flag::StopSending => { - *self.state.lock() = State::SendClosed; - // Wake any blocked poll_write so it can see the state change - self.write_waker.wake(); + // Discard flag if already closed/closing. + if !matches!(self.channel_state.get(), ChannelState::Reset) + && !matches!( + self.writer_state.get(), + WriterState::Fin | WriterState::FinAck + ) + { + self.writer_state.set(WriterState::StopSending); + self.message_rx.close(); + } + return Ok(()); } Flag::ResetStream => { // RESET_STREAM abruptly terminates both sides of the stream // (matching go-libp2p behavior) - // Close the read side - let _ = self.inbound_tx.try_send(Event::RecvClosed); - // Close the write side - *self.state.lock() = State::SendClosed; - // Wake any blocked poll_write so it can see the state change - self.write_waker.wake(); + self.channel_state.set(ChannelState::Reset); + self.message_rx.close(); + let _ = self.message_tx.take(); return Err(Error::ConnectionClosed); } } @@ -251,29 +377,132 @@ impl SubstreamHandle { Ok(()) } + + // This function carries forward the writer half close process. + // + // The following behaviors are expected on: + // WriterState::Open|StopSending state + // - flush any pending message + // - send FIN flag + // - start timeout_fin_ack + // - transition to WriterState::Fin + // WriterState::Fin state + // - wait for FIN_ACK + // - handle timeout + // - transition to WriterState::FinAck + // WriterState::FinAck state: + // - do nothing, shutdown complete + fn poll_half_close(&mut self, cx: &mut Context<'_>) -> Poll> { + match self.writer_state.get() { + // First call to shutdown, if peer sent StopSending we are still + // free to send Fin to make sure this half closes properly. + WriterState::Open | WriterState::StopSending => { + // Initialize the timeout for FIN_ACK + let mut timeout = Box::pin(tokio::time::sleep(FIN_ACK_TIMEOUT)); + // Poll the timeout once to register it with tokio's timer + // This ensures we'll be woken when it expires + if timeout.as_mut().poll(cx).is_ready() { + tracing::error!( + target: LOG_TARGET, + "misconfigured timer is not supposed to be ready" + ); + } + self.fin_ack_timeout = Some(timeout); + self.writer_state.set(WriterState::Fin); + // Send message with FIN flag + Poll::Ready(Some(Message { + payload: vec![], + flag: Some(Flag::Fin), + })) + } + // Sent FIN, waiting for FIN_ACK - poll timeout and return Pending + WriterState::Fin => { + // Poll the timeout - if it fires, force shutdown completion + match self.fin_ack_timeout.as_mut() { + Some(timeout) => + if timeout.as_mut().poll(cx).is_ready() { + tracing::debug!( + target: LOG_TARGET, + "FIN_ACK timeout exceeded, forcing shutdown completion" + ); + } else { + return Poll::Pending; + }, + None => { + tracing::warn!( + target: LOG_TARGET, + "unexpected writer state, forcing shutdown completion" + ); + } + } + + // If the timeout is reached we treat it as having received the + // acknowledge but the channel is reset anyway. + self.channel_state.set(ChannelState::Reset); + + self.fin_ack_timeout = None; + Poll::Ready(Some(Message { + payload: vec![], + flag: Some(Flag::ResetStream), + })) + } + // Already received FIN_ACK, shutdown complete + WriterState::FinAck => Poll::Ready(None), + } + } } impl Stream for SubstreamHandle { - type Item = Event; + type Item = Message; fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { - // First, try to drain any pending outbound messages - match self.rx.poll_recv(cx) { - Poll::Ready(Some(event)) => return Poll::Ready(Some(event)), - Poll::Ready(None) => { - // Outbound channel closed (all senders dropped) - return Poll::Ready(None); - } - Poll::Pending => { - // No messages available, check if we should signal closure + // There are three states which need to be taken into consideration to poll the stream: + // - channel_state: it preempts any other state if Reset has been entered + // + // NOTE: channel_state's waker is already registered from Substream's writing + // half (poll_shutdown), so we don't register it again here. That's fine + // because every transition to ChannelState::Reset is driven from this same + // task. This match is just an early-out when the channel has already been reset. + if matches!(self.channel_state.get(), ChannelState::Reset) { + // If something went wrong, RESET_STREAM should have been sent, + // in that case the channel is treated as closed. + return Poll::Ready(None); + } + + // - reader_state: this is mainly driven by the `on_message` function which reacts to + // incoming messages, there are 2 side effects which connects the two streams: + // 1. If FIN arrived then FIN_ACK is expected to be sent back. + // 2. If FIN_ACK arrived the writer_state is updated. + if matches!(self.reader_state.register_and_get(cx), ReaderState::Fin) { + self.reader_state.set(ReaderState::FinAck); + return Poll::Ready(Some(Message { + payload: vec![], + flag: Some(Flag::FinAck), + })); + } + + // - writer_state: here messages sent from the `Substream` needs to be forwarded wrapped by + // the right flags. Based on the state the close procedure can be carried or messages can + // simply be forwarded. + let writer_state_stream_result = match self.writer_state.get() { + WriterState::Open => { + match self.message_rx.poll_recv(cx) { + // Writes are finished, start half close procedure. + Poll::Ready(None) => self.poll_half_close(cx), + res => res, + } } + WriterState::Fin | WriterState::StopSending => self.poll_half_close(cx), + WriterState::FinAck => Poll::Ready(None), + }; + + if !matches!(writer_state_stream_result, Poll::Ready(None)) { + return writer_state_stream_result; } - // Check if Substream has been dropped (inbound channel closed) - // When Substream is dropped, there will be no more outbound messages - // Since we've already tried to recv above and got Pending, we know the queue is empty - // Therefore, it's safe to signal closure - if self.inbound_tx.is_closed() { + // The writer state has reached conclusion, if the same applies to the reader + // state then graceful shutdown has been carried, close the Stream. + if matches!(self.reader_state.get(), ReaderState::FinAck) { return Poll::Ready(None); } @@ -299,9 +528,8 @@ impl tokio::io::AsyncRead for Substream { } match futures::ready!(self.rx.poll_recv(cx)) { - None | Some(Event::RecvClosed) => - Poll::Ready(Err(std::io::ErrorKind::BrokenPipe.into())), - Some(Event::Message { payload, flag: _ }) => { + None => Poll::Ready(Ok(())), + Some(Message { payload, flag: _ }) => { if payload.len() > MAX_FRAME_SIZE { return Poll::Ready(Err(std::io::ErrorKind::PermissionDenied.into())); } @@ -327,34 +555,20 @@ impl tokio::io::AsyncWrite for Substream { cx: &mut Context<'_>, buf: &[u8], ) -> Poll> { - // Register waker so we get notified on state changes (e.g., STOP_SENDING) - self.write_waker.register(cx.waker()); - - // Reject writes if we're closing or closed - match *self.state.lock() { - State::SendClosed | State::Closing | State::FinSent | State::FinAcked => { - return Poll::Ready(Err(std::io::ErrorKind::BrokenPipe.into())); - } - State::Open => {} - } + let Some(tx) = self.tx.as_mut() else { + return Poll::Ready(Err(std::io::ErrorKind::BrokenPipe.into())); + }; - match futures::ready!(self.tx.poll_reserve(cx)) { + // Backpressure delegated to tokio channel. + match futures::ready!(tx.poll_reserve(cx)) { Ok(()) => {} Err(_) => return Poll::Ready(Err(std::io::ErrorKind::BrokenPipe.into())), }; - // Re-check state after poll_reserve - it may have changed while we were waiting - match *self.state.lock() { - State::SendClosed | State::Closing | State::FinSent | State::FinAcked => { - return Poll::Ready(Err(std::io::ErrorKind::BrokenPipe.into())); - } - State::Open => {} - } - let num_bytes = std::cmp::min(MAX_FRAME_SIZE, buf.len()); let frame = buf[..num_bytes].to_vec(); - match self.tx.send_item(Event::Message { + match tx.send_item(Message { payload: frame, flag: None, }) { @@ -364,6 +578,17 @@ impl tokio::io::AsyncWrite for Substream { } fn poll_flush(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll> { + let Some(tx) = self.tx.as_ref() else { + // shutdown already ran + return Poll::Ready(Err(std::io::ErrorKind::BrokenPipe.into())); + }; + if tx.is_closed() { + // StopSending or ResetStream closed the receiver. Anything we've + // enqueued past this point will not be delivered to the peer. + return Poll::Ready(Err(std::io::ErrorKind::BrokenPipe.into())); + } + // Channel still open. `poll_write` already waits for channel capacity before returning, + // so by the time we get here the channel has accepted every byte we acknowledged. Poll::Ready(Ok(())) } @@ -371,107 +596,34 @@ impl tokio::io::AsyncWrite for Substream { mut self: Pin<&mut Self>, cx: &mut Context<'_>, ) -> Poll> { - // State machine for proper shutdown: - // 1. Transition to Closing (stops accepting new writes) - // 2. Flush pending data - // 3. Send FIN flag - // 4. Transition to FinSent - // 5. Wait for FIN_ACK - // 6. Transition to FinAcked and complete - - let current_state = *self.state.lock(); + // Backpressure is delegated, based on tokio channel and str0m reliably + // sending messages in order. + let _ = self.tx.take(); - match current_state { - // Already received FIN_ACK, shutdown complete - State::FinAcked => return Poll::Ready(Ok(())), - - // Sent FIN, waiting for FIN_ACK - poll timeout and return Pending - State::FinSent => { - // Register waker FIRST to avoid race condition with on_message - self.shutdown_waker.register(cx.waker()); - - // Re-check state after waker registration in case FIN_ACK arrived - // between the initial state check and waker registration - if matches!(*self.state.lock(), State::FinAcked) { - return Poll::Ready(Ok(())); - } - - // Poll the timeout - if it fires, force shutdown completion - if let Some(timeout) = self.fin_ack_timeout.as_mut() { - if timeout.as_mut().poll(cx).is_ready() { - tracing::debug!( - target: "litep2p::webrtc::substream", - "FIN_ACK timeout exceeded, forcing shutdown completion" - ); - *self.state.lock() = State::FinAcked; - return Poll::Ready(Ok(())); - } - } - - return Poll::Pending; - } - - // First call to shutdown - transition to Closing - State::Open => { - *self.state.lock() = State::Closing; - } - - State::Closing => { - // Already in closing state, continue with shutdown process. - // Guard against duplicate FIN sends: if timeout is already set, we've - // already sent FIN and are waiting for FIN_ACK. This shouldn't happen - // with correct AsyncWrite usage (&mut self), but provides defense in depth. - if self.fin_ack_timeout.is_some() { - self.shutdown_waker.register(cx.waker()); - return Poll::Pending; - } - } - - State::SendClosed => { - // Remote closed send, we can still send FIN - } + // Shutdown process is complete if either the channel entered a Reset + // state or the writer has received a FinAck or StopSending. + // + // NOTE: short-circuiting the waker registration here is fine because + // channel_state takes precedence over any writing state. + let shutdown = matches!(self.channel_state.register_and_get(cx), ChannelState::Reset) + || matches!( + self.writer_state.register_and_get(cx), + WriterState::FinAck | WriterState::StopSending + ); + + if shutdown { + Poll::Ready(Ok(())) + } else { + Poll::Pending } + } +} - // Flush any pending data - // Note: Currently poll_flush is a no-op, but the channel backpressure - // provides implicit flushing since we wait for poll_reserve below - futures::ready!(self.as_mut().poll_flush(cx))?; - - // Reserve space to send FIN - match futures::ready!(self.tx.poll_reserve(cx)) { - Ok(()) => {} - Err(_) => return Poll::Ready(Err(std::io::ErrorKind::BrokenPipe.into())), - }; - - // Send message with FIN flag - match self.tx.send_item(Event::Message { - payload: vec![], - flag: Some(Flag::Fin), - }) { - Ok(()) => { - // Race condition mitigation strategy: - // 1. Transition to FinSent FIRST so on_message can recognize FIN_ACK (if waker - // registered first, FIN_ACK would be ignored since state != FinSent) - // 2. Register waker so we'll be notified on future FIN_ACK arrivals - // 3. Re-check state to catch FIN_ACK that arrived between steps 1 and 2 (wake() - // called before waker registered has no effect, but state changed) - *self.state.lock() = State::FinSent; - self.shutdown_waker.register(cx.waker()); - if matches!(*self.state.lock(), State::FinAcked) { - return Poll::Ready(Ok(())); - } - - // Initialize the timeout for FIN_ACK - let mut timeout = Box::pin(tokio::time::sleep(FIN_ACK_TIMEOUT)); - // Poll the timeout once to register it with tokio's timer - // This ensures we'll be woken when it expires - let _ = timeout.as_mut().poll(cx); - self.fin_ack_timeout = Some(timeout); - - Poll::Pending - } - Err(_) => Poll::Ready(Err(std::io::ErrorKind::BrokenPipe.into())), - } +impl Drop for SubstreamHandle { + fn drop(&mut self) { + // This allows to close all the pending channels if the SubstreamHandle + // has been dropped, if graceful shutdown already happened this is a no-op. + self.channel_state.set(ChannelState::Reset); } } @@ -489,7 +641,7 @@ mod tests { assert_eq!( handle.next().await, - Some(Event::Message { + Some(Message { payload: vec![0u8; 1337], flag: None }) @@ -509,22 +661,22 @@ mod tests { substream.write_all(&vec![0u8; (2 * MAX_FRAME_SIZE) + 1]).await.unwrap(); assert_eq!( - handle.rx.recv().await, - Some(Event::Message { + handle.message_rx.recv().await, + Some(Message { payload: vec![0u8; MAX_FRAME_SIZE], flag: None, }) ); assert_eq!( - handle.rx.recv().await, - Some(Event::Message { + handle.message_rx.recv().await, + Some(Message { payload: vec![0u8; MAX_FRAME_SIZE], flag: None, }) ); assert_eq!( - handle.rx.recv().await, - Some(Event::Message { + handle.message_rx.recv().await, + Some(Message { payload: vec![0u8; 1], flag: None, }) @@ -538,14 +690,61 @@ mod tests { } #[tokio::test] - async fn try_to_write_to_closed_substream() { - let (mut substream, handle) = Substream::new(); - *handle.state.lock() = State::SendClosed; + async fn handle_stop_sending_with_graceful_shutdown() { + let (_substream, mut handle) = Substream::new(); - match substream.write_all(&vec![0u8; 1337]).await { - Err(error) => assert_eq!(error.kind(), std::io::ErrorKind::BrokenPipe), - _ => panic!("invalid event"), - } + // Receiving StopSending + handle + .on_message(WebRtcMessage { + payload: None, + flag: Some(Flag::StopSending), + }) + .await + .unwrap(); + + // Expecting FIN to be sent immediately + assert_eq!( + handle.next().await, + Some(Message { + payload: vec![], + flag: Some(Flag::Fin), + }) + ); + assert!(matches!(handle.writer_state.get(), WriterState::Fin)); + + // Receiving FIN_ACK + handle + .on_message(WebRtcMessage { + payload: None, + flag: Some(Flag::FinAck), + }) + .await + .unwrap(); + + // Write side is closed now, not the read side. + assert!(matches!(handle.writer_state.get(), WriterState::FinAck)); + + // Receiving FIN + handle + .on_message(WebRtcMessage { + payload: None, + flag: Some(Flag::Fin), + }) + .await + .unwrap(); + + assert!(matches!(handle.reader_state.get(), ReaderState::Fin)); + // Expecing FIN_ACK to be sent immediately + assert_eq!( + handle.next().await, + Some(Message { + payload: vec![], + flag: Some(Flag::FinAck), + }) + ); + assert!(matches!(handle.reader_state.get(), ReaderState::FinAck)); + + assert_eq!(handle.next().await, None); } #[tokio::test] @@ -561,7 +760,7 @@ mod tests { assert_eq!( handle.next().await, - Some(Event::Message { + Some(Message { payload: vec![1u8; 1337], flag: None, }) @@ -569,7 +768,7 @@ mod tests { // After shutdown, should send FIN flag assert_eq!( handle.next().await, - Some(Event::Message { + Some(Message { payload: vec![], flag: Some(Flag::Fin) }) @@ -589,7 +788,7 @@ mod tests { #[tokio::test] async fn try_to_read_from_closed_substream() { - let (mut substream, handle) = Substream::new(); + let (mut substream, mut handle) = Substream::new(); handle .on_message(WebRtcMessage { payload: None, @@ -599,7 +798,9 @@ mod tests { .unwrap(); match substream.read(&mut vec![0u8; 256]).await { - Err(error) => assert_eq!(error.kind(), std::io::ErrorKind::BrokenPipe), + Ok(read_bytes) => { + assert_eq!(read_bytes, 0) + } _ => panic!("invalid event"), } } @@ -608,8 +809,10 @@ mod tests { async fn read_small_frame() { let (mut substream, handle) = Substream::new(); handle - .inbound_tx - .send(Event::Message { + .message_tx + .as_ref() + .unwrap() + .send(Message { payload: vec![1u8; 256], flag: None, }) @@ -643,8 +846,10 @@ mod tests { first.extend_from_slice(&vec![2u8; 256]); handle - .inbound_tx - .send(Event::Message { + .message_tx + .as_ref() + .unwrap() + .send(Message { payload: first, flag: None, }) @@ -686,16 +891,20 @@ mod tests { first.extend_from_slice(&vec![2u8; 256]); handle - .inbound_tx - .send(Event::Message { + .message_tx + .as_ref() + .unwrap() + .send(Message { payload: first, flag: None, }) .await .unwrap(); handle - .inbound_tx - .send(Event::Message { + .message_tx + .as_ref() + .unwrap() + .send(Message { payload: vec![4u8; 2048], flag: None, }) @@ -827,14 +1036,14 @@ mod tests { // Should receive FIN flag assert_eq!( handle.next().await, - Some(Event::Message { + Some(Message { payload: vec![], flag: Some(Flag::Fin) }) ); - // Verify state is FinSent - assert!(matches!(*handle.state.lock(), State::FinSent)); + // Verify state is Fin + assert!(matches!(handle.writer_state.get(), WriterState::Fin)); // Send FIN_ACK to complete shutdown cleanly (avoids waiting for timeout) handle @@ -858,7 +1067,7 @@ mod tests { // Substream should receive RecvClosed let mut buf = vec![0u8; 1024]; match substream.read(&mut buf).await { - Err(e) if e.kind() == std::io::ErrorKind::BrokenPipe => { + Ok(0) => { // Expected - read half closed } other => panic!("Unexpected result: {:?}", other), @@ -880,7 +1089,7 @@ mod tests { // Verify FIN_ACK was sent outbound to network assert_eq!( handle.next().await, - Some(Event::Message { + Some(Message { payload: vec![], flag: Some(Flag::FinAck) }) @@ -889,18 +1098,23 @@ mod tests { #[tokio::test] async fn fin_ack_received_transitions_to_fin_acked() { - let (mut substream, handle) = Substream::new(); + let (mut substream, mut handle) = Substream::new(); // Spawn shutdown since it waits for FIN_ACK let shutdown_task = tokio::spawn(async move { substream.shutdown().await.unwrap(); }); - // Wait a bit for FIN to be sent - tokio::task::yield_now().await; + assert_eq!( + handle.next().await, + Some(Message { + payload: vec![], + flag: Some(Flag::Fin) + }), + ); - // Verify we're in FinSent state - assert!(matches!(*handle.state.lock(), State::FinSent)); + // Verify we're in Fin state + assert!(matches!(handle.writer_state.get(), WriterState::Fin)); // Simulate receiving FIN_ACK from remote handle @@ -912,7 +1126,7 @@ mod tests { .unwrap(); // Should transition to FinAcked - assert!(matches!(*handle.state.lock(), State::FinAcked)); + assert!(matches!(handle.writer_state.get(), WriterState::FinAck)); // Shutdown should now complete shutdown_task.await.unwrap(); @@ -933,7 +1147,7 @@ mod tests { // Verify data was sent assert_eq!( handle.next().await, - Some(Event::Message { + Some(Message { payload: vec![1u8; 100], flag: None, }) @@ -942,7 +1156,7 @@ mod tests { // Verify FIN was sent assert_eq!( handle.next().await, - Some(Event::Message { + Some(Message { payload: vec![], flag: Some(Flag::Fin) }) @@ -958,7 +1172,7 @@ mod tests { .unwrap(); // Should be in FinAcked state - assert!(matches!(*handle.state.lock(), State::FinAcked)); + assert!(matches!(handle.writer_state.get(), WriterState::FinAck)); // Shutdown should now complete shutdown_task.await.unwrap(); @@ -966,7 +1180,7 @@ mod tests { #[tokio::test] async fn stop_sending_flag_closes_send_half() { - let (mut substream, handle) = Substream::new(); + let (mut substream, mut handle) = Substream::new(); // Simulate receiving STOP_SENDING handle @@ -978,7 +1192,10 @@ mod tests { .unwrap(); // Should transition to SendClosed - assert!(matches!(*handle.state.lock(), State::SendClosed)); + assert!(matches!( + handle.writer_state.get(), + WriterState::StopSending + )); // Attempting to write should fail match substream.write_all(&[0u8; 100]).await { @@ -989,8 +1206,7 @@ mod tests { #[tokio::test] async fn reset_stream_flag_closes_both_sides() { - use tokio::io::AsyncWriteExt; - let (mut substream, handle) = Substream::new(); + let (mut substream, mut handle) = Substream::new(); // Simulate receiving RESET_STREAM let result = handle @@ -1004,33 +1220,34 @@ mod tests { assert!(matches!(result, Err(Error::ConnectionClosed))); // Write side should be closed (state = SendClosed) - assert!(matches!(*handle.state.lock(), State::SendClosed)); + assert!(matches!(handle.channel_state.get(), ChannelState::Reset)); - // Attempting to write should fail - match substream.write_all(&[0u8; 100]).await { - Err(error) => assert_eq!(error.kind(), std::io::ErrorKind::BrokenPipe), - _ => panic!("write should have failed"), + let mut buf = vec![0u8; 1024]; + match substream.read(&mut buf).await { + Ok(0) => (), + other => panic!("Unexpected result: {:?}", other), } - - // Read side should also be closed (RecvClosed event was sent) - // The substream's rx channel should have RecvClosed - assert!(matches!(substream.rx.try_recv(), Ok(Event::RecvClosed))); + assert!(substream.shutdown().await.is_ok()); } #[tokio::test] async fn fin_ack_does_not_trigger_other_flag() { - let (mut substream, handle) = Substream::new(); + let (mut substream, mut handle) = Substream::new(); // Spawn shutdown since it waits for FIN_ACK let shutdown_task = tokio::spawn(async move { substream.shutdown().await.unwrap(); }); - // Wait a bit for FIN to be sent - tokio::task::yield_now().await; - - // Verify we're in FinSent state - assert!(matches!(*handle.state.lock(), State::FinSent)); + assert_eq!( + handle.next().await, + Some(Message { + payload: vec![], + flag: Some(Flag::Fin) + }), + ); + // Verify we're in Fin state + assert!(matches!(handle.writer_state.get(), WriterState::Fin)); // Now simulate receiving FIN_ACK (value = 3) // This should NOT trigger STOP_SENDING (value = 1) or RESET_STREAM (value = 2) @@ -1044,7 +1261,7 @@ mod tests { .unwrap(); // Should transition to FinAcked, not SendClosed - assert!(matches!(*handle.state.lock(), State::FinAcked)); + assert!(matches!(handle.writer_state.get(), WriterState::FinAck)); // Shutdown should complete shutdown_task.await.unwrap(); @@ -1056,7 +1273,7 @@ mod tests { #[tokio::test] async fn flags_are_mutually_exclusive() { - let (_substream, handle) = Substream::new(); + let (_substream, mut handle) = Substream::new(); // Test that STOP_SENDING (1) is handled correctly handle @@ -1067,10 +1284,13 @@ mod tests { .await .unwrap(); - assert!(matches!(*handle.state.lock(), State::SendClosed)); + assert!(matches!( + handle.writer_state.get(), + WriterState::StopSending + )); // Create a new substream for RESET_STREAM test - let (_substream2, handle2) = Substream::new(); + let (_substream2, mut handle2) = Substream::new(); // Test that RESET_STREAM (2) is handled correctly let result = handle2 @@ -1081,17 +1301,25 @@ mod tests { .await; assert!(matches!(result, Err(Error::ConnectionClosed))); + assert!(matches!(handle2.channel_state.get(), ChannelState::Reset)); // Create a new substream for FIN test - let (mut substream3, handle3) = Substream::new(); + let (mut substream3, mut handle3) = Substream::new(); // Spawn shutdown since it waits for FIN_ACK let shutdown_task3 = tokio::spawn(async move { substream3.shutdown().await.unwrap(); }); - // Wait a bit for FIN to be sent - tokio::task::yield_now().await; + assert_eq!( + handle3.next().await, + Some(Message { + payload: vec![], + flag: Some(Flag::Fin) + }), + ); + // Verify we're in Fin state + assert!(matches!(handle3.writer_state.get(), WriterState::Fin)); // Test that FIN_ACK (3) is handled correctly handle3 @@ -1102,7 +1330,7 @@ mod tests { .await .unwrap(); - assert!(matches!(*handle3.state.lock(), State::FinAcked)); + assert!(matches!(handle3.writer_state.get(), WriterState::FinAck)); // Shutdown should complete shutdown_task3.await.unwrap(); @@ -1111,7 +1339,7 @@ mod tests { #[tokio::test] async fn stop_sending_wakes_blocked_writer() { use tokio::io::AsyncWriteExt; - let (mut substream, handle) = Substream::new(); + let (mut substream, mut handle) = Substream::new(); // Fill up the channel to cause poll_write to return Pending // Channel capacity is 256 @@ -1150,7 +1378,7 @@ mod tests { #[tokio::test] async fn reset_stream_wakes_blocked_writer() { use tokio::io::AsyncWriteExt; - let (mut substream, handle) = Substream::new(); + let (mut substream, mut handle) = Substream::new(); // Fill up the channel to cause poll_write to return Pending // Channel capacity is 256 @@ -1203,21 +1431,21 @@ mod tests { // Wait for data and FIN to be sent assert_eq!( handle.next().await, - Some(Event::Message { + Some(Message { payload: vec![1u8; 100], flag: None, }) ); assert_eq!( handle.next().await, - Some(Event::Message { + Some(Message { payload: vec![], flag: Some(Flag::Fin) }) ); - // Verify we transitioned through Closing to FinSent - assert!(matches!(*handle.state.lock(), State::FinSent)); + // Verify we transitioned through Closing to Fin + assert!(matches!(handle.writer_state.get(), WriterState::Fin)); // Send FIN_ACK to complete shutdown handle @@ -1246,12 +1474,12 @@ mod tests { // Wait for FIN to be sent assert_eq!( handle.next().await, - Some(Event::Message { + Some(Message { payload: vec![], flag: Some(Flag::Fin) }) ); - assert!(matches!(*handle.state.lock(), State::FinSent)); + assert!(matches!(handle.writer_state.get(), WriterState::Fin)); // Send FIN_ACK to complete first shutdown handle @@ -1267,7 +1495,7 @@ mod tests { // Second shutdown should succeed without error (already in FinAcked state) substream.shutdown().await.unwrap(); - assert!(matches!(*handle.state.lock(), State::FinAcked)); + assert!(matches!(handle.writer_state.get(), WriterState::FinAck)); } #[tokio::test] @@ -1284,43 +1512,29 @@ mod tests { // Wait for FIN to be sent assert_eq!( handle.next().await, - Some(Event::Message { + Some(Message { payload: vec![], flag: Some(Flag::Fin) }) ); - // Verify we're in FinSent state - assert!(matches!(*handle.state.lock(), State::FinSent)); + // Verify we're in Fin state + assert!(matches!(handle.writer_state.get(), WriterState::Fin)); // DON'T send FIN_ACK - let it timeout - // The shutdown should complete after FIN_ACK_TIMEOUT (5 seconds) + // The shutdown should complete after FIN_ACK_TIMEOUT (10 seconds) // Add a bit of buffer to the timeout - let result = timeout(Duration::from_secs(7), shutdown_task).await; + let _ = timeout(Duration::from_secs(11), shutdown_task).await; - assert!(result.is_ok(), "Shutdown should complete after timeout"); - assert!( - result.unwrap().is_ok(), - "Shutdown should succeed after timeout" + // The timeout branch surfaces a RESET_STREAM event before signalling closure. + assert_eq!( + handle.next().await, + Some(Message { + payload: vec![], + flag: Some(Flag::ResetStream) + }), ); - - // Should have transitioned to FinAcked after timeout - assert!(matches!(*handle.state.lock(), State::FinAcked)); - } - - #[tokio::test] - async fn closing_state_blocks_writes() { - use tokio::io::AsyncWriteExt; - - let (mut substream, handle) = Substream::new(); - - // Manually transition to Closing state - *handle.state.lock() = State::Closing; - - // Attempt to write should fail - let result = substream.write_all(&[1u8; 100]).await; - assert!(result.is_err()); - assert_eq!(result.unwrap_err().kind(), std::io::ErrorKind::BrokenPipe); + assert!(matches!(handle.channel_state.get(), ChannelState::Reset)); } #[tokio::test] @@ -1338,7 +1552,7 @@ mod tests { // Receive FIN assert_eq!( handle.next().await, - Some(Event::Message { + Some(Message { payload: vec![], flag: Some(Flag::Fin) }) @@ -1356,6 +1570,9 @@ mod tests { // Wait for shutdown to complete and Substream to drop shutdown_task.await.unwrap(); + assert!(matches!(handle.writer_state.get(), WriterState::FinAck)); + // If reader_state.inner.state is also closed, then we can expect a None. + handle.reader_state.set(ReaderState::FinAck); // Verify handle signals closure (returns None) assert_eq!( handle.next().await, @@ -1375,9 +1592,7 @@ mod tests { let mut buf = vec![0u8; 1024]; // This should fail because we receive RecvClosed match substream.read(&mut buf).await { - Err(e) if e.kind() == std::io::ErrorKind::BrokenPipe => { - // Expected - read half closed by FIN - } + Ok(0) => (), other => panic!("Unexpected result: {:?}", other), } // Substream dropped here (server closes after receiving FIN) @@ -1395,7 +1610,7 @@ mod tests { // Verify FIN_ACK was sent back assert_eq!( handle.next().await, - Some(Event::Message { + Some(Message { payload: vec![], flag: Some(Flag::FinAck) }) @@ -1404,11 +1619,13 @@ mod tests { // Wait for server to close substream server_task.await.unwrap(); - // Verify handle signals closure (returns None) - this is the key fix! + // Verify handle signals closure (returns Fin) assert_eq!( handle.next().await, - None, - "SubstreamHandle should signal closure after server receives FIN and drops Substream" + Some(Message { + payload: vec![], + flag: Some(Flag::Fin) + }) ); } @@ -1416,13 +1633,13 @@ mod tests { async fn simultaneous_close() { // Test simultaneous close where both sides send FIN at the same time. // This verifies that: - // 1. Both sides can be in FinSent state simultaneously - // 2. Both sides correctly respond to FIN with FIN_ACK even when in FinSent state + // 1. Both sides can be in Fin state simultaneously + // 2. Both sides correctly respond to FIN with FIN_ACK even when in Fin state // 3. Both sides eventually transition to FinAcked let (mut substream, mut handle) = Substream::new(); - // Local side initiates shutdown (sends FIN, transitions to FinSent) + // Local side initiates shutdown (sends FIN, transitions to Fin) let shutdown_task = tokio::spawn(async move { substream.shutdown().await.unwrap(); }); @@ -1430,17 +1647,17 @@ mod tests { // Wait for local FIN to be sent assert_eq!( handle.next().await, - Some(Event::Message { + Some(Message { payload: vec![], flag: Some(Flag::Fin) }) ); - // Verify local is in FinSent state - assert!(matches!(*handle.state.lock(), State::FinSent)); + // Verify local is in Fin state + assert!(matches!(handle.writer_state.get(), WriterState::Fin)); // Now simulate remote also sending FIN (simultaneous close) - // This should trigger FIN_ACK response even though we're in FinSent state + // This should trigger FIN_ACK response even though we're in Fin state handle .on_message(WebRtcMessage { payload: None, @@ -1452,14 +1669,15 @@ mod tests { // Local should send FIN_ACK in response to remote's FIN assert_eq!( handle.next().await, - Some(Event::Message { + Some(Message { payload: vec![], flag: Some(Flag::FinAck) }) ); - // Local should still be in FinSent (waiting for FIN_ACK from remote) - assert!(matches!(*handle.state.lock(), State::FinSent)); + // Local should still be in Fin (waiting for FIN_ACK from remote) + assert!(matches!(handle.writer_state.get(), WriterState::Fin)); + assert!(matches!(handle.reader_state.get(), ReaderState::FinAck)); // Now remote sends FIN_ACK (completing their side of the handshake) handle @@ -1471,7 +1689,7 @@ mod tests { .unwrap(); // Local should now transition to FinAcked - assert!(matches!(*handle.state.lock(), State::FinAcked)); + assert!(matches!(handle.writer_state.get(), WriterState::FinAck)); // Shutdown should complete successfully shutdown_task.await.unwrap(); @@ -1483,7 +1701,7 @@ mod tests { // to the substream before the RecvClosed event. This is important because // the spec allows a FIN message to contain final data. - let (mut substream, handle) = Substream::new(); + let (mut substream, mut handle) = Substream::new(); // Simulate receiving FIN with payload from remote handle @@ -1501,9 +1719,7 @@ mod tests { // Then, subsequent read should fail with BrokenPipe (RecvClosed) match substream.read(&mut buf).await { - Err(e) if e.kind() == std::io::ErrorKind::BrokenPipe => { - // Expected - read half closed after FIN - } + Ok(0) => (), other => panic!("Expected BrokenPipe error, got: {:?}", other), } }