Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
95 changes: 51 additions & 44 deletions src/brpc/rdma/rdma_endpoint.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -161,7 +161,7 @@ RdmaEndpoint::~RdmaEndpoint() {
void RdmaEndpoint::Reset() {
DeallocateResources();

_state = UNINIT;
_state.store(UNINIT, butil::memory_order_relaxed);
_resource = NULL;
_send_cq_events = 0;
_recv_cq_events = 0;
Expand Down Expand Up @@ -195,7 +195,8 @@ void RdmaConnect::StartConnect(const Socket* socket,
return;
}
if (!IsRdmaAvailable()) {
rdma_transport->_rdma_ep->_state = RdmaEndpoint::FALLBACK_TCP;
rdma_transport->_rdma_ep->_state.store(RdmaEndpoint::FALLBACK_TCP,
butil::memory_order_relaxed);
rdma_transport->_rdma_state = RdmaTransport::RDMA_OFF;
done(0, data);
return;
Expand All @@ -206,7 +207,8 @@ void RdmaConnect::StartConnect(const Socket* socket,
bthread_attr_t attr = BTHREAD_ATTR_NORMAL;
bthread_attr_set_name(&attr, "RdmaProcessHandshakeAtClient");
if (bthread_start_background(&tid, &attr,
RdmaEndpoint::ProcessHandshakeAtClient, rdma_transport->_rdma_ep) < 0) {
RdmaEndpoint::ProcessHandshakeAtClient,
rdma_transport->_rdma_ep) < 0) {
LOG(FATAL) << "Fail to start handshake bthread";
Run();
} else {
Expand All @@ -230,7 +232,7 @@ static void TryReadOnTcpDuringRdmaEst(Socket* s) {
const int saved_errno = errno;
PLOG(WARNING) << "Fail to read from " << s;
s->SetFailed(saved_errno, "Fail to read from %s: %s",
s->description().c_str(), berror(saved_errno));
s->description().c_str(), berror(saved_errno));
return;
}
if (!s->MoreReadEvents(&progress)) {
Expand All @@ -255,22 +257,22 @@ void RdmaEndpoint::OnNewDataFromTcp(Socket* m) {

int progress = Socket::PROGRESS_INIT;
while (true) {
if (ep->_state == UNINIT) {
State state = ep->_state.load(butil::memory_order_acquire);
if (state == UNINIT) {
if (!m->CreatedByConnect()) {
if (!IsRdmaAvailable()) {
ep->_state = FALLBACK_TCP;
rdma_transport->_rdma_state = RdmaTransport::RDMA_OFF;
ep->_state.store(FALLBACK_TCP, butil::memory_order_relaxed);
continue;
}
bthread_t tid;
ep->_state = S_HELLO_WAIT;
ep->_state.store(S_HELLO_WAIT, butil::memory_order_relaxed);
SocketUniquePtr s;
m->ReAddress(&s);
bthread_attr_t attr = BTHREAD_ATTR_NORMAL;
bthread_attr_set_name(&attr, "RdmaProcessHandshakeAtServer");
if (bthread_start_background(&tid, &attr,
ProcessHandshakeAtServer, ep) < 0) {
ep->_state = UNINIT;
if (bthread_start_background(&tid, &attr, ProcessHandshakeAtServer, ep) < 0) {
ep->_state.store(UNINIT, butil::memory_order_relaxed);
LOG(FATAL) << "Fail to start handshake bthread";
} else {
s.release();
Expand All @@ -280,13 +282,13 @@ void RdmaEndpoint::OnNewDataFromTcp(Socket* m) {
// starts handshake. This will be handled by client handshake.
// Ignore the exception here.
}
} else if (ep->_state < ESTABLISHED) { // during handshake
} else if (state < ESTABLISHED) { // during handshake
ep->_read_butex->fetch_add(1, butil::memory_order_release);
bthread::butex_wake(ep->_read_butex);
} else if (ep->_state == FALLBACK_TCP){ // handshake finishes
} else if (state == FALLBACK_TCP){ // handshake finishes
InputMessenger::OnNewMessages(m);
return;
} else if (ep->_state == ESTABLISHED) {
} else if (state == ESTABLISHED) {
TryReadOnTcpDuringRdmaEst(ep->_socket);
return;
}
Expand Down Expand Up @@ -422,9 +424,10 @@ int RdmaEndpoint::WriteToFd(butil::IOBuf* data) {

inline void RdmaEndpoint::TryReadOnTcp() {
if (_socket->_nevent.fetch_add(1, butil::memory_order_acq_rel) == 0) {
if (_state == FALLBACK_TCP) {
State state = _state.load(butil::memory_order_acquire);
if (state == FALLBACK_TCP) {
InputMessenger::OnNewMessages(_socket);
} else if (_state == ESTABLISHED) {
} else if (state == ESTABLISHED) {
TryReadOnTcpDuringRdmaEst(_socket);
}
}
Expand Down Expand Up @@ -475,28 +478,28 @@ void* RdmaEndpoint::ProcessHandshakeAtClient(void* arg) {
ep->_handshake_version = handshake->ProtocolVersion();

// First initialize CQ and QP resources.
ep->_state = C_ALLOC_QPCQ;
ep->_state.store(C_ALLOC_QPCQ, butil::memory_order_relaxed);
if (ep->AllocateResources() < 0) {
LOG(WARNING) << "Fallback to tcp:" << s->description();
rdma_transport->_rdma_state = RdmaTransport::RDMA_OFF;
ep->_state = FALLBACK_TCP;
ep->_state.store(FALLBACK_TCP, butil::memory_order_release);
return NULL;
}

// Send hello message to server
ep->_state = C_HELLO_SEND;
ep->_state.store(C_HELLO_SEND, butil::memory_order_relaxed);
if (handshake->SendLocalHello() < 0) {
int saved_errno = errno;
PLOG(WARNING) << "Fail to send hello message to server:"
<< s->description();
s->SetFailed(saved_errno, "Fail to complete rdma handshake from %s: %s",
s->description().c_str(), berror(saved_errno));
ep->_state = FAILED;
ep->_state.store(FAILED, butil::memory_order_relaxed);
return NULL;
}

// Receive and parse remote hello.
ep->_state = C_HELLO_WAIT;
ep->_state.store(C_HELLO_WAIT, butil::memory_order_relaxed);
ParsedHello remote{};
bool negotiated = false;
if (handshake->ReceiveAndParseRemoteHello(&remote, &negotiated) < 0) {
Expand All @@ -505,7 +508,7 @@ void* RdmaEndpoint::ProcessHandshakeAtClient(void* arg) {
<< s->description();
s->SetFailed(saved_errno, "Fail to complete rdma handshake from %s: %s",
s->description().c_str(), berror(saved_errno));
ep->_state = FAILED;
ep->_state.store(FAILED, butil::memory_order_relaxed);
return NULL;
}

Expand All @@ -515,7 +518,7 @@ void* RdmaEndpoint::ProcessHandshakeAtClient(void* arg) {
rdma_transport->_rdma_state = RdmaTransport::RDMA_OFF;
} else {
ep->ApplyRemoteHello(remote);
ep->_state = C_BRINGUP_QP;
ep->_state.store(C_BRINGUP_QP, butil::memory_order_relaxed);
if (ep->BringUpQp(remote.lid, remote.gid, remote.qp_num) < 0) {
LOG(WARNING) << "Fail to bringup QP, fallback to tcp:"
<< s->description();
Expand All @@ -526,26 +529,27 @@ void* RdmaEndpoint::ProcessHandshakeAtClient(void* arg) {
}

// Send ACK message to server
ep->_state = C_ACK_SEND;
uint32_t flags = rdma_transport->_rdma_state != RdmaTransport::RDMA_OFF ? HELLO_ACK_RDMA_OK : 0;
ep->_state.store(C_ACK_SEND, butil::memory_order_relaxed);
bool rdma_on = rdma_transport->_rdma_state == RdmaTransport::RDMA_ON;
uint32_t flags = rdma_on ? HELLO_ACK_RDMA_OK : 0;
uint32_t flags_be = butil::HostToNet32(flags);
if (ep->WriteToFd(&flags_be, HELLO_ACK_LEN) < 0) {
int saved_errno = errno;
PLOG(WARNING) << "Fail to send Ack Message to server:"
<< s->description();
s->SetFailed(saved_errno, "Fail to complete rdma handshake from %s: %s",
s->description().c_str(), berror(saved_errno));
ep->_state = FAILED;
ep->_state.store(FAILED, butil::memory_order_relaxed);
return NULL;
}

if (rdma_transport->_rdma_state == RdmaTransport::RDMA_ON) {
ep->_state = ESTABLISHED;
ep->_state.store(ESTABLISHED, butil::memory_order_release);
LOG_IF(INFO, FLAGS_rdma_trace_verbose)
<< "Client handshake ends (use rdma v" << ep->_handshake_version
<< ") on " << s->description();
} else {
ep->_state = FALLBACK_TCP;
ep->_state.store(FALLBACK_TCP, butil::memory_order_release);
LOG_IF(INFO, FLAGS_rdma_trace_verbose)
<< "Client handshake ends (use tcp) on " << s->description();
}
Expand Down Expand Up @@ -578,15 +582,15 @@ void* RdmaEndpoint::ProcessHandshakeAtServer(void* arg) {
LOG_IF(INFO, FLAGS_rdma_trace_verbose)
<< "Start handshake on " << s->description();

ep->_state = S_HELLO_WAIT;
ep->_state.store(S_HELLO_WAIT, butil::memory_order_relaxed);
uint8_t magic[MAGIC_STR_LEN];
if (ep->ReadFromFd(magic, MAGIC_STR_LEN) < 0) {
int saved_errno = errno;
PLOG(WARNING) << "Fail to read Hello Message from client:"
<< s->description() << " " << s->_remote_side;
s->SetFailed(saved_errno, "Fail to complete rdma handshake from %s: %s",
s->description().c_str(), berror(saved_errno));
ep->_state = FAILED;
ep->_state.store(FAILED, butil::memory_order_relaxed);
return NULL;
}

Expand All @@ -598,8 +602,11 @@ void* RdmaEndpoint::ProcessHandshakeAtServer(void* arg) {
<< s->description();
// We need to copy data read back to _socket->_read_buf.
s->_read_buf.append(magic, MAGIC_STR_LEN);
ep->_state = FALLBACK_TCP;
rdma_transport->_rdma_state = RdmaTransport::RDMA_OFF;
// Use release memory order to publish the magic bytes appended
// above to whoever reads `_state == FALLBACK_TCP` (the event
// thread in OnNewDataFromTcp).
ep->_state.store(FALLBACK_TCP, butil::memory_order_release);
ep->TryReadOnTcp();
return NULL;
}
Expand All @@ -614,7 +621,7 @@ void* RdmaEndpoint::ProcessHandshakeAtServer(void* arg) {
<< s->description();
s->SetFailed(saved_errno, "Fail to complete rdma handshake from %s: %s",
s->description().c_str(), berror(saved_errno));
ep->_state = FAILED;
ep->_state.store(FAILED, butil::memory_order_relaxed);
return NULL;
}

Expand All @@ -624,13 +631,13 @@ void* RdmaEndpoint::ProcessHandshakeAtServer(void* arg) {
rdma_transport->_rdma_state = RdmaTransport::RDMA_OFF;
} else {
ep->ApplyRemoteHello(remote);
ep->_state = S_ALLOC_QPCQ;
ep->_state.store(S_ALLOC_QPCQ, butil::memory_order_relaxed);
if (ep->AllocateResources() < 0) {
LOG(WARNING) << "Fail to allocate rdma resources, fallback to tcp:"
<< s->description();
rdma_transport->_rdma_state = RdmaTransport::RDMA_OFF;
} else {
ep->_state = S_BRINGUP_QP;
ep->_state.store(S_BRINGUP_QP, butil::memory_order_relaxed);
if (ep->BringUpQp(remote.lid, remote.gid, remote.qp_num) < 0) {
LOG(WARNING) << "Fail to bringup QP, fallback to tcp:"
<< s->description();
Expand All @@ -639,31 +646,30 @@ void* RdmaEndpoint::ProcessHandshakeAtServer(void* arg) {
}
}

ep->_state = S_HELLO_SEND;
ep->_state.store(S_HELLO_SEND, butil::memory_order_relaxed);
if (handshake->SendLocalHello() < 0) {
int saved_errno = errno;
PLOG(WARNING) << "Fail to send Hello Message to client:"
<< s->description();
s->SetFailed(saved_errno, "Fail to complete rdma handshake from %s: %s",
s->description().c_str(), berror(saved_errno));
ep->_state = FAILED;
ep->_state.store(FAILED, butil::memory_order_relaxed);
return NULL;
}

ep->_state = S_ACK_WAIT;
ep->_state.store(S_ACK_WAIT, butil::memory_order_relaxed);
uint32_t flags_be = 0;
if (ep->ReadFromFd(&flags_be, HELLO_ACK_LEN) < 0) {
int saved_errno = errno;
PLOG(WARNING) << "Fail to read ack message from client:"
<< s->description();
s->SetFailed(saved_errno, "Fail to complete rdma handshake from %s: %s",
s->description().c_str(), berror(saved_errno));
ep->_state = FAILED;
ep->_state.store(FAILED, butil::memory_order_relaxed);
return NULL;
}
uint32_t flags = butil::NetToHost32(flags_be);
bool client_ack_ok = (flags & HELLO_ACK_RDMA_OK) != 0;

if (client_ack_ok) {
if (rdma_transport->_rdma_state == RdmaTransport::RDMA_OFF) {
// Client asked for RDMA but we are falling back: protocol
Expand All @@ -673,17 +679,17 @@ void* RdmaEndpoint::ProcessHandshakeAtServer(void* arg) {
<< "RDMA_OFF state: " << s->description();
s->SetFailed(EPROTO, "Fail to complete rdma handshake from %s: %s",
s->description().c_str(), berror(EPROTO));
ep->_state = FAILED;
ep->_state.store(FAILED, butil::memory_order_relaxed);
return NULL;
}
rdma_transport->_rdma_state = RdmaTransport::RDMA_ON;
ep->_state = ESTABLISHED;
ep->_state.store(ESTABLISHED, butil::memory_order_release);
LOG_IF(INFO, FLAGS_rdma_trace_verbose)
<< "Server handshake ends (use rdma v" << ep->_handshake_version
<< ") on " << s->description();
} else {
rdma_transport->_rdma_state = RdmaTransport::RDMA_OFF;
ep->_state = FALLBACK_TCP;
ep->_state.store(FALLBACK_TCP, butil::memory_order_release);
LOG_IF(INFO, FLAGS_rdma_trace_verbose)
<< "Server handshake ends (use tcp) on " << s->description();
}
Expand Down Expand Up @@ -712,7 +718,8 @@ friend class RdmaEndpoint;
// blocks or first max_len bytes.
// Return: the bytes included in the sglist, or -1 if failed
ssize_t cut_into_sglist_and_iobuf(ibv_sge* sglist, size_t* sge_index,
butil::IOBuf* to, size_t max_sge, size_t max_len) {
butil::IOBuf* to, size_t max_sge,
size_t max_len) {
size_t len = 0;
while (*sge_index < max_sge) {
if (len == max_len || _ref_num() == 0) {
Expand Down Expand Up @@ -967,7 +974,7 @@ ssize_t RdmaEndpoint::HandleCompletion(ibv_wc& wc) {
if (wc.byte_len < (uint32_t)FLAGS_rdma_zerocopy_min_size) {
zerocopy = false;
}
CHECK(_state != FALLBACK_TCP);
CHECK_NE(_state.load(butil::memory_order_acquire), FALLBACK_TCP);
if (zerocopy) {
_rbuf[_rq_received].cutn(&_socket->_read_buf, wc.byte_len);
} else {
Expand Down Expand Up @@ -1586,7 +1593,7 @@ void RdmaEndpoint::PollCq(Socket* m) {
}

std::string RdmaEndpoint::GetStateStr() const {
switch (_state) {
switch (_state.load(butil::memory_order_relaxed)) {
case UNINIT: return "UNINIT";
case C_ALLOC_QPCQ: return "C_ALLOC_QPCQ";
case C_HELLO_SEND: return "C_HELLO_SEND";
Expand Down
4 changes: 2 additions & 2 deletions src/brpc/rdma/rdma_endpoint.h
Original file line number Diff line number Diff line change
Expand Up @@ -250,7 +250,7 @@ friend int v3_wire::WriteV3Hello(RdmaEndpoint*, const RdmaHello&);
std::string GetStateStr() const;

// Try to read data on TCP fd in _socket
inline void TryReadOnTcp();
void TryReadOnTcp();

// Add cq socket id to poller
void PollerAddCqSid();
Expand All @@ -262,7 +262,7 @@ friend int v3_wire::WriteV3Hello(RdmaEndpoint*, const RdmaHello&);
Socket* _socket;

// State of Handshake
State _state;
butil::atomic<State> _state;

// Wire-level handshake protocol version (set by dispatch in
// ProcessHandshakeAtClient/Server). Aligned with the protocol code:
Expand Down
Loading