From fd32d6ffd0efdbc720265ebfa654c9de3a58cdf0 Mon Sep 17 00:00:00 2001 From: chenBright Date: Mon, 15 Jun 2026 22:33:49 +0800 Subject: [PATCH] Make _state atomic to prevent concurrent _read_buf mutation on TCP fallback RdmaEndpoint::_state was a plain enum, written by the handshake bthread and read concurrently by the event-dispatching thread (OnNewDataFromTcp). This is a data race, and on a weak memory model it can let the two threads concurrently mutate _socket->_read_buf. Make _state a butil::atomic: - Terminal-state stores use release and the matching loads use acquire, so data published before a terminal state (the magic bytes put back into _read_buf, and the RDMA window/resource setup before ESTABLISHED) is visible to the reader. - Non-terminal handshake transitions use relaxed. --- src/brpc/rdma/rdma_endpoint.cpp | 95 ++++++++++++++++++--------------- src/brpc/rdma/rdma_endpoint.h | 4 +- 2 files changed, 53 insertions(+), 46 deletions(-) diff --git a/src/brpc/rdma/rdma_endpoint.cpp b/src/brpc/rdma/rdma_endpoint.cpp index 658c7a2fcc..a5016dc710 100644 --- a/src/brpc/rdma/rdma_endpoint.cpp +++ b/src/brpc/rdma/rdma_endpoint.cpp @@ -161,7 +161,7 @@ RdmaEndpoint::~RdmaEndpoint() { void RdmaEndpoint::Reset() { DeallocateResources(); - _state = UNINIT; + _state.store(UNINIT, butil::memory_order_relaxed); _resource = NULL; _send_cq_events = 0; _recv_cq_events = 0; @@ -195,7 +195,8 @@ void RdmaConnect::StartConnect(const Socket* socket, return; } if (!IsRdmaAvailable()) { - rdma_transport->_rdma_ep->_state = RdmaEndpoint::FALLBACK_TCP; + rdma_transport->_rdma_ep->_state.store(RdmaEndpoint::FALLBACK_TCP, + butil::memory_order_relaxed); rdma_transport->_rdma_state = RdmaTransport::RDMA_OFF; done(0, data); return; @@ -206,7 +207,8 @@ void RdmaConnect::StartConnect(const Socket* socket, bthread_attr_t attr = BTHREAD_ATTR_NORMAL; bthread_attr_set_name(&attr, "RdmaProcessHandshakeAtClient"); if (bthread_start_background(&tid, &attr, - RdmaEndpoint::ProcessHandshakeAtClient, rdma_transport->_rdma_ep) < 0) { + RdmaEndpoint::ProcessHandshakeAtClient, + rdma_transport->_rdma_ep) < 0) { LOG(FATAL) << "Fail to start handshake bthread"; Run(); } else { @@ -230,7 +232,7 @@ static void TryReadOnTcpDuringRdmaEst(Socket* s) { const int saved_errno = errno; PLOG(WARNING) << "Fail to read from " << s; s->SetFailed(saved_errno, "Fail to read from %s: %s", - s->description().c_str(), berror(saved_errno)); + s->description().c_str(), berror(saved_errno)); return; } if (!s->MoreReadEvents(&progress)) { @@ -255,22 +257,22 @@ void RdmaEndpoint::OnNewDataFromTcp(Socket* m) { int progress = Socket::PROGRESS_INIT; while (true) { - if (ep->_state == UNINIT) { + State state = ep->_state.load(butil::memory_order_acquire); + if (state == UNINIT) { if (!m->CreatedByConnect()) { if (!IsRdmaAvailable()) { - ep->_state = FALLBACK_TCP; rdma_transport->_rdma_state = RdmaTransport::RDMA_OFF; + ep->_state.store(FALLBACK_TCP, butil::memory_order_relaxed); continue; } bthread_t tid; - ep->_state = S_HELLO_WAIT; + ep->_state.store(S_HELLO_WAIT, butil::memory_order_relaxed); SocketUniquePtr s; m->ReAddress(&s); bthread_attr_t attr = BTHREAD_ATTR_NORMAL; bthread_attr_set_name(&attr, "RdmaProcessHandshakeAtServer"); - if (bthread_start_background(&tid, &attr, - ProcessHandshakeAtServer, ep) < 0) { - ep->_state = UNINIT; + if (bthread_start_background(&tid, &attr, ProcessHandshakeAtServer, ep) < 0) { + ep->_state.store(UNINIT, butil::memory_order_relaxed); LOG(FATAL) << "Fail to start handshake bthread"; } else { s.release(); @@ -280,13 +282,13 @@ void RdmaEndpoint::OnNewDataFromTcp(Socket* m) { // starts handshake. This will be handled by client handshake. // Ignore the exception here. } - } else if (ep->_state < ESTABLISHED) { // during handshake + } else if (state < ESTABLISHED) { // during handshake ep->_read_butex->fetch_add(1, butil::memory_order_release); bthread::butex_wake(ep->_read_butex); - } else if (ep->_state == FALLBACK_TCP){ // handshake finishes + } else if (state == FALLBACK_TCP){ // handshake finishes InputMessenger::OnNewMessages(m); return; - } else if (ep->_state == ESTABLISHED) { + } else if (state == ESTABLISHED) { TryReadOnTcpDuringRdmaEst(ep->_socket); return; } @@ -422,9 +424,10 @@ int RdmaEndpoint::WriteToFd(butil::IOBuf* data) { inline void RdmaEndpoint::TryReadOnTcp() { if (_socket->_nevent.fetch_add(1, butil::memory_order_acq_rel) == 0) { - if (_state == FALLBACK_TCP) { + State state = _state.load(butil::memory_order_acquire); + if (state == FALLBACK_TCP) { InputMessenger::OnNewMessages(_socket); - } else if (_state == ESTABLISHED) { + } else if (state == ESTABLISHED) { TryReadOnTcpDuringRdmaEst(_socket); } } @@ -475,28 +478,28 @@ void* RdmaEndpoint::ProcessHandshakeAtClient(void* arg) { ep->_handshake_version = handshake->ProtocolVersion(); // First initialize CQ and QP resources. - ep->_state = C_ALLOC_QPCQ; + ep->_state.store(C_ALLOC_QPCQ, butil::memory_order_relaxed); if (ep->AllocateResources() < 0) { LOG(WARNING) << "Fallback to tcp:" << s->description(); rdma_transport->_rdma_state = RdmaTransport::RDMA_OFF; - ep->_state = FALLBACK_TCP; + ep->_state.store(FALLBACK_TCP, butil::memory_order_release); return NULL; } // Send hello message to server - ep->_state = C_HELLO_SEND; + ep->_state.store(C_HELLO_SEND, butil::memory_order_relaxed); if (handshake->SendLocalHello() < 0) { int saved_errno = errno; PLOG(WARNING) << "Fail to send hello message to server:" << s->description(); s->SetFailed(saved_errno, "Fail to complete rdma handshake from %s: %s", s->description().c_str(), berror(saved_errno)); - ep->_state = FAILED; + ep->_state.store(FAILED, butil::memory_order_relaxed); return NULL; } // Receive and parse remote hello. - ep->_state = C_HELLO_WAIT; + ep->_state.store(C_HELLO_WAIT, butil::memory_order_relaxed); ParsedHello remote{}; bool negotiated = false; if (handshake->ReceiveAndParseRemoteHello(&remote, &negotiated) < 0) { @@ -505,7 +508,7 @@ void* RdmaEndpoint::ProcessHandshakeAtClient(void* arg) { << s->description(); s->SetFailed(saved_errno, "Fail to complete rdma handshake from %s: %s", s->description().c_str(), berror(saved_errno)); - ep->_state = FAILED; + ep->_state.store(FAILED, butil::memory_order_relaxed); return NULL; } @@ -515,7 +518,7 @@ void* RdmaEndpoint::ProcessHandshakeAtClient(void* arg) { rdma_transport->_rdma_state = RdmaTransport::RDMA_OFF; } else { ep->ApplyRemoteHello(remote); - ep->_state = C_BRINGUP_QP; + ep->_state.store(C_BRINGUP_QP, butil::memory_order_relaxed); if (ep->BringUpQp(remote.lid, remote.gid, remote.qp_num) < 0) { LOG(WARNING) << "Fail to bringup QP, fallback to tcp:" << s->description(); @@ -526,8 +529,9 @@ void* RdmaEndpoint::ProcessHandshakeAtClient(void* arg) { } // Send ACK message to server - ep->_state = C_ACK_SEND; - uint32_t flags = rdma_transport->_rdma_state != RdmaTransport::RDMA_OFF ? HELLO_ACK_RDMA_OK : 0; + ep->_state.store(C_ACK_SEND, butil::memory_order_relaxed); + bool rdma_on = rdma_transport->_rdma_state == RdmaTransport::RDMA_ON; + uint32_t flags = rdma_on ? HELLO_ACK_RDMA_OK : 0; uint32_t flags_be = butil::HostToNet32(flags); if (ep->WriteToFd(&flags_be, HELLO_ACK_LEN) < 0) { int saved_errno = errno; @@ -535,17 +539,17 @@ void* RdmaEndpoint::ProcessHandshakeAtClient(void* arg) { << s->description(); s->SetFailed(saved_errno, "Fail to complete rdma handshake from %s: %s", s->description().c_str(), berror(saved_errno)); - ep->_state = FAILED; + ep->_state.store(FAILED, butil::memory_order_relaxed); return NULL; } if (rdma_transport->_rdma_state == RdmaTransport::RDMA_ON) { - ep->_state = ESTABLISHED; + ep->_state.store(ESTABLISHED, butil::memory_order_release); LOG_IF(INFO, FLAGS_rdma_trace_verbose) << "Client handshake ends (use rdma v" << ep->_handshake_version << ") on " << s->description(); } else { - ep->_state = FALLBACK_TCP; + ep->_state.store(FALLBACK_TCP, butil::memory_order_release); LOG_IF(INFO, FLAGS_rdma_trace_verbose) << "Client handshake ends (use tcp) on " << s->description(); } @@ -578,7 +582,7 @@ void* RdmaEndpoint::ProcessHandshakeAtServer(void* arg) { LOG_IF(INFO, FLAGS_rdma_trace_verbose) << "Start handshake on " << s->description(); - ep->_state = S_HELLO_WAIT; + ep->_state.store(S_HELLO_WAIT, butil::memory_order_relaxed); uint8_t magic[MAGIC_STR_LEN]; if (ep->ReadFromFd(magic, MAGIC_STR_LEN) < 0) { int saved_errno = errno; @@ -586,7 +590,7 @@ void* RdmaEndpoint::ProcessHandshakeAtServer(void* arg) { << s->description() << " " << s->_remote_side; s->SetFailed(saved_errno, "Fail to complete rdma handshake from %s: %s", s->description().c_str(), berror(saved_errno)); - ep->_state = FAILED; + ep->_state.store(FAILED, butil::memory_order_relaxed); return NULL; } @@ -598,8 +602,11 @@ void* RdmaEndpoint::ProcessHandshakeAtServer(void* arg) { << s->description(); // We need to copy data read back to _socket->_read_buf. s->_read_buf.append(magic, MAGIC_STR_LEN); - ep->_state = FALLBACK_TCP; rdma_transport->_rdma_state = RdmaTransport::RDMA_OFF; + // Use release memory order to publish the magic bytes appended + // above to whoever reads `_state == FALLBACK_TCP` (the event + // thread in OnNewDataFromTcp). + ep->_state.store(FALLBACK_TCP, butil::memory_order_release); ep->TryReadOnTcp(); return NULL; } @@ -614,7 +621,7 @@ void* RdmaEndpoint::ProcessHandshakeAtServer(void* arg) { << s->description(); s->SetFailed(saved_errno, "Fail to complete rdma handshake from %s: %s", s->description().c_str(), berror(saved_errno)); - ep->_state = FAILED; + ep->_state.store(FAILED, butil::memory_order_relaxed); return NULL; } @@ -624,13 +631,13 @@ void* RdmaEndpoint::ProcessHandshakeAtServer(void* arg) { rdma_transport->_rdma_state = RdmaTransport::RDMA_OFF; } else { ep->ApplyRemoteHello(remote); - ep->_state = S_ALLOC_QPCQ; + ep->_state.store(S_ALLOC_QPCQ, butil::memory_order_relaxed); if (ep->AllocateResources() < 0) { LOG(WARNING) << "Fail to allocate rdma resources, fallback to tcp:" << s->description(); rdma_transport->_rdma_state = RdmaTransport::RDMA_OFF; } else { - ep->_state = S_BRINGUP_QP; + ep->_state.store(S_BRINGUP_QP, butil::memory_order_relaxed); if (ep->BringUpQp(remote.lid, remote.gid, remote.qp_num) < 0) { LOG(WARNING) << "Fail to bringup QP, fallback to tcp:" << s->description(); @@ -639,18 +646,18 @@ void* RdmaEndpoint::ProcessHandshakeAtServer(void* arg) { } } - ep->_state = S_HELLO_SEND; + ep->_state.store(S_HELLO_SEND, butil::memory_order_relaxed); if (handshake->SendLocalHello() < 0) { int saved_errno = errno; PLOG(WARNING) << "Fail to send Hello Message to client:" << s->description(); s->SetFailed(saved_errno, "Fail to complete rdma handshake from %s: %s", s->description().c_str(), berror(saved_errno)); - ep->_state = FAILED; + ep->_state.store(FAILED, butil::memory_order_relaxed); return NULL; } - ep->_state = S_ACK_WAIT; + ep->_state.store(S_ACK_WAIT, butil::memory_order_relaxed); uint32_t flags_be = 0; if (ep->ReadFromFd(&flags_be, HELLO_ACK_LEN) < 0) { int saved_errno = errno; @@ -658,12 +665,11 @@ void* RdmaEndpoint::ProcessHandshakeAtServer(void* arg) { << s->description(); s->SetFailed(saved_errno, "Fail to complete rdma handshake from %s: %s", s->description().c_str(), berror(saved_errno)); - ep->_state = FAILED; + ep->_state.store(FAILED, butil::memory_order_relaxed); return NULL; } uint32_t flags = butil::NetToHost32(flags_be); bool client_ack_ok = (flags & HELLO_ACK_RDMA_OK) != 0; - if (client_ack_ok) { if (rdma_transport->_rdma_state == RdmaTransport::RDMA_OFF) { // Client asked for RDMA but we are falling back: protocol @@ -673,17 +679,17 @@ void* RdmaEndpoint::ProcessHandshakeAtServer(void* arg) { << "RDMA_OFF state: " << s->description(); s->SetFailed(EPROTO, "Fail to complete rdma handshake from %s: %s", s->description().c_str(), berror(EPROTO)); - ep->_state = FAILED; + ep->_state.store(FAILED, butil::memory_order_relaxed); return NULL; } rdma_transport->_rdma_state = RdmaTransport::RDMA_ON; - ep->_state = ESTABLISHED; + ep->_state.store(ESTABLISHED, butil::memory_order_release); LOG_IF(INFO, FLAGS_rdma_trace_verbose) << "Server handshake ends (use rdma v" << ep->_handshake_version << ") on " << s->description(); } else { rdma_transport->_rdma_state = RdmaTransport::RDMA_OFF; - ep->_state = FALLBACK_TCP; + ep->_state.store(FALLBACK_TCP, butil::memory_order_release); LOG_IF(INFO, FLAGS_rdma_trace_verbose) << "Server handshake ends (use tcp) on " << s->description(); } @@ -712,7 +718,8 @@ friend class RdmaEndpoint; // blocks or first max_len bytes. // Return: the bytes included in the sglist, or -1 if failed ssize_t cut_into_sglist_and_iobuf(ibv_sge* sglist, size_t* sge_index, - butil::IOBuf* to, size_t max_sge, size_t max_len) { + butil::IOBuf* to, size_t max_sge, + size_t max_len) { size_t len = 0; while (*sge_index < max_sge) { if (len == max_len || _ref_num() == 0) { @@ -967,7 +974,7 @@ ssize_t RdmaEndpoint::HandleCompletion(ibv_wc& wc) { if (wc.byte_len < (uint32_t)FLAGS_rdma_zerocopy_min_size) { zerocopy = false; } - CHECK(_state != FALLBACK_TCP); + CHECK_NE(_state.load(butil::memory_order_acquire), FALLBACK_TCP); if (zerocopy) { _rbuf[_rq_received].cutn(&_socket->_read_buf, wc.byte_len); } else { @@ -1586,7 +1593,7 @@ void RdmaEndpoint::PollCq(Socket* m) { } std::string RdmaEndpoint::GetStateStr() const { - switch (_state) { + switch (_state.load(butil::memory_order_relaxed)) { case UNINIT: return "UNINIT"; case C_ALLOC_QPCQ: return "C_ALLOC_QPCQ"; case C_HELLO_SEND: return "C_HELLO_SEND"; diff --git a/src/brpc/rdma/rdma_endpoint.h b/src/brpc/rdma/rdma_endpoint.h index 7b6652bc86..41c33824be 100644 --- a/src/brpc/rdma/rdma_endpoint.h +++ b/src/brpc/rdma/rdma_endpoint.h @@ -250,7 +250,7 @@ friend int v3_wire::WriteV3Hello(RdmaEndpoint*, const RdmaHello&); std::string GetStateStr() const; // Try to read data on TCP fd in _socket - inline void TryReadOnTcp(); + void TryReadOnTcp(); // Add cq socket id to poller void PollerAddCqSid(); @@ -262,7 +262,7 @@ friend int v3_wire::WriteV3Hello(RdmaEndpoint*, const RdmaHello&); Socket* _socket; // State of Handshake - State _state; + butil::atomic _state; // Wire-level handshake protocol version (set by dispatch in // ProcessHandshakeAtClient/Server). Aligned with the protocol code: