Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
44 changes: 29 additions & 15 deletions srtcore/core.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -8691,6 +8691,8 @@ void srt::CUDT::updateSndLossListOnACK(int32_t ackdata_seqno)
void srt::CUDT::processCtrlAck(const CPacket &ctrlpkt, const steady_clock::time_point& currtime)
{
const int32_t* ackdata = (const int32_t*)ctrlpkt.m_pcData;

// Note: minimum of one 4-byte field is granted before the call.
const int32_t ackdata_seqno = ackdata[ACKD_RCVLASTACK];

// Check the value of ACK in case when it was some rogue peer
Expand All @@ -8705,7 +8707,7 @@ void srt::CUDT::processCtrlAck(const CPacket &ctrlpkt, const steady_clock::time_
return;
}

const bool isLiteAck = ctrlpkt.getLength() == (size_t)SEND_LITE_ACK;
const bool isLiteAck = ctrlpkt.getLength() == size_t(SEND_LITE_ACK);
HLOGC(inlog.Debug,
log << CONID() << "ACK covers: " << m_iSndLastDataAck << " - " << ackdata_seqno << " [ACK=" << m_iSndLastAck
<< "]" << (isLiteAck ? "[LITE]" : "[FULL]"));
Expand All @@ -8727,6 +8729,15 @@ void srt::CUDT::processCtrlAck(const CPacket &ctrlpkt, const steady_clock::time_
return;
}

const size_t acksize = ctrlpkt.getLength() / ACKD_FIELD_SIZE; // ACTUAL VALUE

// Check minimum size acceptable. If less, reject it.
if (acksize < ACKD_TOTAL_SIZE_SMALL)
{
LOGC(inlog.Error, log << CONID() << "EPE: ACK msg received with too small size: " << ctrlpkt.getLength());
return;
}

// Decide to send ACKACK or not
{
// Sequence number of the ACK packet
Expand Down Expand Up @@ -8822,18 +8833,6 @@ void srt::CUDT::processCtrlAck(const CPacket &ctrlpkt, const steady_clock::time_
}
#endif

size_t acksize = ctrlpkt.getLength(); // TEMPORARY VALUE FOR CHECKING
bool wrongsize = 0 != (acksize % ACKD_FIELD_SIZE);
acksize = acksize / ACKD_FIELD_SIZE; // ACTUAL VALUE

if (wrongsize)
{
// Issue a log, but don't do anything but skipping the "odd" bytes from the payload.
LOGC(inlog.Warn,
log << CONID() << "Received UMSG_ACK payload is not evened up to 4-byte based field size - cutting to "
<< acksize << " fields");
}

// Start with checking the base size.
if (acksize < ACKD_TOTAL_SIZE_SMALL)
{
Expand Down Expand Up @@ -9052,7 +9051,7 @@ void srt::CUDT::processCtrlAckAck(const CPacket& ctrlpkt, const time_point& tsAr
void srt::CUDT::processCtrlLossReport(const CPacket& ctrlpkt)
{
const int32_t* losslist = (int32_t*)(ctrlpkt.m_pcData);
const size_t losslist_len = ctrlpkt.getLength() / 4;
const size_t losslist_len = ctrlpkt.getLength() / sizeof(int32_t);

bool secure = true;

Expand Down Expand Up @@ -9213,7 +9212,11 @@ void srt::CUDT::processCtrlLossReport(const CPacket& ctrlpkt)
void srt::CUDT::processCtrlHS(const CPacket& ctrlpkt)
{
CHandShake req;
req.load_from(ctrlpkt.m_pcData, ctrlpkt.getLength());
if (-1 == req.load_from(ctrlpkt.m_pcData, ctrlpkt.getLength()))
{
LOGC(inlog.Error, log << CONID() << "processCtrlHS: EPE: Handshake has wrong size: " << ctrlpkt.getLength());
return;
}

HLOGC(inlog.Debug, log << CONID() << "processCtrl: got HS: " << req.show());

Expand Down Expand Up @@ -9448,6 +9451,17 @@ void srt::CUDT::processCtrl(const CPacket &ctrlpkt)
const steady_clock::time_point currtime = steady_clock::now();
m_tsLastRspTime = currtime;

// Extra check for the payload size:
// - must be aligned to int32_t
// - cannot be 0 (msgs with no args use 4-byte zero-filled padding).
size_t pktlen = ctrlpkt.getLength();
if (!pktlen || pktlen % sizeof(int32_t) != 0)
{
LOGC(inlog.Error, log << CONID() << "EPE: incoming UMSG: " << ctrlpkt.getType() << " INVALID SIZE: " << pktlen
<< " (expected > 0 and aligned to " << sizeof(int32_t) << " bytes)");
return;
}

HLOGC(inlog.Debug,
log << CONID() << "incoming UMSG:" << ctrlpkt.getType() << " ("
<< MessageTypeStr(ctrlpkt.getType(), ctrlpkt.getExtendedType()) << ") socket=%" << ctrlpkt.id());
Expand Down
11 changes: 4 additions & 7 deletions srtcore/crypto.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -370,26 +370,23 @@ int srt::CCryptoControl::processSrtMsg_KMREQ(

int srt::CCryptoControl::processSrtMsg_KMRSP(const uint32_t* srtdata, size_t len, unsigned srtv)
{
uint32_t srtd[SRTDATA_MAXSIZE];
size_t srtlen = len/sizeof(uint32_t);
// Validate the wire-supplied length before using it:
// - oversize would overflow the fixed-size stack buffer below;
// - non-word-aligned or too-small payloads are malformed by protocol and would
// feed uninitialised stack into downstream key-matching logic.
if (len > SRT_CMD_MAXSZ
|| len < sizeof(uint32_t)
|| (len % sizeof(uint32_t)) != 0)
if (srtlen > SRTDATA_MAXSIZE)
{
LOGC(cnlog.Error, log << "processSrtMsg_KMRSP: malformed len " << len
<< " (must be a non-zero multiple of " << sizeof(uint32_t)
<< ", up to " << SRT_CMD_MAXSZ << ") - rejecting");
<< " (must be up to " << SRT_CMD_MAXSZ << ") - rejecting");
return SRT_CMD_NONE;
}

/* All 32-bit msg fields (if present) swapped on reception
* But HaiCrypt expect network order message
* Re-swap to cancel it.
*/
uint32_t srtd[SRTDATA_MAXSIZE];
size_t srtlen = len/sizeof(uint32_t);
HtoNLA(srtd, srtdata, srtlen);

int retstatus = -1;
Expand Down
8 changes: 0 additions & 8 deletions test/test_crypto.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -20,14 +20,6 @@ TEST(CryptoKMRSP, RejectsMalformedLengths)
// Oversize: would overflow uint32_t srtd[SRTDATA_MAXSIZE].
EXPECT_EQ(crypt.processSrtMsg_KMRSP(garbage.data(), SRT_CMD_MAXSZ + sizeof(uint32_t), srtv),
srt::SRT_CMD_NONE);

// Non-word-aligned: silently drops bytes and risks misinterpretation.
EXPECT_EQ(crypt.processSrtMsg_KMRSP(garbage.data(), 7, srtv), srt::SRT_CMD_NONE);

// Empty / under-a-word: HtoNLA writes nothing and downstream code would read
// uninitialised stack from srtd[].
EXPECT_EQ(crypt.processSrtMsg_KMRSP(garbage.data(), 0, srtv), srt::SRT_CMD_NONE);
EXPECT_EQ(crypt.processSrtMsg_KMRSP(garbage.data(), 3, srtv), srt::SRT_CMD_NONE);
}

#if defined(SRT_ENABLE_ENCRYPTION) && defined(ENABLE_AEAD_API_PREVIEW)
Expand Down
Loading