Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions pkg/server/conn.go
Original file line number Diff line number Diff line change
Expand Up @@ -2099,9 +2099,17 @@ func (cc *clientConn) handleStmt(
//nolint: errcheck
rs.Finish()
})
fn := func() bool {
if cc.bufReadConn != nil {
return cc.bufReadConn.IsAlive() != 0
}
return true
}
cc.ctx.GetSessionVars().SQLKiller.IsConnectionAlive.Store(&fn)
cc.ctx.GetSessionVars().SQLKiller.InWriteResultSet.Store(true)
defer cc.ctx.GetSessionVars().SQLKiller.InWriteResultSet.Store(false)
defer cc.ctx.GetSessionVars().SQLKiller.ClearFinishFunc()
defer cc.ctx.GetSessionVars().SQLKiller.IsConnectionAlive.Store(nil)
if retryable, err := cc.writeResultSet(ctx, rs, false, status, 0); err != nil {
return retryable, err
}
Expand Down
11 changes: 11 additions & 0 deletions pkg/server/conn_stmt.go
Original file line number Diff line number Diff line change
Expand Up @@ -233,6 +233,17 @@ func (cc *clientConn) executePlanCacheStmt(ctx context.Context, stmt any, args [
ctx = context.WithValue(ctx, execdetails.StmtExecDetailKey, &execdetails.StmtExecDetails{})
ctx = context.WithValue(ctx, util.ExecDetailsKey, &util.ExecDetails{})
ctx = context.WithValue(ctx, util.RUDetailsCtxKey, util.NewRUDetails())

fn := func() bool {
if cc.bufReadConn != nil {
return cc.bufReadConn.IsAlive() != 0
}
return true
}
cc.ctx.GetSessionVars().SQLKiller.IsConnectionAlive.Store(&fn)
defer cc.ctx.GetSessionVars().SQLKiller.IsConnectionAlive.Store(nil)

//nolint:forcetypeassert
retryable, err := cc.executePreparedStmtAndWriteResult(ctx, stmt.(PreparedStatement), args, useCursor)
if err != nil {
action, txnErr := sessiontxn.GetTxnManager(&cc.ctx).OnStmtErrorForNextAction(ctx, sessiontxn.StmtErrAfterQuery, err)
Expand Down
46 changes: 46 additions & 0 deletions pkg/server/internal/util/buffered_read_conn.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,10 @@ package util

import (
"bufio"
"io"
"net"
"sync"
"time"
)

// DefaultReaderSize is the default size of bufio.Reader.
Expand All @@ -26,11 +29,15 @@ const DefaultReaderSize = 16 * 1024
type BufferedReadConn struct {
net.Conn
rb *bufio.Reader
// `mu` is for `IsAlive()` function.
// We use this to ensure that `SetReadDeadline` is not called concurrently.
mu *sync.Mutex
}

// NewBufferedReadConn creates a BufferedReadConn.
func NewBufferedReadConn(conn net.Conn) *BufferedReadConn {
return &BufferedReadConn{
mu: &sync.Mutex{},
Conn: conn,
rb: bufio.NewReaderSize(conn, DefaultReaderSize),
}
Expand All @@ -40,3 +47,42 @@ func NewBufferedReadConn(conn net.Conn) *BufferedReadConn {
func (conn BufferedReadConn) Read(b []byte) (n int, err error) {
return conn.rb.Read(b)
}

// Peek peeks from the connection.
func (conn BufferedReadConn) Peek(n int) ([]byte, error) {
return conn.rb.Peek(n)
}

// IsAlive detects the connection is alive or not.
// return value < 0, means unknown
// return value = 0, means not alive
// return value = 1, means still alive
func (conn BufferedReadConn) IsAlive() int {
if conn.mu.TryLock() {
defer conn.mu.Unlock()
err := conn.SetReadDeadline(time.Now().Add(30 * time.Microsecond))
if err != nil {
return -1
}
// nolint:errcheck
defer conn.SetReadDeadline(time.Time{})
// At the TCP level, a successful `Peek` operation doesn't guarantee
// the connection remains active. However, in the MySQL protocol,
// clients shouldn't send new data while the server is processing SQL.
// Therefore, we can safely assume `Peek` won't intercept any data
// during this period. Even if `Peek` does capture data, it only means
// the liveness check might be inaccurate - this won't impact the
// actual connection state or its operations.
_, err = conn.Peek(1)
if err == nil {
return 1
}
if err == io.EOF {
return 0
} else if ne, ok := err.(net.Error); ok && ne.Timeout() {
return 1
}
return 0
}
return -1
Comment thread
coderabbitai[bot] marked this conversation as resolved.
}
80 changes: 80 additions & 0 deletions pkg/server/tests/commontest/tidb_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ import (
"sync/atomic"
"testing"
"time"
"unsafe"

"github.com/go-sql-driver/mysql"
"github.com/pingcap/errors"
Expand Down Expand Up @@ -3707,3 +3708,82 @@ func TestAuditPluginRetrying(t *testing.T) {
runExplicitTransactionRetry(db, true)
})
}

func TestIssue57531(t *testing.T) {
ts := servertestkit.CreateTidbTestSuite(t)

processlistCount := func(dbt *testkit.DBTestKit) int {
rsCnt := 0
rs := dbt.MustQuery("show processlist")
for rs.Next() {
rsCnt++
}
require.NoError(t, rs.Err())
require.NoError(t, rs.Close())
return rsCnt
}

for i := range 2 {
ts.RunTests(t, nil, func(dbt *testkit.DBTestKit) {
var netConn net.Conn
conn, err := dbt.GetDB().Conn(context.Background())
require.NoError(t, err)
defer func() {
_ = conn.Close()
}()

// get the TCP connection
err = conn.Raw(func(driverConn any) error {
v := reflect.ValueOf(driverConn)
if v.Kind() == reflect.Ptr {
v = v.Elem()
}
f := v.FieldByName("netConn")
if f.IsValid() && f.Type().Implements(reflect.TypeOf((*net.Conn)(nil)).Elem()) {
netConn = *(*net.Conn)(unsafe.Pointer(f.UnsafeAddr()))
}
return nil
})
require.NoError(t, err)
require.NotNil(t, netConn)

// execute `select sleep(300)`
queryDone := make(chan struct{})
go func() {
defer close(queryDone)
if i == 0 {
rows, err := conn.QueryContext(context.Background(), "select sleep(300)")
if err == nil {
_ = rows.Close()
}
} else {
stmt, err := conn.PrepareContext(context.Background(), "select sleep(?)")
if err == nil {
defer stmt.Close()
_, _ = stmt.Exec(300)
}
}
}()

// have two sessions
require.Eventually(t, func() bool {
return processlistCount(dbt) == 2
}, time.Second, time.Millisecond*10)

// close tcp connection
require.NoError(t, netConn.Close())

select {
case <-queryDone:
case <-time.After(time.Second * 3):
require.Fail(t, "query did not exit after closing the TCP connection")
}
_ = conn.Close()

// the `select sleep(300)` is killed
require.Eventually(t, func() bool {
return processlistCount(dbt) == 1
}, time.Second*3, time.Millisecond*10)
Comment on lines +3729 to +3786

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major | ⚡ Quick win

Wait for the sleep statement itself, not just a second session.

Line 3769 can become true immediately after Line 3729 opens conn, before select sleep(...) is actually running. That makes this a false-positive-prone regression test: it can pass by opening and closing an idle connection, without ever exercising the long-running query path. Please wait until the target connection shows a non-empty Info/select sleep... in SHOW PROCESSLIST (or information_schema.processlist) before closing the socket, and ideally assert that the goroutine exits because of that disconnect.

Suggested tightening
-			// have two sessions
-			require.Eventually(t, func() bool {
-				return processlistCount(dbt) == 2
-			}, time.Second, time.Millisecond*10)
+			hasRunningSleep := func() bool {
+				rs := dbt.MustQuery("select info from information_schema.processlist where info like 'select sleep%'")
+				defer func() {
+					require.NoError(t, rs.Close())
+				}()
+				ok := rs.Next()
+				require.NoError(t, rs.Err())
+				return ok
+			}
+			require.Eventually(t, hasRunningSleep, time.Second, time.Millisecond*10)

As per coding guidelines, test changes should stay minimal and deterministic.

🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

In `@pkg/server/tests/commontest/tidb_test.go` around lines 3729 - 3786, The test
currently waits only for processlistCount(dbt)==2 which can be true before the
long-running query starts; modify the setup so that before closing netConn you
poll the processlist and wait until the specific session shows a non-empty Info
containing "sleep" (e.g., query text from SHOW PROCESSLIST or
information_schema.processlist) for the connection created by conn (use its
connection id or other unique marker), then close netConn and assert the
goroutine tied to queryDone exits due to that disconnect; update the
require.Eventually call around processlistCount(dbt) to instead check the
presence of the "sleep" Info for the target session and keep the later
assertions (waiting on queryDone and verifying processlistCount returns to 1)
intact.

})
}
}
13 changes: 13 additions & 0 deletions pkg/util/deeptest/statictesthelper.go
Original file line number Diff line number Diff line change
Expand Up @@ -156,6 +156,19 @@ func (h *staticTestHelper) assertDeepClonedEqual(t require.TestingT, valA, valB
require.NotEqual(t, valA.Pointer(), valB.Pointer(), path+" should be different")
h.assertDeepClonedEqual(t, valA.Elem(), valB.Elem(), path)
}
case reflect.UnsafePointer:
if valA.IsNil() && valB.IsNil() {
return
}
// both of them are not nil
require.NotEqual(t, 0, valA.Pointer(), path+" should not be nil")
require.NotEqual(t, 0, valB.Pointer(), path+" should not be nil")

if h.shouldComparePointer(path) {
require.Equal(t, valA.Pointer(), valB.Pointer(), path+" should be the same")
} else {
require.NotEqual(t, valA.Pointer(), valB.Pointer(), path+" should be different")
}
case reflect.Slice:
if valA.IsNil() && valB.IsNil() {
return
Expand Down
1 change: 1 addition & 0 deletions pkg/util/sqlkiller/BUILD.bazel
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ go_library(
visibility = ["//visibility:public"],
deps = [
"//pkg/util/dbterror/exeerrors",
"//pkg/util/intest",
"//pkg/util/logutil",
"@com_github_pingcap_failpoint//:failpoint",
"@org_uber_go_zap//:zap",
Expand Down
27 changes: 27 additions & 0 deletions pkg/util/sqlkiller/sqlkiller.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,9 +18,11 @@ import (
"math/rand"
"sync"
"sync/atomic"
"time"

"github.com/pingcap/failpoint"
"github.com/pingcap/tidb/pkg/util/dbterror/exeerrors"
"github.com/pingcap/tidb/pkg/util/intest"
"github.com/pingcap/tidb/pkg/util/logutil"
"go.uber.org/zap"
)
Expand Down Expand Up @@ -51,6 +53,9 @@ type SQLKiller struct {
// InWriteResultSet is used to indicate whether the query is currently calling clientConn.writeResultSet().
// If the query is in writeResultSet and Finish() can acquire rs.finishLock, we can assume the query is waiting for the client to receive data from the server over network I/O.
InWriteResultSet atomic.Bool

lastCheckTime atomic.Pointer[time.Time]
IsConnectionAlive atomic.Pointer[func() bool]
}

// SendKillSignal sends a kill signal to the query.
Expand Down Expand Up @@ -122,6 +127,27 @@ func (killer *SQLKiller) HandleSignal() error {
}
}
})

// Checks if the connection is alive.
// For performance reasons, the check interval should be at least `checkConnectionAliveDur`(1 second).
fn := killer.IsConnectionAlive.Load()
lastCheckTime := killer.lastCheckTime.Load()
if fn != nil {
var checkConnectionAliveDur time.Duration = time.Second
now := time.Now()
if intest.InTest {
checkConnectionAliveDur = time.Millisecond
}
if lastCheckTime == nil {
killer.lastCheckTime.Store(&now)
} else if now.Sub(*lastCheckTime) > checkConnectionAliveDur {
killer.lastCheckTime.Store(&now)
if !(*fn)() {
atomic.CompareAndSwapUint32(&killer.Signal, 0, QueryInterrupted)
}
}
}

status := atomic.LoadUint32(&killer.Signal)
err := killer.getKillError(status)
if status == ServerMemoryExceeded {
Expand All @@ -137,4 +163,5 @@ func (killer *SQLKiller) Reset() {
logutil.BgLogger().Warn("kill finished", zap.Uint64("conn", killer.ConnID.Load()))
}
atomic.StoreUint32(&killer.Signal, 0)
killer.lastCheckTime.Store(nil)
}