diff --git a/pkg/server/conn.go b/pkg/server/conn.go index 9b866d8a38c16..7a6a1dded21ba 100644 --- a/pkg/server/conn.go +++ b/pkg/server/conn.go @@ -2099,9 +2099,17 @@ func (cc *clientConn) handleStmt( //nolint: errcheck rs.Finish() }) + fn := func() bool { + if cc.bufReadConn != nil { + return cc.bufReadConn.IsAlive() != 0 + } + return true + } + cc.ctx.GetSessionVars().SQLKiller.IsConnectionAlive.Store(&fn) cc.ctx.GetSessionVars().SQLKiller.InWriteResultSet.Store(true) defer cc.ctx.GetSessionVars().SQLKiller.InWriteResultSet.Store(false) defer cc.ctx.GetSessionVars().SQLKiller.ClearFinishFunc() + defer cc.ctx.GetSessionVars().SQLKiller.IsConnectionAlive.Store(nil) if retryable, err := cc.writeResultSet(ctx, rs, false, status, 0); err != nil { return retryable, err } diff --git a/pkg/server/conn_stmt.go b/pkg/server/conn_stmt.go index 45c7ba35dee6f..740e0bb6c8ebf 100644 --- a/pkg/server/conn_stmt.go +++ b/pkg/server/conn_stmt.go @@ -233,6 +233,17 @@ func (cc *clientConn) executePlanCacheStmt(ctx context.Context, stmt any, args [ ctx = context.WithValue(ctx, execdetails.StmtExecDetailKey, &execdetails.StmtExecDetails{}) ctx = context.WithValue(ctx, util.ExecDetailsKey, &util.ExecDetails{}) ctx = context.WithValue(ctx, util.RUDetailsCtxKey, util.NewRUDetails()) + + fn := func() bool { + if cc.bufReadConn != nil { + return cc.bufReadConn.IsAlive() != 0 + } + return true + } + cc.ctx.GetSessionVars().SQLKiller.IsConnectionAlive.Store(&fn) + defer cc.ctx.GetSessionVars().SQLKiller.IsConnectionAlive.Store(nil) + + //nolint:forcetypeassert retryable, err := cc.executePreparedStmtAndWriteResult(ctx, stmt.(PreparedStatement), args, useCursor) if err != nil { action, txnErr := sessiontxn.GetTxnManager(&cc.ctx).OnStmtErrorForNextAction(ctx, sessiontxn.StmtErrAfterQuery, err) diff --git a/pkg/server/internal/util/buffered_read_conn.go b/pkg/server/internal/util/buffered_read_conn.go index d6bc1f24c9bc4..9629039638dd1 100644 --- a/pkg/server/internal/util/buffered_read_conn.go +++ b/pkg/server/internal/util/buffered_read_conn.go @@ -16,7 +16,10 @@ package util import ( "bufio" + "io" "net" + "sync" + "time" ) // DefaultReaderSize is the default size of bufio.Reader. @@ -26,11 +29,15 @@ const DefaultReaderSize = 16 * 1024 type BufferedReadConn struct { net.Conn rb *bufio.Reader + // `mu` is for `IsAlive()` function. + // We use this to ensure that `SetReadDeadline` is not called concurrently. + mu *sync.Mutex } // NewBufferedReadConn creates a BufferedReadConn. func NewBufferedReadConn(conn net.Conn) *BufferedReadConn { return &BufferedReadConn{ + mu: &sync.Mutex{}, Conn: conn, rb: bufio.NewReaderSize(conn, DefaultReaderSize), } @@ -40,3 +47,42 @@ func NewBufferedReadConn(conn net.Conn) *BufferedReadConn { func (conn BufferedReadConn) Read(b []byte) (n int, err error) { return conn.rb.Read(b) } + +// Peek peeks from the connection. +func (conn BufferedReadConn) Peek(n int) ([]byte, error) { + return conn.rb.Peek(n) +} + +// IsAlive detects the connection is alive or not. +// return value < 0, means unknown +// return value = 0, means not alive +// return value = 1, means still alive +func (conn BufferedReadConn) IsAlive() int { + if conn.mu.TryLock() { + defer conn.mu.Unlock() + err := conn.SetReadDeadline(time.Now().Add(30 * time.Microsecond)) + if err != nil { + return -1 + } + // nolint:errcheck + defer conn.SetReadDeadline(time.Time{}) + // At the TCP level, a successful `Peek` operation doesn't guarantee + // the connection remains active. However, in the MySQL protocol, + // clients shouldn't send new data while the server is processing SQL. + // Therefore, we can safely assume `Peek` won't intercept any data + // during this period. Even if `Peek` does capture data, it only means + // the liveness check might be inaccurate - this won't impact the + // actual connection state or its operations. + _, err = conn.Peek(1) + if err == nil { + return 1 + } + if err == io.EOF { + return 0 + } else if ne, ok := err.(net.Error); ok && ne.Timeout() { + return 1 + } + return 0 + } + return -1 +} diff --git a/pkg/server/tests/commontest/tidb_test.go b/pkg/server/tests/commontest/tidb_test.go index 8bc38e57af1d1..bcf80f3d6c89e 100644 --- a/pkg/server/tests/commontest/tidb_test.go +++ b/pkg/server/tests/commontest/tidb_test.go @@ -31,6 +31,7 @@ import ( "sync/atomic" "testing" "time" + "unsafe" "github.com/go-sql-driver/mysql" "github.com/pingcap/errors" @@ -3707,3 +3708,82 @@ func TestAuditPluginRetrying(t *testing.T) { runExplicitTransactionRetry(db, true) }) } + +func TestIssue57531(t *testing.T) { + ts := servertestkit.CreateTidbTestSuite(t) + + processlistCount := func(dbt *testkit.DBTestKit) int { + rsCnt := 0 + rs := dbt.MustQuery("show processlist") + for rs.Next() { + rsCnt++ + } + require.NoError(t, rs.Err()) + require.NoError(t, rs.Close()) + return rsCnt + } + + for i := range 2 { + ts.RunTests(t, nil, func(dbt *testkit.DBTestKit) { + var netConn net.Conn + conn, err := dbt.GetDB().Conn(context.Background()) + require.NoError(t, err) + defer func() { + _ = conn.Close() + }() + + // get the TCP connection + err = conn.Raw(func(driverConn any) error { + v := reflect.ValueOf(driverConn) + if v.Kind() == reflect.Ptr { + v = v.Elem() + } + f := v.FieldByName("netConn") + if f.IsValid() && f.Type().Implements(reflect.TypeOf((*net.Conn)(nil)).Elem()) { + netConn = *(*net.Conn)(unsafe.Pointer(f.UnsafeAddr())) + } + return nil + }) + require.NoError(t, err) + require.NotNil(t, netConn) + + // execute `select sleep(300)` + queryDone := make(chan struct{}) + go func() { + defer close(queryDone) + if i == 0 { + rows, err := conn.QueryContext(context.Background(), "select sleep(300)") + if err == nil { + _ = rows.Close() + } + } else { + stmt, err := conn.PrepareContext(context.Background(), "select sleep(?)") + if err == nil { + defer stmt.Close() + _, _ = stmt.Exec(300) + } + } + }() + + // have two sessions + require.Eventually(t, func() bool { + return processlistCount(dbt) == 2 + }, time.Second, time.Millisecond*10) + + // close tcp connection + require.NoError(t, netConn.Close()) + + select { + case <-queryDone: + case <-time.After(time.Second * 3): + require.Fail(t, "query did not exit after closing the TCP connection") + } + _ = conn.Close() + + // the `select sleep(300)` is killed + require.Eventually(t, func() bool { + return processlistCount(dbt) == 1 + }, time.Second*3, time.Millisecond*10) + }) + } +} diff --git a/pkg/util/deeptest/statictesthelper.go b/pkg/util/deeptest/statictesthelper.go index b9493444037ae..9a66d2e4f7e9e 100644 --- a/pkg/util/deeptest/statictesthelper.go +++ b/pkg/util/deeptest/statictesthelper.go @@ -156,6 +156,19 @@ func (h *staticTestHelper) assertDeepClonedEqual(t require.TestingT, valA, valB require.NotEqual(t, valA.Pointer(), valB.Pointer(), path+" should be different") h.assertDeepClonedEqual(t, valA.Elem(), valB.Elem(), path) } + case reflect.UnsafePointer: + if valA.IsNil() && valB.IsNil() { + return + } + // both of them are not nil + require.NotEqual(t, 0, valA.Pointer(), path+" should not be nil") + require.NotEqual(t, 0, valB.Pointer(), path+" should not be nil") + + if h.shouldComparePointer(path) { + require.Equal(t, valA.Pointer(), valB.Pointer(), path+" should be the same") + } else { + require.NotEqual(t, valA.Pointer(), valB.Pointer(), path+" should be different") + } case reflect.Slice: if valA.IsNil() && valB.IsNil() { return diff --git a/pkg/util/sqlkiller/BUILD.bazel b/pkg/util/sqlkiller/BUILD.bazel index 5a4eacf70afd6..5a3a76d5a558e 100644 --- a/pkg/util/sqlkiller/BUILD.bazel +++ b/pkg/util/sqlkiller/BUILD.bazel @@ -7,6 +7,7 @@ go_library( visibility = ["//visibility:public"], deps = [ "//pkg/util/dbterror/exeerrors", + "//pkg/util/intest", "//pkg/util/logutil", "@com_github_pingcap_failpoint//:failpoint", "@org_uber_go_zap//:zap", diff --git a/pkg/util/sqlkiller/sqlkiller.go b/pkg/util/sqlkiller/sqlkiller.go index da81f99ee535d..c674a1c2f02f9 100644 --- a/pkg/util/sqlkiller/sqlkiller.go +++ b/pkg/util/sqlkiller/sqlkiller.go @@ -18,9 +18,11 @@ import ( "math/rand" "sync" "sync/atomic" + "time" "github.com/pingcap/failpoint" "github.com/pingcap/tidb/pkg/util/dbterror/exeerrors" + "github.com/pingcap/tidb/pkg/util/intest" "github.com/pingcap/tidb/pkg/util/logutil" "go.uber.org/zap" ) @@ -51,6 +53,9 @@ type SQLKiller struct { // InWriteResultSet is used to indicate whether the query is currently calling clientConn.writeResultSet(). // If the query is in writeResultSet and Finish() can acquire rs.finishLock, we can assume the query is waiting for the client to receive data from the server over network I/O. InWriteResultSet atomic.Bool + + lastCheckTime atomic.Pointer[time.Time] + IsConnectionAlive atomic.Pointer[func() bool] } // SendKillSignal sends a kill signal to the query. @@ -122,6 +127,27 @@ func (killer *SQLKiller) HandleSignal() error { } } }) + + // Checks if the connection is alive. + // For performance reasons, the check interval should be at least `checkConnectionAliveDur`(1 second). + fn := killer.IsConnectionAlive.Load() + lastCheckTime := killer.lastCheckTime.Load() + if fn != nil { + var checkConnectionAliveDur time.Duration = time.Second + now := time.Now() + if intest.InTest { + checkConnectionAliveDur = time.Millisecond + } + if lastCheckTime == nil { + killer.lastCheckTime.Store(&now) + } else if now.Sub(*lastCheckTime) > checkConnectionAliveDur { + killer.lastCheckTime.Store(&now) + if !(*fn)() { + atomic.CompareAndSwapUint32(&killer.Signal, 0, QueryInterrupted) + } + } + } + status := atomic.LoadUint32(&killer.Signal) err := killer.getKillError(status) if status == ServerMemoryExceeded { @@ -137,4 +163,5 @@ func (killer *SQLKiller) Reset() { logutil.BgLogger().Warn("kill finished", zap.Uint64("conn", killer.ConnID.Load())) } atomic.StoreUint32(&killer.Signal, 0) + killer.lastCheckTime.Store(nil) }