diff --git a/srtcore/core.cpp b/srtcore/core.cpp index e6087fd68..24ede852f 100644 --- a/srtcore/core.cpp +++ b/srtcore/core.cpp @@ -8691,6 +8691,8 @@ void srt::CUDT::updateSndLossListOnACK(int32_t ackdata_seqno) void srt::CUDT::processCtrlAck(const CPacket &ctrlpkt, const steady_clock::time_point& currtime) { const int32_t* ackdata = (const int32_t*)ctrlpkt.m_pcData; + + // Note: minimum of one 4-byte field is granted before the call. const int32_t ackdata_seqno = ackdata[ACKD_RCVLASTACK]; // Check the value of ACK in case when it was some rogue peer @@ -8705,7 +8707,7 @@ void srt::CUDT::processCtrlAck(const CPacket &ctrlpkt, const steady_clock::time_ return; } - const bool isLiteAck = ctrlpkt.getLength() == (size_t)SEND_LITE_ACK; + const bool isLiteAck = ctrlpkt.getLength() == size_t(SEND_LITE_ACK); HLOGC(inlog.Debug, log << CONID() << "ACK covers: " << m_iSndLastDataAck << " - " << ackdata_seqno << " [ACK=" << m_iSndLastAck << "]" << (isLiteAck ? "[LITE]" : "[FULL]")); @@ -8727,6 +8729,15 @@ void srt::CUDT::processCtrlAck(const CPacket &ctrlpkt, const steady_clock::time_ return; } + const size_t acksize = ctrlpkt.getLength() / ACKD_FIELD_SIZE; // ACTUAL VALUE + + // Check minimum size acceptable. If less, reject it. + if (acksize < ACKD_TOTAL_SIZE_SMALL) + { + LOGC(inlog.Error, log << CONID() << "EPE: ACK msg received with too small size: " << ctrlpkt.getLength()); + return; + } + // Decide to send ACKACK or not { // Sequence number of the ACK packet @@ -8822,18 +8833,6 @@ void srt::CUDT::processCtrlAck(const CPacket &ctrlpkt, const steady_clock::time_ } #endif - size_t acksize = ctrlpkt.getLength(); // TEMPORARY VALUE FOR CHECKING - bool wrongsize = 0 != (acksize % ACKD_FIELD_SIZE); - acksize = acksize / ACKD_FIELD_SIZE; // ACTUAL VALUE - - if (wrongsize) - { - // Issue a log, but don't do anything but skipping the "odd" bytes from the payload. - LOGC(inlog.Warn, - log << CONID() << "Received UMSG_ACK payload is not evened up to 4-byte based field size - cutting to " - << acksize << " fields"); - } - // Start with checking the base size. if (acksize < ACKD_TOTAL_SIZE_SMALL) { @@ -9052,7 +9051,7 @@ void srt::CUDT::processCtrlAckAck(const CPacket& ctrlpkt, const time_point& tsAr void srt::CUDT::processCtrlLossReport(const CPacket& ctrlpkt) { const int32_t* losslist = (int32_t*)(ctrlpkt.m_pcData); - const size_t losslist_len = ctrlpkt.getLength() / 4; + const size_t losslist_len = ctrlpkt.getLength() / sizeof(int32_t); bool secure = true; @@ -9213,7 +9212,11 @@ void srt::CUDT::processCtrlLossReport(const CPacket& ctrlpkt) void srt::CUDT::processCtrlHS(const CPacket& ctrlpkt) { CHandShake req; - req.load_from(ctrlpkt.m_pcData, ctrlpkt.getLength()); + if (-1 == req.load_from(ctrlpkt.m_pcData, ctrlpkt.getLength())) + { + LOGC(inlog.Error, log << CONID() << "processCtrlHS: EPE: Handshake has wrong size: " << ctrlpkt.getLength()); + return; + } HLOGC(inlog.Debug, log << CONID() << "processCtrl: got HS: " << req.show()); @@ -9448,6 +9451,17 @@ void srt::CUDT::processCtrl(const CPacket &ctrlpkt) const steady_clock::time_point currtime = steady_clock::now(); m_tsLastRspTime = currtime; + // Extra check for the payload size: + // - must be aligned to int32_t + // - cannot be 0 (msgs with no args use 4-byte zero-filled padding). + size_t pktlen = ctrlpkt.getLength(); + if (!pktlen || pktlen % sizeof(int32_t) != 0) + { + LOGC(inlog.Error, log << CONID() << "EPE: incoming UMSG: " << ctrlpkt.getType() << " INVALID SIZE: " << pktlen + << " (expected > 0 and aligned to " << sizeof(int32_t) << " bytes)"); + return; + } + HLOGC(inlog.Debug, log << CONID() << "incoming UMSG:" << ctrlpkt.getType() << " (" << MessageTypeStr(ctrlpkt.getType(), ctrlpkt.getExtendedType()) << ") socket=%" << ctrlpkt.id()); diff --git a/srtcore/crypto.cpp b/srtcore/crypto.cpp index 20c8ba47e..660a1f177 100644 --- a/srtcore/crypto.cpp +++ b/srtcore/crypto.cpp @@ -370,17 +370,16 @@ int srt::CCryptoControl::processSrtMsg_KMREQ( int srt::CCryptoControl::processSrtMsg_KMRSP(const uint32_t* srtdata, size_t len, unsigned srtv) { + uint32_t srtd[SRTDATA_MAXSIZE]; + size_t srtlen = len/sizeof(uint32_t); // Validate the wire-supplied length before using it: // - oversize would overflow the fixed-size stack buffer below; // - non-word-aligned or too-small payloads are malformed by protocol and would // feed uninitialised stack into downstream key-matching logic. - if (len > SRT_CMD_MAXSZ - || len < sizeof(uint32_t) - || (len % sizeof(uint32_t)) != 0) + if (srtlen > SRTDATA_MAXSIZE) { LOGC(cnlog.Error, log << "processSrtMsg_KMRSP: malformed len " << len - << " (must be a non-zero multiple of " << sizeof(uint32_t) - << ", up to " << SRT_CMD_MAXSZ << ") - rejecting"); + << " (must be up to " << SRT_CMD_MAXSZ << ") - rejecting"); return SRT_CMD_NONE; } @@ -388,8 +387,6 @@ int srt::CCryptoControl::processSrtMsg_KMRSP(const uint32_t* srtdata, size_t len * But HaiCrypt expect network order message * Re-swap to cancel it. */ - uint32_t srtd[SRTDATA_MAXSIZE]; - size_t srtlen = len/sizeof(uint32_t); HtoNLA(srtd, srtdata, srtlen); int retstatus = -1; diff --git a/test/test_crypto.cpp b/test/test_crypto.cpp index 466497d45..f7651a85c 100644 --- a/test/test_crypto.cpp +++ b/test/test_crypto.cpp @@ -20,14 +20,6 @@ TEST(CryptoKMRSP, RejectsMalformedLengths) // Oversize: would overflow uint32_t srtd[SRTDATA_MAXSIZE]. EXPECT_EQ(crypt.processSrtMsg_KMRSP(garbage.data(), SRT_CMD_MAXSZ + sizeof(uint32_t), srtv), srt::SRT_CMD_NONE); - - // Non-word-aligned: silently drops bytes and risks misinterpretation. - EXPECT_EQ(crypt.processSrtMsg_KMRSP(garbage.data(), 7, srtv), srt::SRT_CMD_NONE); - - // Empty / under-a-word: HtoNLA writes nothing and downstream code would read - // uninitialised stack from srtd[]. - EXPECT_EQ(crypt.processSrtMsg_KMRSP(garbage.data(), 0, srtv), srt::SRT_CMD_NONE); - EXPECT_EQ(crypt.processSrtMsg_KMRSP(garbage.data(), 3, srtv), srt::SRT_CMD_NONE); } #if defined(SRT_ENABLE_ENCRYPTION) && defined(ENABLE_AEAD_API_PREVIEW)