Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
75 changes: 73 additions & 2 deletions src/channel.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
//! Data channel related types.

use std::time::Duration;
use std::{fmt, str, time::Instant};

use crate::sctp::RtcSctp;
Expand Down Expand Up @@ -171,6 +172,9 @@ impl fmt::Debug for ChannelData {
pub(crate) struct ChannelHandler {
allocations: Vec<ChannelAllocation>,
next_channel_id: usize,
/// Stream IDs recently closed, with the time they were closed.
/// Excluded from allocation until the cooldown expires.
closed_stream_ids: Vec<(u16, Instant)>,
}

#[derive(Debug)]
Expand All @@ -185,6 +189,8 @@ struct ChannelAllocation {
config: Option<ChannelConfig>,
}

const STREAM_ID_COOLDOWN: Duration = Duration::from_secs(2);

impl ChannelHandler {
pub fn new_channel(&mut self, config: &ChannelConfig) -> ChannelId {
let id = self.next_channel_id();
Expand Down Expand Up @@ -285,6 +291,7 @@ impl ChannelHandler {
.allocations
.iter()
.filter_map(|a| a.sctp_stream_id)
.chain(self.closed_stream_ids.iter().map(|(id, _)| *id))
.collect();

for a in &mut self.allocations {
Expand Down Expand Up @@ -357,7 +364,21 @@ impl ChannelHandler {
}
}

pub fn remove_channel(&mut self, id: ChannelId) {
/// Remove stream IDs from the cooldown list that have expired.
pub fn expire_closed_stream_ids(&mut self, now: Instant) {
self.closed_stream_ids
.retain(|(_, closed_at)| now.duration_since(*closed_at) < STREAM_ID_COOLDOWN);
}

pub fn remove_channel(&mut self, id: ChannelId, now: Instant) {
if let Some(stream_id) = self
.allocations
.iter()
.find(|a| a.id == id)
.and_then(|a| a.sctp_stream_id)
{
self.closed_stream_ids.push((stream_id, now));
}
self.allocations.retain(|a| a.id != id)
}
}
Expand All @@ -368,6 +389,7 @@ mod tests {

#[test]
fn channel_id_allocation() {
let now = Instant::now();
let mut handler = ChannelHandler::default();

// allocate first channel, get unique id
Expand All @@ -378,8 +400,57 @@ mod tests {

// free channel 0, allocate two more channels and verify that the
// new channels have unique IDs.
handler.remove_channel(ChannelId(0));
handler.remove_channel(ChannelId(0), now);
assert_eq!(handler.new_channel(&Default::default()), ChannelId(2));
assert_eq!(handler.new_channel(&Default::default()), ChannelId(3));
}

#[test]
fn stream_id_not_reused_during_cooldown() {
let now = Instant::now();
let mut handler = ChannelHandler::default();

// Simulate two channels with known stream IDs (as if do_allocations ran
// for a client: even IDs 0, 2).
let id0 = handler.new_channel(&Default::default());
let id1 = handler.new_channel(&Default::default());
// Manually set stream IDs as do_allocations would.
handler.allocations[0].sctp_stream_id = Some(0);
handler.allocations[1].sctp_stream_id = Some(2);

// Close channel 0 (stream ID 0). It should enter cooldown.
handler.remove_channel(id0, now);
assert_eq!(handler.closed_stream_ids.len(), 1);
assert_eq!(handler.closed_stream_ids[0].0, 0);

// Allocate a new channel and manually assign a stream ID the way
// do_allocations would — stream 0 should be skipped (in cooldown).
let id2 = handler.new_channel(&Default::default());
// Build the taken list as do_allocations does.
let taken: Vec<u16> = handler
.allocations
.iter()
.filter_map(|a| a.sctp_stream_id)
.chain(handler.closed_stream_ids.iter().map(|(id, _)| *id))
.collect();
// Stream 0 is in cooldown, stream 2 is active, so next available is 4.
assert!(taken.contains(&0), "stream 0 should be in cooldown");
assert!(taken.contains(&2), "stream 2 should be active");

// After cooldown expires, stream 0 should be available again.
let after_cooldown = now + STREAM_ID_COOLDOWN;
handler.expire_closed_stream_ids(after_cooldown);
assert!(handler.closed_stream_ids.is_empty());

let taken_after: Vec<u16> = handler
.allocations
.iter()
.filter_map(|a| a.sctp_stream_id)
.chain(handler.closed_stream_ids.iter().map(|(id, _)| *id))
.collect();
assert!(
!taken_after.contains(&0),
"stream 0 should be available after cooldown"
);
}
}
3 changes: 2 additions & 1 deletion src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1592,7 +1592,7 @@ impl Rtc {
warn!("Drop ChannelClose event for id: {:?}", id);
continue;
};
self.chan.remove_channel(id);
self.chan.remove_channel(id, self.last_now);
return Ok(Output::Event(Event::ChannelClose(id)));
}
SctpEvent::Data { id, binary, data } => {
Expand Down Expand Up @@ -1839,6 +1839,7 @@ impl Rtc {
self.last_now = now;
self.ice.handle_timeout(now);
self.sctp.handle_timeout(now);
self.chan.expire_closed_stream_ids(now);
self.chan.handle_timeout(now, &mut self.sctp);
self.session.handle_timeout(now)?;

Expand Down
60 changes: 39 additions & 21 deletions src/sctp/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -598,16 +598,25 @@ impl RtcSctp {
let n = dcep.marshal_to(&mut buf);
buf.truncate(n);

let l = s
.write_with_ppi(&buf, PayloadProtocolIdentifier::Dcep)
.expect("writing dcep open");
assert!(n == l);

entry.set_state(StreamEntryState::AwaitDcepAck);

// Start over with polling, since we might have caused some network traffic by
// writing the DcepOpen.
return self.do_poll();
match s.write_with_ppi(&buf, PayloadProtocolIdentifier::Dcep) {
Ok(l) => {
assert!(n == l);
entry.set_state(StreamEntryState::AwaitDcepAck);

// Start over with polling, since we might have caused
// some network traffic by writing the DcepOpen.
return self.do_poll();
}
Err(e) => {
warn!(
"Failed to write DCEP open on stream {}: {:?}",
entry.id, e
);
entry.do_close = true;
entry.set_state(StreamEntryState::Closed);
return Some(SctpEvent::Close { id: entry.id });
}
}
}

// Continuing means we are opening the stream out-of-band.
Expand Down Expand Up @@ -719,17 +728,26 @@ impl RtcSctp {

let mut obuf = [0];
DcepAck.marshal_to(&mut obuf);
let l = stream
.write_with_ppi(&obuf, PayloadProtocolIdentifier::Dcep)
.expect("writing dcep open");
assert!(obuf.len() == l);

entry.set_state(StreamEntryState::Open);

return Some(SctpEvent::Open {
id: entry.id,
label: dcep.label,
});
match stream.write_with_ppi(&obuf, PayloadProtocolIdentifier::Dcep) {
Ok(l) => {
assert!(obuf.len() == l);
entry.set_state(StreamEntryState::Open);

return Some(SctpEvent::Open {
id: entry.id,
label: dcep.label,
});
}
Err(e) => {
warn!(
"Failed to write DCEP ack on stream {}: {:?}",
entry.id, e
);
entry.do_close = true;
entry.set_state(StreamEntryState::Closed);
return Some(SctpEvent::Close { id: entry.id });
}
}
}
StreamEntryState::AwaitDcepAck => {
let res: Result<DcepAck, _> = buf.as_slice().try_into();
Expand Down
Loading