diff --git a/pkg/executor/BUILD.bazel b/pkg/executor/BUILD.bazel index 60400e4525ea4..ccb921c17ed3e 100644 --- a/pkg/executor/BUILD.bazel +++ b/pkg/executor/BUILD.bazel @@ -20,6 +20,7 @@ go_library( "brie.go", "brie_utils.go", "builder.go", + "cached_result_exec.go", "check_table_index.go", "checksum.go", "compact_table.go", @@ -349,6 +350,7 @@ go_test( "brie_test.go", "brie_utils_test.go", "builder_index_join_cleanup_test.go", + "cached_result_exec_test.go", "checksum_test.go", "chunk_size_control_test.go", "cluster_table_test.go", @@ -379,6 +381,7 @@ go_test( "inspection_summary_test.go", "join_pkg_test.go", "main_test.go", + "mem_reader_test.go", "memtable_reader_test.go", "metrics_reader_test.go", "parallel_apply_test.go", @@ -521,6 +524,7 @@ go_test( "//pkg/util/paging", "//pkg/util/ranger", "//pkg/util/regionsplit", + "//pkg/util/rowcodec", "//pkg/util/sem", "//pkg/util/sem/v2:sem", "//pkg/util/set", diff --git a/pkg/executor/adapter.go b/pkg/executor/adapter.go index 698f9e5ceb3b0..8173c704b2d2b 100644 --- a/pkg/executor/adapter.go +++ b/pkg/executor/adapter.go @@ -1409,6 +1409,10 @@ func (a *ExecStmt) buildExecutor() (exec.Executor, error) { ctx.GetSessionVars().StmtCtx.Priority = kv.PriorityLow } e = executorExec.stmtExec + } else { + // For non-prepared queries, wrap with result set cache here. + // Prepared statements are wrapped inside ExecuteExec.Build(). + e = b.wrapWithResultCache(e, a.StmtNode, a.Plan) } a.isSelectForUpdate = b.hasLock && (!stmtCtx.InDeleteStmt && !stmtCtx.InUpdateStmt && !stmtCtx.InInsertStmt) return e, nil diff --git a/pkg/executor/adapter_slow_log.go b/pkg/executor/adapter_slow_log.go index 50647d048cc75..dacb4c481bf3b 100644 --- a/pkg/executor/adapter_slow_log.go +++ b/pkg/executor/adapter_slow_log.go @@ -260,6 +260,7 @@ func SetSlowLogItems(a *ExecStmt, txnTS uint64, hasMoreResults bool, items *vari items.ResultRows = stmtCtx.GetResultRowsCount() items.IsExplicitTxn = sessVars.TxnCtx.IsExplicit items.IsWriteCacheTable = stmtCtx.WaitLockLeaseTime > 0 + items.ResultCacheHit = stmtCtx.ReadFromResultCache items.UsedStats = stmtCtx.GetUsedStatsInfo(false) items.IsSyncStatsFailed = stmtCtx.IsSyncStatsFailed items.Warnings = variable.CollectWarningsForSlowLog(stmtCtx) diff --git a/pkg/executor/builder.go b/pkg/executor/builder.go index a0bff5d4ea9fe..0648a4838cc73 100644 --- a/pkg/executor/builder.go +++ b/pkg/executor/builder.go @@ -128,6 +128,16 @@ type executorBuilder struct { // Used when building MPPGather. encounterUnionScan bool + // cachedTbl is set when a cached table's KV cache is hit during building. + // Used to attach a result set cache wrapper around the final executor. + cachedTbl table.CachedTable + // cachedTblID records the first cached table ID observed in the plan. + // Used to disable result cache when multiple cached tables are involved. + cachedTblID int64 + // disableResultCache indicates the plan touches multiple cached tables. + // Result set cache entries are scoped to a single cached table instance. + disableResultCache bool + // stmtCtxLock guards statement context and telemetry updates when executor building happens concurrently. // It is only set for dataReaderBuilder instances used by index join inner workers. stmtCtxLock *sync.Mutex @@ -1492,6 +1502,61 @@ func collectColIdxFromByItems(byItems []*plannerutil.ByItems, cols []*model.Colu return colIdxs, nil } +// removeRedundantAccessConditions removes access conditions from allConds that are +// already satisfied by the index kvRanges and don't need to be reserved (i.e., the +// condition references only full-length index columns). This mirrors the ranger's +// shouldReserve logic: a condition is safe to remove only when all its referenced +// columns are full-length index columns (IdxColLens[i] == types.UnspecifiedLength +// or IdxColLens[i] == col.GetFlen()). +func removeRedundantAccessConditions( + allConds []expression.Expression, + accessConds []expression.Expression, + idxCols []*expression.Column, + idxColLens []int, + evalCtx expression.EvalContext, +) []expression.Expression { + // Build a set of full-length index column UniqueIDs. + fullLenColIDs := make(map[int64]struct{}, len(idxCols)) + for i, col := range idxCols { + if i < len(idxColLens) { + length := idxColLens[i] + if length == types.UnspecifiedLength || length == col.GetType(evalCtx).GetFlen() { + fullLenColIDs[col.UniqueID] = struct{}{} + } + } + } + + // Build the set of canonical hash codes for access conditions that are safe to remove. + safeToRemove := make(map[string]struct{}, len(accessConds)) + for _, ac := range accessConds { + cols := expression.ExtractColumns(ac) + allFullLen := true + for _, col := range cols { + if _, ok := fullLenColIDs[col.UniqueID]; !ok { + allFullLen = false + break + } + } + if allFullLen { + // CanonicalHashCode caches inside expression objects and is not goroutine-safe. + // Clone before hashing to avoid data races when building executors concurrently. + safeToRemove[string(ac.Clone().CanonicalHashCode())] = struct{}{} + } + } + if len(safeToRemove) == 0 { + return allConds + } + + // Filter out conditions whose canonical hash matches a safe-to-remove access condition. + result := make([]expression.Expression, 0, len(allConds)) + for _, cond := range allConds { + if _, ok := safeToRemove[string(cond.Clone().CanonicalHashCode())]; !ok { + result = append(result, cond) + } + } + return result +} + // buildUnionScanFromReader builds union scan executor from child executor. // Note that this function may be called by inner workers of index lookup join concurrently. // Be careful to avoid data race. @@ -1577,6 +1642,18 @@ func (b *executorBuilder) buildUnionScanFromReader(reader exec.Executor, v *phys } } us.conditions, us.conditionsWithVirCol = physicalop.SplitSelCondsWithVirtualColumn(v.Conditions) + // Remove access conditions already satisfied by kvRanges to avoid redundant EvalBool. + if idxReader, ok := v.Children()[0].(*physicalop.PhysicalIndexReader); ok { + if idxScan, ok := idxReader.IndexPlans[0].(*physicalop.PhysicalIndexScan); ok { + if len(idxScan.AccessCondition) > 0 && len(us.conditions) > 0 { + us.conditions = removeRedundantAccessConditions( + us.conditions, idxScan.AccessCondition, + idxScan.IdxCols, idxScan.IdxColLens, + b.ctx.GetExprCtx().GetEvalCtx(), + ) + } + } + } us.columns = x.columns us.partitionIDMap = x.partitionIDMap us.table = x.table @@ -1648,6 +1725,7 @@ type bypassDataSourceExecutor interface { func (us *UnionScanExec) handleCachedTable(b *executorBuilder, x bypassDataSourceExecutor, vars *variable.SessionVars, startTS uint64) { tbl := x.Table() if tbl.Meta().TableCacheStatusType == model.TableCacheStatusEnable { + b.observeCachedTable(tbl.Meta().ID) cachedTable := tbl.(table.CachedTable) // Determine whether the cache can be used. leaseDuration := time.Duration(vardef.TableCacheLease.Load()) * time.Second @@ -1656,6 +1734,41 @@ func (us *UnionScanExec) handleCachedTable(b *executorBuilder, x bypassDataSourc vars.StmtCtx.ReadFromTableCache = true x.setDummy() us.cacheTable = cacheData + // Prefer caches pinned to the same cacheData generation as cacheTable. + if dcp, ok := cachedTable.(interface { + GetCachedDatumDataForMemBuffer(kv.MemBuffer) *tables.CachedDatumData + }); ok { + us.datumCache = dcp.GetCachedDatumDataForMemBuffer(cacheData) + } else if dcp, ok := cachedTable.(interface { + GetCachedDatumData() *tables.CachedDatumData + }); ok { + us.datumCache = dcp.GetCachedDatumData() + } + // Prefer index caches pinned to the same cacheData generation as cacheTable. + if icp, ok := cachedTable.(interface { + GetCachedIndexDatumDataForMemBuffer(kv.MemBuffer, int64) *tables.CachedIndexDatumData + }); ok { + for _, idx := range tbl.Meta().Indices { + if dc := icp.GetCachedIndexDatumDataForMemBuffer(cacheData, idx.ID); dc != nil { + if us.indexDatumCaches == nil { + us.indexDatumCaches = make(map[int64]*tables.CachedIndexDatumData) + } + us.indexDatumCaches[idx.ID] = dc + } + } + } else if icp, ok := cachedTable.(interface { + GetCachedIndexDatumData(int64) *tables.CachedIndexDatumData + }); ok { + for _, idx := range tbl.Meta().Indices { + if dc := icp.GetCachedIndexDatumData(idx.ID); dc != nil { + if us.indexDatumCaches == nil { + us.indexDatumCaches = make(map[int64]*tables.CachedIndexDatumData) + } + us.indexDatumCaches[idx.ID] = dc + } + } + } + b.recordCachedTable(cachedTable) } else if loading { return } else if !b.inUpdateStmt && !b.inDeleteStmt && !b.inInsertStmt && !vars.StmtCtx.InExplainStmt { @@ -6292,20 +6405,81 @@ func (b *executorBuilder) getCacheTable(tblInfo *model.TableInfo, startTS uint64 return nil } sessVars := b.ctx.GetSessionVars() + b.observeCachedTable(tblInfo.ID) leaseDuration := time.Duration(vardef.TableCacheLease.Load()) * time.Second - cacheData, loading := tbl.(table.CachedTable).TryReadFromCache(startTS, leaseDuration) + cachedTable := tbl.(table.CachedTable) + cacheData, loading := cachedTable.TryReadFromCache(startTS, leaseDuration) if cacheData != nil { sessVars.StmtCtx.ReadFromTableCache = true + b.recordCachedTable(cachedTable) return cacheData } else if loading { return nil } if !b.ctx.GetSessionVars().StmtCtx.InExplainStmt && !b.inDeleteStmt && !b.inUpdateStmt { - tbl.(table.CachedTable).UpdateLockForRead(context.Background(), b.ctx.GetStore(), startTS, leaseDuration) + cachedTable.UpdateLockForRead(context.Background(), b.ctx.GetStore(), startTS, leaseDuration) } return nil } +func (b *executorBuilder) observeCachedTable(tableID int64) { + if b.disableResultCache || tableID == 0 { + return + } + if b.cachedTblID == 0 { + b.cachedTblID = tableID + return + } + if b.cachedTblID != tableID { + b.disableResultCache = true + b.cachedTbl = nil + } +} + +func (b *executorBuilder) recordCachedTable(cachedTable table.CachedTable) { + if cachedTable == nil { + return + } + meta := cachedTable.Meta() + if meta == nil { + b.disableResultCache = true + b.cachedTbl = nil + return + } + b.observeCachedTable(meta.ID) + if b.disableResultCache || b.cachedTbl != nil { + return + } + b.cachedTbl = cachedTable +} + +// wrapWithResultCache wraps the top-level executor with CachedResultExec when +// the query is eligible for result set caching on a cached table. +func (b *executorBuilder) wrapWithResultCache(e exec.Executor, stmtNode ast.StmtNode, plan base.Plan) exec.Executor { + if b.cachedTbl == nil || b.disableResultCache { + return e + } + inDML := b.inUpdateStmt || b.inDeleteStmt || b.inInsertStmt + physPlan, ok := plan.(base.PhysicalPlan) + if !ok { + return e + } + if !plannercore.CanCacheResultSet(stmtNode, physPlan, inDML) { + return e + } + key, paramBytes, ok := plannercore.BuildResultCacheKey(b.ctx) + if !ok { + return e + } + return &CachedResultExec{ + BaseExecutor: exec.NewBaseExecutor(b.ctx, e.Schema(), physPlan.ID(), e), + original: e, + cachedTable: b.cachedTbl, + cacheKey: key, + paramBytes: paramBytes, + } +} + func (b *executorBuilder) buildCompactTable(v *plannercore.CompactTable) exec.Executor { if v.ReplicaKind != ast.CompactReplicaKindTiFlash && v.ReplicaKind != ast.CompactReplicaKindAll { b.err = errors.Errorf("compact %v replica is not supported", strings.ToLower(string(v.ReplicaKind))) diff --git a/pkg/executor/cached_result_exec.go b/pkg/executor/cached_result_exec.go new file mode 100644 index 0000000000000..406e8953fb7e5 --- /dev/null +++ b/pkg/executor/cached_result_exec.go @@ -0,0 +1,168 @@ +// Copyright 2025 PingCAP, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package executor + +import ( + "context" + + "github.com/pingcap/tidb/pkg/executor/internal/exec" + "github.com/pingcap/tidb/pkg/metrics" + "github.com/pingcap/tidb/pkg/table" + "github.com/pingcap/tidb/pkg/types" + "github.com/pingcap/tidb/pkg/util/chunk" + "github.com/pingcap/tidb/pkg/util/execdetails" +) + +// CachedResultExec wraps an existing executor and adds result set caching +// for cached table queries. On Open, it checks the result cache; on hit it +// serves chunks directly from memory. On miss it delegates to the original +// executor and collects the result for cache back-fill on Close. +type CachedResultExec struct { + exec.BaseExecutor + + original exec.Executor + cachedTable table.CachedTable + cacheKey table.ResultCacheKey + paramBytes []byte + + // cache hit state + hitCache bool + cachedChunks []*chunk.Chunk + chunkIdx int + + // cache miss state: collect results for back-fill + collecting bool + collectedChunks []*chunk.Chunk + resultSchema []*types.FieldType +} + +// Open checks the result cache before opening the wrapped executor. +func (e *CachedResultExec) Open(ctx context.Context) error { + // Reset state in case Open is called multiple times. + e.hitCache = false + e.cachedChunks = nil + e.chunkIdx = 0 + e.collecting = false + e.collectedChunks = nil + e.resultSchema = nil + + chunks, fieldTypes, ok := e.cachedTable.GetCachedResult(e.cacheKey, e.paramBytes) + if ok && schemaMatch(fieldTypes, e.RetFieldTypes()) { + e.hitCache = true + e.cachedChunks = chunks + e.chunkIdx = 0 + metrics.ResultCacheHitCounter.Inc() + e.Ctx().GetSessionVars().StmtCtx.ReadFromResultCache = true + // Register runtime stats for EXPLAIN ANALYZE. + if coll := e.Ctx().GetSessionVars().StmtCtx.RuntimeStatsColl; coll != nil { + var cachedRows int64 + for _, chk := range chunks { + cachedRows += int64(chk.NumRows()) + } + coll.RegisterStats(e.ID(), &execdetails.ResultCacheRuntimeStats{ + HitCache: true, + CachedRows: cachedRows, + }) + } + return nil + } + + // Cache miss — open the original executor. + metrics.ResultCacheMissCounter.Inc() + if coll := e.Ctx().GetSessionVars().StmtCtx.RuntimeStatsColl; coll != nil { + coll.RegisterStats(e.ID(), &execdetails.ResultCacheRuntimeStats{HitCache: false}) + } + if err := e.original.Open(ctx); err != nil { + return err + } + e.collecting = true + e.resultSchema = e.RetFieldTypes() + return nil +} + +// Next returns the next chunk of results, either from cache or the original executor. +func (e *CachedResultExec) Next(ctx context.Context, req *chunk.Chunk) error { + if e.hitCache { + return e.nextFromCache(req) + } + + err := e.original.Next(ctx, req) + if err != nil { + e.collecting = false + return err + } + + if e.collecting && req.NumRows() > 0 { + // Deep copy: the session reuses the chunk memory between Next calls. + copied := req.CopyConstruct() + e.collectedChunks = append(e.collectedChunks, copied) + } + + return nil +} + +// Close back-fills the result cache on miss and closes the original executor. +func (e *CachedResultExec) Close() error { + if e.hitCache { + // Original executor was never opened. + return nil + } + + if err := e.original.Close(); err != nil { + e.collecting = false + return err + } + + // Back-fill the cache only after the wrapped executor closes successfully. + if e.collecting { + e.cachedTable.PutCachedResult(e.cacheKey, e.paramBytes, e.collectedChunks, e.resultSchema) + } + + return nil +} + +// nextFromCache serves chunks from the cached result set. +func (e *CachedResultExec) nextFromCache(req *chunk.Chunk) error { + req.Reset() + if e.chunkIdx >= len(e.cachedChunks) { + return nil // EOF + } + src := e.cachedChunks[e.chunkIdx] + e.chunkIdx++ + // Copy into req so we don't hand out shared cache memory to the session. + req.Append(src, 0, src.NumRows()) + return nil +} + +// schemaMatch returns true when the cached field types are compatible with +// the current executor's output schema. This guards against schema changes +// (e.g. DDL altering a column type) that would make a cached result invalid. +func schemaMatch(cached, current []*types.FieldType) bool { + if len(cached) != len(current) { + return false + } + for i := range cached { + if cached[i] == nil || current[i] == nil { + if cached[i] != current[i] { + return false + } + continue + } + if !cached[i].Equal(current[i]) { + return false + } + } + return true +} diff --git a/pkg/executor/cached_result_exec_test.go b/pkg/executor/cached_result_exec_test.go new file mode 100644 index 0000000000000..543c959c9207a --- /dev/null +++ b/pkg/executor/cached_result_exec_test.go @@ -0,0 +1,129 @@ +// Copyright 2025 PingCAP, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package executor + +import ( + "context" + "errors" + "testing" + "time" + + "github.com/pingcap/tidb/pkg/executor/internal/exec" + "github.com/pingcap/tidb/pkg/expression" + "github.com/pingcap/tidb/pkg/kv" + "github.com/pingcap/tidb/pkg/meta/model" + "github.com/pingcap/tidb/pkg/parser/mysql" + "github.com/pingcap/tidb/pkg/table" + "github.com/pingcap/tidb/pkg/table/tables" + "github.com/pingcap/tidb/pkg/types" + "github.com/pingcap/tidb/pkg/util/chunk" + "github.com/pingcap/tidb/pkg/util/mock" + "github.com/pingcap/tidb/pkg/util/sqlexec" + "github.com/stretchr/testify/require" +) + +type mockCachedResultTable struct { + table.Table + putCount int +} + +func (*mockCachedResultTable) Init(sqlexec.SQLExecutor) error { + return nil +} + +func (*mockCachedResultTable) TryReadFromCache(uint64, time.Duration) (kv.MemBuffer, bool) { + return nil, false +} + +func (*mockCachedResultTable) UpdateLockForRead(context.Context, kv.Storage, uint64, time.Duration) { +} + +func (*mockCachedResultTable) WriteLockAndKeepAlive(context.Context, chan struct{}, *uint64, chan error) { +} + +func (*mockCachedResultTable) GetCachedResult(table.ResultCacheKey, []byte) ([]*chunk.Chunk, []*types.FieldType, bool) { + return nil, nil, false +} + +func (t *mockCachedResultTable) PutCachedResult(table.ResultCacheKey, []byte, []*chunk.Chunk, []*types.FieldType) bool { + t.putCount++ + return true +} + +type mockResultSourceExec struct { + exec.BaseExecutor + rows [][]types.Datum + cursor int + closeErr error +} + +func (e *mockResultSourceExec) Open(context.Context) error { + e.cursor = 0 + return nil +} + +func (e *mockResultSourceExec) Next(_ context.Context, req *chunk.Chunk) error { + req.GrowAndReset(e.MaxChunkSize()) + if e.cursor >= len(e.rows) { + return nil + } + for i := range e.rows[e.cursor] { + req.AppendDatum(i, &e.rows[e.cursor][i]) + } + e.cursor++ + return nil +} + +func (e *mockResultSourceExec) Close() error { + return e.closeErr +} + +func TestCachedResultExecCloseErrorDoesNotBackfill(t *testing.T) { + sctx := mock.NewContext() + ft := types.NewFieldType(mysql.TypeLonglong) + col := &expression.Column{UniqueID: 1, Index: 0, RetType: ft} + schema := expression.NewSchema(col) + + tblInfo := &model.TableInfo{ID: 1, Columns: []*model.ColumnInfo{{ID: 1, Offset: 0, FieldType: *ft}}} + cachedTbl := &mockCachedResultTable{Table: tables.MockTableFromMeta(tblInfo)} + + original := &mockResultSourceExec{ + BaseExecutor: exec.NewBaseExecutor(sctx, schema, 1), + rows: [][]types.Datum{{types.NewIntDatum(7)}}, + closeErr: errors.New("close failed"), + } + + wrapped := &CachedResultExec{ + BaseExecutor: exec.NewBaseExecutor(sctx, schema, 2, original), + original: original, + cachedTable: cachedTbl, + cacheKey: table.ResultCacheKey{ParamHash: 1}, + } + + require.NoError(t, wrapped.Open(context.Background())) + + req := wrapped.NewChunk() + require.NoError(t, wrapped.Next(context.Background(), req)) + require.Equal(t, 1, req.NumRows()) + require.Equal(t, int64(7), req.GetRow(0).GetInt64(0)) + + require.NoError(t, wrapped.Next(context.Background(), req)) + require.Zero(t, req.NumRows()) + + err := wrapped.Close() + require.EqualError(t, err, "close failed") + require.Zero(t, cachedTbl.putCount) + require.False(t, sctx.GetSessionVars().StmtCtx.ReadFromResultCache) +} diff --git a/pkg/executor/mem_reader.go b/pkg/executor/mem_reader.go index ca106821bb772..dc8c8f5b4dea7 100644 --- a/pkg/executor/mem_reader.go +++ b/pkg/executor/mem_reader.go @@ -15,9 +15,12 @@ package executor import ( + "bytes" "context" + "encoding/binary" "math" "slices" + "time" "github.com/pingcap/errors" "github.com/pingcap/tidb/pkg/distsql" @@ -60,14 +63,20 @@ type memIndexReader struct { retFieldTypes []*types.FieldType outputOffset []int cacheTable kv.MemBuffer + indexDatumCache *tables.CachedIndexDatumData // pre-decoded index datum cache keepOrder bool physTblIDIdx int partitionIDMap map[int64]struct{} compareExec - buf [16]byte - decodeBuff [][]byte - resultRows []types.Datum + buf [16]byte + decodeBuff [][]byte + resultRows []types.Datum + restoredDec *tablecodec.IndexRestoredDecoder // cached restored values decoder + + hdStatus tablecodec.HandleStatus // cached, computed once per scan + loc *time.Location // cached from session vars + evalCtx expression.EvalContext // cached from expr context } func buildMemIndexReader(ctx context.Context, us *UnionScanExec, idxReader *IndexReaderExecutor) *memIndexReader { @@ -80,20 +89,25 @@ func buildMemIndexReader(ctx context.Context, us *UnionScanExec, idxReader *Inde if us.desc { slices.Reverse(kvRanges) } + var indexDatumCache *tables.CachedIndexDatumData + if us.indexDatumCaches != nil { + indexDatumCache = us.indexDatumCaches[idxReader.index.ID] + } return &memIndexReader{ - ctx: us.Ctx(), - index: idxReader.index, - table: idxReader.table.Meta(), - kvRanges: kvRanges, - conditions: us.conditions, - retFieldTypes: exec.RetTypes(us), - outputOffset: outputOffset, - cacheTable: us.cacheTable, - keepOrder: us.keepOrder, - compareExec: us.compareExec, - physTblIDIdx: us.physTblIDIdx, - partitionIDMap: us.partitionIDMap, - resultRows: make([]types.Datum, 0, len(outputOffset)), + ctx: us.Ctx(), + index: idxReader.index, + table: idxReader.table.Meta(), + kvRanges: kvRanges, + conditions: us.conditions, + retFieldTypes: exec.RetTypes(us), + outputOffset: outputOffset, + cacheTable: us.cacheTable, + indexDatumCache: indexDatumCache, + keepOrder: us.keepOrder, + compareExec: us.compareExec, + physTblIDIdx: us.physTblIDIdx, + partitionIDMap: us.partitionIDMap, + resultRows: make([]types.Datum, 0, len(outputOffset)), } } @@ -114,6 +128,9 @@ func (m *memIndexReader) getMemRowsIter(ctx context.Context) (memRowsIter, error tps := m.getTypes() colInfos := tables.BuildRowcodecColInfoForIndexColumns(m.index, m.table) colInfos = tables.TryAppendCommonHandleRowcodecColInfos(colInfos, m.table) + if m.evalCtx == nil { + m.evalCtx = m.ctx.GetExprCtx().GetEvalCtx() + } return &memRowsIterForIndex{ kvIter: kvIter, tps: tps, @@ -155,17 +172,22 @@ func (m *memIndexReader) getMemRows(ctx context.Context) ([][]types.Datum, error colInfos := tables.BuildRowcodecColInfoForIndexColumns(m.index, m.table) colInfos = tables.TryAppendCommonHandleRowcodecColInfos(colInfos, m.table) + if m.evalCtx == nil { + m.evalCtx = m.ctx.GetExprCtx().GetEvalCtx() + } mutableRow := chunk.MutRowFromTypes(m.retFieldTypes) err := iterTxnMemBuffer(m.ctx, m.cacheTable, m.kvRanges, m.desc, func(key, value []byte) error { - data, err := m.decodeIndexKeyValue(key, value, tps, colInfos) + data, err := m.decodeIndexKeyValue(key, value, tps, colInfos, true) if err != nil { return err } - mutableRow.SetDatums(data...) - matched, _, err := expression.EvalBool(m.ctx.GetExprCtx().GetEvalCtx(), m.conditions, mutableRow.ToRow()) - if err != nil || !matched { - return err + if len(m.conditions) > 0 { + mutableRow.SetDatums(data...) + matched, _, err := expression.EvalBool(m.evalCtx, m.conditions, mutableRow.ToRow()) + if err != nil || !matched { + return err + } } m.addedRows = append(m.addedRows, data) m.resultRows = make([]types.Datum, 0, len(data)) @@ -189,13 +211,28 @@ func (m *memIndexReader) getMemRows(ctx context.Context) ([][]types.Datum, error return m.addedRows, nil } -func (m *memIndexReader) decodeIndexKeyValue(key, value []byte, tps []*types.FieldType, colInfos []rowcodec.ColInfo) ([]types.Datum, error) { - hdStatus := tablecodec.HandleDefault - // `HandleIsUnsigned` only affects IntHandle which always has one column. - if mysql.HasUnsignedFlag(tps[len(m.index.Columns)].GetFlag()) { - hdStatus = tablecodec.HandleIsUnsigned +// decodeIndexKeyValue decodes index key/value into datums. +// durableBytes controls whether the encoded bytes produced from restored values must remain valid +// after subsequent DecodeIndexKVEx calls (e.g. when the returned datums are cached or accumulated). +func (m *memIndexReader) decodeIndexKeyValue(key, value []byte, tps []*types.FieldType, colInfos []rowcodec.ColInfo, durableBytes bool) ([]types.Datum, error) { + // Lazy init cached per-scan invariants. + if m.loc == nil { + m.loc = m.ctx.GetSessionVars().Location() + m.hdStatus = tablecodec.HandleDefault + // `HandleIsUnsigned` only affects IntHandle which always has one column. + if mysql.HasUnsignedFlag(tps[len(m.index.Columns)].GetFlag()) { + m.hdStatus = tablecodec.HandleIsUnsigned + } } + // Fast path: use pre-decoded index cache. + if m.indexDatumCache != nil { + if cachedDatums, ok := m.indexDatumCache.Entries[string(key)]; ok { + return m.projectCachedIndexDatums(cachedDatums, key), nil + } + } + + // Slow path: full decode. colsLen := len(m.index.Columns) if m.decodeBuff == nil { m.decodeBuff = make([][]byte, colsLen, colsLen+len(colInfos)) @@ -203,7 +240,13 @@ func (m *memIndexReader) decodeIndexKeyValue(key, value []byte, tps []*types.Fie m.decodeBuff = m.decodeBuff[: colsLen : colsLen+len(colInfos)] } buf := m.buf[:0] - values, err := tablecodec.DecodeIndexKVEx(key, value, colsLen, hdStatus, colInfos, buf, m.decodeBuff) + if m.restoredDec == nil { + m.restoredDec = tablecodec.NewIndexRestoredDecoder(colInfos[:colsLen]) + } + // Restored values are decoded into bytes backed by restoredDec's arena. Some Datum kinds (e.g. JSON/vector) + // decode zero-copy from the encoded bytes, so we must avoid arena reuse if the datums will outlive this call. + m.restoredDec.SetReuseArena(!durableBytes) + values, err := tablecodec.DecodeIndexKVEx(key, value, colsLen, m.hdStatus, colInfos, buf, m.decodeBuff, m.restoredDec) if err != nil { return nil, errors.Trace(err) } @@ -226,15 +269,53 @@ func (m *memIndexReader) decodeIndexKeyValue(key, value []byte, tps []*types.Fie if offset > physTblIDColumnIdx { offset = offset - 1 } - d, err := tablecodec.DecodeColumnValue(values[offset], tps[offset], m.ctx.GetSessionVars().Location()) - if err != nil { + ds = append(ds, types.Datum{}) + if err := tablecodec.DecodeColumnValueWithDatum(values[offset], tps[offset], m.loc, &ds[len(ds)-1]); err != nil { return nil, err } - ds = append(ds, d) } return ds, nil } +// projectCachedIndexDatums applies outputOffset projection to cached datums, +// handles physTblIDIdx, and converts TIMESTAMP columns from UTC to session timezone. +func (m *memIndexReader) projectCachedIndexDatums(cachedDatums []types.Datum, key []byte) []types.Datum { + physTblIDColumnIdx := math.MaxInt64 + if m.physTblIDIdx >= 0 { + physTblIDColumnIdx = m.outputOffset[m.physTblIDIdx] + } + + ds := m.resultRows[:0] + for i, offset := range m.outputOffset { + if m.physTblIDIdx == i { + tid, _, _, _ := tablecodec.DecodeKeyHead(key) + ds = append(ds, types.NewIntDatum(tid)) + continue + } + if offset > physTblIDColumnIdx { + offset = offset - 1 + } + d := cachedDatums[offset] + // Convert TIMESTAMP from UTC to session timezone. + if m.loc != time.UTC && len(m.indexDatumCache.TsColIndices) > 0 { + for _, tsIdx := range m.indexDatumCache.TsColIndices { + if offset == tsIdx && !d.IsNull() { + t := d.GetMysqlTime() + if !t.IsZero() { + // ConvertTimeZone modifies t in place; safe because GetMysqlTime returns a copy. + _ = t.ConvertTimeZone(time.UTC, m.loc) + d.SetMysqlTime(t) + } + break + } + } + } + ds = append(ds, d) + } + m.resultRows = ds + return ds +} + type memTableReader struct { ctx sessionctx.Context table *model.TableInfo @@ -247,6 +328,7 @@ type memTableReader struct { buffer allocBuf pkColIDs []int64 cacheTable kv.MemBuffer + datumCache *tables.CachedDatumData offsets []int keepOrder bool compareExec @@ -310,6 +392,7 @@ func buildMemTableReader(ctx context.Context, us *UnionScanExec, kvRanges []kv.K }, pkColIDs: pkColIDs, cacheTable: us.cacheTable, + datumCache: us.datumCache, keepOrder: us.keepOrder, compareExec: us.compareExec, } @@ -422,15 +505,50 @@ func (m *memTableReader) getMemRowsIter(ctx context.Context) (memRowsIter, error if err != nil { return nil, errors.Trace(err) } + intHandle := !m.table.IsCommonHandle + if m.cacheTable != nil { + // Try pre-decoded datum cache fast path to skip KV decode. + // The datum cache is built from cacheTable only and currently only supports full-table scans. + // If txn membuffer has any key in the scan range or the scan only covers a subset of record keys, + // fall back to KV iteration so pushed-down handle ranges are preserved. + if m.kvRangesCoverFullTable() && !memBufferHasAnyEntryInRanges(kvIter.txn.GetMemBuffer(), m.kvRanges) { + if iter := m.buildDatumCacheIter(); iter != nil { + kvIter.Close() + return iter, nil + } + } + batchChk := chunk.New(m.retFieldTypes, cachedTableBatchSize, cachedTableBatchSize) + return &memRowsBatchIterForTable{ + kvIter: kvIter, + cd: m.buffer.cd, + batchChk: batchChk, + batchIt: chunk.NewIterator4Chunk(batchChk), + sel: make([]int, 0, cachedTableBatchSize), + datumRow: make([]types.Datum, len(m.retFieldTypes)), + retFieldTypes: m.retFieldTypes, + intHandle: intHandle, + memTableReader: m, + }, nil + } return &memRowsIterForTable{ kvIter: kvIter, cd: m.buffer.cd, chk: chunk.New(m.retFieldTypes, 1, 1), datumRow: make([]types.Datum, len(m.retFieldTypes)), + intHandle: intHandle, memTableReader: m, }, nil } +func (m *memTableReader) kvRangesCoverFullTable() bool { + if len(m.kvRanges) != 1 { + return false + } + recordPrefix := tablecodec.GenTableRecordPrefix(m.table.ID) + rg := m.kvRanges[0] + return bytes.Equal(rg.StartKey, recordPrefix) && bytes.Equal(rg.EndKey, recordPrefix.PrefixNext()) +} + func (m *memTableReader) getMemRows(ctx context.Context) ([][]types.Datum, error) { defer tracing.StartRegion(ctx, "memTableReader.getMemRows").End() mutableRow := chunk.MutRowFromTypes(m.retFieldTypes) @@ -446,10 +564,12 @@ func (m *memTableReader) getMemRows(ctx context.Context) ([][]types.Datum, error return err } - mutableRow.SetDatums(resultRows...) - matched, _, err := expression.EvalBool(m.ctx.GetExprCtx().GetEvalCtx(), m.conditions, mutableRow.ToRow()) - if err != nil || !matched { - return err + if len(m.conditions) > 0 { + mutableRow.SetDatums(resultRows...) + matched, _, err := expression.EvalBool(m.ctx.GetExprCtx().GetEvalCtx(), m.conditions, mutableRow.ToRow()) + if err != nil || !matched { + return err + } } m.addedRows = append(m.addedRows, resultRows) resultRows = make([]types.Datum, len(m.columns)) @@ -473,7 +593,7 @@ func (m *memTableReader) getMemRows(ctx context.Context) ([][]types.Datum, error } func (m *memTableReader) decodeRecordKeyValue(key, value []byte, resultRows *[]types.Datum) ([]types.Datum, error) { - handle, err := tablecodec.DecodeRowKey(key) + handle, err := decodeHandleFromRowKey(key, !m.table.IsCommonHandle) if err != nil { return nil, errors.Trace(err) } @@ -558,8 +678,9 @@ func (m *memTableReader) getRowData(handle kv.Handle, value []byte) ([][]byte, e // getMemRowsHandle is called when memIndexMergeReader.partialPlans[i] is TableScan. func (m *memTableReader) getMemRowsHandle() ([]kv.Handle, error) { handles := make([]kv.Handle, 0, 16) + intHandle := !m.table.IsCommonHandle err := iterTxnMemBuffer(m.ctx, m.cacheTable, m.kvRanges, m.desc, func(key, _ []byte) error { - handle, err := tablecodec.DecodeRowKey(key) + handle, err := decodeHandleFromRowKey(key, intHandle) if err != nil { return err } @@ -674,7 +795,7 @@ func (m *memIndexReader) getMemRowsHandle() ([]kv.Handle, error) { handle = newHandle } } - // filter key/value by partitition id + // filter key/value by partition id if ph, ok := handle.(kv.PartitionHandle); ok { if _, exist := m.partitionIDMap[ph.PartitionID]; !exist { return nil @@ -906,16 +1027,19 @@ func (*defaultRowsIter) Close() {} // memRowsIterForTable combine a kv.Iterator and a kv decoder to get a memRowsIter. type memRowsIterForTable struct { - kvIter *txnMemBufferIter // txnMemBufferIter is the kv.Iterator - cd *rowcodec.ChunkDecoder - chk *chunk.Chunk - datumRow []types.Datum + kvIter *txnMemBufferIter // txnMemBufferIter is the kv.Iterator + cd *rowcodec.ChunkDecoder + chk *chunk.Chunk + datumRow []types.Datum + intHandle bool *memTableReader } func (iter *memRowsIterForTable) Next() ([]types.Datum, error) { curr := iter.kvIter var ret []types.Datum + evalCtx := iter.ctx.GetExprCtx().GetEvalCtx() + hasConds := len(iter.conditions) > 0 for curr.Valid() { key := curr.Key() value := curr.Value() @@ -927,7 +1051,7 @@ func (iter *memRowsIterForTable) Next() ([]types.Datum, error) { if len(value) == 0 { continue } - handle, err := tablecodec.DecodeRowKey(key) + handle, err := decodeHandleFromRowKey(key, iter.intHandle) if err != nil { return nil, errors.Trace(err) } @@ -936,19 +1060,21 @@ func (iter *memRowsIterForTable) Next() ([]types.Datum, error) { if !rowcodec.IsNewFormat(value) { // TODO: remove the legacy code! // fallback to the old way. - iter.datumRow, err = iter.memTableReader.decodeRecordKeyValue(key, value, &iter.datumRow) + iter.datumRow, err = iter.memTableReader.decodeRowData(handle, value, &iter.datumRow) if err != nil { return nil, errors.Trace(err) } - mutableRow := chunk.MutRowFromTypes(iter.retFieldTypes) - mutableRow.SetDatums(iter.datumRow...) - matched, _, err := expression.EvalBool(iter.ctx.GetExprCtx().GetEvalCtx(), iter.conditions, mutableRow.ToRow()) - if err != nil { - return nil, errors.Trace(err) - } - if !matched { - continue + if hasConds { + mutableRow := chunk.MutRowFromTypes(iter.retFieldTypes) + mutableRow.SetDatums(iter.datumRow...) + matched, _, err := expression.EvalBool(evalCtx, iter.conditions, mutableRow.ToRow()) + if err != nil { + return nil, errors.Trace(err) + } + if !matched { + continue + } } return iter.datumRow, nil } @@ -959,12 +1085,14 @@ func (iter *memRowsIterForTable) Next() ([]types.Datum, error) { } row := iter.chk.GetRow(0) - matched, _, err := expression.EvalBool(iter.ctx.GetExprCtx().GetEvalCtx(), iter.conditions, row) - if err != nil { - return nil, errors.Trace(err) - } - if !matched { - continue + if hasConds { + matched, _, err := expression.EvalBool(evalCtx, iter.conditions, row) + if err != nil { + return nil, errors.Trace(err) + } + if !matched { + continue + } } ret = row.GetDatumRowWithBuffer(iter.retFieldTypes, iter.datumRow) break @@ -978,6 +1106,333 @@ func (iter *memRowsIterForTable) Close() { } } +const cachedTableBatchSize = 64 + +// memRowsBatchIterForTable batch decodes cached table rows into a chunk and applies filters in vectorized mode when possible. +type memRowsBatchIterForTable struct { + kvIter *txnMemBufferIter + cd *rowcodec.ChunkDecoder + batchChk *chunk.Chunk + batchIt *chunk.Iterator4Chunk + selected []bool + sel []int + cursor int + datumRow []types.Datum + retFieldTypes []*types.FieldType + intHandle bool + *memTableReader +} + +func (iter *memRowsBatchIterForTable) Next() ([]types.Datum, error) { + curr := iter.kvIter + evalCtx := iter.ctx.GetExprCtx().GetEvalCtx() + vecEnabled := iter.ctx.GetSessionVars().EnableVectorizedExpression + + for { + // Return remaining matched rows in current batch. + if iter.cursor < len(iter.sel) { + row := iter.batchChk.GetRow(iter.sel[iter.cursor]) + iter.cursor++ + return row.GetDatumRowWithBuffer(iter.retFieldTypes, iter.datumRow), nil + } + + // Fill a new batch. + iter.batchChk.Reset() + iter.sel = iter.sel[:0] + iter.cursor = 0 + + for iter.batchChk.NumRows() < cachedTableBatchSize && curr.Valid() { + key := curr.Key() + value := curr.Value() + + // check whether the key was been deleted. + if len(value) == 0 { + if err := curr.Next(); err != nil { + return nil, errors.Trace(err) + } + continue + } + + handle, err := decodeHandleFromRowKey(key, iter.intHandle) + if err != nil { + return nil, errors.Trace(err) + } + + if !rowcodec.IsNewFormat(value) { + // Keep the scan order: flush current batch first. + if iter.batchChk.NumRows() > 0 { + break + } + + if err := curr.Next(); err != nil { + return nil, errors.Trace(err) + } + + // TODO: remove the legacy code! + iter.datumRow, err = iter.memTableReader.decodeRowData(handle, value, &iter.datumRow) + if err != nil { + return nil, errors.Trace(err) + } + mutableRow := chunk.MutRowFromTypes(iter.retFieldTypes) + mutableRow.SetDatums(iter.datumRow...) + matched, _, err := expression.EvalBool(evalCtx, iter.conditions, mutableRow.ToRow()) + if err != nil { + return nil, errors.Trace(err) + } + if !matched { + continue + } + return iter.datumRow, nil + } + + if err := curr.Next(); err != nil { + return nil, errors.Trace(err) + } + + err = iter.cd.DecodeToChunk(value, 0, handle, iter.batchChk) + if err != nil { + return nil, errors.Trace(err) + } + } + + if iter.batchChk.NumRows() == 0 { + return nil, nil + } + + if len(iter.conditions) == 0 { + for i := 0; i < iter.batchChk.NumRows(); i++ { + iter.sel = append(iter.sel, i) + } + continue + } + + iter.batchIt.ResetChunk(iter.batchChk) + iter.selected = iter.selected[:0] + var err error + iter.selected, err = expression.VectorizedFilter(evalCtx, vecEnabled, iter.conditions, iter.batchIt, iter.selected) + if err != nil { + return nil, errors.Trace(err) + } + for i := range iter.selected { + if iter.selected[i] { + iter.sel = append(iter.sel, i) + } + } + } +} + +func decodeHandleFromRowKey(key kv.Key, intHandle bool) (kv.Handle, error) { + if !intHandle { + return tablecodec.DecodeRowKey(key) + } + // Int handle record keys are fixed-length. Decode handle directly from the last 8 bytes to + // avoid the additional checks in tablecodec.DecodeRowKey. + if len(key) != tablecodec.RecordRowKeyLen { + return tablecodec.DecodeRowKey(key) + } + u := binary.BigEndian.Uint64(key[len(key)-8:]) + return kv.IntHandle(codec.DecodeCmpUintToInt(u)), nil +} + +func memBufferHasAnyEntryInRanges(mb kv.MemBuffer, kvRanges []kv.KeyRange) bool { + if mb == nil || len(kvRanges) == 0 { + return false + } + for _, rg := range kvRanges { + it := mb.SnapshotIter(rg.StartKey, rg.EndKey) + hasAny := it.Valid() + it.Close() + if hasAny { + return true + } + } + return false +} + +func (iter *memRowsBatchIterForTable) Close() { + if iter.kvIter != nil { + iter.kvIter.Close() + } + iter.batchChk = nil + iter.batchIt = nil + iter.selected = nil + iter.sel = nil + iter.cursor = 0 + iter.datumRow = nil + iter.retFieldTypes = nil + iter.memTableReader = nil +} + +// memCachedDatumIter directly iterates pre-decoded CachedDatumData, skipping KV decode. +type memCachedDatumIter struct { + data *tables.CachedDatumData + chunkIdx int + rowIdx int + desc bool + + // Column projection: maps query column index to datum cache column index. + colProjection []int + cacheFieldTypes []*types.FieldType + datumRow []types.Datum + retFieldTypes []*types.FieldType + mutableRow chunk.MutRow + + // filter + conditions []expression.Expression + evalCtx expression.EvalContext + + // TIMESTAMP timezone conversion + needTZConvert bool + sessionLoc *time.Location + tsColProjected []int // indices in projected (query) columns that are TIMESTAMP +} + +func (iter *memCachedDatumIter) Next() ([]types.Datum, error) { + for { + // Check chunk bounds. + if !iter.desc { + if iter.chunkIdx >= len(iter.data.Chunks) { + return nil, nil + } + } else { + if iter.chunkIdx < 0 { + return nil, nil + } + } + + chk := iter.data.Chunks[iter.chunkIdx] + + // Check row bounds and advance to next/prev chunk. + if !iter.desc { + if iter.rowIdx >= chk.NumRows() { + iter.chunkIdx++ + iter.rowIdx = 0 + continue + } + } else { + if iter.rowIdx < 0 { + iter.chunkIdx-- + if iter.chunkIdx >= 0 { + iter.rowIdx = iter.data.Chunks[iter.chunkIdx].NumRows() - 1 + } + continue + } + } + + row := chk.GetRow(iter.rowIdx) + if !iter.desc { + iter.rowIdx++ + } else { + iter.rowIdx-- + } + + // Project: extract only the columns the query needs. + for i, cacheIdx := range iter.colProjection { + iter.datumRow[i] = row.GetDatum(cacheIdx, iter.cacheFieldTypes[cacheIdx]) + } + + // TIMESTAMP timezone conversion (on projected datum copy, safe). + // Datum cache stores TIMESTAMP in UTC; convert to session timezone on read. + // Skip NULL values and zero timestamps (consistent with KV decode path in decoder.go). + if iter.needTZConvert { + for _, idx := range iter.tsColProjected { + if !iter.datumRow[idx].IsNull() { + t := iter.datumRow[idx].GetMysqlTime() + if !t.IsZero() { + if err := t.ConvertTimeZone(time.UTC, iter.sessionLoc); err != nil { + return nil, err + } + iter.datumRow[idx].SetMysqlTime(t) + } + } + } + } + + // Apply filter conditions. + if len(iter.conditions) > 0 { + if iter.mutableRow.ToRow().Chunk() == nil { + iter.mutableRow = chunk.MutRowFromTypes(iter.retFieldTypes) + } + iter.mutableRow.SetDatums(iter.datumRow...) + matched, _, err := expression.EvalBool(iter.evalCtx, iter.conditions, iter.mutableRow.ToRow()) + if err != nil { + return nil, err + } + if !matched { + continue + } + } + + return iter.datumRow, nil + } +} + +func (*memCachedDatumIter) Close() { + // CachedDatumData is managed by cacheData; not released here. +} + +// buildDatumCacheIter tries to build a memCachedDatumIter from the pre-decoded datum cache. +// Returns nil if the datum cache is not available or column projection fails. +func (m *memTableReader) buildDatumCacheIter() *memCachedDatumIter { + if m.datumCache == nil { + return nil + } + + // Build column ID -> datum cache index mapping from the table's public columns. + allCols := m.table.Cols() + colIDToCacheIdx := make(map[int64]int, len(allCols)) + for i, col := range allCols { + colIDToCacheIdx[col.ID] = i + } + + // Build projection: for each query column, find its datum cache index. + colProjection := make([]int, len(m.columns)) + for i, col := range m.columns { + idx, ok := colIDToCacheIdx[col.ID] + if !ok { + // Column not found in datum cache, fallback to KV decode. + return nil + } + colProjection[i] = idx + } + + // Find TIMESTAMP columns in the projected (query) columns for TZ conversion. + sessionLoc := m.ctx.GetSessionVars().Location() + var tsColProjected []int + for i, col := range m.columns { + if col.GetType() == mysql.TypeTimestamp { + tsColProjected = append(tsColProjected, i) + } + } + needTZConvert := len(tsColProjected) > 0 && sessionLoc.String() != time.UTC.String() + + iter := &memCachedDatumIter{ + data: m.datumCache, + desc: m.desc, + colProjection: colProjection, + cacheFieldTypes: m.datumCache.FieldTypes, + datumRow: make([]types.Datum, len(m.columns)), + retFieldTypes: m.retFieldTypes, + mutableRow: chunk.MutRowFromTypes(m.retFieldTypes), + conditions: m.conditions, + evalCtx: m.ctx.GetExprCtx().GetEvalCtx(), + needTZConvert: needTZConvert, + sessionLoc: sessionLoc, + tsColProjected: tsColProjected, + } + + // For descending scan, start from the last row of the last chunk. + if m.desc { + lastIdx := len(m.datumCache.Chunks) - 1 + iter.chunkIdx = lastIdx + if lastIdx >= 0 { + iter.rowIdx = m.datumCache.Chunks[lastIdx].NumRows() - 1 + } + } + + return iter +} + type memRowsIterForIndex struct { kvIter *txnMemBufferIter tps []*types.FieldType @@ -1000,7 +1455,7 @@ func (iter *memRowsIterForIndex) Next() ([]types.Datum, error) { continue } - // filter key/value by partitition id + // filter key/value by partition id if iter.index.Global { _, pid, err := codec.DecodeInt(tablecodec.SplitIndexValue(value).PartitionID) if err != nil { @@ -1011,18 +1466,20 @@ func (iter *memRowsIterForIndex) Next() ([]types.Datum, error) { } } - data, err := iter.memIndexReader.decodeIndexKeyValue(key, value, iter.tps, iter.colInfos) + data, err := iter.memIndexReader.decodeIndexKeyValue(key, value, iter.tps, iter.colInfos, false) if err != nil { return nil, err } - iter.mutableRow.SetDatums(data...) - matched, _, err := expression.EvalBool(iter.memIndexReader.ctx.GetExprCtx().GetEvalCtx(), iter.memIndexReader.conditions, iter.mutableRow.ToRow()) - if err != nil { - return nil, errors.Trace(err) - } - if !matched { - continue + if len(iter.memIndexReader.conditions) > 0 { + iter.mutableRow.SetDatums(data...) + matched, _, err := expression.EvalBool(iter.memIndexReader.evalCtx, iter.memIndexReader.conditions, iter.mutableRow.ToRow()) + if err != nil { + return nil, errors.Trace(err) + } + if !matched { + continue + } } ret = data break diff --git a/pkg/executor/mem_reader_test.go b/pkg/executor/mem_reader_test.go new file mode 100644 index 0000000000000..c45b456114dcd --- /dev/null +++ b/pkg/executor/mem_reader_test.go @@ -0,0 +1,862 @@ +package executor + +import ( + "testing" + "time" + + "github.com/pingcap/tidb/pkg/executor/internal/exec" + "github.com/pingcap/tidb/pkg/expression" + "github.com/pingcap/tidb/pkg/kv" + "github.com/pingcap/tidb/pkg/meta/model" + "github.com/pingcap/tidb/pkg/parser/ast" + pmodel "github.com/pingcap/tidb/pkg/parser/model" + "github.com/pingcap/tidb/pkg/parser/mysql" + "github.com/pingcap/tidb/pkg/sessionctx/stmtctx" + "github.com/pingcap/tidb/pkg/store/mockstore" + "github.com/pingcap/tidb/pkg/table" + "github.com/pingcap/tidb/pkg/table/tables" + "github.com/pingcap/tidb/pkg/tablecodec" + "github.com/pingcap/tidb/pkg/types" + "github.com/pingcap/tidb/pkg/util/chunk" + "github.com/pingcap/tidb/pkg/util/codec" + "github.com/pingcap/tidb/pkg/util/mock" + "github.com/pingcap/tidb/pkg/util/rowcodec" + "github.com/pingcap/tidb/pkg/util/sqlexec" + "github.com/stretchr/testify/require" + "github.com/tikv/client-go/v2/tikv" +) + +var memReaderBenchSink int64 + +func TestMemRowsIterFastDecodeRowKey(t *testing.T) { + const tableID int64 = 1 + + intKey := tablecodec.EncodeRowKeyWithHandle(tableID, kv.IntHandle(-123)) + fastHandle, err := decodeHandleFromRowKey(intKey, true) + require.NoError(t, err) + slowHandle, err := tablecodec.DecodeRowKey(intKey) + require.NoError(t, err) + require.True(t, fastHandle.IsInt()) + require.Equal(t, int64(-123), fastHandle.IntValue()) + require.True(t, fastHandle.Equal(slowHandle)) + + encoded, err := codec.EncodeKey(stmtctx.NewStmtCtx().TimeZone(), nil, types.MakeDatums(int64(100), "abc")...) + require.NoError(t, err) + commonHandle, err := kv.NewCommonHandle(encoded) + require.NoError(t, err) + commonKey := tablecodec.EncodeRowKeyWithHandle(tableID, commonHandle) + fastHandle, err = decodeHandleFromRowKey(commonKey, false) + require.NoError(t, err) + slowHandle, err = tablecodec.DecodeRowKey(commonKey) + require.NoError(t, err) + require.False(t, fastHandle.IsInt()) + require.True(t, fastHandle.Equal(slowHandle)) +} + +func BenchmarkMemRowsIterForTable(b *testing.B) { + b.ReportAllocs() + + store, err := mockstore.NewMockStore() + if err != nil { + b.Fatal(err) + } + b.Cleanup(func() { + _ = store.Close() + }) + + sctx := mock.NewContext() + sctx.Store = store + sctx.GetSessionVars().EnableVectorizedExpression = true + + const ( + tableID int64 = 1 + numRows = cachedTableBatchSize * 512 + ) + + intTp1 := types.NewFieldType(mysql.TypeLonglong) + intTp2 := types.NewFieldType(mysql.TypeLonglong) + retFieldTypes := []*types.FieldType{intTp1, intTp2} + + col1Expr := &expression.Column{ID: 1, UniqueID: 1, Index: 0, RetType: intTp1} + col2Expr := &expression.Column{ID: 2, UniqueID: 2, Index: 1, RetType: intTp2} + schema := expression.NewSchema(col1Expr, col2Expr) + + col1Info := &model.ColumnInfo{ID: 1, Offset: 0, FieldType: *intTp1} + col2Info := &model.ColumnInfo{ID: 2, Offset: 1, FieldType: *intTp2} + tblInfo := &model.TableInfo{ID: tableID, Columns: []*model.ColumnInfo{col1Info, col2Info}} + + cd := NewRowDecoder(sctx, schema, tblInfo) + + buffTxn, err := store.Begin(tikv.WithStartTS(0)) + if err != nil { + b.Fatal(err) + } + cacheTable := buffTxn.GetMemBuffer() + + var encoder rowcodec.Encoder + colIDs := []int64{1, 2} + datums := make([]types.Datum, 2) + buf := make([]byte, 0, 64) + loc := sctx.GetSessionVars().Location() + for i := 0; i < numRows; i++ { + key := tablecodec.EncodeRowKeyWithHandle(tableID, kv.IntHandle(i)) + datums[0] = types.NewIntDatum(int64(i)) + datums[1] = types.NewIntDatum(int64(i & 63)) + buf = buf[:0] + value, err := encoder.Encode(loc, colIDs, datums, nil, buf) + if err != nil { + b.Fatal(err) + } + if err := cacheTable.Set(key, value); err != nil { + b.Fatal(err) + } + } + + recordPrefix := tablecodec.GenTableRecordPrefix(tableID) + kvRanges := []kv.KeyRange{{StartKey: recordPrefix, EndKey: recordPrefix.PrefixNext()}} + + tinyTp := types.NewFieldType(mysql.TypeTiny) + constExpr := &expression.Constant{Value: types.NewIntDatum(32), RetType: intTp2} + filter, err := expression.NewFunction(sctx.GetExprCtx(), ast.LT, tinyTp, col2Expr, constExpr) + if err != nil { + b.Fatal(err) + } + + memTblReader := &memTableReader{ + ctx: sctx, + table: tblInfo, + columns: []*model.ColumnInfo{col1Info, col2Info}, + kvRanges: kvRanges, + conditions: []expression.Expression{filter}, + retFieldTypes: retFieldTypes, + colIDs: map[int64]int{1: 0, 2: 1}, + buffer: allocBuf{ + handleBytes: make([]byte, 0, 16), + cd: cd, + }, + cacheTable: cacheTable, + } + memTblReader.offsets = []int{0, 1} + + newRowIter := func(tb testing.TB) memRowsIter { + kvIter, err := newTxnMemBufferIter(sctx, cacheTable, kvRanges, false) + if err != nil { + tb.Fatal(err) + } + return &memRowsIterForTable{ + kvIter: kvIter, + cd: cd, + chk: chunk.New(retFieldTypes, 1, 1), + datumRow: make([]types.Datum, len(retFieldTypes)), + intHandle: true, + memTableReader: memTblReader, + } + } + + newBatchIter := func(tb testing.TB) memRowsIter { + kvIter, err := newTxnMemBufferIter(sctx, cacheTable, kvRanges, false) + if err != nil { + tb.Fatal(err) + } + batchChk := chunk.New(retFieldTypes, cachedTableBatchSize, cachedTableBatchSize) + return &memRowsBatchIterForTable{ + kvIter: kvIter, + cd: cd, + batchChk: batchChk, + batchIt: chunk.NewIterator4Chunk(batchChk), + sel: make([]int, 0, cachedTableBatchSize), + datumRow: make([]types.Datum, len(retFieldTypes)), + retFieldTypes: retFieldTypes, + intHandle: true, + memTableReader: memTblReader, + } + } + + scanOnce := func(tb testing.TB, it memRowsIter) (matched int, sum int64) { + defer it.Close() + for { + row, err := it.Next() + if err != nil { + tb.Fatal(err) + } + if row == nil { + return matched, sum + } + matched++ + sum += row[0].GetInt64() + } + } + + expected := numRows / 2 + b.StopTimer() + if got, _ := scanOnce(b, newRowIter(b)); got != expected { + b.Fatalf("unexpected matched rows by memRowsIterForTable: got %d, want %d", got, expected) + } + if got, _ := scanOnce(b, newBatchIter(b)); got != expected { + b.Fatalf("unexpected matched rows by memRowsBatchIterForTable: got %d, want %d", got, expected) + } + b.StartTimer() + + b.Run("row", func(b *testing.B) { + var sum int64 + for i := 0; i < b.N; i++ { + it := newRowIter(b) + rows, s := scanOnce(b, it) + if rows != expected { + b.Fatalf("unexpected matched rows: got %d, want %d", rows, expected) + } + sum += s + } + memReaderBenchSink = sum + }) + b.Run("batch", func(b *testing.B) { + var sum int64 + for i := 0; i < b.N; i++ { + it := newBatchIter(b) + rows, s := scanOnce(b, it) + if rows != expected { + b.Fatalf("unexpected matched rows: got %d, want %d", rows, expected) + } + sum += s + } + memReaderBenchSink = sum + }) +} + +// buildTestDatumCache builds a CachedDatumData with 3 columns: int64 (id), varchar (name), int64 (val). +// rows is a list of (id, name, val) tuples. +func buildTestDatumCache(rows [][]any) *tables.CachedDatumData { + ftInt := types.NewFieldType(mysql.TypeLonglong) + ftStr := types.NewFieldType(mysql.TypeVarchar) + ftStr.SetFlen(64) + ftVal := types.NewFieldType(mysql.TypeLonglong) + fieldTypes := []*types.FieldType{ftInt, ftStr, ftVal} + + chk := chunk.New(fieldTypes, 1024, 1024) + for _, row := range rows { + chk.AppendInt64(0, row[0].(int64)) + chk.AppendString(1, row[1].(string)) + chk.AppendInt64(2, row[2].(int64)) + } + + return &tables.CachedDatumData{ + Chunks: []*chunk.Chunk{chk}, + FieldTypes: fieldTypes, + TotalRows: len(rows), + } +} + +func TestMemCachedDatumIterProjection(t *testing.T) { + data := buildTestDatumCache([][]any{ + {int64(1), "alice", int64(100)}, + {int64(2), "bob", int64(200)}, + {int64(3), "charlie", int64(300)}, + }) + + // SELECT id, val (skip name column) — project columns 0 and 2 from cache. + ftInt := types.NewFieldType(mysql.TypeLonglong) + ftVal := types.NewFieldType(mysql.TypeLonglong) + iter := &memCachedDatumIter{ + data: data, + colProjection: []int{0, 2}, // query col 0 -> cache col 0 (id), query col 1 -> cache col 2 (val) + cacheFieldTypes: data.FieldTypes, + datumRow: make([]types.Datum, 2), + retFieldTypes: []*types.FieldType{ftInt, ftVal}, + } + + row, err := iter.Next() + require.NoError(t, err) + require.NotNil(t, row) + require.Len(t, row, 2) + require.Equal(t, int64(1), row[0].GetInt64()) + require.Equal(t, int64(100), row[1].GetInt64()) + + row, err = iter.Next() + require.NoError(t, err) + require.Equal(t, int64(2), row[0].GetInt64()) + require.Equal(t, int64(200), row[1].GetInt64()) + + row, err = iter.Next() + require.NoError(t, err) + require.Equal(t, int64(3), row[0].GetInt64()) + require.Equal(t, int64(300), row[1].GetInt64()) + + row, err = iter.Next() + require.NoError(t, err) + require.Nil(t, row) +} + +func TestMemCachedDatumIterProjectionAllCols(t *testing.T) { + data := buildTestDatumCache([][]any{ + {int64(1), "alice", int64(100)}, + {int64(2), "bob", int64(200)}, + }) + + // SELECT * — all columns in same order. + iter := &memCachedDatumIter{ + data: data, + colProjection: []int{0, 1, 2}, + cacheFieldTypes: data.FieldTypes, + datumRow: make([]types.Datum, 3), + retFieldTypes: data.FieldTypes, + } + + row, err := iter.Next() + require.NoError(t, err) + require.Len(t, row, 3) + require.Equal(t, int64(1), row[0].GetInt64()) + require.Equal(t, "alice", row[1].GetString()) + require.Equal(t, int64(100), row[2].GetInt64()) + + row, err = iter.Next() + require.NoError(t, err) + require.Equal(t, int64(2), row[0].GetInt64()) + require.Equal(t, "bob", row[1].GetString()) + require.Equal(t, int64(200), row[2].GetInt64()) + + row, err = iter.Next() + require.NoError(t, err) + require.Nil(t, row) +} + +func TestMemCachedDatumIterDesc(t *testing.T) { + data := buildTestDatumCache([][]any{ + {int64(1), "alice", int64(100)}, + {int64(2), "bob", int64(200)}, + {int64(3), "charlie", int64(300)}, + }) + + iter := &memCachedDatumIter{ + data: data, + desc: true, + chunkIdx: len(data.Chunks) - 1, + rowIdx: data.Chunks[len(data.Chunks)-1].NumRows() - 1, + colProjection: []int{0, 1, 2}, + cacheFieldTypes: data.FieldTypes, + datumRow: make([]types.Datum, 3), + retFieldTypes: data.FieldTypes, + } + + // Descending: should yield rows 3, 2, 1. + row, err := iter.Next() + require.NoError(t, err) + require.Equal(t, int64(3), row[0].GetInt64()) + require.Equal(t, "charlie", row[1].GetString()) + + row, err = iter.Next() + require.NoError(t, err) + require.Equal(t, int64(2), row[0].GetInt64()) + + row, err = iter.Next() + require.NoError(t, err) + require.Equal(t, int64(1), row[0].GetInt64()) + + row, err = iter.Next() + require.NoError(t, err) + require.Nil(t, row) +} + +func TestMemCachedDatumIterDescWithFilter(t *testing.T) { + sctx := mock.NewContext() + + data := buildTestDatumCache([][]any{ + {int64(1), "alice", int64(100)}, + {int64(2), "bob", int64(200)}, + {int64(3), "charlie", int64(300)}, + {int64(4), "dave", int64(400)}, + }) + + // Filter: val < 300 (should match rows 1 and 2). + ftVal := types.NewFieldType(mysql.TypeLonglong) + // Projected columns: id (idx 0), val (idx 1). + col := &expression.Column{UniqueID: 1, Index: 1, RetType: ftVal} // val is at projected index 1 + constVal := &expression.Constant{Value: types.NewIntDatum(300), RetType: ftVal} + tinyTp := types.NewFieldType(mysql.TypeTiny) + filter, err := expression.NewFunction(sctx.GetExprCtx(), ast.LT, tinyTp, col, constVal) + require.NoError(t, err) + + ftInt := types.NewFieldType(mysql.TypeLonglong) + iter := &memCachedDatumIter{ + data: data, + desc: true, + chunkIdx: len(data.Chunks) - 1, + rowIdx: data.Chunks[len(data.Chunks)-1].NumRows() - 1, + colProjection: []int{0, 2}, // query col 0 -> id, query col 1 -> val + cacheFieldTypes: data.FieldTypes, + datumRow: make([]types.Datum, 2), + retFieldTypes: []*types.FieldType{ftInt, ftVal}, + conditions: []expression.Expression{filter}, + evalCtx: sctx.GetExprCtx().GetEvalCtx(), + } + + // DESC + filter: row 4 (val=400) skipped, row 3 (val=300) skipped, row 2 (val=200) matched, row 1 (val=100) matched. + row, err := iter.Next() + require.NoError(t, err) + require.NotNil(t, row) + require.Equal(t, int64(2), row[0].GetInt64()) // id=2 + require.Equal(t, int64(200), row[1].GetInt64()) // val=200 + + row, err = iter.Next() + require.NoError(t, err) + require.NotNil(t, row) + require.Equal(t, int64(1), row[0].GetInt64()) // id=1 + require.Equal(t, int64(100), row[1].GetInt64()) // val=100 + + row, err = iter.Next() + require.NoError(t, err) + require.Nil(t, row) +} + +func TestMemCachedDatumIterProjectionWithTimestamp(t *testing.T) { + ftInt := types.NewFieldType(mysql.TypeLonglong) + ftStr := types.NewFieldType(mysql.TypeVarchar) + ftStr.SetFlen(64) + ftTs := types.NewFieldType(mysql.TypeTimestamp) + ftTs.SetDecimal(0) + cacheFieldTypes := []*types.FieldType{ftInt, ftStr, ftTs} + + // Build chunk with TIMESTAMP stored in UTC. + chk := chunk.New(cacheFieldTypes, 1024, 1024) + utcTime1, err := types.ParseTimestamp(types.DefaultStmtNoWarningContext, "2025-01-15 10:00:00") + require.NoError(t, err) + utcTime2, err := types.ParseTimestamp(types.DefaultStmtNoWarningContext, "2025-06-20 18:30:00") + require.NoError(t, err) + + chk.AppendInt64(0, 1) + chk.AppendString(1, "alice") + chk.AppendTime(2, utcTime1) + chk.AppendInt64(0, 2) + chk.AppendString(1, "bob") + chk.AppendTime(2, utcTime2) + + data := &tables.CachedDatumData{ + Chunks: []*chunk.Chunk{chk}, + FieldTypes: cacheFieldTypes, + TotalRows: 2, + } + + sessionLoc, err := time.LoadLocation("Asia/Shanghai") + require.NoError(t, err) + + // SELECT id, ts (skip name) — project columns 0 and 2. + // TIMESTAMP at projected index 1. + iter := &memCachedDatumIter{ + data: data, + colProjection: []int{0, 2}, + cacheFieldTypes: cacheFieldTypes, + datumRow: make([]types.Datum, 2), + retFieldTypes: []*types.FieldType{ftInt, ftTs}, + needTZConvert: true, + sessionLoc: sessionLoc, + tsColProjected: []int{1}, // projected index 1 is the TIMESTAMP column + } + + row, err := iter.Next() + require.NoError(t, err) + require.NotNil(t, row) + require.Equal(t, int64(1), row[0].GetInt64()) + // UTC 10:00:00 → Asia/Shanghai +8 → 18:00:00 + ts := row[1].GetMysqlTime() + require.Equal(t, "2025-01-15 18:00:00", ts.String()) + + row, err = iter.Next() + require.NoError(t, err) + require.Equal(t, int64(2), row[0].GetInt64()) + // UTC 18:30:00 → Asia/Shanghai +8 → next day 02:30:00 + ts = row[1].GetMysqlTime() + require.Equal(t, "2025-06-21 02:30:00", ts.String()) + + row, err = iter.Next() + require.NoError(t, err) + require.Nil(t, row) +} + +func TestMemCachedDatumIterDescMultiChunk(t *testing.T) { + ftInt := types.NewFieldType(mysql.TypeLonglong) + ftStr := types.NewFieldType(mysql.TypeVarchar) + ftStr.SetFlen(64) + ftVal := types.NewFieldType(mysql.TypeLonglong) + fieldTypes := []*types.FieldType{ftInt, ftStr, ftVal} + + // Create 2 chunks to test cross-chunk descending iteration. + chk1 := chunk.New(fieldTypes, 1024, 1024) + chk1.AppendInt64(0, 1) + chk1.AppendString(1, "a") + chk1.AppendInt64(2, 10) + chk1.AppendInt64(0, 2) + chk1.AppendString(1, "b") + chk1.AppendInt64(2, 20) + + chk2 := chunk.New(fieldTypes, 1024, 1024) + chk2.AppendInt64(0, 3) + chk2.AppendString(1, "c") + chk2.AppendInt64(2, 30) + + data := &tables.CachedDatumData{ + Chunks: []*chunk.Chunk{chk1, chk2}, + FieldTypes: fieldTypes, + TotalRows: 3, + } + + iter := &memCachedDatumIter{ + data: data, + desc: true, + chunkIdx: 1, // last chunk + rowIdx: 0, // last chunk has 1 row, index 0 + colProjection: []int{0, 2}, // id and val only + cacheFieldTypes: fieldTypes, + datumRow: make([]types.Datum, 2), + retFieldTypes: []*types.FieldType{ftInt, ftVal}, + } + + // DESC across chunks: 3, 2, 1. + row, err := iter.Next() + require.NoError(t, err) + require.Equal(t, int64(3), row[0].GetInt64()) + require.Equal(t, int64(30), row[1].GetInt64()) + + row, err = iter.Next() + require.NoError(t, err) + require.Equal(t, int64(2), row[0].GetInt64()) + require.Equal(t, int64(20), row[1].GetInt64()) + + row, err = iter.Next() + require.NoError(t, err) + require.Equal(t, int64(1), row[0].GetInt64()) + require.Equal(t, int64(10), row[1].GetInt64()) + + row, err = iter.Next() + require.NoError(t, err) + require.Nil(t, row) +} + +func TestMemTableReaderDatumCacheFallbackOnTxnOverride(t *testing.T) { + store, err := mockstore.NewMockStore() + require.NoError(t, err) + t.Cleanup(func() { + _ = store.Close() + }) + + sctx := mock.NewContext() + sctx.Store = store + + const tableID int64 = 1 + + ftID := types.NewFieldType(mysql.TypeLonglong) + ftName := types.NewFieldType(mysql.TypeVarchar) + ftName.SetFlen(64) + ftVal := types.NewFieldType(mysql.TypeLonglong) + retFieldTypes := []*types.FieldType{ftID, ftName, ftVal} + + colIDExpr := &expression.Column{ID: 1, UniqueID: 1, Index: 0, RetType: ftID} + colNameExpr := &expression.Column{ID: 2, UniqueID: 2, Index: 1, RetType: ftName} + colValExpr := &expression.Column{ID: 3, UniqueID: 3, Index: 2, RetType: ftVal} + schema := expression.NewSchema(colIDExpr, colNameExpr, colValExpr) + + col1Info := &model.ColumnInfo{ID: 1, Offset: 0, State: model.StatePublic, FieldType: *ftID} + col2Info := &model.ColumnInfo{ID: 2, Offset: 1, State: model.StatePublic, FieldType: *ftName} + col3Info := &model.ColumnInfo{ID: 3, Offset: 2, State: model.StatePublic, FieldType: *ftVal} + tblInfo := &model.TableInfo{ID: tableID, Columns: []*model.ColumnInfo{col1Info, col2Info, col3Info}} + + cd := NewRowDecoder(sctx, schema, tblInfo) + + buffTxn, err := store.Begin(tikv.WithStartTS(0)) + require.NoError(t, err) + t.Cleanup(func() { + _ = buffTxn.Rollback() + }) + cacheTable := buffTxn.GetMemBuffer() + + var encoder rowcodec.Encoder + colIDs := []int64{1, 2, 3} + loc := sctx.GetSessionVars().Location() + + key := tablecodec.EncodeRowKeyWithHandle(tableID, kv.IntHandle(1)) + datums := []types.Datum{types.NewIntDatum(1), types.NewStringDatum("alice"), types.NewIntDatum(100)} + valSnap, err := encoder.Encode(loc, colIDs, datums, nil, nil) + require.NoError(t, err) + require.NoError(t, cacheTable.Set(key, append([]byte(nil), valSnap...))) + + txn, err := sctx.Txn(true) + require.NoError(t, err) + datums[2] = types.NewIntDatum(200) + valTxn, err := encoder.Encode(loc, colIDs, datums, nil, nil) + require.NoError(t, err) + require.NoError(t, txn.GetMemBuffer().Set(key, append([]byte(nil), valTxn...))) + + recordPrefix := tablecodec.GenTableRecordPrefix(tableID) + kvRanges := []kv.KeyRange{{StartKey: recordPrefix, EndKey: recordPrefix.PrefixNext()}} + + memTblReader := &memTableReader{ + ctx: sctx, + table: tblInfo, + columns: []*model.ColumnInfo{col1Info, col2Info, col3Info}, + kvRanges: kvRanges, + retFieldTypes: retFieldTypes, + colIDs: map[int64]int{1: 0, 2: 1, 3: 2}, + buffer: allocBuf{ + handleBytes: make([]byte, 0, 16), + cd: cd, + }, + cacheTable: cacheTable, + datumCache: buildTestDatumCache([][]any{{int64(1), "alice", int64(100)}}), + } + + it, err := memTblReader.getMemRowsIter(context.Background()) + require.NoError(t, err) + defer it.Close() + + _, isDatumIter := it.(*memCachedDatumIter) + require.False(t, isDatumIter) + + row, err := it.Next() + require.NoError(t, err) + require.NotNil(t, row) + require.Equal(t, int64(1), row[0].GetInt64()) + require.Equal(t, "alice", row[1].GetString()) + require.Equal(t, int64(200), row[2].GetInt64()) + + row, err = it.Next() + require.NoError(t, err) + require.Nil(t, row) +} + +func TestMemTableReaderDatumCacheFallbackOnPartialRange(t *testing.T) { + store, err := mockstore.NewMockStore() + require.NoError(t, err) + t.Cleanup(func() { + _ = store.Close() + }) + + sctx := mock.NewContext() + sctx.Store = store + + const tableID int64 = 1 + + ftID := types.NewFieldType(mysql.TypeLonglong) + ftName := types.NewFieldType(mysql.TypeVarchar) + ftName.SetFlen(64) + ftVal := types.NewFieldType(mysql.TypeLonglong) + retFieldTypes := []*types.FieldType{ftID, ftName, ftVal} + + colIDExpr := &expression.Column{ID: 1, UniqueID: 1, Index: 0, RetType: ftID} + colNameExpr := &expression.Column{ID: 2, UniqueID: 2, Index: 1, RetType: ftName} + colValExpr := &expression.Column{ID: 3, UniqueID: 3, Index: 2, RetType: ftVal} + schema := expression.NewSchema(colIDExpr, colNameExpr, colValExpr) + + col1Info := &model.ColumnInfo{ID: 1, Offset: 0, State: model.StatePublic, FieldType: *ftID} + col2Info := &model.ColumnInfo{ID: 2, Offset: 1, State: model.StatePublic, FieldType: *ftName} + col3Info := &model.ColumnInfo{ID: 3, Offset: 2, State: model.StatePublic, FieldType: *ftVal} + tblInfo := &model.TableInfo{ID: tableID, Columns: []*model.ColumnInfo{col1Info, col2Info, col3Info}} + + cd := NewRowDecoder(sctx, schema, tblInfo) + + buffTxn, err := store.Begin(tikv.WithStartTS(0)) + require.NoError(t, err) + t.Cleanup(func() { + _ = buffTxn.Rollback() + }) + cacheTable := buffTxn.GetMemBuffer() + + var encoder rowcodec.Encoder + colIDs := []int64{1, 2, 3} + loc := sctx.GetSessionVars().Location() + rows := [][]types.Datum{ + {types.NewIntDatum(1), types.NewStringDatum("alice"), types.NewIntDatum(100)}, + {types.NewIntDatum(2), types.NewStringDatum("bob"), types.NewIntDatum(200)}, + {types.NewIntDatum(3), types.NewStringDatum("charlie"), types.NewIntDatum(300)}, + } + for _, row := range rows { + key := tablecodec.EncodeRowKeyWithHandle(tableID, kv.IntHandle(row[0].GetInt64())) + val, err := encoder.Encode(loc, colIDs, row, nil, nil) + require.NoError(t, err) + require.NoError(t, cacheTable.Set(key, append([]byte(nil), val...))) + } + + kvRanges := []kv.KeyRange{{ + StartKey: tablecodec.EncodeRowKeyWithHandle(tableID, kv.IntHandle(2)), + EndKey: tablecodec.EncodeRowKeyWithHandle(tableID, kv.IntHandle(3)), + }} + + memTblReader := &memTableReader{ + ctx: sctx, + table: tblInfo, + columns: []*model.ColumnInfo{col1Info, col2Info, col3Info}, + kvRanges: kvRanges, + retFieldTypes: retFieldTypes, + colIDs: map[int64]int{1: 0, 2: 1, 3: 2}, + buffer: allocBuf{ + handleBytes: make([]byte, 0, 16), + cd: cd, + }, + cacheTable: cacheTable, + datumCache: buildTestDatumCache([][]any{ + {int64(1), "alice", int64(100)}, + {int64(2), "bob", int64(200)}, + {int64(3), "charlie", int64(300)}, + }), + } + + it, err := memTblReader.getMemRowsIter(context.Background()) + require.NoError(t, err) + defer it.Close() + + _, isDatumIter := it.(*memCachedDatumIter) + require.False(t, isDatumIter) + + row, err := it.Next() + require.NoError(t, err) + require.NotNil(t, row) + require.Equal(t, int64(2), row[0].GetInt64()) + require.Equal(t, "bob", row[1].GetString()) + require.Equal(t, int64(200), row[2].GetInt64()) + + row, err = it.Next() + require.NoError(t, err) + require.Nil(t, row) +} + +type mockCachedTableWithPinnedDatumCache struct { + table.Table + cacheData kv.MemBuffer + + pinnedDatumCache *tables.CachedDatumData + latestDatumCache *tables.CachedDatumData + + pinnedIndexDatumCaches map[int64]*tables.CachedIndexDatumData + latestIndexDatumCaches map[int64]*tables.CachedIndexDatumData +} + +func (*mockCachedTableWithPinnedDatumCache) Init(sqlexec.SQLExecutor) error { + return nil +} + +func (t *mockCachedTableWithPinnedDatumCache) TryReadFromCache(uint64, time.Duration) (kv.MemBuffer, bool) { + return t.cacheData, false +} + +func (*mockCachedTableWithPinnedDatumCache) UpdateLockForRead(context.Context, kv.Storage, uint64, time.Duration) { +} + +func (*mockCachedTableWithPinnedDatumCache) WriteLockAndKeepAlive(context.Context, chan struct{}, *uint64, chan error) { +} + +func (*mockCachedTableWithPinnedDatumCache) GetCachedResult(table.ResultCacheKey, []byte) ([]*chunk.Chunk, []*types.FieldType, bool) { + return nil, nil, false +} + +func (*mockCachedTableWithPinnedDatumCache) PutCachedResult(table.ResultCacheKey, []byte, []*chunk.Chunk, []*types.FieldType) bool { + return false +} + +func (t *mockCachedTableWithPinnedDatumCache) GetCachedDatumData() *tables.CachedDatumData { + return t.latestDatumCache +} + +func (t *mockCachedTableWithPinnedDatumCache) GetCachedDatumDataForMemBuffer(mb kv.MemBuffer) *tables.CachedDatumData { + if mb != t.cacheData { + return nil + } + return t.pinnedDatumCache +} + +func (t *mockCachedTableWithPinnedDatumCache) GetCachedIndexDatumData(indexID int64) *tables.CachedIndexDatumData { + return t.latestIndexDatumCaches[indexID] +} + +func (t *mockCachedTableWithPinnedDatumCache) GetCachedIndexDatumDataForMemBuffer(mb kv.MemBuffer, indexID int64) *tables.CachedIndexDatumData { + if mb != t.cacheData { + return nil + } + return t.pinnedIndexDatumCaches[indexID] +} + +type mockBypassDataSourceExecutor struct { + exec.BaseExecutor + tbl table.Table + dummy bool +} + +func (e *mockBypassDataSourceExecutor) Table() table.Table { + return e.tbl +} + +func (e *mockBypassDataSourceExecutor) setDummy() { + e.dummy = true +} + +func TestHandleCachedTablePinsDatumCachesToMemBufferGeneration(t *testing.T) { + store, err := mockstore.NewMockStore() + require.NoError(t, err) + t.Cleanup(func() { + _ = store.Close() + }) + + cacheTxn, err := store.Begin(tikv.WithStartTS(0)) + require.NoError(t, err) + t.Cleanup(func() { + _ = cacheTxn.Rollback() + }) + cacheData := cacheTxn.GetMemBuffer() + + ft := types.NewFieldType(mysql.TypeLonglong) + tblInfo := &model.TableInfo{ + ID: 42, + Name: pmodel.NewCIStr("t_cache"), + TableCacheStatusType: model.TableCacheStatusEnable, + Columns: []*model.ColumnInfo{{ + ID: 1, + Name: pmodel.NewCIStr("a"), + Offset: 0, + State: model.StatePublic, + FieldType: *ft, + }}, + Indices: []*model.IndexInfo{{ + ID: 7, + Name: pmodel.NewCIStr("idx_a"), + State: model.StatePublic, + Columns: []*model.IndexColumn{{ + Name: pmodel.NewCIStr("a"), + Offset: 0, + }}, + }}, + } + + pinnedDatumCache := buildTestDatumCache([][]any{{int64(1), "pinned", int64(10)}}) + latestDatumCache := buildTestDatumCache([][]any{{int64(2), "stale", int64(20)}}) + pinnedIndexCache := &tables.CachedIndexDatumData{ + Entries: map[string][]types.Datum{"pinned": {types.NewIntDatum(1)}}, + } + latestIndexCache := &tables.CachedIndexDatumData{ + Entries: map[string][]types.Datum{"stale": {types.NewIntDatum(2)}}, + } + + cachedTbl := &mockCachedTableWithPinnedDatumCache{ + Table: tables.MockTableFromMeta(tblInfo), + cacheData: cacheData, + pinnedDatumCache: pinnedDatumCache, + latestDatumCache: latestDatumCache, + pinnedIndexDatumCaches: map[int64]*tables.CachedIndexDatumData{7: pinnedIndexCache}, + latestIndexDatumCaches: map[int64]*tables.CachedIndexDatumData{7: latestIndexCache}, + } + + sctx := mock.NewContext() + sctx.Store = store + builder := &executorBuilder{ctx: sctx} + us := &UnionScanExec{} + reader := &mockBypassDataSourceExecutor{ + BaseExecutor: exec.NewBaseExecutor(sctx, nil, 1), + tbl: cachedTbl, + } + + us.handleCachedTable(builder, reader, sctx.GetSessionVars(), 123) + + require.True(t, reader.dummy) + require.Same(t, cacheData, us.cacheTable) + require.Same(t, pinnedDatumCache, us.datumCache) + require.NotSame(t, latestDatumCache, us.datumCache) + require.Len(t, us.indexDatumCaches, 1) + require.Same(t, pinnedIndexCache, us.indexDatumCaches[7]) + require.NotSame(t, latestIndexCache, us.indexDatumCaches[7]) + require.Same(t, cachedTbl, builder.cachedTbl) +} diff --git a/pkg/executor/prepared.go b/pkg/executor/prepared.go index 9d40e395b8a9e..727411ca6706e 100644 --- a/pkg/executor/prepared.go +++ b/pkg/executor/prepared.go @@ -199,6 +199,8 @@ func (e *ExecuteExec) Build(b *executorBuilder) error { log.Warn("rebuild plan in EXECUTE statement failed", zap.String("labelName of PREPARE statement", e.name)) return errors.Trace(b.err) } + // Wrap with result set cache for cached table queries. + stmtExec = b.wrapWithResultCache(stmtExec, e.stmt, e.plan) e.stmtExec = stmtExec if e.Ctx().GetSessionVars().StmtCtx.Priority == mysql.NoPriority { e.lowerPriority = needLowerPriority(e.plan) diff --git a/pkg/executor/slow_query.go b/pkg/executor/slow_query.go index 17fdcf294fdff..994825ca1b591 100644 --- a/pkg/executor/slow_query.go +++ b/pkg/executor/slow_query.go @@ -1168,7 +1168,8 @@ func getColumnValueFactoryByName(colName string, columnIdx int) (slowQueryColumn }, nil case variable.SlowLogPrepared, variable.SlowLogSucc, variable.SlowLogPlanFromCache, variable.SlowLogPlanFromBinding, variable.SlowLogIsInternalStr, variable.SlowLogIsExplicitTxn, variable.SlowLogIsWriteCacheTable, variable.SlowLogHasMoreResults, - variable.SlowLogStorageFromKV, variable.SlowLogStorageFromMPP: + variable.SlowLogStorageFromKV, variable.SlowLogStorageFromMPP, + variable.SlowLogResultCacheHit: return func(row []types.Datum, value string, _ *time.Location, _ *slowLogChecker) (valid bool, err error) { v, err := strconv.ParseBool(value) if err != nil { diff --git a/pkg/executor/union_scan.go b/pkg/executor/union_scan.go index f367406f5c917..543a5432c34f4 100644 --- a/pkg/executor/union_scan.go +++ b/pkg/executor/union_scan.go @@ -27,6 +27,7 @@ import ( plannerutil "github.com/pingcap/tidb/pkg/planner/util" "github.com/pingcap/tidb/pkg/sessionctx/stmtctx" "github.com/pingcap/tidb/pkg/table" + "github.com/pingcap/tidb/pkg/table/tables" "github.com/pingcap/tidb/pkg/tablecodec" "github.com/pingcap/tidb/pkg/types" "github.com/pingcap/tidb/pkg/util/chunk" @@ -58,6 +59,11 @@ type UnionScanExec struct { // cacheTable not nil means it's reading from cached table. cacheTable kv.MemBuffer + // datumCache holds pre-decoded datum cache for cached tables. + // When non-nil, memCachedDatumIter is used to skip KV decode. + datumCache *tables.CachedDatumData + // indexDatumCaches holds pre-decoded datum caches for index scans, keyed by index ID. + indexDatumCaches map[int64]*tables.CachedIndexDatumData // If partitioned table and the physical table id is encoded in the chuck at this column index // used with dynamic prune mode @@ -132,6 +138,7 @@ func (us *UnionScanExec) open(ctx context.Context) error { return err } us.snapshotChunkBuffer = exec.TryNewCacheChunk(us) + us.mutableRow = chunk.MutRowFromTypes(exec.RetTypes(us)) return nil } @@ -144,7 +151,13 @@ func (us *UnionScanExec) Next(ctx context.Context, req *chunk.Chunk) error { // the for-loop may exit without read one single row! req.GrowAndReset(us.MaxChunkSize()) - mutableRow := chunk.MutRowFromTypes(exec.RetTypes(us)) + // For cached tables, getSnapshotRow() always returns nil, so the merge + // logic in getOneRow() is unnecessary. Use a dedicated fast path that + // reads directly from addedRowsIter. + if us.cacheTable != nil { + return us.nextForCacheTable(req) + } + for batchSize := req.Capacity(); req.NumRows() < batchSize; { row, err := us.getOneRow(ctx) if err != nil { @@ -154,11 +167,11 @@ func (us *UnionScanExec) Next(ctx context.Context, req *chunk.Chunk) error { if row == nil { return nil } - mutableRow.SetDatums(row...) + us.mutableRow.SetDatums(row...) sctx := us.Ctx() for _, idx := range us.virtualColumnIndex { - datum, err := us.Schema().Columns[idx].EvalVirtualColumn(sctx.GetExprCtx().GetEvalCtx(), mutableRow.ToRow()) + datum, err := us.Schema().Columns[idx].EvalVirtualColumn(sctx.GetExprCtx().GetEvalCtx(), us.mutableRow.ToRow()) if err != nil { return err } @@ -172,16 +185,71 @@ func (us *UnionScanExec) Next(ctx context.Context, req *chunk.Chunk) error { if (mysql.HasNotNullFlag(us.columns[idx].GetFlag()) || mysql.HasPreventNullInsertFlag(us.columns[idx].GetFlag())) && castDatum.IsNull() { castDatum = table.GetZeroValue(us.columns[idx]) } - mutableRow.SetDatum(idx, castDatum) + us.mutableRow.SetDatum(idx, castDatum) } - matched, _, err := expression.EvalBool(us.Ctx().GetExprCtx().GetEvalCtx(), us.conditionsWithVirCol, mutableRow.ToRow()) + matched, _, err := expression.EvalBool(us.Ctx().GetExprCtx().GetEvalCtx(), us.conditionsWithVirCol, us.mutableRow.ToRow()) if err != nil { return err } if matched { - req.AppendRow(mutableRow.ToRow()) + req.AppendRow(us.mutableRow.ToRow()) + } + } + return nil +} + +// nextForCacheTable is a fast path for cached tables. Since cached tables have +// no snapshot rows (getSnapshotRow always returns nil), we skip the merge logic +// and read directly from addedRowsIter. +func (us *UnionScanExec) nextForCacheTable(req *chunk.Chunk) error { + needsMutableRow := len(us.virtualColumnIndex) > 0 || len(us.conditionsWithVirCol) > 0 + for batchSize := req.Capacity(); req.NumRows() < batchSize; { + row, err := us.addedRowsIter.Next() + if err != nil { + return err + } + if row == nil { + return nil + } + + if !needsMutableRow { + // Fast path: no virtual columns and no conditions, + // write datums directly to req without the intermediate mutableRow copy. + for i := range row { + req.AppendDatum(i, &row[i]) + } + continue + } + + us.mutableRow.SetDatums(row...) + + sctx := us.Ctx() + for _, idx := range us.virtualColumnIndex { + datum, err := us.Schema().Columns[idx].EvalVirtualColumn(sctx.GetExprCtx().GetEvalCtx(), us.mutableRow.ToRow()) + if err != nil { + return err + } + castDatum, err := table.CastValue(us.Ctx(), datum, us.columns[idx], false, true) + if err != nil { + return err + } + if (mysql.HasNotNullFlag(us.columns[idx].GetFlag()) || mysql.HasPreventNullInsertFlag(us.columns[idx].GetFlag())) && castDatum.IsNull() { + castDatum = table.GetZeroValue(us.columns[idx]) + } + us.mutableRow.SetDatum(idx, castDatum) + } + + if len(us.conditionsWithVirCol) > 0 { + matched, _, err := expression.EvalBool(us.Ctx().GetExprCtx().GetEvalCtx(), us.conditionsWithVirCol, us.mutableRow.ToRow()) + if err != nil { + return err + } + if !matched { + continue + } } + req.AppendRow(us.mutableRow.ToRow()) } return nil } diff --git a/pkg/expression/builtin_compare.go b/pkg/expression/builtin_compare.go index 74fa13d33ef0d..6ec099f99dda8 100644 --- a/pkg/expression/builtin_compare.go +++ b/pkg/expression/builtin_compare.go @@ -3405,7 +3405,13 @@ func CompareInt(sctx EvalContext, lhsArg, rhsArg Expression, lhsRow, rhsRow chun return compareNull(isNull0, isNull1), true, nil } - isUnsigned0, isUnsigned1 := mysql.HasUnsignedFlag(lhsArg.GetType(sctx).GetFlag()), mysql.HasUnsignedFlag(rhsArg.GetType(sctx).GetFlag()) + getTypeFlag := func(arg Expression) uint { + if c, ok := arg.(*Constant); ok && c.ParamMarker == nil { + return c.RetType.GetFlag() + } + return arg.GetType(sctx).GetFlag() + } + isUnsigned0, isUnsigned1 := mysql.HasUnsignedFlag(getTypeFlag(lhsArg)), mysql.HasUnsignedFlag(getTypeFlag(rhsArg)) return int64(types.CompareInt(arg0, isUnsigned0, arg1, isUnsigned1)), false, nil } diff --git a/pkg/expression/constant.go b/pkg/expression/constant.go index 3dc6d5325ec8d..f522cbf80172e 100644 --- a/pkg/expression/constant.go +++ b/pkg/expression/constant.go @@ -243,6 +243,21 @@ func (c *Constant) GetType(ctx EvalContext) *types.FieldType { return c.RetType } +// getTypeNonAlloc fills buf for ParamMarker constants without heap allocation, +// or returns c.RetType for non-ParamMarker constants. +func (c *Constant) getTypeNonAlloc(ctx EvalContext, buf *types.FieldType) *types.FieldType { + if c.ParamMarker != nil { + types.InitUnspecifiedFieldType(buf) + dt, err := c.ParamMarker.GetUserVar(ctx) + if err != nil { + return nil + } + types.InferParamTypeFromDatum(&dt, buf) + return buf + } + return c.RetType +} + // VecEvalInt evaluates this expression in a vectorized manner. func (c *Constant) VecEvalInt(ctx EvalContext, input *chunk.Chunk, result *chunk.Column) error { if c.DeferredExpr == nil { @@ -363,12 +378,14 @@ func (c *Constant) EvalInt(ctx EvalContext, row chunk.Row) (int64, bool, error) if !lazy { dt = c.Value } - if c.GetType(ctx).GetType() == mysql.TypeNull || dt.IsNull() { + var tpBuf types.FieldType + tp := c.getTypeNonAlloc(ctx, &tpBuf) + if tp.GetType() == mysql.TypeNull || dt.IsNull() { return 0, true, nil } else if dt.Kind() == types.KindBinaryLiteral { val, err := dt.GetBinaryLiteral().ToInt(typeCtx(ctx)) return int64(val), err != nil, err - } else if c.GetType(ctx).Hybrid() || dt.Kind() == types.KindString { + } else if tp.Hybrid() || dt.Kind() == types.KindString { res, err := dt.ToInt64(typeCtx(ctx)) return res, false, err } else if dt.Kind() == types.KindMysqlBit { @@ -387,10 +404,12 @@ func (c *Constant) EvalReal(ctx EvalContext, row chunk.Row) (float64, bool, erro if !lazy { dt = c.Value } - if c.GetType(ctx).GetType() == mysql.TypeNull || dt.IsNull() { + var tpBuf types.FieldType + tp := c.getTypeNonAlloc(ctx, &tpBuf) + if tp.GetType() == mysql.TypeNull || dt.IsNull() { return 0, true, nil } - if c.GetType(ctx).Hybrid() || dt.Kind() == types.KindBinaryLiteral || dt.Kind() == types.KindString { + if tp.Hybrid() || dt.Kind() == types.KindBinaryLiteral || dt.Kind() == types.KindString { res, err := dt.ToFloat64(typeCtx(ctx)) return res, false, err } @@ -406,7 +425,9 @@ func (c *Constant) EvalString(ctx EvalContext, row chunk.Row) (string, bool, err if !lazy { dt = c.Value } - if c.GetType(ctx).GetType() == mysql.TypeNull || dt.IsNull() { + var tpBuf types.FieldType + tp := c.getTypeNonAlloc(ctx, &tpBuf) + if tp.GetType() == mysql.TypeNull || dt.IsNull() { return "", true, nil } res, err := dt.ToString() @@ -422,14 +443,16 @@ func (c *Constant) EvalDecimal(ctx EvalContext, row chunk.Row) (*types.MyDecimal if !lazy { dt = c.Value } - if c.GetType(ctx).GetType() == mysql.TypeNull || dt.IsNull() { + var tpBuf types.FieldType + tp := c.getTypeNonAlloc(ctx, &tpBuf) + if tp.GetType() == mysql.TypeNull || dt.IsNull() { return nil, true, nil } res, err := dt.ToDecimal(typeCtx(ctx)) if err != nil { return nil, false, err } - if err := c.adjustDecimal(ctx, res); err != nil { + if err := c.adjustDecimalWithType(tp, res); err != nil { return nil, false, err } return res, false, nil @@ -444,6 +467,14 @@ func (c *Constant) adjustDecimal(ctx EvalContext, d *types.MyDecimal) error { return nil } +func (c *Constant) adjustDecimalWithType(tp *types.FieldType, d *types.MyDecimal) error { + _, frac := d.PrecisionAndFrac() + if frac < tp.GetDecimal() { + return d.Round(d, tp.GetDecimal(), types.ModeHalfUp) + } + return nil +} + // EvalTime returns DATE/DATETIME/TIMESTAMP representation of Constant. func (c *Constant) EvalTime(ctx EvalContext, row chunk.Row) (val types.Time, isNull bool, err error) { dt, lazy, err := c.getLazyDatum(ctx, row) @@ -453,7 +484,9 @@ func (c *Constant) EvalTime(ctx EvalContext, row chunk.Row) (val types.Time, isN if !lazy { dt = c.Value } - if c.GetType(ctx).GetType() == mysql.TypeNull || dt.IsNull() { + var tpBuf types.FieldType + tp := c.getTypeNonAlloc(ctx, &tpBuf) + if tp.GetType() == mysql.TypeNull || dt.IsNull() { return types.ZeroTime, true, nil } return dt.GetMysqlTime(), false, nil @@ -468,7 +501,9 @@ func (c *Constant) EvalDuration(ctx EvalContext, row chunk.Row) (val types.Durat if !lazy { dt = c.Value } - if c.GetType(ctx).GetType() == mysql.TypeNull || dt.IsNull() { + var tpBuf types.FieldType + tp := c.getTypeNonAlloc(ctx, &tpBuf) + if tp.GetType() == mysql.TypeNull || dt.IsNull() { return types.Duration{}, true, nil } return dt.GetMysqlDuration(), false, nil @@ -483,7 +518,9 @@ func (c *Constant) EvalJSON(ctx EvalContext, row chunk.Row) (types.BinaryJSON, b if !lazy { dt = c.Value } - if c.GetType(ctx).GetType() == mysql.TypeNull || dt.IsNull() { + var tpBuf types.FieldType + tp := c.getTypeNonAlloc(ctx, &tpBuf) + if tp.GetType() == mysql.TypeNull || dt.IsNull() { return types.BinaryJSON{}, true, nil } return dt.GetMysqlJSON(), false, nil @@ -498,7 +535,9 @@ func (c *Constant) EvalVectorFloat32(ctx EvalContext, row chunk.Row) (types.Vect if !lazy { dt = c.Value } - if c.GetType(ctx).GetType() == mysql.TypeNull || dt.IsNull() { + var tpBuf types.FieldType + tp := c.getTypeNonAlloc(ctx, &tpBuf) + if tp.GetType() == mysql.TypeNull || dt.IsNull() { return types.ZeroVectorFloat32, true, nil } return dt.GetVectorFloat32(), false, nil diff --git a/pkg/expression/util.go b/pkg/expression/util.go index 474865b5ebd33..39627f2d4b987 100644 --- a/pkg/expression/util.go +++ b/pkg/expression/util.go @@ -1560,9 +1560,10 @@ func IsImmutableFunc(expr Expression) bool { } } -// RemoveDupExprs removes identical exprs. Not that if expr contains functions which -// are mutable or have side effects, we cannot remove it even if it has duplicates; -// if the plan is going to be cached, we cannot remove expressions containing `?` neither. +// RemoveDupExprs removes identical expressions. Note that if exprs contain functions +// that are mutable or have side effects, we cannot remove them even if they duplicate +// other expressions. If the plan is going to be cached, we cannot remove expressions +// containing `?` either. func RemoveDupExprs(exprs []Expression) []Expression { if len(exprs) <= 1 { return exprs diff --git a/pkg/infoschema/tables.go b/pkg/infoschema/tables.go index 3ce9897994e71..43d8e7588bbc1 100644 --- a/pkg/infoschema/tables.go +++ b/pkg/infoschema/tables.go @@ -968,6 +968,7 @@ var slowQueryCols = []columnInfo{ {name: variable.SlowLogSucc, tp: mysql.TypeTiny, size: 1}, {name: variable.SlowLogIsExplicitTxn, tp: mysql.TypeTiny, size: 1}, {name: variable.SlowLogIsWriteCacheTable, tp: mysql.TypeTiny, size: 1}, + {name: variable.SlowLogResultCacheHit, tp: mysql.TypeTiny, size: 1}, {name: variable.SlowLogPlanFromCache, tp: mysql.TypeTiny, size: 1}, {name: variable.SlowLogPlanFromBinding, tp: mysql.TypeTiny, size: 1}, {name: variable.SlowLogHasMoreResults, tp: mysql.TypeTiny, size: 1}, diff --git a/pkg/metrics/executor.go b/pkg/metrics/executor.go index 6c06293d8d56c..3148e39a50333 100644 --- a/pkg/metrics/executor.go +++ b/pkg/metrics/executor.go @@ -79,6 +79,18 @@ var ( // IndexLookUpCopTaskCount records the number of cop tasks in index look up executor IndexLookUpCopTaskCount *prometheus.CounterVec + + // ResultCacheHitCounter records the number of result cache hits on cached tables. + ResultCacheHitCounter prometheus.Counter + + // ResultCacheMissCounter records the number of result cache misses on cached tables. + ResultCacheMissCounter prometheus.Counter + + // ResultCacheMemoryGauge records the memory usage of result cache on cached tables. + ResultCacheMemoryGauge prometheus.Gauge + + // ResultCacheEvictCounter records the number of result cache evictions (lease expiry). + ResultCacheEvictCounter prometheus.Counter ) // InitExecutorMetrics initializes excutor metrics. @@ -200,4 +212,40 @@ func InitExecutorMetrics() { Name: "index_lookup_cop_task_count", Help: "Counter for index lookup cop tasks", }, []string{LblType}) + + ResultCacheHitCounter = metricscommon.NewCounter( + prometheus.CounterOpts{ + Namespace: "tidb", + Subsystem: "executor", + Name: "result_cache_hit_total", + Help: "Total number of result cache hits on cached tables.", + }, + ) + + ResultCacheMissCounter = metricscommon.NewCounter( + prometheus.CounterOpts{ + Namespace: "tidb", + Subsystem: "executor", + Name: "result_cache_miss_total", + Help: "Total number of result cache misses on cached tables.", + }, + ) + + ResultCacheMemoryGauge = metricscommon.NewGauge( + prometheus.GaugeOpts{ + Namespace: "tidb", + Subsystem: "executor", + Name: "result_cache_memory_bytes", + Help: "Memory usage of result cache on cached tables.", + }, + ) + + ResultCacheEvictCounter = metricscommon.NewCounter( + prometheus.CounterOpts{ + Namespace: "tidb", + Subsystem: "executor", + Name: "result_cache_evict_total", + Help: "Total number of result cache evictions (lease expiry).", + }, + ) } diff --git a/pkg/metrics/metrics.go b/pkg/metrics/metrics.go index 547d2f5bba0b7..02b3437601eff 100644 --- a/pkg/metrics/metrics.go +++ b/pkg/metrics/metrics.go @@ -321,6 +321,11 @@ func RegisterMetrics() { prometheus.MustRegister(InfoSchemaV2CacheObjCnt) prometheus.MustRegister(TableByNameDuration) + prometheus.MustRegister(ResultCacheHitCounter) + prometheus.MustRegister(ResultCacheMissCounter) + prometheus.MustRegister(ResultCacheMemoryGauge) + prometheus.MustRegister(ResultCacheEvictCounter) + prometheus.MustRegister(BindingCacheHitCounter) prometheus.MustRegister(BindingCacheMissCounter) prometheus.MustRegister(BindingCacheMemUsage) diff --git a/pkg/parser/ast/misc.go b/pkg/parser/ast/misc.go index cc257f7851f7b..4f0dad3fdfb22 100644 --- a/pkg/parser/ast/misc.go +++ b/pkg/parser/ast/misc.go @@ -4102,7 +4102,7 @@ func (n *TableOptimizerHint) Restore(ctx *format.RestoreCtx) error { } // Hints without args except query block. switch n.HintName.L { - case "mpp_1phase_agg", "mpp_2phase_agg", "hash_agg", "stream_agg", "agg_to_cop", "read_consistent_replica", "no_index_merge", "ignore_plan_cache", "limit_to_cop", "straight_join", "merge", "no_decorrelate": + case "mpp_1phase_agg", "mpp_2phase_agg", "hash_agg", "stream_agg", "agg_to_cop", "read_consistent_replica", "no_index_merge", "ignore_plan_cache", "use_plan_cache", "limit_to_cop", "straight_join", "merge", "no_decorrelate": ctx.WritePlain(")") return nil } diff --git a/pkg/planner/core/BUILD.bazel b/pkg/planner/core/BUILD.bazel index f977988f2a680..4a9edac1ff629 100644 --- a/pkg/planner/core/BUILD.bazel +++ b/pkg/planner/core/BUILD.bazel @@ -46,6 +46,8 @@ go_library( "property_cols_prune.go", "recheck_cte.go", "resolve_indices.go", + "result_cache_check.go", + "result_cache_key.go", "rule_aggregation_elimination.go", "rule_aggregation_push_down.go", "rule_aggregation_skew_rewrite.go", @@ -230,6 +232,8 @@ go_test( "plan_to_pb_test.go", "planbuilder_test.go", "preprocess_test.go", + "result_cache_check_test.go", + "result_cache_key_test.go", "rule_generate_column_substitute_test.go", "rule_join_reorder_dp_test.go", "runtime_filter_generator_test.go", @@ -285,6 +289,7 @@ go_test( "//pkg/session/sessionapi", "//pkg/session/sessmgr", "//pkg/sessionctx", + "//pkg/sessionctx/stmtctx", "//pkg/sessionctx/vardef", "//pkg/sessionctx/variable", "//pkg/sessiontxn", diff --git a/pkg/planner/core/casetest/plancache/BUILD.bazel b/pkg/planner/core/casetest/plancache/BUILD.bazel index 25f20e76c21e6..5875bbd4cbb68 100644 --- a/pkg/planner/core/casetest/plancache/BUILD.bazel +++ b/pkg/planner/core/casetest/plancache/BUILD.bazel @@ -17,7 +17,7 @@ go_test( "//pkg/planner/core:plan_clone_utils.go", ], flaky = True, - shard_count = 45, + shard_count = 47, deps = [ "//pkg/expression", "//pkg/infoschema", diff --git a/pkg/planner/core/casetest/plancache/plan_cache_suite_test.go b/pkg/planner/core/casetest/plancache/plan_cache_suite_test.go index 6c7dbd47527a3..c57acda82de8f 100644 --- a/pkg/planner/core/casetest/plancache/plan_cache_suite_test.go +++ b/pkg/planner/core/casetest/plancache/plan_cache_suite_test.go @@ -1811,6 +1811,95 @@ func runPreparedPlanCacheForUpdateInTxn(t *testing.T, tk *testkit.TestKit) { tk.MustExec(`deallocate prepare st`) } +func TestPreparedPlanCacheHintOnlyWithoutUsePlanCacheHint(t *testing.T) { + store := testkit.CreateMockStore(t) + tk := testkit.NewTestKit(t, store) + tk.MustExec("use test") + + preparedCache := tk.MustQuery("select @@session.tidb_enable_prepared_plan_cache").Rows()[0][0] + planCachePolicy := tk.MustQuery("select @@session.tidb_plan_cache_policy").Rows()[0][0] + defer func() { + tk.MustExec(fmt.Sprintf("set @@session.tidb_enable_prepared_plan_cache=%v", preparedCache)) + tk.MustExec(fmt.Sprintf("set @@session.tidb_plan_cache_policy=%q", planCachePolicy)) + }() + tk.MustExec(`set @@session.tidb_enable_prepared_plan_cache=1`) + tk.MustExec(`set @@session.tidb_plan_cache_policy='hint_only'`) + + tableName := "t_prepare_hint_only_without_use_hint" + tk.MustExec(fmt.Sprintf("drop table if exists %s", tableName)) + tk.MustExec(fmt.Sprintf("create table %s (a int)", tableName)) + tk.MustExec(fmt.Sprintf("insert into %s values (1)", tableName)) + + tk.MustExec(fmt.Sprintf("prepare st from 'select 1 from %s where a = ?'", tableName)) + tk.MustExec("set @a=1") + tk.MustExec("execute st using @a") + tk.MustQuery(`select @@last_plan_from_cache`).Check(testkit.Rows("0")) + tk.MustExec("execute st using @a") + tk.MustQuery(`select @@last_plan_from_cache`).Check(testkit.Rows("0")) + + tk.MustExec(fmt.Sprintf( + "create global binding for select 1 from %s where a = ? using select /*+ use_plan_cache() */ 1 from %s where a = ?", + tableName, tableName, + )) + tk.MustExec("execute st using @a") + tk.MustQuery("select @@last_plan_from_binding, @@last_plan_from_cache").Check(testkit.Rows("1 0")) + tk.MustExec("execute st using @a") + tk.MustQuery("select @@last_plan_from_binding, @@last_plan_from_cache").Check(testkit.Rows("1 1")) + tk.MustExec("execute st using @a") + tk.MustQuery("select @@last_plan_from_binding, @@last_plan_from_cache").Check(testkit.Rows("1 1")) +} + +func TestPreparedPlanCacheHintOnlyWithBinding(t *testing.T) { + store := testkit.CreateMockStore(t) + tk := testkit.NewTestKit(t, store) + tk.MustExec("use test") + + preparedCache := tk.MustQuery("select @@session.tidb_enable_prepared_plan_cache").Rows()[0][0] + planCachePolicy := tk.MustQuery("select @@session.tidb_plan_cache_policy").Rows()[0][0] + defer func() { + tk.MustExec(fmt.Sprintf("set @@session.tidb_enable_prepared_plan_cache=%v", preparedCache)) + tk.MustExec(fmt.Sprintf("set @@session.tidb_plan_cache_policy=%q", planCachePolicy)) + }() + tk.MustExec(`set @@session.tidb_enable_prepared_plan_cache=1`) + tk.MustExec(`set @@session.tidb_plan_cache_policy='hint_only'`) + + tableName := "t_prepare_hint_only_binding" + tk.MustExec(fmt.Sprintf("drop table if exists %s", tableName)) + tk.MustExec(fmt.Sprintf("create table %s (pk int, a int, primary key(pk))", tableName)) + tk.MustExec(fmt.Sprintf("insert into %s values (1, 1), (2, 2)", tableName)) + + tk.MustExec(fmt.Sprintf("prepare st from 'select * from %s where pk >= ?'", tableName)) + tk.MustExec("set @a=1") + tk.MustExec("execute st using @a") + tk.MustQuery(`select @@last_plan_from_cache`).Check(testkit.Rows("0")) + tk.MustExec("execute st using @a") + tk.MustQuery(`select @@last_plan_from_cache`).Check(testkit.Rows("0")) + tk.MustExec("execute st using @a") + tk.MustQuery(`select @@last_plan_from_cache`).Check(testkit.Rows("0")) + + tk.MustExec(fmt.Sprintf( + "CREATE BINDING FOR select * from %s where pk >= ? USING select /*+ use_plan_cache() */ * from %s where pk >= ?", + tableName, tableName, + )) + tk.MustExec("execute st using @a") + tk.MustQuery("select @@last_plan_from_binding, @@last_plan_from_cache").Check(testkit.Rows("1 0")) + tk.MustExec("execute st using @a") + tk.MustQuery("select @@last_plan_from_binding, @@last_plan_from_cache").Check(testkit.Rows("1 1")) + tk.MustExec("execute st using @a") + tk.MustQuery(`select @@last_plan_from_cache`).Check(testkit.Rows("1")) + + tk.MustExec(fmt.Sprintf( + "CREATE BINDING FOR select * from %s where pk >= ? USING select /*+ ignore_plan_cache() */ * from %s where pk >= ?", + tableName, tableName, + )) + tk.MustExec("execute st using @a") + tk.MustQuery(`select @@last_plan_from_cache`).Check(testkit.Rows("0")) + tk.MustExec("execute st using @a") + tk.MustQuery(`select @@last_plan_from_cache`).Check(testkit.Rows("0")) + tk.MustExec("execute st using @a") + tk.MustQuery(`select @@last_plan_from_cache`).Check(testkit.Rows("0")) +} + func TestNonPreparedPlanCacheSupportsFeatures(t *testing.T) { store := testkit.CreateMockStore(t) tk := testkit.NewTestKit(t, store) diff --git a/pkg/planner/core/plan_cache.go b/pkg/planner/core/plan_cache.go index 23ce7512b8f3d..8486762497903 100644 --- a/pkg/planner/core/plan_cache.go +++ b/pkg/planner/core/plan_cache.go @@ -20,6 +20,7 @@ import ( "time" "github.com/pingcap/errors" + "github.com/pingcap/tidb/pkg/bindinfo" "github.com/pingcap/tidb/pkg/domain" "github.com/pingcap/tidb/pkg/expression" "github.com/pingcap/tidb/pkg/infoschema" @@ -194,6 +195,7 @@ func GetPlanFromPlanCache(ctx context.Context, sctx sessionctx.Context, sessVars := sctx.GetSessionVars() stmtCtx := sessVars.StmtCtx + var matchedBinding *bindinfo.Binding cacheEnabled := false if isNonPrepared { stmtCtx.SetCacheType(contextutil.SessionNonPrepared) @@ -203,7 +205,14 @@ func GetPlanFromPlanCache(ctx context.Context, sctx sessionctx.Context, cacheEnabled = sessVars.EnablePreparedPlanCache } if stmt.StmtCacheable && cacheEnabled { - stmtCtx.EnablePlanCache() + if sessVars.PlanCachePolicy == vardef.PlanCachePolicyHintOnly && !stmt.UsePlanCacheHint { + matchedBinding = matchSQLBindingWithCache(sctx, stmt) + } + if allowPlanCacheByPolicy(sctx, stmt, matchedBinding) { + stmtCtx.EnablePlanCache() + } else { + stmtCtx.WarnSkipPlanCache("the switch 'tidb_plan_cache_policy' is set to hint_only and no USE_PLAN_CACHE() hint is found") + } } if stmt.UncacheableReason != "" { stmtCtx.WarnSkipPlanCache(stmt.UncacheableReason) @@ -212,7 +221,10 @@ func GetPlanFromPlanCache(ctx context.Context, sctx sessionctx.Context, var cacheKey, binding, reason string var cacheable bool if stmtCtx.UseCache() { - cacheKey, binding, cacheable, reason, err = NewPlanCacheKey(sctx, stmt) + if matchedBinding == nil { + matchedBinding = matchSQLBindingWithCache(sctx, stmt) + } + cacheKey, binding, cacheable, reason, err = newPlanCacheKeyWithMatchedBinding(sctx, stmt, matchedBinding) if err != nil { return nil, nil, err } @@ -238,6 +250,30 @@ func GetPlanFromPlanCache(ctx context.Context, sctx sessionctx.Context, return generateNewPlan(ctx, sctx, isNonPrepared, is, stmt, cacheKey, binding, paramTypes) } +func allowPlanCacheByPolicy(sctx sessionctx.Context, stmt *PlanCacheStmt, matchedBinding *bindinfo.Binding) bool { + if sctx.GetSessionVars().PlanCachePolicy != vardef.PlanCachePolicyHintOnly { + return true + } + if stmt.UsePlanCacheHint { + return true + } + return bindingHasUsePlanCacheHint(matchedBinding) +} + +func matchSQLBindingWithCache(sctx sessionctx.Context, stmt *PlanCacheStmt) *bindinfo.Binding { + matchedBinding, matched, _ := bindinfo.MatchSQLBindingWithCache(sctx, stmt.PreparedAst.Stmt, &stmt.BindingInfo) + if !matched { + return nil + } + return matchedBinding +} + +func bindingHasUsePlanCacheHint(matchedBinding *bindinfo.Binding) bool { + return matchedBinding != nil && + matchedBinding.Hint != nil && + matchedBinding.Hint.ContainTableHint(hint.HintUsePlanCache) +} + func clonePlanForInstancePlanCache(ctx context.Context, sctx sessionctx.Context, stmt *PlanCacheStmt, plan base.Plan) (clonedPlan base.Plan, ok bool) { defer func(begin time.Time) { diff --git a/pkg/planner/core/plan_cache_utils.go b/pkg/planner/core/plan_cache_utils.go index a8df9017af6f4..5fc7e4d4687fe 100644 --- a/pkg/planner/core/plan_cache_utils.go +++ b/pkg/planner/core/plan_cache_utils.go @@ -92,6 +92,7 @@ func (e *paramMarkerExtractor) Leave(in ast.Node) (ast.Node, bool) { func GeneratePlanCacheStmtWithAST(ctx context.Context, sctx sessionctx.Context, isPrepStmt bool, paramSQL string, paramStmt ast.StmtNode, is infoschema.InfoSchema) (*PlanCacheStmt, base.Plan, int, error) { vars := sctx.GetSessionVars() + usePlanCacheHint := hasUsePlanCacheHint(paramStmt) var extractor paramMarkerExtractor paramStmt.Accept(&extractor) @@ -222,6 +223,7 @@ func GeneratePlanCacheStmtWithAST(ctx context.Context, sctx sessionctx.Context, SnapshotTSEvaluator: ret.SnapshotTSEvaluator, StmtCacheable: cacheable, UncacheableReason: reason, + UsePlanCacheHint: usePlanCacheHint, dbName: dbName, tbls: tbls, SchemaVersion: ret.InfoSchema.SchemaMetaVersion(), @@ -238,6 +240,15 @@ func GeneratePlanCacheStmtWithAST(ctx context.Context, sctx sessionctx.Context, return preparedObj, p, paramCount, nil } +func hasUsePlanCacheHint(stmt ast.StmtNode) bool { + for _, h := range hint.ExtractTableHintsFromStmtNode(stmt, nil) { + if h.HintName.L == hint.HintUsePlanCache { + return true + } + } + return false +} + // tableIDSlicePool is a pool for int64 slices used in hashInt64Uint64Map. var tableIDSlicePool = zeropool.New[[]int64](func() []int64 { return make([]int64, 0, 8) @@ -310,7 +321,15 @@ func hashInt64Uint64Map(b []byte, m map[int64]uint64) []byte { // differentiate the cache key. In other cases, it will be 0. // All information that might affect the plan should be considered in this function. func NewPlanCacheKey(sctx sessionctx.Context, stmt *PlanCacheStmt) (key, binding string, cacheable bool, reason string, err error) { - if matchedBinding, matched, _ := bindinfo.MatchSQLBindingWithCache(sctx, stmt.PreparedAst.Stmt, &stmt.BindingInfo); matched { + matchedBinding, matched, _ := bindinfo.MatchSQLBindingWithCache(sctx, stmt.PreparedAst.Stmt, &stmt.BindingInfo) + if !matched { + matchedBinding = nil + } + return newPlanCacheKeyWithMatchedBinding(sctx, stmt, matchedBinding) +} + +func newPlanCacheKeyWithMatchedBinding(sctx sessionctx.Context, stmt *PlanCacheStmt, matchedBinding *bindinfo.Binding) (key, binding string, cacheable bool, reason string, err error) { + if matchedBinding != nil { // Record the matched binding SQL so the plan cache key reflects the effective hints. binding = matchedBinding.BindSQL } @@ -736,6 +755,7 @@ type PlanCacheStmt struct { StmtCacheable bool // Whether this stmt is cacheable. UncacheableReason string // Why this stmt is uncacheable. + UsePlanCacheHint bool // Whether this stmt contains the USE_PLAN_CACHE() hint. limits []*ast.Limit hasSubquery bool diff --git a/pkg/planner/core/result_cache_check.go b/pkg/planner/core/result_cache_check.go new file mode 100644 index 0000000000000..86afc604edb94 --- /dev/null +++ b/pkg/planner/core/result_cache_check.go @@ -0,0 +1,273 @@ +// Copyright 2025 PingCAP, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package core + +import ( + "slices" + "strings" + + "github.com/pingcap/tidb/pkg/expression" + "github.com/pingcap/tidb/pkg/meta/model" + "github.com/pingcap/tidb/pkg/parser/ast" + "github.com/pingcap/tidb/pkg/planner/core/base" + "github.com/pingcap/tidb/pkg/planner/core/operator/physicalop" +) + +// resultCacheNonDeterministicFuncs lists functions whose results are +// non-deterministic, session-dependent, or have side effects, making query +// results non-cacheable. +// This mirrors expression.mutableEffectsFunctions and extends it with some +// session-dependent/side-effect functions that may be constant-folded during +// planning, making them undetectable in the physical plan. +var resultCacheNonDeterministicFuncs = map[string]struct{}{ + ast.Now: {}, ast.CurrentTimestamp: {}, ast.UTCTime: {}, + ast.Curtime: {}, ast.CurrentTime: {}, ast.UTCTimestamp: {}, + ast.UnixTimestamp: {}, ast.Sysdate: {}, ast.Curdate: {}, + ast.CurrentDate: {}, ast.UTCDate: {}, + ast.LocalTime: {}, ast.LocalTimestamp: {}, + ast.Rand: {}, ast.RandomBytes: {}, + ast.UUID: {}, ast.UUIDShort: {}, + ast.Sleep: {}, ast.SetVar: {}, ast.GetVar: {}, + ast.AnyValue: {}, + ast.ConnectionID: {}, ast.CurrentUser: {}, ast.User: {}, + ast.SessionUser: {}, ast.SystemUser: {}, + ast.Database: {}, ast.Schema: {}, ast.CurrentRole: {}, ast.CurrentResourceGroup: {}, + ast.LastInsertId: {}, ast.RowCount: {}, ast.FoundRows: {}, + ast.NextVal: {}, ast.LastVal: {}, ast.SetVal: {}, + ast.GetLock: {}, ast.ReleaseLock: {}, ast.IsFreeLock: {}, ast.IsUsedLock: {}, + ast.ReleaseAllLocks: {}, +} + +// CanCacheResultSet checks whether a physical plan's result set can be cached +// in the result set cache of cached tables. +// The stmtNode is the original AST node, needed because the optimizer folds +// non-deterministic functions (NOW, RAND, etc.) into constants, making them +// undetectable in the physical plan. +// Returns true only when ALL of the following conditions are met: +// 1. Not in DML context (inDML is false) +// 2. No FOR UPDATE / FOR SHARE lock +// 3. No non-deterministic or side-effect functions (NOW, RAND, UUID, SLEEP, etc.) +// 4. No user/session variable references (@var, @@var) +// 5. All accessed tables are cached tables +func CanCacheResultSet(stmtNode ast.StmtNode, plan base.PhysicalPlan, inDML bool) bool { + if inDML { + return false + } + // Point get plans bypass non-prepared plan cache parameterization. + // Skip result set caching for them. + switch plan.(type) { + case *physicalop.PointGetPlan, *physicalop.BatchPointGetPlan: + return false + } + // Check AST for non-deterministic functions and variable references + // that get folded to constants during optimization. + if hasMutableExprInAST(stmtNode) { + return false + } + return checkPlanTreeCacheable(plan) +} + +// hasMutableExprInAST walks the AST to detect non-deterministic functions +// and variable references that would make the result set non-cacheable. +func hasMutableExprInAST(node ast.Node) bool { + if node == nil { + return true + } + checker := &mutableExprChecker{} + node.Accept(checker) + return checker.found +} + +// mutableExprChecker is an AST visitor that detects non-deterministic +// functions and variable references. +type mutableExprChecker struct { + found bool +} + +func (c *mutableExprChecker) Enter(in ast.Node) (ast.Node, bool) { + if c.found { + return in, true + } + switch node := in.(type) { + case *ast.FuncCallExpr: + if _, ok := resultCacheNonDeterministicFuncs[node.FnName.L]; ok { + c.found = true + return in, true + } + case *ast.AggregateFuncExpr: + if _, ok := resultCacheNonDeterministicFuncs[strings.ToLower(node.F)]; ok { + c.found = true + return in, true + } + case *ast.VariableExpr: + // User variables (@var) and session variables (@@var). + c.found = true + return in, true + } + return in, false +} + +func (c *mutableExprChecker) Leave(in ast.Node) (ast.Node, bool) { + return in, !c.found +} + +// checkPlanTreeCacheable recursively checks whether the plan tree is cacheable. +func checkPlanTreeCacheable(plan base.PhysicalPlan) bool { + if plan == nil { + return false + } + if !checkNodeCacheable(plan) { + return false + } + // Recurse into normal children. + for _, child := range plan.Children() { + if !checkPlanTreeCacheable(child) { + return false + } + } + // For reader types, the inner plans (tablePlan/indexPlan) are NOT exposed + // via Children(). We must traverse them explicitly. + switch x := plan.(type) { + case *physicalop.PhysicalTableReader: + if x.TablePlan != nil && !checkPlanTreeCacheable(x.TablePlan) { + return false + } + case *physicalop.PhysicalIndexReader: + if x.IndexPlan != nil && !checkPlanTreeCacheable(x.IndexPlan) { + return false + } + case *physicalop.PhysicalIndexLookUpReader: + if x.IndexPlan != nil && !checkPlanTreeCacheable(x.IndexPlan) { + return false + } + if x.TablePlan != nil && !checkPlanTreeCacheable(x.TablePlan) { + return false + } + case *physicalop.PhysicalIndexMergeReader: + for _, partial := range x.PartialPlansRaw { + if !checkPlanTreeCacheable(partial) { + return false + } + } + if x.TablePlan != nil && !checkPlanTreeCacheable(x.TablePlan) { + return false + } + } + return true +} + +// checkNodeCacheable checks a single plan node for cacheability. +func checkNodeCacheable(plan base.PhysicalPlan) bool { + if plan == nil { + return false + } + // 1. Check FOR UPDATE / FOR SHARE lock. + if lock, ok := plan.(*physicalop.PhysicalLock); ok { + if lock.Lock != nil && lock.Lock.LockType != ast.SelectLockNone { + return false + } + } + // 2. Check that scanned tables are cached tables. + switch x := plan.(type) { + case *physicalop.PhysicalUnionScan: + // Cached-table reads also use UnionScan machinery. Only reject the + // session-local case where UnionScan merges dirty rows from the current + // transaction. + for _, dirty := range plan.SCtx().GetSessionVars().StmtCtx.TblInfo2UnionScan { + if dirty { + return false + } + } + case *physicalop.PhysicalMemTable: + return false + case *physicalop.PhysicalTableScan: + if x.Table.TableCacheStatusType != model.TableCacheStatusEnable { + return false + } + case *physicalop.PhysicalIndexScan: + if x.Table.TableCacheStatusType != model.TableCacheStatusEnable { + return false + } + case *physicalop.PointGetPlan: + if x.TblInfo.TableCacheStatusType != model.TableCacheStatusEnable { + return false + } + if x.Lock { + return false + } + case *physicalop.BatchPointGetPlan: + if x.TblInfo.TableCacheStatusType != model.TableCacheStatusEnable { + return false + } + if x.Lock { + return false + } + } + // 3. Check expressions for non-deterministic / side-effect functions + // that survived optimization (e.g., in WHERE conditions pushed to scans). + if slices.ContainsFunc(collectNodeExprs(plan), expression.IsMutableEffectsExpr) { + return false + } + return true +} + +// collectNodeExprs gathers all user-visible expressions from a plan node. +func collectNodeExprs(plan base.PhysicalPlan) []expression.Expression { + var exprs []expression.Expression + switch x := plan.(type) { + case *physicalop.PhysicalSelection: + exprs = append(exprs, x.Conditions...) + case *physicalop.PhysicalProjection: + exprs = append(exprs, x.Exprs...) + case *physicalop.PhysicalTableScan: + exprs = append(exprs, x.AccessCondition...) + exprs = append(exprs, x.FilterCondition...) + case *physicalop.PhysicalIndexScan: + exprs = append(exprs, x.AccessCondition...) + case *physicalop.PhysicalUnionScan: + exprs = append(exprs, x.Conditions...) + case *physicalop.PhysicalSort: + for _, item := range x.ByItems { + exprs = append(exprs, item.Expr) + } + case *physicalop.PhysicalTopN: + for _, item := range x.ByItems { + exprs = append(exprs, item.Expr) + } + case *physicalop.PhysicalHashAgg: + exprs = append(exprs, x.GroupByItems...) + for _, f := range x.AggFuncs { + exprs = append(exprs, f.Args...) + } + case *physicalop.PhysicalStreamAgg: + exprs = append(exprs, x.GroupByItems...) + for _, f := range x.AggFuncs { + exprs = append(exprs, f.Args...) + } + case *physicalop.PhysicalHashJoin: + exprs = append(exprs, x.LeftConditions...) + exprs = append(exprs, x.RightConditions...) + exprs = append(exprs, x.OtherConditions...) + case *physicalop.PhysicalMergeJoin: + exprs = append(exprs, x.LeftConditions...) + exprs = append(exprs, x.RightConditions...) + exprs = append(exprs, x.OtherConditions...) + case *physicalop.PointGetPlan: + exprs = append(exprs, x.AccessConditions...) + case *physicalop.BatchPointGetPlan: + exprs = append(exprs, x.AccessConditions...) + } + return exprs +} diff --git a/pkg/planner/core/result_cache_check_test.go b/pkg/planner/core/result_cache_check_test.go new file mode 100644 index 0000000000000..f68cbaf6c3e9d --- /dev/null +++ b/pkg/planner/core/result_cache_check_test.go @@ -0,0 +1,321 @@ +// Copyright 2025 PingCAP, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package core_test + +import ( + "context" + "testing" + + "github.com/pingcap/tidb/pkg/parser/ast" + "github.com/pingcap/tidb/pkg/planner" + "github.com/pingcap/tidb/pkg/planner/core" + "github.com/pingcap/tidb/pkg/planner/core/base" + "github.com/pingcap/tidb/pkg/planner/core/operator/physicalop" + "github.com/pingcap/tidb/pkg/planner/core/resolve" + "github.com/pingcap/tidb/pkg/session" + "github.com/pingcap/tidb/pkg/sessionctx" + "github.com/pingcap/tidb/pkg/testkit" + "github.com/stretchr/testify/require" +) + +type compiledQuery struct { + stmtNode ast.StmtNode + plan base.PhysicalPlan +} + +// compileQuery parses and optimizes a SQL query, returning the AST and physical plan. +func compileQuery(t *testing.T, tk *testkit.TestKit, sql string) compiledQuery { + t.Helper() + ctx := tk.Session().(sessionctx.Context) + stmts, err := session.Parse(ctx, sql) + require.NoError(t, err) + require.Len(t, stmts, 1) + nodeW := resolve.NewNodeW(stmts[0]) + ret := &core.PreprocessorReturn{} + err = core.Preprocess(context.Background(), ctx, nodeW, core.WithPreprocessorReturn(ret)) + require.NoError(t, err) + p, _, err := planner.Optimize(context.TODO(), ctx, nodeW, ret.InfoSchema) + require.NoError(t, err) + pp, ok := p.(base.PhysicalPlan) + require.True(t, ok, "expected PhysicalPlan, got %T", p) + return compiledQuery{stmtNode: stmts[0], plan: pp} +} + +func hasPhysicalUnionScan(plan base.PhysicalPlan) bool { + if plan == nil { + return false + } + flat := core.FlattenPhysicalPlan(plan, false) + if flat == nil { + return false + } + for _, op := range flat.Main { + if _, ok := op.Origin.(*physicalop.PhysicalUnionScan); ok { + return true + } + } + return false +} + +func TestCanCache_PointGet(t *testing.T) { + store := testkit.CreateMockStore(t) + tk := testkit.NewTestKit(t, store) + tk.MustExec("use test") + tk.MustExec("create table cached_t (id int primary key, v int)") + tk.MustExec("alter table cached_t cache") + + // PointGetPlan should not be cacheable (bypasses plan cache parameterization). + q := compileQuery(t, tk, "select * from cached_t where id = 1") + switch q.plan.(type) { + case *core.PointGetPlan: + require.False(t, core.CanCacheResultSet(q.stmtNode, q.plan, false)) + default: + // If the optimizer didn't choose PointGet, this test is not applicable + // but the simple select test below covers non-PointGet plans. + t.Logf("optimizer did not choose PointGetPlan for this query (got %T), skipping", q.plan) + } +} + +func TestCanCache_BatchPointGet(t *testing.T) { + store := testkit.CreateMockStore(t) + tk := testkit.NewTestKit(t, store) + tk.MustExec("use test") + tk.MustExec("create table cached_t (id int primary key, v int)") + tk.MustExec("alter table cached_t cache") + + // BatchPointGetPlan should not be cacheable. + q := compileQuery(t, tk, "select * from cached_t where id in (1, 2, 3)") + switch q.plan.(type) { + case *core.BatchPointGetPlan: + require.False(t, core.CanCacheResultSet(q.stmtNode, q.plan, false)) + default: + t.Logf("optimizer did not choose BatchPointGetPlan for this query (got %T), skipping", q.plan) + } +} + +func TestCanCache_SimpleSelect(t *testing.T) { + store := testkit.CreateMockStore(t) + tk := testkit.NewTestKit(t, store) + tk.MustExec("use test") + tk.MustExec("create table cached_t (id int primary key, v int)") + tk.MustExec("alter table cached_t cache") + + // Use a non-PK filter so the optimizer picks a table scan, not a PointGetPlan. + q := compileQuery(t, tk, "select * from cached_t where v = 1") + require.True(t, core.CanCacheResultSet(q.stmtNode, q.plan, false)) +} + +func TestCanCache_WithNow(t *testing.T) { + store := testkit.CreateMockStore(t) + tk := testkit.NewTestKit(t, store) + tk.MustExec("use test") + tk.MustExec("create table cached_t (id int primary key, v int)") + tk.MustExec("alter table cached_t cache") + + q := compileQuery(t, tk, "select now(), id from cached_t") + require.False(t, core.CanCacheResultSet(q.stmtNode, q.plan, false)) +} + +func TestCanCache_WithRand(t *testing.T) { + store := testkit.CreateMockStore(t) + tk := testkit.NewTestKit(t, store) + tk.MustExec("use test") + tk.MustExec("create table cached_t (id int primary key, v int)") + tk.MustExec("alter table cached_t cache") + + q := compileQuery(t, tk, "select * from cached_t where v > rand()") + require.False(t, core.CanCacheResultSet(q.stmtNode, q.plan, false)) +} + +func TestCanCache_ForUpdate(t *testing.T) { + store := testkit.CreateMockStore(t) + tk := testkit.NewTestKit(t, store) + tk.MustExec("use test") + tk.MustExec("create table cached_t (id int primary key, v int)") + tk.MustExec("alter table cached_t cache") + + // FOR UPDATE needs a transaction context. + tk.MustExec("begin") + q := compileQuery(t, tk, "select * from cached_t where id = 1 for update") + require.False(t, core.CanCacheResultSet(q.stmtNode, q.plan, false)) + tk.MustExec("rollback") +} + +func TestCanCache_JoinAllCached(t *testing.T) { + store := testkit.CreateMockStore(t) + tk := testkit.NewTestKit(t, store) + tk.MustExec("use test") + tk.MustExec("create table cached_a (id int primary key, v int)") + tk.MustExec("create table cached_b (id int primary key, v int)") + tk.MustExec("alter table cached_a cache") + tk.MustExec("alter table cached_b cache") + + q := compileQuery(t, tk, "select * from cached_a a join cached_b b on a.id = b.id") + require.True(t, core.CanCacheResultSet(q.stmtNode, q.plan, false)) +} + +func TestCanCache_JoinMixed(t *testing.T) { + store := testkit.CreateMockStore(t) + tk := testkit.NewTestKit(t, store) + tk.MustExec("use test") + tk.MustExec("create table cached_t (id int primary key, v int)") + tk.MustExec("create table normal_t (id int primary key, v int)") + tk.MustExec("alter table cached_t cache") + + q := compileQuery(t, tk, "select * from cached_t a join normal_t b on a.id = b.id") + require.False(t, core.CanCacheResultSet(q.stmtNode, q.plan, false)) +} + +func TestCanCache_SubqueryNonCached(t *testing.T) { + store := testkit.CreateMockStore(t) + tk := testkit.NewTestKit(t, store) + tk.MustExec("use test") + tk.MustExec("create table cached_t (id int primary key, v int)") + tk.MustExec("create table normal_t (id int primary key, v int)") + tk.MustExec("alter table cached_t cache") + + q := compileQuery(t, tk, "select * from cached_t where id in (select id from normal_t)") + require.False(t, core.CanCacheResultSet(q.stmtNode, q.plan, false)) +} + +func TestCanCache_InDML(t *testing.T) { + store := testkit.CreateMockStore(t) + tk := testkit.NewTestKit(t, store) + tk.MustExec("use test") + tk.MustExec("create table cached_t (id int primary key, v int)") + tk.MustExec("alter table cached_t cache") + + q := compileQuery(t, tk, "select * from cached_t where id = 1") + require.False(t, core.CanCacheResultSet(q.stmtNode, q.plan, true)) +} + +func TestCanCache_SessionVar(t *testing.T) { + store := testkit.CreateMockStore(t) + tk := testkit.NewTestKit(t, store) + tk.MustExec("use test") + tk.MustExec("create table cached_t (id int primary key, v int)") + tk.MustExec("alter table cached_t cache") + + // User variable @a uses GET_VAR which is in mutableEffectsFunctions. + q := compileQuery(t, tk, "select @a, id from cached_t") + require.False(t, core.CanCacheResultSet(q.stmtNode, q.plan, false)) +} + +func TestCanCache_UUID(t *testing.T) { + store := testkit.CreateMockStore(t) + tk := testkit.NewTestKit(t, store) + tk.MustExec("use test") + tk.MustExec("create table cached_t (id int primary key, v int)") + tk.MustExec("alter table cached_t cache") + + q := compileQuery(t, tk, "select uuid(), id from cached_t") + require.False(t, core.CanCacheResultSet(q.stmtNode, q.plan, false)) +} + +func TestCanCache_NonCachedTable(t *testing.T) { + store := testkit.CreateMockStore(t) + tk := testkit.NewTestKit(t, store) + tk.MustExec("use test") + tk.MustExec("create table normal_t (id int primary key, v int)") + + q := compileQuery(t, tk, "select * from normal_t where id = 1") + require.False(t, core.CanCacheResultSet(q.stmtNode, q.plan, false)) +} + +func TestCanCache_SystemVar(t *testing.T) { + store := testkit.CreateMockStore(t) + tk := testkit.NewTestKit(t, store) + tk.MustExec("use test") + tk.MustExec("create table cached_t (id int primary key, v int)") + tk.MustExec("alter table cached_t cache") + + // System variable @@sql_mode. + q := compileQuery(t, tk, "select @@sql_mode, id from cached_t") + require.False(t, core.CanCacheResultSet(q.stmtNode, q.plan, false)) +} + +func TestCanCache_DatabaseFunc(t *testing.T) { + store := testkit.CreateMockStore(t) + tk := testkit.NewTestKit(t, store) + tk.MustExec("use test") + tk.MustExec("create table cached_t (id int primary key, v int)") + tk.MustExec("alter table cached_t cache") + + q := compileQuery(t, tk, "select database(), id from cached_t") + require.False(t, core.CanCacheResultSet(q.stmtNode, q.plan, false)) +} + +func TestCanCache_SchemaFunc(t *testing.T) { + store := testkit.CreateMockStore(t) + tk := testkit.NewTestKit(t, store) + tk.MustExec("use test") + tk.MustExec("create table cached_t (id int primary key, v int)") + tk.MustExec("alter table cached_t cache") + + q := compileQuery(t, tk, "select schema(), id from cached_t") + require.False(t, core.CanCacheResultSet(q.stmtNode, q.plan, false)) +} + +func TestCanCache_SessionUser(t *testing.T) { + store := testkit.CreateMockStore(t) + tk := testkit.NewTestKit(t, store) + tk.MustExec("use test") + tk.MustExec("create table cached_t (id int primary key, v int)") + tk.MustExec("alter table cached_t cache") + + q := compileQuery(t, tk, "select session_user(), id from cached_t") + require.False(t, core.CanCacheResultSet(q.stmtNode, q.plan, false)) +} + +func TestCanCache_ReleaseAllLocks(t *testing.T) { + store := testkit.CreateMockStore(t) + tk := testkit.NewTestKit(t, store) + tk.MustExec("use test") + tk.MustExec("create table cached_t (id int primary key, v int)") + tk.MustExec("alter table cached_t cache") + + q := compileQuery(t, tk, "select release_all_locks(), id from cached_t") + require.False(t, core.CanCacheResultSet(q.stmtNode, q.plan, false)) +} + +func TestCanCache_InfoSchema(t *testing.T) { + store := testkit.CreateMockStore(t) + tk := testkit.NewTestKit(t, store) + tk.MustExec("use test") + tk.MustExec("create table cached_t (id int primary key, v int)") + tk.MustExec("alter table cached_t cache") + + // information_schema tables are represented by PhysicalMemTable and must not be cached. + q := compileQuery(t, tk, "select * from cached_t, information_schema.tables limit 1") + require.False(t, core.CanCacheResultSet(q.stmtNode, q.plan, false)) +} + +func TestCanCache_UnionScan(t *testing.T) { + store := testkit.CreateMockStore(t) + tk := testkit.NewTestKit(t, store) + tk.MustExec("use test") + tk.MustExec("create table cached_t (id int primary key, v int)") + tk.MustExec("insert into cached_t values (1, 1)") + tk.MustExec("alter table cached_t cache") + + tk.MustExec("begin") + tk.MustExec("insert into cached_t values (2, 2)") + + q := compileQuery(t, tk, "select * from cached_t where v >= 1") + require.True(t, hasPhysicalUnionScan(q.plan), "expected UnionScan in plan, got %T", q.plan) + require.False(t, core.CanCacheResultSet(q.stmtNode, q.plan, false)) + + tk.MustExec("rollback") +} diff --git a/pkg/planner/core/result_cache_key.go b/pkg/planner/core/result_cache_key.go new file mode 100644 index 0000000000000..beaee06c10c3f --- /dev/null +++ b/pkg/planner/core/result_cache_key.go @@ -0,0 +1,151 @@ +// Copyright 2025 PingCAP, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package core + +import ( + "encoding/binary" + "hash/fnv" + "time" + + "github.com/pingcap/tidb/pkg/sessionctx" + "github.com/pingcap/tidb/pkg/table" + "github.com/pingcap/tidb/pkg/types" + "github.com/pingcap/tidb/pkg/util/codec" +) + +// BuildResultCacheKey constructs a ResultCacheKey from the current session +// context. It uses the plan digest to identify the query shape and hashes +// prepared-statement parameter values (or non-prep plan cache literal values) +// together with session-sensitive fields so that semantically different +// executions do not share the same cache entry. +// +// The returned []byte is the structured verification payload stored alongside +// the cache entry and compared again on lookup to guard against hash +// collisions. +// +// Returns false when no plan digest is available (e.g. the plan has not been +// finalized yet). +func BuildResultCacheKey(sctx sessionctx.Context) (table.ResultCacheKey, []byte, bool) { + vars := sctx.GetSessionVars() + stmtCtx := vars.StmtCtx + + // Obtain the plan digest. It is set after optimization. + _, planDigest := stmtCtx.GetPlanDigest() + if planDigest == nil { + return table.ResultCacheKey{}, nil, false + } + + var key table.ResultCacheKey + digestBytes := planDigest.Bytes() + copy(key.PlanDigest[:], digestBytes) + + verificationBytes := appendLengthPrefixedBytes(nil, digestBytes) + + // Hash parameter values when present. For prepared statements these are the + // EXECUTE parameters; for non-prepared plan cache they are the extracted + // literal values. + if params := vars.PlanCacheParams.AllParamValues(); len(params) > 0 { + encodedParams, ok := encodeParams(vars.Location(), params) + if !ok { + return table.ResultCacheKey{}, nil, false + } + verificationBytes = appendLengthPrefixedBytes(verificationBytes, encodedParams) + } else if len(stmtCtx.OriginalSQL) > 0 { + // PlanCacheParams is empty when non-prepared plan cache is disabled or + // the query bypasses plan cache parameterization. Fall back to hashing + // the original SQL text to distinguish queries with different literals. + verificationBytes = appendLengthPrefixedBytes(verificationBytes, []byte(stmtCtx.OriginalSQL)) + } + + // Include session fields that can change result bytes without changing the + // physical plan shape. TIMESTAMP values are rendered in session timezone, and + // connection charset/collation affect literal string semantics. + verificationBytes = appendTZOffset(verificationBytes, vars.Location()) + charset, collation := vars.GetCharsetInfo() + verificationBytes = appendLengthPrefixedBytes(verificationBytes, []byte(charset)) + verificationBytes = appendLengthPrefixedBytes(verificationBytes, []byte(collation)) + + key.ParamHash = hashBytes(verificationBytes) + return key, verificationBytes, true +} + +// HashParamsForTest is exported for testing only. +var HashParamsForTest = hashParams + +// EncodeParamsForTest is exported for testing only. +var EncodeParamsForTest = encodeParamsForTest + +// hashBytes computes a 64-bit FNV-1a hash over a byte slice. +func hashBytes(b []byte) uint64 { + h := fnv.New64a() + h.Write(b) + return h.Sum64() +} + +// encodeParams encodes a slice of Datum values into a single byte slice using +// codec.EncodeKey, concatenating the encoded bytes for each parameter. +func encodeParams(loc *time.Location, params []types.Datum) ([]byte, bool) { + if loc == nil { + loc = time.UTC + } + var buf []byte + for _, p := range params { + var err error + buf, err = codec.EncodeKey(loc, buf, p) + if err != nil { + return nil, false + } + } + return buf, true +} + +func encodeParamsForTest(params []types.Datum) []byte { + buf, _ := encodeParams(time.UTC, params) + return buf +} + +// hashParams computes a 64-bit FNV-1a hash over a slice of Datum values. +// It uses codec.EncodeKey to produce a stable, type-aware byte representation +// of each parameter before feeding it to the hasher. +func hashParams(params []types.Datum) uint64 { + encoded, ok := encodeParams(time.UTC, params) + if !ok { + return 0 + } + return hashBytes(encoded) +} + +func appendLengthPrefixedBytes(buf []byte, payload []byte) []byte { + var b [4]byte + binary.BigEndian.PutUint32(b[:], uint32(len(payload))) + buf = append(buf, b[:]...) + return append(buf, payload...) +} + +// appendTZOffset appends the timezone UTC offset (in seconds) to buf so that +// different session timezones produce distinct cache keys. +// +// It also appends the timezone name to distinguish locations that share the +// same UTC offset but have different DST rules (e.g. "UTC" vs "Europe/London"). +func appendTZOffset(buf []byte, loc *time.Location) []byte { + if loc == nil { + loc = time.UTC + } + _, offset := time.Date(2000, 1, 1, 0, 0, 0, 0, loc).Zone() + var b [4]byte + binary.BigEndian.PutUint32(b[:], uint32(int32(offset))) + buf = append(buf, b[:]...) + return appendLengthPrefixedBytes(buf, []byte(loc.String())) +} diff --git a/pkg/planner/core/result_cache_key_test.go b/pkg/planner/core/result_cache_key_test.go new file mode 100644 index 0000000000000..4a12321323918 --- /dev/null +++ b/pkg/planner/core/result_cache_key_test.go @@ -0,0 +1,276 @@ +// Copyright 2025 PingCAP, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package core_test + +import ( + "testing" + + "github.com/pingcap/tidb/pkg/parser" + "github.com/pingcap/tidb/pkg/planner/core" + "github.com/pingcap/tidb/pkg/sessionctx" + "github.com/pingcap/tidb/pkg/sessionctx/stmtctx" + "github.com/pingcap/tidb/pkg/table" + "github.com/pingcap/tidb/pkg/testkit" + "github.com/pingcap/tidb/pkg/types" + "github.com/stretchr/testify/require" +) + +// execAndBuildKey executes a SQL statement and builds the result cache key +// from the session context afterward (plan digest is set after execution). +func execAndBuildKey(t *testing.T, tk *testkit.TestKit, sql string) (table.ResultCacheKey, []byte, bool) { + t.Helper() + tk.MustQuery(sql) + sctx := tk.Session().(sessionctx.Context) + return core.BuildResultCacheKey(sctx) +} + +func TestBuildResultCacheKey_NoPlan(t *testing.T) { + store := testkit.CreateMockStore(t) + tk := testkit.NewTestKit(t, store) + tk.MustExec("use test") + + // Replace StmtCtx with a fresh one that has no plan digest. + sctx := tk.Session().(sessionctx.Context) + sctx.GetSessionVars().StmtCtx = stmtctx.NewStmtCtx() + _, _, ok := core.BuildResultCacheKey(sctx) + require.False(t, ok) +} + +func TestBuildResultCacheKey_NonPrepared(t *testing.T) { + store := testkit.CreateMockStore(t) + tk := testkit.NewTestKit(t, store) + tk.MustExec("use test") + tk.MustExec("create table cached_t (id int primary key, v int)") + tk.MustExec("alter table cached_t cache") + tk.MustExec("set tidb_enable_non_prepared_plan_cache = ON") + + // Two different queries with same shape but different constants. + // Use a non-PK filter to avoid PointGetPlan and keep the plan digest stable. + key1, pb1, ok1 := execAndBuildKey(t, tk, "select * from cached_t where v = 1") + require.True(t, ok1) + require.NotNil(t, pb1) + + key2, pb2, ok2 := execAndBuildKey(t, tk, "select * from cached_t where v = 2") + require.True(t, ok2) + require.NotNil(t, pb2) + + // With the same plan shape, the PlanDigest portion should be equal. + require.Equal(t, key1.PlanDigest, key2.PlanDigest) + // Different literal values must produce different ParamHash. + require.NotEqual(t, key1.ParamHash, key2.ParamHash) + // Different param bytes for different values. + require.NotEqual(t, pb1, pb2) + + // A structurally different query should produce a different plan digest. + key3, _, ok3 := execAndBuildKey(t, tk, "select v from cached_t") + require.True(t, ok3) + require.NotEqual(t, key1.PlanDigest, key3.PlanDigest) +} + +func TestBuildResultCacheKey_NonPrepared_DifferentValues(t *testing.T) { + store := testkit.CreateMockStore(t) + tk := testkit.NewTestKit(t, store) + tk.MustExec("use test") + tk.MustExec("create table cached_t (id int primary key, v int)") + tk.MustExec("alter table cached_t cache") + // Disable non-prepared plan cache to trigger the SQL fallback path. + tk.MustExec("set tidb_enable_non_prepared_plan_cache = OFF") + + key1, pb1, ok1 := execAndBuildKey(t, tk, "select * from cached_t where v = 1") + require.True(t, ok1) + require.NotNil(t, pb1) + + key2, pb2, ok2 := execAndBuildKey(t, tk, "select * from cached_t where v = 2") + require.True(t, ok2) + require.NotNil(t, pb2) + + // Same plan shape → same digest. + require.Equal(t, key1.PlanDigest, key2.PlanDigest) + // Different literals → different ParamHash (the bug was both being 0). + require.NotEqual(t, key1.ParamHash, key2.ParamHash) + // Different param bytes for different values. + require.NotEqual(t, pb1, pb2) + + // Same query twice must produce the same key and param bytes. + key3, pb3, ok3 := execAndBuildKey(t, tk, "select * from cached_t where v = 1") + require.True(t, ok3) + require.Equal(t, key1, key3) + require.Equal(t, pb1, pb3) +} + +func TestBuildResultCacheKey_Prepared(t *testing.T) { + store := testkit.CreateMockStore(t) + tk := testkit.NewTestKit(t, store) + tk.MustExec("use test") + tk.MustExec("create table cached_t (id int primary key, v int)") + tk.MustExec("alter table cached_t cache") + + // Use prepared statements through the SQL interface. + tk.MustExec("prepare stmt from 'select * from cached_t where id = ?'") + + // Execute with param = 1 + tk.MustExec("set @a = 1") + tk.MustQuery("execute stmt using @a") + sctx := tk.Session().(sessionctx.Context) + key1, pb1, ok1 := core.BuildResultCacheKey(sctx) + require.True(t, ok1) + require.NotNil(t, pb1) + + // Execute with param = 2 + tk.MustExec("set @a = 2") + tk.MustQuery("execute stmt using @a") + key2, pb2, ok2 := core.BuildResultCacheKey(sctx) + require.True(t, ok2) + require.NotNil(t, pb2) + + // Same plan digest but different param hash. + require.Equal(t, key1.PlanDigest, key2.PlanDigest) + require.NotEqual(t, key1.ParamHash, key2.ParamHash) + // Different param bytes for different values. + require.NotEqual(t, pb1, pb2) +} + +func TestBuildResultCacheKey_SameParams(t *testing.T) { + store := testkit.CreateMockStore(t) + tk := testkit.NewTestKit(t, store) + tk.MustExec("use test") + tk.MustExec("create table cached_t (id int primary key, v int)") + tk.MustExec("alter table cached_t cache") + + tk.MustExec("prepare stmt from 'select * from cached_t where id = ?'") + + // Execute twice with the same param value. + tk.MustExec("set @a = 42") + tk.MustQuery("execute stmt using @a") + sctx := tk.Session().(sessionctx.Context) + key1, pb1, ok1 := core.BuildResultCacheKey(sctx) + require.True(t, ok1) + + tk.MustQuery("execute stmt using @a") + key2, pb2, ok2 := core.BuildResultCacheKey(sctx) + require.True(t, ok2) + + // Identical keys and param bytes. + require.Equal(t, key1, key2) + require.Equal(t, pb1, pb2) +} + +func TestHashParams_DifferentTypes(t *testing.T) { + // Verify that hashParams produces different hashes for different Datum types + // even when the "value" might look similar (e.g., int 1 vs string "1"). + intDatum := types.NewIntDatum(1) + strDatum := types.NewStringDatum("1") + floatDatum := types.NewFloat64Datum(1.0) + + h1 := core.HashParamsForTest([]types.Datum{intDatum}) + h2 := core.HashParamsForTest([]types.Datum{strDatum}) + h3 := core.HashParamsForTest([]types.Datum{floatDatum}) + + // All three should be different since they have different type encodings. + require.NotEqual(t, h1, h2) + require.NotEqual(t, h1, h3) + require.NotEqual(t, h2, h3) + + // Same input should produce the same hash. + h1Again := core.HashParamsForTest([]types.Datum{intDatum}) + require.Equal(t, h1, h1Again) +} + +func TestHashParams_MultipleParams(t *testing.T) { + // Different ordering of params should produce different hashes. + d1 := types.NewIntDatum(1) + d2 := types.NewIntDatum(2) + + h12 := core.HashParamsForTest([]types.Datum{d1, d2}) + h21 := core.HashParamsForTest([]types.Datum{d2, d1}) + require.NotEqual(t, h12, h21) +} + +func TestBuildResultCacheKey_DifferentTimezones(t *testing.T) { + store := testkit.CreateMockStore(t) + tk := testkit.NewTestKit(t, store) + tk.MustExec("use test") + tk.MustExec("create table cached_tz (id int primary key, ts timestamp)") + tk.MustExec("alter table cached_tz cache") + + // Same query executed under different timezones must produce different keys. + tk.MustExec("set @@time_zone = '+00:00'") + key1, pb1, ok1 := execAndBuildKey(t, tk, "select * from cached_tz") + require.True(t, ok1) + + tk.MustExec("set @@time_zone = '+08:00'") + key2, pb2, ok2 := execAndBuildKey(t, tk, "select * from cached_tz") + require.True(t, ok2) + + // Same plan digest (same query shape). + require.Equal(t, key1.PlanDigest, key2.PlanDigest) + // Different ParamHash because timezone is included. + require.NotEqual(t, key1.ParamHash, key2.ParamHash) + // Different param bytes. + require.NotEqual(t, pb1, pb2) + + // Same timezone should produce same key. + tk.MustExec("set @@time_zone = '+00:00'") + key3, pb3, ok3 := execAndBuildKey(t, tk, "select * from cached_tz") + require.True(t, ok3) + require.Equal(t, key1, key3) + require.Equal(t, pb1, pb3) +} + +func TestBuildResultCacheKey_PlanDigestVerificationBytes(t *testing.T) { + store := testkit.CreateMockStore(t) + tk := testkit.NewTestKit(t, store) + tk.MustExec("use test") + tk.MustExec("create table cached_t (id int primary key, v int)") + tk.MustExec("alter table cached_t cache") + + key1, pb1, ok1 := execAndBuildKey(t, tk, "select * from cached_t where v = 1") + require.True(t, ok1) + + sctx := tk.Session().(sessionctx.Context) + stmtCtx := sctx.GetSessionVars().StmtCtx + normalized, planDigest := stmtCtx.GetPlanDigest() + require.NotNil(t, planDigest) + + altDigestBytes := append([]byte(nil), planDigest.Bytes()...) + altDigestBytes[len(altDigestBytes)-1] ^= 0xff + stmtCtx.SetPlanDigest(normalized, parser.NewDigest(altDigestBytes)) + + key2, pb2, ok2 := core.BuildResultCacheKey(sctx) + require.True(t, ok2) + require.Equal(t, key1.PlanDigest, key2.PlanDigest) + require.NotEqual(t, key1.ParamHash, key2.ParamHash) + require.NotEqual(t, pb1, pb2) +} + +func TestBuildResultCacheKey_DifferentCollations(t *testing.T) { + store := testkit.CreateMockStore(t) + tk := testkit.NewTestKit(t, store) + tk.MustExec("use test") + tk.MustExec("create table cached_t_coll (id int primary key, v varchar(10))") + tk.MustExec("alter table cached_t_coll cache") + + tk.MustExec("set names utf8mb4 collate utf8mb4_bin") + key1, pb1, ok1 := execAndBuildKey(t, tk, "select * from cached_t_coll where v = 'a'") + require.True(t, ok1) + + tk.MustExec("set names utf8mb4 collate utf8mb4_general_ci") + key2, pb2, ok2 := execAndBuildKey(t, tk, "select * from cached_t_coll where v = 'a'") + require.True(t, ok2) + + require.Equal(t, key1.PlanDigest, key2.PlanDigest) + require.NotEqual(t, key1.ParamHash, key2.ParamHash) + require.NotEqual(t, pb1, pb2) +} diff --git a/pkg/server/internal/column/column.go b/pkg/server/internal/column/column.go index 5be0e28c06c4b..82729b0e55e66 100644 --- a/pkg/server/internal/column/column.go +++ b/pkg/server/internal/column/column.go @@ -152,7 +152,7 @@ func DumpTextRow(buffer []byte, columns []*Info, row chunk.Row, d *ResultEncoder if d == nil { d = NewResultEncoder(charset.CharsetUTF8MB4) } - tmp := make([]byte, 0, 20) + tmp := make([]byte, 0, 32) for i, col := range columns { if row.IsNull(i) { buffer = append(buffer, 0xfb) @@ -199,7 +199,8 @@ func DumpTextRow(buffer []byte, columns []*Info, row chunk.Row, d *ResultEncoder d.UpdateDataEncoding(col.Charset) buffer = dump.LengthEncodedString(buffer, d.EncodeData(row.GetBytes(i))) case mysql.TypeDate, mysql.TypeDatetime, mysql.TypeTimestamp: - buffer = dump.LengthEncodedString(buffer, hack.Slice(row.GetTime(i).String())) + tmp = row.GetTime(i).AppendString(tmp[:0]) + buffer = dump.LengthEncodedString(buffer, tmp) case mysql.TypeDuration: dur := row.GetDuration(i, int(col.Decimal)) buffer = dump.LengthEncodedString(buffer, hack.Slice(dur.String())) diff --git a/pkg/sessionctx/stmtctx/stmtctx.go b/pkg/sessionctx/stmtctx/stmtctx.go index 1284eddc81ea0..1170858abf3fc 100644 --- a/pkg/sessionctx/stmtctx/stmtctx.go +++ b/pkg/sessionctx/stmtctx/stmtctx.go @@ -404,6 +404,9 @@ type StatementContext struct { // If the statement read from table cache, this flag is set. ReadFromTableCache bool + // ReadFromResultCache indicates whether the result was served from the result set cache. + ReadFromResultCache bool + // cache is used to reduce object allocation. cache struct { execdetails.RuntimeStatsColl diff --git a/pkg/sessionctx/vardef/tidb_vars.go b/pkg/sessionctx/vardef/tidb_vars.go index 26414cb98d6a6..f461f40c40797 100644 --- a/pkg/sessionctx/vardef/tidb_vars.go +++ b/pkg/sessionctx/vardef/tidb_vars.go @@ -953,6 +953,8 @@ const ( TiDBNonPreparedPlanCacheSize = "tidb_non_prepared_plan_cache_size" // TiDBPlanCacheMaxPlanSize controls the maximum size of a plan that can be cached. TiDBPlanCacheMaxPlanSize = "tidb_plan_cache_max_plan_size" + // TiDBPlanCachePolicy controls how plan cache is enabled. + TiDBPlanCachePolicy = "tidb_plan_cache_policy" // TiDBPlanCacheInvalidationOnFreshStats controls if plan cache will be invalidated automatically when // related stats are analyzed after the plan cache is generated. TiDBPlanCacheInvalidationOnFreshStats = "tidb_plan_cache_invalidation_on_fresh_stats" @@ -1670,6 +1672,7 @@ const ( DefTiDBEnableNonPreparedPlanCacheForDML = true DefTiDBNonPreparedPlanCacheSize = 100 DefTiDBPlanCacheMaxPlanSize = 2 * size.MB + DefTiDBPlanCachePolicy = PlanCachePolicyAll DefTiDBInstancePlanCacheMaxMemSize = 100 * size.MB MinTiDBInstancePlanCacheMemSize = 100 * size.MB DefTiDBInstancePlanCacheReservedPercentage = 0.1 @@ -2121,6 +2124,11 @@ const ( // StrategyCustom is a choice of variable TiDBPipelinedDmlResourcePolicy, StrategyCustom = "custom" + // PlanCachePolicyAll means all cacheable statements can use plan cache. + PlanCachePolicyAll = "all" + // PlanCachePolicyHintOnly means only statements with the USE_PLAN_CACHE() hint can use plan cache. + PlanCachePolicyHintOnly = "hint_only" + // IndexLookUpPushDownPolicyHintOnly indicates only use the hint to decide whether to push down the index lookup or not. IndexLookUpPushDownPolicyHintOnly = "hint-only" // IndexLookUpPushDownPolicyAffinityForce indicates to force push down the index lookup for table with affinity options. diff --git a/pkg/sessionctx/variable/session.go b/pkg/sessionctx/variable/session.go index 90107e5d227e6..4b15d57c1721a 100644 --- a/pkg/sessionctx/variable/session.go +++ b/pkg/sessionctx/variable/session.go @@ -1642,6 +1642,9 @@ type SessionVars struct { // PlanCacheMaxPlanSize controls the maximum size of a plan that can be cached. PlanCacheMaxPlanSize uint64 + // PlanCachePolicy controls how plan cache is enabled. + PlanCachePolicy string + // SessionPlanCacheSize controls the size of session plan cache. SessionPlanCacheSize uint64 @@ -2455,6 +2458,7 @@ func NewSessionVars(hctx HookContext) *SessionVars { RegardNULLAsPoint: vardef.DefTiDBRegardNULLAsPoint, AllowProjectionPushDown: vardef.DefOptEnableProjectionPushDown, SkipMissingPartitionStats: vardef.DefTiDBSkipMissingPartitionStats, + PlanCachePolicy: vardef.DefTiDBPlanCachePolicy, IndexLookUpPushDownPolicy: vardef.DefTiDBIndexLookUpPushDownPolicy, OptPartialOrderedIndexForTopN: vardef.DefTiDBOptPartialOrderedIndexForTopN, } diff --git a/pkg/sessionctx/variable/setvar_affect.go b/pkg/sessionctx/variable/setvar_affect.go index 258de00d355dd..596ae1342a44f 100644 --- a/pkg/sessionctx/variable/setvar_affect.go +++ b/pkg/sessionctx/variable/setvar_affect.go @@ -114,6 +114,7 @@ var isHintUpdatableVerified = map[string]struct{}{ "tidb_enable_prepared_plan_cache": {}, "tidb_enable_non_prepared_plan_cache": {}, "tidb_plan_cache_max_plan_size": {}, + "tidb_plan_cache_policy": {}, "tidb_opt_range_max_size": {}, "tidb_opt_advanced_join_hint": {}, "tidb_opt_prefix_index_single_scan": {}, diff --git a/pkg/sessionctx/variable/slow_log.go b/pkg/sessionctx/variable/slow_log.go index b3a742a2b1e3d..3a0d236056f69 100644 --- a/pkg/sessionctx/variable/slow_log.go +++ b/pkg/sessionctx/variable/slow_log.go @@ -128,6 +128,8 @@ const ( SlowLogIsExplicitTxn = "IsExplicitTxn" // SlowLogIsWriteCacheTable is used to indicate whether writing to the cache table need to wait for the read lock to expire. SlowLogIsWriteCacheTable = "IsWriteCacheTable" + // SlowLogResultCacheHit is used to indicate whether the result was served from the result set cache. + SlowLogResultCacheHit = "Result_cache_hit" // SlowLogIsSyncStatsFailed is used to indicate whether any failure happen during sync stats SlowLogIsSyncStatsFailed = "IsSyncStatsFailed" // SlowLogRRU is the read request_unit(RU) cost @@ -279,6 +281,7 @@ type SlowQueryLogItems struct { Succ bool IsExplicitTxn bool IsWriteCacheTable bool + ResultCacheHit bool IsSyncStatsFailed bool Prepared bool // plan information @@ -532,6 +535,9 @@ func (s *SessionVars) SlowLogFormat(logItems *SlowQueryLogItems) string { if s.StmtCtx.WaitLockLeaseTime > 0 { writeSlowLogItem(&buf, SlowLogIsWriteCacheTable, strconv.FormatBool(logItems.IsWriteCacheTable)) } + if logItems.ResultCacheHit { + writeSlowLogItem(&buf, SlowLogResultCacheHit, strconv.FormatBool(logItems.ResultCacheHit)) + } if len(logItems.Plan) != 0 { writeSlowLogItem(&buf, SlowLogPlan, logItems.Plan) } diff --git a/pkg/sessionctx/variable/sysvar.go b/pkg/sessionctx/variable/sysvar.go index e8ffc994bb9a2..250efd87a111f 100644 --- a/pkg/sessionctx/variable/sysvar.go +++ b/pkg/sessionctx/variable/sysvar.go @@ -1587,6 +1587,10 @@ var defaultSysVars = []*SysVar{ } return err }}, + {Scope: vardef.ScopeGlobal | vardef.ScopeSession, Name: vardef.TiDBPlanCachePolicy, Value: vardef.DefTiDBPlanCachePolicy, Type: vardef.TypeEnum, PossibleValues: []string{vardef.PlanCachePolicyAll, vardef.PlanCachePolicyHintOnly}, SetSession: func(s *SessionVars, val string) error { + s.PlanCachePolicy = val + return nil + }}, {Scope: vardef.ScopeGlobal | vardef.ScopeSession, Name: vardef.TiDBSessionPlanCacheSize, Aliases: []string{vardef.TiDBPrepPlanCacheSize}, Value: strconv.FormatUint(uint64(vardef.DefTiDBSessionPlanCacheSize), 10), Type: vardef.TypeUnsigned, MinValue: 1, MaxValue: 100000, SetSession: func(s *SessionVars, val string) error { uVal, err := strconv.ParseUint(val, 10, 64) if err == nil { diff --git a/pkg/table/table.go b/pkg/table/table.go index 575f2e18c7a24..cbe01f9f6411a 100644 --- a/pkg/table/table.go +++ b/pkg/table/table.go @@ -534,6 +534,12 @@ var TableFromMeta func(allocators autoid.Allocators, tblInfo *model.TableInfo) ( // MockTableFromMeta only serves for test. var MockTableFromMeta func(tableInfo *model.TableInfo) Table +// ResultCacheKey is the lookup key for result set caching on cached tables. +type ResultCacheKey struct { + PlanDigest [16]byte // digest of the normalized plan + ParamHash uint64 // hash distinguishing queries with same plan shape but different values +} + // CachedTable is a Table, and it has a UpdateLockForRead() method // UpdateLockForRead() according to the reasons for not meeting the read conditions, update the lock information, // And at the same time reload data from the original table. @@ -553,6 +559,17 @@ type CachedTable interface { // 'exit' is a channel to tell the keep alive goroutine to exit. // The result is sent to the 'wg' channel. WriteLockAndKeepAlive(ctx context.Context, exit chan struct{}, leasePtr *uint64, wg chan error) + + // GetCachedResult looks up a previously cached query result set. + // paramBytes is the raw encoded parameter bytes for secondary verification + // against hash collisions. Returns the cached chunks, their field types, + // and whether the lookup was a hit. + GetCachedResult(key ResultCacheKey, paramBytes []byte) ([]*chunk.Chunk, []*types.FieldType, bool) + + // PutCachedResult stores a query result set in the cache. + // paramBytes is stored alongside the entry for secondary verification on future lookups. + // Returns false if the cache rejects the entry (e.g. memory/entry limits exceeded). + PutCachedResult(key ResultCacheKey, paramBytes []byte, chunks []*chunk.Chunk, fieldTypes []*types.FieldType) bool } // CheckRowConstraint verify row check constraints. diff --git a/pkg/table/tables/BUILD.bazel b/pkg/table/tables/BUILD.bazel index eb6dac9730867..593583fc816db 100644 --- a/pkg/table/tables/BUILD.bazel +++ b/pkg/table/tables/BUILD.bazel @@ -5,9 +5,11 @@ go_library( srcs = [ "assertion.go", "cache.go", + "cached_datum.go", "index.go", "mutation_checker.go", "partition.go", + "result_cache.go", "state_remote.go", "tables.go", "testutil.go", @@ -71,16 +73,19 @@ go_test( "assertion_test.go", "bench_test.go", "cache_test.go", + "cache_tz_test.go", + "cached_datum_test.go", "export_test.go", "index_test.go", "main_test.go", "mutation_checker_test.go", + "result_cache_test.go", "state_remote_test.go", "tables_test.go", ], embed = [":tables"], flaky = True, - shard_count = 43, + shard_count = 50, deps = [ "//pkg/ddl", "//pkg/domain", @@ -115,6 +120,7 @@ go_test( "//pkg/types", "//pkg/util", "//pkg/util/benchdaily", + "//pkg/util/chunk", "//pkg/util/codec", "//pkg/util/collate", "//pkg/util/context", diff --git a/pkg/table/tables/cache.go b/pkg/table/tables/cache.go index c8fe53a824475..40bdff5971da0 100644 --- a/pkg/table/tables/cache.go +++ b/pkg/table/tables/cache.go @@ -22,12 +22,17 @@ import ( "github.com/pingcap/errors" "github.com/pingcap/failpoint" "github.com/pingcap/log" + "github.com/pingcap/tidb/pkg/expression/exprstatic" "github.com/pingcap/tidb/pkg/kv" + "github.com/pingcap/tidb/pkg/meta/model" "github.com/pingcap/tidb/pkg/metrics" + "github.com/pingcap/tidb/pkg/parser/mysql" "github.com/pingcap/tidb/pkg/table" "github.com/pingcap/tidb/pkg/tablecodec" "github.com/pingcap/tidb/pkg/types" + "github.com/pingcap/tidb/pkg/util/chunk" "github.com/pingcap/tidb/pkg/util/logutil" + "github.com/pingcap/tidb/pkg/util/rowcodec" "github.com/pingcap/tidb/pkg/util/sqlexec" "github.com/tikv/client-go/v2/oracle" "github.com/tikv/client-go/v2/tikv" @@ -40,8 +45,10 @@ var ( type cachedTable struct { TableCommon - cacheData atomic.Pointer[cacheData] - totalSize int64 + cacheData atomic.Pointer[cacheData] + resultCache atomic.Pointer[resultSetCache] + resultCacheMem atomic.Int64 + totalSize int64 // StateRemote is not thread-safe, this tokenLimit is used to keep only one visitor. tokenLimit } @@ -71,6 +78,8 @@ type cacheData struct { Start uint64 Lease uint64 kv.MemBuffer + datumCache *CachedDatumData // pre-decoded datum cache for table scans + indexDatumCaches map[int64]*CachedIndexDatumData // pre-decoded datum caches for index scans, keyed by index ID } func leaseFromTS(ts uint64, leaseDuration time.Duration) uint64 { @@ -205,14 +214,15 @@ func (c *cachedTable) updateLockForRead(ctx context.Context, handle StateRemote, return } if succ { + c.invalidateResultCache() // Data is about to be reloaded, old result sets are stale. c.cacheData.Store(&cacheData{ Start: ts, Lease: lease, MemBuffer: nil, // Async loading, this will be set later. }) - // Make the load data process async, in case that loading data takes longer the - // lease duration, then the loaded data get staled and that process repeats forever. + // Make the load data process async, in case that loading data takes longer than the + // lease duration, then the loaded data becomes stale and that process repeats forever. go func() { start := time.Now() mb, startTS, totalSize, err := c.loadDataFromOriginalTable(store) @@ -222,12 +232,17 @@ func (c *cachedTable) updateLockForRead(ctx context.Context, handle StateRemote, return } + datumCache := c.buildDatumCache(mb) + indexDatumCaches := c.buildIndexDatumCaches(mb) + tmp := c.cacheData.Load() if tmp != nil && tmp.Start == ts { c.cacheData.Store(&cacheData{ - Start: startTS, - Lease: tmp.Lease, - MemBuffer: mb, + Start: startTS, + Lease: tmp.Lease, + MemBuffer: mb, + datumCache: datumCache, + indexDatumCaches: indexDatumCaches, }) atomic.StoreInt64(&c.totalSize, totalSize) } @@ -236,7 +251,7 @@ func (c *cachedTable) updateLockForRead(ctx context.Context, handle StateRemote, // Current status is not suitable to cache. } -const cachedTableSizeLimit = 64 * (1 << 20) +const cachedTableSizeLimit = 256 * (1 << 20) // AddRecord implements the AddRecord method for the table.Table interface. func (c *cachedTable) AddRecord(sctx table.MutateContext, txn kv.Transaction, r []types.Datum, opts ...table.AddRecordOption) (recordID kv.Handle, err error) { @@ -255,7 +270,7 @@ func txnCtxAddCachedTable(sctx table.MutateContext, tid int64, handle *cachedTab // UpdateRecord implements table.Table func (c *cachedTable) UpdateRecord(ctx table.MutateContext, txn kv.Transaction, h kv.Handle, oldData, newData []types.Datum, touched []bool, opts ...table.UpdateRecordOption) error { - // Prevent furthur writing when the table is already too large. + // Prevent further writing when the table is already too large. if atomic.LoadInt64(&c.totalSize) > cachedTableSizeLimit { return table.ErrOptOnCacheTable.GenWithStackByArgs("table too large") } @@ -288,14 +303,19 @@ func (c *cachedTable) renewLease(handle StateRemote, ts uint64, data *cacheData, if !kv.IsTxnRetryableError(err) { log.Warn("Renew read lease error", zap.Error(err)) } + c.invalidateResultCache() // Renewal failed, data may have changed. return } if newLease > 0 { c.cacheData.Store(&cacheData{ - Start: data.Start, - Lease: newLease, - MemBuffer: data.MemBuffer, + Start: data.Start, + Lease: newLease, + MemBuffer: data.MemBuffer, + datumCache: data.datumCache, + indexDatumCaches: data.indexDatumCaches, }) + } else { + c.invalidateResultCache() // Lease not renewed, data may have changed. } failpoint.Inject("mockRenewLeaseABA2", func(_ failpoint.Value) { @@ -303,6 +323,179 @@ func (c *cachedTable) renewLease(handle StateRemote, ts uint64, data *cacheData, }) } +// buildDatumCache builds a CachedDatumData from the given MemBuffer. +// It constructs the decoder parameters from the table schema. +// Returns nil if building fails (non-fatal, the cache simply won't be available). +func (c *cachedTable) buildDatumCache(mb kv.MemBuffer) *CachedDatumData { + cols := c.Cols() + tblMeta := c.Meta() + defaultExprCtx := exprstatic.NewExprContext( + exprstatic.WithEvalCtx(exprstatic.NewEvalContext(exprstatic.WithLocation(time.UTC))), + ) + + colInfo := make([]rowcodec.ColInfo, len(cols)) + fieldTypes := make([]*types.FieldType, len(cols)) + for i, col := range cols { + ft := rowcodec.FieldTypeFromModelColumn(col.ColumnInfo) + colInfo[i] = rowcodec.ColInfo{ + ID: col.ID, + IsPKHandle: tblMeta.PKIsHandle && mysql.HasPriKeyFlag(col.GetFlag()), + Ft: ft, + } + fieldTypes[i] = ft + } + + pkColIDs := TryGetCommonPkColumnIds(tblMeta) + if len(pkColIDs) == 0 { + if tblMeta.PKIsHandle { + // For PKIsHandle tables, the PK value is stored in the row key (handle), + // not in the row value. The ChunkDecoder needs the actual column ID to + // match it in tryAppendHandleColumn and write the handle value into the chunk. + for _, col := range cols { + if mysql.HasPriKeyFlag(col.GetFlag()) { + pkColIDs = []int64{col.ID} + break + } + } + } + if len(pkColIDs) == 0 { + pkColIDs = []int64{-1} + } + } + + defDatum := func(i int, chk *chunk.Chunk) error { + d, err := table.GetColOriginDefaultValueWithoutStrictSQLMode(defaultExprCtx, cols[i].ColumnInfo) + if err != nil { + return err + } + chk.AppendDatum(i, &d) + return nil + } + + dc, err := BuildCachedDatumData(mb, c.tableID, colInfo, pkColIDs, defDatum, fieldTypes) + if err != nil { + log.Warn("build datum cache failed", zap.Error(err)) + return nil + } + return dc +} + +func (c *cachedTable) GetCachedDatumData() *CachedDatumData { + data := c.cacheData.Load() + if data == nil { + return nil + } + return data.datumCache +} + +// GetCachedDatumDataForMemBuffer returns the datum cache only when it belongs to +// the same cacheData generation as mb. This prevents mixing a MemBuffer from one +// lease generation with pre-decoded datums from a later reload. +func (c *cachedTable) GetCachedDatumDataForMemBuffer(mb kv.MemBuffer) *CachedDatumData { + data := c.cacheData.Load() + if data == nil || data.MemBuffer != mb { + return nil + } + return data.datumCache +} + +// buildIndexDatumCaches builds CachedIndexDatumData for all public indexes. +// Returns nil if no indexes can be cached (non-fatal). +func (c *cachedTable) buildIndexDatumCaches(mb kv.MemBuffer) map[int64]*CachedIndexDatumData { + tblMeta := c.Meta() + indices := tblMeta.Indices + if len(indices) == 0 { + return nil + } + + caches := make(map[int64]*CachedIndexDatumData, len(indices)) + for _, idx := range indices { + if idx.State != model.StatePublic { + continue + } + dc, err := BuildCachedIndexDatumData(mb, c.tableID, idx, tblMeta) + if err != nil { + log.Warn("build index datum cache failed", + zap.String("index", idx.Name.O), + zap.Error(err)) + continue + } + caches[idx.ID] = dc + } + if len(caches) == 0 { + return nil + } + return caches +} + +// GetCachedIndexDatumData returns the pre-decoded index datum cache for the given index ID. +func (c *cachedTable) GetCachedIndexDatumData(indexID int64) *CachedIndexDatumData { + data := c.cacheData.Load() + if data == nil || data.indexDatumCaches == nil { + return nil + } + return data.indexDatumCaches[indexID] +} + +// GetCachedIndexDatumDataForMemBuffer returns the index datum cache only when it +// belongs to the same cacheData generation as mb. +func (c *cachedTable) GetCachedIndexDatumDataForMemBuffer(mb kv.MemBuffer, indexID int64) *CachedIndexDatumData { + data := c.cacheData.Load() + if data == nil || data.MemBuffer != mb || data.indexDatumCaches == nil { + return nil + } + return data.indexDatumCaches[indexID] +} + +func (c *cachedTable) getResultCache() *resultSetCache { + return c.resultCache.Load() +} + +func (c *cachedTable) getOrCreateResultCache() *resultSetCache { + if rc := c.resultCache.Load(); rc != nil { + return rc + } + rc := newResultSetCache() + if c.resultCache.CompareAndSwap(nil, rc) { + return rc + } + return c.resultCache.Load() +} + +func (c *cachedTable) invalidateResultCache() { + old := c.resultCache.Swap(nil) + if accounted := c.resultCacheMem.Swap(0); accounted > 0 { + metrics.ResultCacheMemoryGauge.Sub(float64(accounted)) + } + if old != nil { + if n := old.Len(); n > 0 { + metrics.ResultCacheEvictCounter.Add(float64(n)) + } + } +} + +func (c *cachedTable) GetCachedResult(key table.ResultCacheKey, paramBytes []byte) ([]*chunk.Chunk, []*types.FieldType, bool) { + rc := c.getResultCache() + if rc == nil { + return nil, nil, false + } + return rc.Get(key, paramBytes) +} + +func (c *cachedTable) PutCachedResult(key table.ResultCacheKey, paramBytes []byte, chunks []*chunk.Chunk, fieldTypes []*types.FieldType) bool { + rc := c.getOrCreateResultCache() + paramCopy := append([]byte(nil), paramBytes...) + ok, memSize := rc.put(key, paramCopy, chunks, fieldTypes) + if ok && memSize > 0 { + // If the result cache is invalidated concurrently, don't account its memory in metrics. + if c.resultCache.Load() == rc { + c.resultCacheMem.Add(memSize) + metrics.ResultCacheMemoryGauge.Add(float64(memSize)) + } + } + return ok +} + const cacheTableWriteLease = 5 * time.Second func (c *cachedTable) WriteLockAndKeepAlive(ctx context.Context, exit chan struct{}, leasePtr *uint64, wg chan error) { @@ -350,6 +543,6 @@ func (c *cachedTable) renew(ctx context.Context, leasePtr *uint64) error { func (c *cachedTable) lockForWrite(ctx context.Context) (uint64, error) { handle := c.TakeStateRemoteHandle() defer c.PutStateRemoteHandle(handle) - + c.invalidateResultCache() // Write incoming, result cache is no longer valid. return handle.LockForWrite(ctx, c.Meta().ID, cacheTableWriteLease) } diff --git a/pkg/table/tables/cache_test.go b/pkg/table/tables/cache_test.go index 4157182c7d049..7044097395c08 100644 --- a/pkg/table/tables/cache_test.go +++ b/pkg/table/tables/cache_test.go @@ -324,7 +324,7 @@ func TestBeginSleepABA(t *testing.T) { } require.True(t, cacheUsed) - // tk1 should not use the staled cache, because the data is changed. + // tk1 should not use the stale cache, because the data is changed. tk1.MustQuery("select * from aba").Check(testkit.Rows("1 1")) require.False(t, lastReadFromCache(tk1)) } @@ -548,8 +548,117 @@ func TestRenewLeaseABAFailPoint(t *testing.T) { require.NoError(t, failpoint.Disable("github.com/pingcap/tidb/pkg/table/tables/mockRenewLeaseABA1")) require.NoError(t, failpoint.Disable("github.com/pingcap/tidb/pkg/table/tables/mockRenewLeaseABA2")) - // The renew lease operation should not success, - // And the session should not read from a staled cache data. + // The renew lease operation should not succeed, + // and the session should not read from a stale cache data. tk.MustQuery("select * from t_lease").Check(testkit.Rows("1 2")) require.False(t, lastReadFromCache(tk)) } + +func lastReadFromResultCache(tk *testkit.TestKit) bool { + return tk.Session().GetSessionVars().StmtCtx.ReadFromResultCache +} + +func TestResultCacheStmtCtxFlag(t *testing.T) { + store := testkit.CreateMockStore(t) + tk := testkit.NewTestKit(t, store) + tk.MustExec("use test") + tk.MustExec("create table t_rc_flag (id int primary key, v int)") + tk.MustExec("insert into t_rc_flag values (1, 10), (2, 20)") + tk.MustExec("alter table t_rc_flag cache") + + // Wait until the table is cached. + cached := false + for i := 0; i < 20; i++ { + tk.MustQuery("select * from t_rc_flag") + if lastReadFromCache(tk) { + cached = true + break + } + time.Sleep(50 * time.Millisecond) + } + require.True(t, cached) + + // First query: cache miss, should populate the result cache. + tk.MustQuery("select * from t_rc_flag where v = 10").Check(testkit.Rows("1 10")) + require.False(t, lastReadFromResultCache(tk)) + + // Second query with the same plan: should hit the result cache. + tk.MustQuery("select * from t_rc_flag where v = 10").Check(testkit.Rows("1 10")) + require.True(t, lastReadFromResultCache(tk)) +} + +func TestResultCacheMetrics(t *testing.T) { + store := testkit.CreateMockStore(t) + tk := testkit.NewTestKit(t, store) + tk.MustExec("use test") + tk.MustExec("create table t_rc_metrics (id int primary key, v int)") + tk.MustExec("insert into t_rc_metrics values (1, 10)") + tk.MustExec("alter table t_rc_metrics cache") + + // Wait until the table is cached. + cached := false + for i := 0; i < 20; i++ { + tk.MustQuery("select * from t_rc_metrics") + if lastReadFromCache(tk) { + cached = true + break + } + time.Sleep(50 * time.Millisecond) + } + require.True(t, cached) + + hitCounter := metrics.ResultCacheHitCounter + missCounter := metrics.ResultCacheMissCounter + hitPB := &dto.Metric{} + missPB := &dto.Metric{} + + require.NoError(t, hitCounter.Write(hitPB)) + require.NoError(t, missCounter.Write(missPB)) + hitBefore := hitPB.GetCounter().GetValue() + missBefore := missPB.GetCounter().GetValue() + + // First query: should be a miss (no result cache entry yet). + tk.MustQuery("select * from t_rc_metrics where v = 10").Check(testkit.Rows("1 10")) + + // Second query: should hit the result cache. + tk.MustQuery("select * from t_rc_metrics where v = 10").Check(testkit.Rows("1 10")) + + require.NoError(t, hitCounter.Write(hitPB)) + require.NoError(t, missCounter.Write(missPB)) + hitAfter := hitPB.GetCounter().GetValue() + missAfter := missPB.GetCounter().GetValue() + + // We should see at least one hit and one miss. + require.Greater(t, hitAfter, hitBefore) + require.Greater(t, missAfter, missBefore) +} + +func TestResultCacheSlowLog(t *testing.T) { + store := testkit.CreateMockStore(t) + tk := testkit.NewTestKit(t, store) + tk.MustExec("use test") + tk.MustExec("create table t_rc_slow (id int primary key, v int)") + tk.MustExec("insert into t_rc_slow values (1, 10)") + tk.MustExec("alter table t_rc_slow cache") + + // Wait until the table is cached. + cached := false + for i := 0; i < 20; i++ { + tk.MustQuery("select * from t_rc_slow") + if lastReadFromCache(tk) { + cached = true + break + } + time.Sleep(50 * time.Millisecond) + } + require.True(t, cached) + + // First query populates the result cache. + tk.MustQuery("select * from t_rc_slow where v = 10").Check(testkit.Rows("1 10")) + // Second query should hit. + tk.MustQuery("select * from t_rc_slow where v = 10").Check(testkit.Rows("1 10")) + require.True(t, lastReadFromResultCache(tk)) + + // Verify the StmtCtx flag is set, which feeds into slow log. + require.True(t, tk.Session().GetSessionVars().StmtCtx.ReadFromResultCache) +} diff --git a/pkg/table/tables/cache_tz_test.go b/pkg/table/tables/cache_tz_test.go new file mode 100644 index 0000000000000..c0132859e9b09 --- /dev/null +++ b/pkg/table/tables/cache_tz_test.go @@ -0,0 +1,291 @@ +// Copyright 2025 PingCAP, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package tables_test + +import ( + "fmt" + "testing" + "time" + + "github.com/pingcap/tidb/pkg/testkit" + "github.com/stretchr/testify/require" +) + +// waitForCache repeatedly queries the table until the cache is populated. +// Returns true if cache was used within maxAttempts. +func waitForCache(tk *testkit.TestKit, query string, maxAttempts int) bool { + for i := 0; i < maxAttempts; i++ { + tk.MustQuery(query) + if lastReadFromCache(tk) { + return true + } + time.Sleep(50 * time.Millisecond) + } + return false +} + +// TestCachedTableTimestampTZConvert verifies that TIMESTAMP values are correctly +// converted to the session timezone when reading from the datum cache. +func TestCachedTableTimestampTZConvert(t *testing.T) { + store := testkit.CreateMockStore(t) + tk := testkit.NewTestKit(t, store) + tk.MustExec("use test") + tk.MustExec("set @@time_zone = '+00:00'") + tk.MustExec("create table t_ts (id int, ts timestamp)") + tk.MustExec("insert into t_ts values (1, '2024-01-01 12:00:00')") + tk.MustExec("insert into t_ts values (2, '2024-06-15 00:00:00')") + tk.MustExec("alter table t_ts cache") + + // Wait for cache to be populated. + require.True(t, waitForCache(tk, "select * from t_ts", 100)) + + // Test different timezones on the same cached data. + tests := []struct { + tz string + id int + expected string + }{ + {"+00:00", 1, "2024-01-01 12:00:00"}, + {"+08:00", 1, "2024-01-01 20:00:00"}, + {"-05:00", 1, "2024-01-01 07:00:00"}, + {"+05:30", 1, "2024-01-01 17:30:00"}, + {"+00:00", 2, "2024-06-15 00:00:00"}, + {"+08:00", 2, "2024-06-15 08:00:00"}, + {"-05:00", 2, "2024-06-14 19:00:00"}, + } + + for _, tt := range tests { + tk.MustExec(fmt.Sprintf("set @@time_zone = '%s'", tt.tz)) + result := tk.MustQuery(fmt.Sprintf("select ts from t_ts where id = %d", tt.id)) + result.Check(testkit.Rows(tt.expected)) + require.True(t, lastReadFromCache(tk), + "expected to read from cache for tz=%s id=%d", tt.tz, tt.id) + } +} + +// TestCachedTableTimestampNULL verifies that NULL TIMESTAMP values are not +// affected by timezone conversion. +func TestCachedTableTimestampNULL(t *testing.T) { + store := testkit.CreateMockStore(t) + tk := testkit.NewTestKit(t, store) + tk.MustExec("use test") + tk.MustExec("set @@time_zone = '+00:00'") + tk.MustExec("create table t_ts_null (id int, ts timestamp null)") + tk.MustExec("insert into t_ts_null values (1, '2024-01-01 12:00:00')") + tk.MustExec("insert into t_ts_null values (2, null)") + tk.MustExec("alter table t_ts_null cache") + + require.True(t, waitForCache(tk, "select * from t_ts_null", 100)) + + // NULL should remain NULL regardless of timezone. + for _, tz := range []string{"+00:00", "+08:00", "-05:00"} { + tk.MustExec(fmt.Sprintf("set @@time_zone = '%s'", tz)) + tk.MustQuery("select ts from t_ts_null where id = 2").Check(testkit.Rows("")) + require.True(t, lastReadFromCache(tk), "expected cache read for tz=%s", tz) + } + + // Non-NULL row should still convert correctly. + tk.MustExec("set @@time_zone = '+08:00'") + tk.MustQuery("select ts from t_ts_null where id = 1").Check(testkit.Rows("2024-01-01 20:00:00")) + require.True(t, lastReadFromCache(tk)) +} + +// TestCachedTableTimestampZero verifies that zero timestamps (0000-00-00 00:00:00) +// are not converted across timezones, consistent with the KV decode path. +func TestCachedTableTimestampZero(t *testing.T) { + store := testkit.CreateMockStore(t) + tk := testkit.NewTestKit(t, store) + tk.MustExec("use test") + tk.MustExec("set @@time_zone = '+00:00'") + tk.MustExec("set @@sql_mode = 'ALLOW_INVALID_DATES'") + tk.MustExec("create table t_ts_zero (id int, ts timestamp null)") + tk.MustExec("insert into t_ts_zero values (1, '2024-01-01 12:00:00')") + tk.MustExec("insert into t_ts_zero values (2, '0000-00-00 00:00:00')") + tk.MustExec("alter table t_ts_zero cache") + + require.True(t, waitForCache(tk, "select * from t_ts_zero", 100)) + + // Zero timestamp should remain zero regardless of timezone. + for _, tz := range []string{"+00:00", "+08:00", "-05:00"} { + tk.MustExec(fmt.Sprintf("set @@time_zone = '%s'", tz)) + tk.MustQuery("select ts from t_ts_zero where id = 2").Check(testkit.Rows("0000-00-00 00:00:00")) + require.True(t, lastReadFromCache(tk), "expected cache read for tz=%s", tz) + } +} + +// TestCachedTableTimestampFilter verifies that WHERE conditions on TIMESTAMP +// columns work correctly with timezone conversion on the datum cache path. +func TestCachedTableTimestampFilter(t *testing.T) { + store := testkit.CreateMockStore(t) + tk := testkit.NewTestKit(t, store) + tk.MustExec("use test") + tk.MustExec("set @@time_zone = '+00:00'") + tk.MustExec("create table t_ts_filter (id int, ts timestamp)") + tk.MustExec("insert into t_ts_filter values (1, '2024-01-01 10:00:00')") + tk.MustExec("insert into t_ts_filter values (2, '2024-01-01 12:00:00')") + tk.MustExec("insert into t_ts_filter values (3, '2024-01-01 14:00:00')") + tk.MustExec("insert into t_ts_filter values (4, '2024-01-01 16:00:00')") + tk.MustExec("alter table t_ts_filter cache") + + require.True(t, waitForCache(tk, "select * from t_ts_filter", 100)) + + // In +08:00, the values become 18:00, 20:00, 22:00, 00:00(+1day). + // Filter ts > '2024-01-01 21:00:00' should match id=3 (22:00) and id=4 (00:00 next day). + tk.MustExec("set @@time_zone = '+08:00'") + tk.MustQuery("select id, ts from t_ts_filter where ts > '2024-01-01 21:00:00' order by id").Check(testkit.Rows( + "3 2024-01-01 22:00:00", + "4 2024-01-02 00:00:00", + )) + require.True(t, lastReadFromCache(tk)) + + // In -05:00, the values become 05:00, 07:00, 09:00, 11:00. + // Filter ts < '2024-01-01 08:00:00' should match id=1 (05:00) and id=2 (07:00). + tk.MustExec("set @@time_zone = '-05:00'") + tk.MustQuery("select id, ts from t_ts_filter where ts < '2024-01-01 08:00:00' order by id").Check(testkit.Rows( + "1 2024-01-01 05:00:00", + "2 2024-01-01 07:00:00", + )) + require.True(t, lastReadFromCache(tk)) +} + +// TestCachedTableMultiTimestampCols verifies timezone conversion works correctly +// when a table has multiple TIMESTAMP columns. +func TestCachedTableMultiTimestampCols(t *testing.T) { + store := testkit.CreateMockStore(t) + tk := testkit.NewTestKit(t, store) + tk.MustExec("use test") + tk.MustExec("set @@time_zone = '+00:00'") + tk.MustExec("create table t_multi_ts (id int, created_at timestamp, updated_at timestamp)") + tk.MustExec("insert into t_multi_ts values (1, '2024-01-01 12:00:00', '2024-06-15 18:00:00')") + tk.MustExec("alter table t_multi_ts cache") + + require.True(t, waitForCache(tk, "select * from t_multi_ts", 100)) + + tk.MustExec("set @@time_zone = '+08:00'") + tk.MustQuery("select created_at, updated_at from t_multi_ts where id = 1").Check(testkit.Rows( + "2024-01-01 20:00:00 2024-06-16 02:00:00", + )) + require.True(t, lastReadFromCache(tk)) + + tk.MustExec("set @@time_zone = '-05:00'") + tk.MustQuery("select created_at, updated_at from t_multi_ts where id = 1").Check(testkit.Rows( + "2024-01-01 07:00:00 2024-06-15 13:00:00", + )) + require.True(t, lastReadFromCache(tk)) +} + +// TestCachedTableNoTimestamp verifies that tables without TIMESTAMP columns +// are not affected by timezone settings (no conversion overhead). +func TestCachedTableNoTimestamp(t *testing.T) { + store := testkit.CreateMockStore(t) + tk := testkit.NewTestKit(t, store) + tk.MustExec("use test") + tk.MustExec("create table t_no_ts (id int, name varchar(100), dt datetime)") + tk.MustExec("insert into t_no_ts values (1, 'alice', '2024-01-01 12:00:00')") + tk.MustExec("alter table t_no_ts cache") + + require.True(t, waitForCache(tk, "select * from t_no_ts", 100)) + + // DATETIME is not timezone-aware; result should be the same regardless of time_zone. + for _, tz := range []string{"+00:00", "+08:00", "-05:00"} { + tk.MustExec(fmt.Sprintf("set @@time_zone = '%s'", tz)) + tk.MustQuery("select dt from t_no_ts where id = 1").Check(testkit.Rows("2024-01-01 12:00:00")) + require.True(t, lastReadFromCache(tk), "expected cache read for tz=%s", tz) + } +} + +// TestCachedTableTimestampConsistency verifies that full table scans through +// the datum cache produce correct results across multiple timezone switches +// and that switching timezones does not corrupt cached data. +func TestCachedTableTimestampConsistency(t *testing.T) { + store := testkit.CreateMockStore(t) + tk := testkit.NewTestKit(t, store) + tk.MustExec("use test") + tk.MustExec("set @@time_zone = '+00:00'") + tk.MustExec("create table t_ts_consist (id int, ts timestamp, val varchar(20))") + tk.MustExec("insert into t_ts_consist values (1, '2024-01-01 00:00:00', 'midnight')") + tk.MustExec("insert into t_ts_consist values (2, '2024-01-01 12:00:00', 'noon')") + tk.MustExec("insert into t_ts_consist values (3, '2024-12-31 23:59:59', 'nye')") + tk.MustExec("alter table t_ts_consist cache") + + require.True(t, waitForCache(tk, "select * from t_ts_consist", 100)) + + // Full scan at +08:00: all timestamps should shift by +8 hours. + tk.MustExec("set @@time_zone = '+08:00'") + tk.MustQuery("select id, ts, val from t_ts_consist order by id").Check(testkit.Rows( + "1 2024-01-01 08:00:00 midnight", + "2 2024-01-01 20:00:00 noon", + "3 2025-01-01 07:59:59 nye", + )) + require.True(t, lastReadFromCache(tk)) + + // Full scan at +00:00: timestamps should be in UTC (no corruption from +08:00 read). + tk.MustExec("set @@time_zone = '+00:00'") + tk.MustQuery("select id, ts, val from t_ts_consist order by id").Check(testkit.Rows( + "1 2024-01-01 00:00:00 midnight", + "2 2024-01-01 12:00:00 noon", + "3 2024-12-31 23:59:59 nye", + )) + require.True(t, lastReadFromCache(tk)) + + // Full scan at -05:00: timestamps should shift by -5 hours. + tk.MustExec("set @@time_zone = '-05:00'") + tk.MustQuery("select id, ts, val from t_ts_consist order by id").Check(testkit.Rows( + "1 2023-12-31 19:00:00 midnight", + "2 2024-01-01 07:00:00 noon", + "3 2024-12-31 18:59:59 nye", + )) + require.True(t, lastReadFromCache(tk)) + + // Back to +00:00 again: verify no corruption after multiple timezone switches. + tk.MustExec("set @@time_zone = '+00:00'") + tk.MustQuery("select id, ts, val from t_ts_consist order by id").Check(testkit.Rows( + "1 2024-01-01 00:00:00 midnight", + "2 2024-01-01 12:00:00 noon", + "3 2024-12-31 23:59:59 nye", + )) + require.True(t, lastReadFromCache(tk)) +} + +// TestCachedTableIndexTimestampTZConvert verifies that the pre-decoded index cache +// preserves UTC storage and converts TIMESTAMP values to the session timezone on read. +func TestCachedTableIndexTimestampTZConvert(t *testing.T) { + store := testkit.CreateMockStore(t) + tk := testkit.NewTestKit(t, store) + tk.MustExec("use test") + tk.MustExec("set @@time_zone = '+00:00'") + tk.MustExec("create table t_idx_ts (id int primary key, ts timestamp, index idx_ts (ts, id))") + tk.MustExec("insert into t_idx_ts values (1, '2024-01-01 12:00:00')") + tk.MustExec("insert into t_idx_ts values (2, '2024-06-15 00:00:00')") + tk.MustExec("alter table t_idx_ts cache") + + query := "select id, ts from t_idx_ts force index (idx_ts) order by ts, id" + require.True(t, waitForCache(tk, query, 100)) + + tests := []struct { + tz string + rows []string + }{ + {"+00:00", []string{"1 2024-01-01 12:00:00", "2 2024-06-15 00:00:00"}}, + {"+08:00", []string{"1 2024-01-01 20:00:00", "2 2024-06-15 08:00:00"}}, + {"-05:00", []string{"1 2024-01-01 07:00:00", "2 2024-06-14 19:00:00"}}, + } + + for _, tt := range tests { + tk.MustExec(fmt.Sprintf("set @@time_zone = '%s'", tt.tz)) + tk.MustQuery(query).Check(testkit.Rows(tt.rows...)) + require.True(t, lastReadFromCache(tk), "expected cache read for tz=%s", tt.tz) + } +} diff --git a/pkg/table/tables/cached_datum.go b/pkg/table/tables/cached_datum.go new file mode 100644 index 0000000000000..015ee523aeca5 --- /dev/null +++ b/pkg/table/tables/cached_datum.go @@ -0,0 +1,219 @@ +// Copyright 2025 PingCAP, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package tables + +import ( + "time" + + "github.com/pingcap/tidb/pkg/kv" + "github.com/pingcap/tidb/pkg/meta/model" + "github.com/pingcap/tidb/pkg/parser/mysql" + "github.com/pingcap/tidb/pkg/tablecodec" + "github.com/pingcap/tidb/pkg/types" + "github.com/pingcap/tidb/pkg/util/chunk" + "github.com/pingcap/tidb/pkg/util/rowcodec" +) + +const datumCacheChunkSize = 1024 + +// CachedDatumData holds pre-decoded cached table data as chunks. +// TIMESTAMP columns are stored in UTC; readers must convert to session timezone. +type CachedDatumData struct { + Chunks []*chunk.Chunk + FieldTypes []*types.FieldType + TsColIndices []int // indices of TIMESTAMP columns in FieldTypes; empty means no conversion needed + TotalRows int +} + +// BuildCachedDatumData iterates all rows for tableID in membuf and decodes them +// into chunks using a ChunkDecoder. +// +// TIMESTAMP values are decoded with time.UTC so they remain timezone-neutral in the cache. +// Readers must convert to session timezone on read. +func BuildCachedDatumData( + membuf kv.MemBuffer, + tableID int64, + columns []rowcodec.ColInfo, + handleColIDs []int64, + defDatum func(i int, chk *chunk.Chunk) error, + fieldTypes []*types.FieldType, +) (*CachedDatumData, error) { + cd := rowcodec.NewChunkDecoder(columns, handleColIDs, defDatum, time.UTC) + + tsColIndices := findTimestampColumns(fieldTypes) + + var chunks []*chunk.Chunk + curChk := chunk.New(fieldTypes, datumCacheChunkSize, datumCacheChunkSize) + chunks = append(chunks, curChk) + totalRows := 0 + + prefix := tablecodec.GenTablePrefix(tableID) + it, err := membuf.Iter(prefix, prefix.PrefixNext()) + if err != nil { + return nil, err + } + defer it.Close() + + for it.Valid() { + key := it.Key() + value := it.Value() + + if !tablecodec.IsRecordKey(key) || len(value) == 0 { + if err := it.Next(); err != nil { + return nil, err + } + continue + } + + handle, err := tablecodec.DecodeRowKey(key) + if err != nil { + return nil, err + } + + if curChk.NumRows() >= datumCacheChunkSize { + curChk = chunk.New(fieldTypes, datumCacheChunkSize, datumCacheChunkSize) + chunks = append(chunks, curChk) + } + + if err := cd.DecodeToChunk(value, 0, handle, curChk); err != nil { + return nil, err + } + totalRows++ + + if err := it.Next(); err != nil { + return nil, err + } + } + + return &CachedDatumData{ + Chunks: chunks, + FieldTypes: fieldTypes, + TsColIndices: tsColIndices, + TotalRows: totalRows, + }, nil +} + +func findTimestampColumns(fieldTypes []*types.FieldType) []int { + var indices []int + for i, ft := range fieldTypes { + if ft != nil && ft.GetType() == mysql.TypeTimestamp { + indices = append(indices, i) + } + } + return indices +} + +// CachedIndexDatumData holds pre-decoded index entries for a single index. +// TIMESTAMP columns are stored in UTC; readers must convert to session timezone. +type CachedIndexDatumData struct { + Entries map[string][]types.Datum // raw KV key → decoded datums (all index cols + handle cols) + TsColIndices []int // indices of TIMESTAMP columns in the datum slice +} + +// BuildCachedIndexDatumData iterates all index entries for the given index in membuf +// and decodes them into a map keyed by raw KV key. +// +// TIMESTAMP values are decoded with time.UTC so they remain timezone-neutral in the cache. +// Readers must convert to session timezone on read. +func BuildCachedIndexDatumData( + membuf kv.MemBuffer, + tableID int64, + indexInfo *model.IndexInfo, + tblInfo *model.TableInfo, +) (*CachedIndexDatumData, error) { + // Build field types for all decoded columns (index cols + handle cols). + tps := make([]*types.FieldType, 0, len(indexInfo.Columns)+1) + cols := tblInfo.Columns + for _, col := range indexInfo.Columns { + tps = append(tps, &cols[col.Offset].FieldType) + } + switch { + case tblInfo.PKIsHandle: + for _, col := range tblInfo.Columns { + if mysql.HasPriKeyFlag(col.GetFlag()) { + tps = append(tps, &(col.FieldType)) + break + } + } + case tblInfo.IsCommonHandle: + pkIdx := FindPrimaryIndex(tblInfo) + for _, pkCol := range pkIdx.Columns { + colInfo := tblInfo.Columns[pkCol.Offset] + tps = append(tps, &colInfo.FieldType) + } + default: // ExtraHandle Column tp. + tps = append(tps, types.NewFieldType(mysql.TypeLonglong)) + } + + colInfos := BuildRowcodecColInfoForIndexColumns(indexInfo, tblInfo) + colInfos = TryAppendCommonHandleRowcodecColInfos(colInfos, tblInfo) + + // Determine handle status. + colsLen := len(indexInfo.Columns) + hdStatus := tablecodec.HandleDefault + if mysql.HasUnsignedFlag(tps[colsLen].GetFlag()) { + hdStatus = tablecodec.HandleIsUnsigned + } + + tsColIndices := findTimestampColumns(tps) + loc := time.UTC + + prefix := tablecodec.EncodeTableIndexPrefix(tableID, indexInfo.ID) + it, err := membuf.Iter(prefix, prefix.PrefixNext()) + if err != nil { + return nil, err + } + defer it.Close() + + entries := make(map[string][]types.Datum) + restoredDec := tablecodec.NewIndexRestoredDecoder(colInfos[:colsLen]) + decodeBuff := make([][]byte, colsLen, colsLen+len(colInfos)) + var buf [16]byte + + for it.Valid() { + key := it.Key() + value := it.Value() + + if len(value) == 0 { + if err := it.Next(); err != nil { + return nil, err + } + continue + } + + values, err := tablecodec.DecodeIndexKVEx(key, value, colsLen, hdStatus, colInfos, buf[:0], decodeBuff[:colsLen], restoredDec) + if err != nil { + return nil, err + } + + datums := make([]types.Datum, len(values)) + for i, val := range values { + if err := tablecodec.DecodeColumnValueWithDatum(val, tps[i], loc, &datums[i]); err != nil { + return nil, err + } + } + + entries[string(key)] = datums + + if err := it.Next(); err != nil { + return nil, err + } + } + + return &CachedIndexDatumData{ + Entries: entries, + TsColIndices: tsColIndices, + }, nil +} diff --git a/pkg/table/tables/cached_datum_test.go b/pkg/table/tables/cached_datum_test.go new file mode 100644 index 0000000000000..ba658b66529b9 --- /dev/null +++ b/pkg/table/tables/cached_datum_test.go @@ -0,0 +1,489 @@ +// Copyright 2025 PingCAP, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package tables + +import ( + "bytes" + "context" + "encoding/binary" + "sort" + "testing" + "time" + + "github.com/pingcap/tidb/pkg/kv" + "github.com/pingcap/tidb/pkg/meta/model" + "github.com/pingcap/tidb/pkg/parser/mysql" + "github.com/pingcap/tidb/pkg/tablecodec" + "github.com/pingcap/tidb/pkg/types" + "github.com/pingcap/tidb/pkg/util/chunk" + "github.com/pingcap/tidb/pkg/util/codec" + "github.com/pingcap/tidb/pkg/util/rowcodec" + "github.com/stretchr/testify/require" +) + +// testKV is a key-value pair for the test mock. +type testKV struct { + key kv.Key + value []byte +} + +// testMemBuffer is a minimal mock of kv.MemBuffer that supports Set and Iter. +type testMemBuffer struct { + kvs []testKV +} + +func newTestMemBuf() *testMemBuffer { + return &testMemBuffer{} +} + +func (m *testMemBuffer) Set(k kv.Key, v []byte) error { + keyCopy := make([]byte, len(k)) + copy(keyCopy, k) + valCopy := make([]byte, len(v)) + copy(valCopy, v) + m.kvs = append(m.kvs, testKV{key: keyCopy, value: valCopy}) + sort.Slice(m.kvs, func(i, j int) bool { + return bytes.Compare(m.kvs[i].key, m.kvs[j].key) < 0 + }) + return nil +} + +func (m *testMemBuffer) Iter(k kv.Key, upperBound kv.Key) (kv.Iterator, error) { + var filtered []testKV + for _, pair := range m.kvs { + if bytes.Compare(pair.key, k) >= 0 && (len(upperBound) == 0 || bytes.Compare(pair.key, upperBound) < 0) { + filtered = append(filtered, pair) + } + } + return &testMemBufIter{kvs: filtered, idx: 0}, nil +} + +// Unused interface methods — only Set and Iter are needed by BuildCachedDatumData. +func (m *testMemBuffer) Delete(kv.Key) error { panic("unused") } +func (m *testMemBuffer) Get(_ context.Context, _ kv.Key) ([]byte, error) { panic("unused") } +func (m *testMemBuffer) IterReverse(kv.Key, kv.Key) (kv.Iterator, error) { panic("unused") } +func (m *testMemBuffer) RLock() {} +func (m *testMemBuffer) RUnlock() {} +func (m *testMemBuffer) GetFlags(kv.Key) (kv.KeyFlags, error) { panic("unused") } +func (m *testMemBuffer) SetWithFlags(kv.Key, []byte, ...kv.FlagsOp) error { panic("unused") } +func (m *testMemBuffer) UpdateFlags(kv.Key, ...kv.FlagsOp) { panic("unused") } +func (m *testMemBuffer) DeleteWithFlags(kv.Key, ...kv.FlagsOp) error { panic("unused") } +func (m *testMemBuffer) Staging() kv.StagingHandle { panic("unused") } +func (m *testMemBuffer) Release(kv.StagingHandle) { panic("unused") } +func (m *testMemBuffer) Cleanup(kv.StagingHandle) { panic("unused") } +func (m *testMemBuffer) InspectStage(kv.StagingHandle, func(kv.Key, kv.KeyFlags, []byte)) { + panic("unused") +} +func (m *testMemBuffer) SnapshotGetter() kv.Getter { panic("unused") } +func (m *testMemBuffer) SnapshotIter(kv.Key, kv.Key) kv.Iterator { panic("unused") } +func (m *testMemBuffer) SnapshotIterReverse(kv.Key, kv.Key) kv.Iterator { panic("unused") } +func (m *testMemBuffer) Len() int { return len(m.kvs) } +func (m *testMemBuffer) Size() int { panic("unused") } +func (m *testMemBuffer) RemoveFromBuffer(kv.Key) { panic("unused") } +func (m *testMemBuffer) GetLocal(context.Context, []byte) ([]byte, error) { panic("unused") } +func (m *testMemBuffer) BatchGet(context.Context, [][]byte) (map[string][]byte, error) { + panic("unused") +} + +type testMemBufIter struct { + kvs []testKV + idx int +} + +func (it *testMemBufIter) Valid() bool { return it.idx < len(it.kvs) } +func (it *testMemBufIter) Key() kv.Key { return it.kvs[it.idx].key } +func (it *testMemBufIter) Value() []byte { return it.kvs[it.idx].value } +func (it *testMemBufIter) Next() error { it.idx++; return nil } +func (it *testMemBufIter) Close() {} + +// --- test helpers --- + +const testTableID = int64(42) + +type testDatumCacheSetup struct { + colIDs []int64 + cols []rowcodec.ColInfo + fieldTypes []*types.FieldType + pkColIDs []int64 +} + +func newTestSetup() *testDatumCacheSetup { + ftInt := types.NewFieldType(mysql.TypeLonglong) + ftStr := types.NewFieldType(mysql.TypeVarchar) + ftTs := types.NewFieldType(mysql.TypeTimestamp) + ftTs.SetDecimal(0) + + return &testDatumCacheSetup{ + colIDs: []int64{1, 2, 3}, + cols: []rowcodec.ColInfo{ + {ID: 1, IsPKHandle: true, Ft: ftInt}, + {ID: 2, Ft: ftStr}, + {ID: 3, Ft: ftTs}, + }, + fieldTypes: []*types.FieldType{ftInt, ftStr, ftTs}, + pkColIDs: []int64{-1}, + } +} + +func encodeAndSetRow(t *testing.T, mb kv.MemBuffer, tableID int64, handle kv.Handle, colIDs []int64, values []types.Datum) { + var encoder rowcodec.Encoder + rowBytes, err := encoder.Encode(time.UTC, colIDs, values, nil, nil) + require.NoError(t, err) + key := tablecodec.EncodeRowKeyWithHandle(tableID, handle) + require.NoError(t, mb.Set(key, rowBytes)) +} + +func nilDefDatum(i int, chk *chunk.Chunk) error { + chk.AppendNull(i) + return nil +} + +// --- tests --- + +func TestBuildCachedDatumDataBasic(t *testing.T) { + setup := newTestSetup() + mb := newTestMemBuf() + + ts := types.NewTime(types.FromGoTime(time.Date(2025, 1, 15, 10, 30, 0, 0, time.UTC)), mysql.TypeTimestamp, 0) + + encodeAndSetRow(t, mb, testTableID, kv.IntHandle(1), setup.colIDs, []types.Datum{ + types.NewIntDatum(1), + types.NewStringDatum("hello"), + types.NewTimeDatum(ts), + }) + encodeAndSetRow(t, mb, testTableID, kv.IntHandle(2), setup.colIDs, []types.Datum{ + types.NewIntDatum(2), + types.NewStringDatum("world"), + types.NewTimeDatum(ts), + }) + + cd, err := BuildCachedDatumData(mb, testTableID, setup.cols, setup.pkColIDs, nilDefDatum, setup.fieldTypes) + require.NoError(t, err) + require.Equal(t, 2, cd.TotalRows) + require.Len(t, cd.Chunks, 1) + require.Equal(t, 2, cd.Chunks[0].NumRows()) + + // Verify values. + row0 := cd.Chunks[0].GetRow(0) + require.Equal(t, int64(1), row0.GetInt64(0)) + require.Equal(t, "hello", row0.GetString(1)) + + row1 := cd.Chunks[0].GetRow(1) + require.Equal(t, int64(2), row1.GetInt64(0)) + require.Equal(t, "world", row1.GetString(1)) + + // TIMESTAMP should be stored as-is (UTC, no timezone conversion). + tsVal := row0.GetTime(2) + require.Equal(t, ts.String(), tsVal.String()) + + // TsColIndices should identify column 2 (0-indexed). + require.Equal(t, []int{2}, cd.TsColIndices) + + // FieldTypes should match. + require.Len(t, cd.FieldTypes, 3) +} + +func TestBuildCachedDatumDataCommonHandleTimestamp(t *testing.T) { + mb := newTestMemBuf() + + ftTs := types.NewFieldType(mysql.TypeTimestamp) + ftTs.SetDecimal(0) + ftInt := types.NewFieldType(mysql.TypeLonglong) + + ts := types.NewTime(types.FromGoTime(time.Date(2025, 1, 15, 10, 30, 0, 0, time.UTC)), mysql.TypeTimestamp, 0) + handleEncoded, err := codec.EncodeKey(time.UTC, nil, types.NewTimeDatum(ts)) + require.NoError(t, err) + handle, err := kv.NewCommonHandle(handleEncoded) + require.NoError(t, err) + + var encoder rowcodec.Encoder + rowBytes, err := encoder.Encode(time.UTC, []int64{2}, []types.Datum{types.NewIntDatum(123)}, nil, nil) + require.NoError(t, err) + key := tablecodec.EncodeRowKeyWithHandle(testTableID, handle) + require.NoError(t, mb.Set(key, rowBytes)) + + colInfos := []rowcodec.ColInfo{ + {ID: 1, Ft: ftTs}, + {ID: 2, Ft: ftInt}, + } + fieldTypes := []*types.FieldType{ftTs, ftInt} + handleColIDs := []int64{1} + + cd, err := BuildCachedDatumData(mb, testTableID, colInfos, handleColIDs, nilDefDatum, fieldTypes) + require.NoError(t, err) + require.Equal(t, 1, cd.TotalRows) + require.Equal(t, []int{0}, cd.TsColIndices) + + row := cd.Chunks[0].GetRow(0) + require.Equal(t, ts.String(), row.GetTime(0).String()) + require.Equal(t, int64(123), row.GetInt64(1)) +} + +func TestBuildCachedIndexDatumDataRestoredDecoderDurability(t *testing.T) { + mb := newTestMemBuf() + + ftPK := types.NewFieldType(mysql.TypeLonglong) + ftPK.AddFlag(mysql.PriKeyFlag) + ftJSON := types.NewFieldType(mysql.TypeJSON) + + tblInfo := &model.TableInfo{ + ID: testTableID, + PKIsHandle: true, + Columns: []*model.ColumnInfo{ + {ID: 1, Offset: 0, FieldType: *ftPK}, + {ID: 2, Offset: 1, FieldType: *ftJSON}, + }, + } + idxInfo := &model.IndexInfo{ + ID: 1, + Unique: true, + Columns: []*model.IndexColumn{ + {Offset: 1, Length: types.UnspecifiedLength}, + }, + } + + buildIndexKV := func(handle int64, j types.BinaryJSON) (kv.Key, []byte) { + encodedCols, err := codec.EncodeKey(time.UTC, nil, types.NewJSONDatum(j)) + require.NoError(t, err) + key := tablecodec.EncodeIndexSeekKey(testTableID, idxInfo.ID, encodedCols) + + rd := rowcodec.Encoder{Enable: true} + restoredBytes, err := rd.Encode(time.UTC, []int64{tblInfo.Columns[1].ID}, []types.Datum{types.NewJSONDatum(j)}, nil, nil) + require.NoError(t, err) + + val := make([]byte, 0, 1+len(restoredBytes)+8) + val = append(val, 8) // tailLen = 8 (int handle) + val = append(val, restoredBytes...) + var hBuf [8]byte + binary.BigEndian.PutUint64(hBuf[:], uint64(handle)) + val = append(val, hBuf[:]...) + return key, val + } + + j1, err := types.ParseBinaryJSONFromString(`{"a":1}`) + require.NoError(t, err) + j2, err := types.ParseBinaryJSONFromString(`{"a":2}`) + require.NoError(t, err) + + key1, val1 := buildIndexKV(1, j1) + key2, val2 := buildIndexKV(2, j2) + require.NoError(t, mb.Set(key1, val1)) + require.NoError(t, mb.Set(key2, val2)) + + data, err := BuildCachedIndexDatumData(mb, testTableID, idxInfo, tblInfo) + require.NoError(t, err) + require.Len(t, data.Entries, 2) + + row1, ok := data.Entries[string(key1)] + require.True(t, ok) + require.Len(t, row1, 2) + require.Equal(t, j1, row1[0].GetMysqlJSON()) + require.Equal(t, int64(1), row1[1].GetInt64()) + + row2, ok := data.Entries[string(key2)] + require.True(t, ok) + require.Len(t, row2, 2) + require.Equal(t, j2, row2[0].GetMysqlJSON()) + require.Equal(t, int64(2), row2[1].GetInt64()) +} + +func TestCachedTableBuildDatumCacheUsesOriginDefaults(t *testing.T) { + ftPK := types.NewFieldType(mysql.TypeLonglong) + ftPK.AddFlag(mysql.PriKeyFlag) + ftVal := types.NewFieldType(mysql.TypeLonglong) + ftMissing := types.NewFieldType(mysql.TypeVarchar) + + pkCol := &model.ColumnInfo{ID: 1, Offset: 0, State: model.StatePublic, FieldType: *ftPK} + valCol := &model.ColumnInfo{ID: 2, Offset: 1, State: model.StatePublic, FieldType: *ftVal} + missingCol := &model.ColumnInfo{ID: 3, Offset: 2, State: model.StatePublic, FieldType: *ftMissing} + require.NoError(t, missingCol.SetOriginDefaultValue("filled-by-default")) + require.NoError(t, missingCol.SetDefaultValue("filled-by-default")) + + tblInfo := &model.TableInfo{ + ID: testTableID, + PKIsHandle: true, + TableCacheStatusType: model.TableCacheStatusEnable, + Columns: []*model.ColumnInfo{pkCol, valCol, missingCol}, + } + tbl := MockTableFromMeta(tblInfo) + ct, ok := tbl.(*cachedTable) + require.True(t, ok) + + mb := newTestMemBuf() + encodeAndSetRow(t, mb, testTableID, kv.IntHandle(1), []int64{2}, []types.Datum{ + types.NewIntDatum(7), + }) + + cd := ct.buildDatumCache(mb) + require.NotNil(t, cd) + require.Equal(t, 1, cd.TotalRows) + + row := cd.Chunks[0].GetRow(0) + require.Equal(t, int64(1), row.GetInt64(0)) + require.Equal(t, int64(7), row.GetInt64(1)) + require.Equal(t, "filled-by-default", row.GetString(2)) +} + +func TestCachedTablePutCachedResultNoDoubleAccount(t *testing.T) { + ct := &cachedTable{} + chk := makeTestChunk() + ft := types.NewFieldType(mysql.TypeLonglong) + fts := []*types.FieldType{ft} + key := ResultCacheKey{PlanDigest: [16]byte{7}, ParamHash: 7} + paramBytes := []byte("pb") + expectedMem := estimateChunksMemory([]*chunk.Chunk{chk}) + int64(len(paramBytes)) + + require.True(t, ct.PutCachedResult(key, paramBytes, []*chunk.Chunk{chk}, fts)) + require.Equal(t, expectedMem, ct.resultCacheMem.Load()) + + // A duplicate fill for the same cache entry should be a no-op for memory accounting. + require.True(t, ct.PutCachedResult(key, paramBytes, []*chunk.Chunk{chk}, fts)) + require.Equal(t, expectedMem, ct.resultCacheMem.Load()) + + // A hash collision should still be rejected without changing the accounted memory. + require.False(t, ct.PutCachedResult(key, []byte("other"), []*chunk.Chunk{chk}, fts)) + require.Equal(t, expectedMem, ct.resultCacheMem.Load()) + + ct.invalidateResultCache() + require.Zero(t, ct.resultCacheMem.Load()) +} + +func TestCachedTablePinnedDatumCacheAccessors(t *testing.T) { + mb1 := newTestMemBuf() + mb2 := newTestMemBuf() + ft := types.NewFieldType(mysql.TypeLonglong) + + datumCache1 := &CachedDatumData{FieldTypes: []*types.FieldType{ft}} + datumCache2 := &CachedDatumData{FieldTypes: []*types.FieldType{ft}} + indexCache1 := &CachedIndexDatumData{Entries: map[string][]types.Datum{"k1": {types.NewIntDatum(1)}}} + indexCache2 := &CachedIndexDatumData{Entries: map[string][]types.Datum{"k2": {types.NewIntDatum(2)}}} + + ct := &cachedTable{} + ct.cacheData.Store(&cacheData{ + MemBuffer: mb1, + datumCache: datumCache1, + indexDatumCaches: map[int64]*CachedIndexDatumData{1: indexCache1}, + }) + + require.Same(t, datumCache1, ct.GetCachedDatumDataForMemBuffer(mb1)) + require.Same(t, indexCache1, ct.GetCachedIndexDatumDataForMemBuffer(mb1, 1)) + require.Nil(t, ct.GetCachedDatumDataForMemBuffer(mb2)) + require.Nil(t, ct.GetCachedIndexDatumDataForMemBuffer(mb2, 1)) + + ct.cacheData.Store(&cacheData{ + MemBuffer: mb2, + datumCache: datumCache2, + indexDatumCaches: map[int64]*CachedIndexDatumData{1: indexCache2}, + }) + + // Unpinned accessors expose the latest generation, but the pinned variants must + // reject the stale MemBuffer from the earlier generation. + require.Same(t, datumCache2, ct.GetCachedDatumData()) + require.Same(t, indexCache2, ct.GetCachedIndexDatumData(1)) + require.Nil(t, ct.GetCachedDatumDataForMemBuffer(mb1)) + require.Nil(t, ct.GetCachedIndexDatumDataForMemBuffer(mb1, 1)) + require.Same(t, datumCache2, ct.GetCachedDatumDataForMemBuffer(mb2)) + require.Same(t, indexCache2, ct.GetCachedIndexDatumDataForMemBuffer(mb2, 1)) +} + +func TestBuildCachedDatumDataEmpty(t *testing.T) { + setup := newTestSetup() + mb := newTestMemBuf() + + cd, err := BuildCachedDatumData(mb, testTableID, setup.cols, setup.pkColIDs, nilDefDatum, setup.fieldTypes) + require.NoError(t, err) + require.Equal(t, 0, cd.TotalRows) + // One initial (empty) chunk is always allocated. + require.Len(t, cd.Chunks, 1) + require.Equal(t, 0, cd.Chunks[0].NumRows()) +} + +func TestBuildCachedDatumDataChunkSplit(t *testing.T) { + // Use int-only columns for simplicity. + ft := types.NewFieldType(mysql.TypeLonglong) + cols := []rowcodec.ColInfo{{ID: 1, IsPKHandle: true, Ft: ft}} + fts := []*types.FieldType{ft} + colIDs := []int64{1} + pkColIDs := []int64{-1} + + mb := newTestMemBuf() + + totalRows := datumCacheChunkSize + 100 + for i := 1; i <= totalRows; i++ { + encodeAndSetRow(t, mb, testTableID, kv.IntHandle(int64(i)), colIDs, []types.Datum{ + types.NewIntDatum(int64(i)), + }) + } + + cd, err := BuildCachedDatumData(mb, testTableID, cols, pkColIDs, nilDefDatum, fts) + require.NoError(t, err) + require.Equal(t, totalRows, cd.TotalRows) + require.Len(t, cd.Chunks, 2) + require.Equal(t, datumCacheChunkSize, cd.Chunks[0].NumRows()) + require.Equal(t, 100, cd.Chunks[1].NumRows()) +} + +func TestBuildCachedDatumDataSkipDeleted(t *testing.T) { + setup := newTestSetup() + mb := newTestMemBuf() + + ts := types.NewTime(types.FromGoTime(time.Date(2025, 1, 15, 10, 30, 0, 0, time.UTC)), mysql.TypeTimestamp, 0) + + // Insert a valid row. + encodeAndSetRow(t, mb, testTableID, kv.IntHandle(1), setup.colIDs, []types.Datum{ + types.NewIntDatum(1), + types.NewStringDatum("keep"), + types.NewTimeDatum(ts), + }) + + // Insert a "deleted" row (empty value). + deletedKey := tablecodec.EncodeRowKeyWithHandle(testTableID, kv.IntHandle(2)) + require.NoError(t, mb.Set(deletedKey, []byte{})) + + // Insert another valid row. + encodeAndSetRow(t, mb, testTableID, kv.IntHandle(3), setup.colIDs, []types.Datum{ + types.NewIntDatum(3), + types.NewStringDatum("also keep"), + types.NewTimeDatum(ts), + }) + + cd, err := BuildCachedDatumData(mb, testTableID, setup.cols, setup.pkColIDs, nilDefDatum, setup.fieldTypes) + require.NoError(t, err) + require.Equal(t, 2, cd.TotalRows) + + row0 := cd.Chunks[0].GetRow(0) + require.Equal(t, int64(1), row0.GetInt64(0)) + row1 := cd.Chunks[0].GetRow(1) + require.Equal(t, int64(3), row1.GetInt64(0)) +} + +func TestFindTimestampColumns(t *testing.T) { + fts := []*types.FieldType{ + types.NewFieldType(mysql.TypeLonglong), + types.NewFieldType(mysql.TypeTimestamp), + types.NewFieldType(mysql.TypeVarchar), + types.NewFieldType(mysql.TypeTimestamp), + } + indices := findTimestampColumns(fts) + require.Equal(t, []int{1, 3}, indices) + + // No timestamp columns. + fts2 := []*types.FieldType{ + types.NewFieldType(mysql.TypeLonglong), + types.NewFieldType(mysql.TypeVarchar), + } + indices2 := findTimestampColumns(fts2) + require.Nil(t, indices2) +} diff --git a/pkg/table/tables/result_cache.go b/pkg/table/tables/result_cache.go new file mode 100644 index 0000000000000..3a5dc58b62843 --- /dev/null +++ b/pkg/table/tables/result_cache.go @@ -0,0 +1,131 @@ +// Copyright 2025 PingCAP, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package tables + +import ( + "bytes" + "sync" + "sync/atomic" + + "github.com/pingcap/tidb/pkg/table" + "github.com/pingcap/tidb/pkg/types" + "github.com/pingcap/tidb/pkg/util/chunk" +) + +// ResultCacheKey is an alias for table.ResultCacheKey. +type ResultCacheKey = table.ResultCacheKey + +// cachedResult is a single cached result set entry. +type cachedResult struct { + chunks []*chunk.Chunk + fieldTypes []*types.FieldType // used for schema compatibility check + paramBytes []byte // raw encoded params for secondary hash collision verification + memSize int64 + hitCount atomic.Int64 +} + +// resultSetCache is attached to a cachedTable; its lifetime is bound to the cacheData lease. +type resultSetCache struct { + mu sync.RWMutex + items map[ResultCacheKey]*cachedResult + totalMem int64 + + maxEntries int + maxMemory int64 +} + +const ( + defaultMaxResultCacheEntries = 256 + defaultMaxResultCacheMemory = 64 << 20 // 64MB +) + +func newResultSetCache() *resultSetCache { + return &resultSetCache{ + items: make(map[ResultCacheKey]*cachedResult), + maxEntries: defaultMaxResultCacheEntries, + maxMemory: defaultMaxResultCacheMemory, + } +} + +// Get looks up the cache. On hit it verifies paramBytes to guard against hash +// collisions, then increments hitCount. +func (c *resultSetCache) Get(key ResultCacheKey, paramBytes []byte) ([]*chunk.Chunk, []*types.FieldType, bool) { + c.mu.RLock() + defer c.mu.RUnlock() + if r, ok := c.items[key]; ok { + if !bytes.Equal(r.paramBytes, paramBytes) { + return nil, nil, false + } + r.hitCount.Add(1) + return r.chunks, r.fieldTypes, true + } + return nil, nil, false +} + +// put inserts into the cache and returns whether the entry was accepted plus +// the memory delta for a newly inserted entry. +func (c *resultSetCache) put(key ResultCacheKey, paramBytes []byte, chunks []*chunk.Chunk, fieldTypes []*types.FieldType) (bool, int64) { + memSize := estimateChunksMemory(chunks) + int64(len(paramBytes)) + c.mu.Lock() + defer c.mu.Unlock() + if r, ok := c.items[key]; ok { + // Same hash key but different param bytes indicates a hash collision. + // Keep the existing entry and reject the new one to avoid cache thrash. + return bytes.Equal(r.paramBytes, paramBytes), 0 + } + if len(c.items) >= c.maxEntries || c.totalMem+memSize > c.maxMemory { + return false, 0 + } + c.items[key] = &cachedResult{ + chunks: chunks, + fieldTypes: fieldTypes, + paramBytes: paramBytes, + memSize: memSize, + } + c.totalMem += memSize + return true, memSize +} + +// Put inserts into the cache. If limits are exceeded the entry is rejected +// (no eviction — the entire cache is cleared when the lease expires). +func (c *resultSetCache) Put(key ResultCacheKey, paramBytes []byte, chunks []*chunk.Chunk, fieldTypes []*types.FieldType) bool { + ok, _ := c.put(key, paramBytes, chunks, fieldTypes) + return ok +} + +// Len returns the number of cached entries. +func (c *resultSetCache) Len() int { + c.mu.RLock() + defer c.mu.RUnlock() + return len(c.items) +} + +// MemoryUsage returns the total estimated memory used by cached chunks. +func (c *resultSetCache) MemoryUsage() int64 { + c.mu.RLock() + defer c.mu.RUnlock() + return c.totalMem +} + +func estimateChunksMemory(chunks []*chunk.Chunk) int64 { + var total int64 + for _, chk := range chunks { + if chk == nil { + continue + } + total += chk.MemoryUsage() + } + return total +} diff --git a/pkg/table/tables/result_cache_test.go b/pkg/table/tables/result_cache_test.go new file mode 100644 index 0000000000000..e05bc2e5bf32a --- /dev/null +++ b/pkg/table/tables/result_cache_test.go @@ -0,0 +1,179 @@ +// Copyright 2025 PingCAP, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package tables + +import ( + "sync" + "testing" + + "github.com/pingcap/tidb/pkg/parser/mysql" + "github.com/pingcap/tidb/pkg/types" + "github.com/pingcap/tidb/pkg/util/chunk" + "github.com/stretchr/testify/require" +) + +func makeTestChunk() *chunk.Chunk { + ft := types.NewFieldType(mysql.TypeLonglong) + chk := chunk.NewChunkWithCapacity([]*types.FieldType{ft}, 4) + chk.AppendInt64(0, 1) + chk.AppendInt64(0, 2) + return chk +} + +func TestResultCacheGetMiss(t *testing.T) { + c := newResultSetCache() + key := ResultCacheKey{ParamHash: 42} + chunks, fts, ok := c.Get(key, []byte("params")) + require.False(t, ok) + require.Nil(t, chunks) + require.Nil(t, fts) +} + +func TestResultCachePutAndGet(t *testing.T) { + c := newResultSetCache() + chk := makeTestChunk() + ft := types.NewFieldType(mysql.TypeLonglong) + fts := []*types.FieldType{ft} + key := ResultCacheKey{PlanDigest: [16]byte{1}, ParamHash: 100} + pb := []byte("param-100") + + ok := c.Put(key, pb, []*chunk.Chunk{chk}, fts) + require.True(t, ok) + require.Equal(t, 1, c.Len()) + + gotChunks, gotFts, hit := c.Get(key, pb) + require.True(t, hit) + require.Len(t, gotChunks, 1) + require.Equal(t, chk, gotChunks[0]) + require.Equal(t, fts, gotFts) +} + +func TestResultCacheHitCount(t *testing.T) { + c := newResultSetCache() + chk := makeTestChunk() + ft := types.NewFieldType(mysql.TypeLonglong) + key := ResultCacheKey{PlanDigest: [16]byte{2}} + pb := []byte("pb") + + c.Put(key, pb, []*chunk.Chunk{chk}, []*types.FieldType{ft}) + + for i := 0; i < 5; i++ { + c.Get(key, pb) + } + + c.mu.RLock() + r := c.items[key] + c.mu.RUnlock() + require.Equal(t, int64(5), r.hitCount.Load()) +} + +func TestResultCacheMaxEntries(t *testing.T) { + c := newResultSetCache() + c.maxEntries = 2 + + chk := makeTestChunk() + ft := types.NewFieldType(mysql.TypeLonglong) + fts := []*types.FieldType{ft} + + require.True(t, c.Put(ResultCacheKey{ParamHash: 1}, []byte("p1"), []*chunk.Chunk{chk}, fts)) + require.True(t, c.Put(ResultCacheKey{ParamHash: 2}, []byte("p2"), []*chunk.Chunk{chk}, fts)) + require.False(t, c.Put(ResultCacheKey{ParamHash: 3}, []byte("p3"), []*chunk.Chunk{chk}, fts)) + require.Equal(t, 2, c.Len()) +} + +func TestResultCacheMaxMemory(t *testing.T) { + c := newResultSetCache() + chk := makeTestChunk() + pb := []byte("p1") + mem := estimateChunksMemory([]*chunk.Chunk{chk}) + int64(len(pb)) + // Allow room for exactly one entry. + c.maxMemory = mem + + ft := types.NewFieldType(mysql.TypeLonglong) + fts := []*types.FieldType{ft} + + require.True(t, c.Put(ResultCacheKey{ParamHash: 1}, pb, []*chunk.Chunk{chk}, fts)) + require.False(t, c.Put(ResultCacheKey{ParamHash: 2}, []byte("p2"), []*chunk.Chunk{chk}, fts)) + require.Equal(t, mem, c.MemoryUsage()) +} + +func TestResultCachePutSameKey(t *testing.T) { + c := newResultSetCache() + chk := makeTestChunk() + ft := types.NewFieldType(mysql.TypeLonglong) + fts := []*types.FieldType{ft} + key := ResultCacheKey{PlanDigest: [16]byte{9}, ParamHash: 9} + pb := []byte("pb") + + require.True(t, c.Put(key, pb, []*chunk.Chunk{chk}, fts)) + mem := c.MemoryUsage() + + // Put with the same key+paramBytes should be a no-op. + require.True(t, c.Put(key, pb, []*chunk.Chunk{chk}, fts)) + require.Equal(t, 1, c.Len()) + require.Equal(t, mem, c.MemoryUsage()) + + // Same hash key but different paramBytes should be rejected (hash collision). + require.False(t, c.Put(key, []byte("other"), []*chunk.Chunk{chk}, fts)) + _, _, hit := c.Get(key, pb) + require.True(t, hit) +} + +func TestResultCacheConcurrency(t *testing.T) { + c := newResultSetCache() + chk := makeTestChunk() + ft := types.NewFieldType(mysql.TypeLonglong) + fts := []*types.FieldType{ft} + + var wg sync.WaitGroup + for i := 0; i < 50; i++ { + wg.Add(1) + go func(i int) { + defer wg.Done() + key := ResultCacheKey{ParamHash: uint64(i % 10)} + pb := []byte{byte(i % 10)} + c.Put(key, pb, []*chunk.Chunk{chk}, fts) + c.Get(key, pb) + c.Len() + c.MemoryUsage() + }(i) + } + wg.Wait() + require.True(t, c.Len() <= 10) +} + +func TestResultCacheParamBytesMismatch(t *testing.T) { + // Verify that same hash key but different paramBytes results in a cache miss. + c := newResultSetCache() + chk := makeTestChunk() + ft := types.NewFieldType(mysql.TypeLonglong) + fts := []*types.FieldType{ft} + + // Simulate a hash collision: same key but different actual param bytes. + key := ResultCacheKey{PlanDigest: [16]byte{1}, ParamHash: 999} + pbA := []byte("param-value-A") + pbB := []byte("param-value-B") + + ok := c.Put(key, pbA, []*chunk.Chunk{chk}, fts) + require.True(t, ok) + + // Lookup with matching paramBytes should hit. + _, _, hit := c.Get(key, pbA) + require.True(t, hit) + + // Lookup with different paramBytes (hash collision) should miss. + _, _, hit = c.Get(key, pbB) + require.False(t, hit) +} diff --git a/pkg/tablecodec/BUILD.bazel b/pkg/tablecodec/BUILD.bazel index 187b799004c4f..cf578db935a04 100644 --- a/pkg/tablecodec/BUILD.bazel +++ b/pkg/tablecodec/BUILD.bazel @@ -35,7 +35,7 @@ go_test( ], embed = [":tablecodec"], flaky = True, - shard_count = 25, + shard_count = 27, deps = [ "//pkg/kv", "//pkg/meta/model", diff --git a/pkg/tablecodec/bench_test.go b/pkg/tablecodec/bench_test.go index 780428389801a..f1b605ad54ed5 100644 --- a/pkg/tablecodec/bench_test.go +++ b/pkg/tablecodec/bench_test.go @@ -15,13 +15,16 @@ package tablecodec import ( + "encoding/binary" "testing" "time" "github.com/pingcap/tidb/pkg/kv" + "github.com/pingcap/tidb/pkg/parser/mysql" "github.com/pingcap/tidb/pkg/types" "github.com/pingcap/tidb/pkg/util/benchdaily" "github.com/pingcap/tidb/pkg/util/codec" + "github.com/pingcap/tidb/pkg/util/rowcodec" ) func BenchmarkEncodeRowKeyWithHandle(b *testing.B) { @@ -85,6 +88,165 @@ func BenchmarkDecodeIndexKeyCommonHandle(b *testing.B) { } } +func BenchmarkDecodeIndexKVGeneral(b *testing.B) { + // Benchmark version-0 unique int handle index (no restored data). + // This exercises the general path: CutIndexKeyTo + reEncodeHandleTo. + colValues := []types.Datum{types.NewIntDatum(42), types.NewIntDatum(100)} + encodedCols, _ := codec.EncodeKey(time.UTC, nil, colValues...) + key := EncodeIndexSeekKey(1, 1, encodedCols) + + // Build version 0 value with int handle in tail (unique index). + var value []byte + value = append(value, 8) // tailLen = 8 + value = append(value, 0, 0) + var hBuf [8]byte + binary.BigEndian.PutUint64(hBuf[:], uint64(7)) + value = append(value, hBuf[:]...) + + colsLen := 2 + columns := []rowcodec.ColInfo{ + {ID: 1, Ft: types.NewFieldType(mysql.TypeLonglong)}, + {ID: 2, Ft: types.NewFieldType(mysql.TypeLonglong)}, + {ID: 3, Ft: types.NewFieldType(mysql.TypeLonglong)}, + } + + b.Run("WithPreAlloc", func(b *testing.B) { + preAlloc := make([][]byte, colsLen, colsLen+1) + var buf [9]byte + b.ResetTimer() + for i := 0; i < b.N; i++ { + preAlloc = preAlloc[:colsLen:colsLen+1] + _, err := DecodeIndexKVEx(key, value, colsLen, HandleDefault, columns, buf[:0], preAlloc) + if err != nil { + b.Fatal(err) + } + } + }) + + b.Run("WithoutPreAlloc", func(b *testing.B) { + b.ResetTimer() + for i := 0; i < b.N; i++ { + _, err := DecodeIndexKV(key, value, colsLen, HandleDefault, columns) + if err != nil { + b.Fatal(err) + } + } + }) +} + +func BenchmarkDecodeIndexKVGeneralNonUnique(b *testing.B) { + // Benchmark version-0 non-unique int handle index (handle in key suffix). + handleVal := int64(7) + colValues := []types.Datum{types.NewIntDatum(42), types.NewIntDatum(100)} + allDatums := append(colValues, types.NewIntDatum(handleVal)) + encodedAll, _ := codec.EncodeKey(time.UTC, nil, allDatums...) + key := EncodeIndexSeekKey(1, 1, encodedAll) + + // Build version 0 non-unique value (no handle in value, padded > 9 bytes). + var value []byte + value = append(value, 0) // tailLen = 0 + value = append(value, make([]byte, 9)...) + + colsLen := 2 + columns := []rowcodec.ColInfo{ + {ID: 1, Ft: types.NewFieldType(mysql.TypeLonglong)}, + {ID: 2, Ft: types.NewFieldType(mysql.TypeLonglong)}, + {ID: 3, Ft: types.NewFieldType(mysql.TypeLonglong)}, + } + + b.Run("WithPreAlloc", func(b *testing.B) { + preAlloc := make([][]byte, colsLen, colsLen+1) + var buf [9]byte + b.ResetTimer() + for i := 0; i < b.N; i++ { + preAlloc = preAlloc[:colsLen:colsLen+1] + _, err := DecodeIndexKVEx(key, value, colsLen, HandleDefault, columns, buf[:0], preAlloc) + if err != nil { + b.Fatal(err) + } + } + }) + + b.Run("WithoutPreAlloc", func(b *testing.B) { + b.ResetTimer() + for i := 0; i < b.N; i++ { + _, err := DecodeIndexKV(key, value, colsLen, HandleDefault, columns) + if err != nil { + b.Fatal(err) + } + } + }) +} + +func BenchmarkDecodeRestoredValues(b *testing.B) { + // Build restored values data using rowcodec.Encoder. + colIDs := []int64{1, 2, 3} + datums := []types.Datum{ + types.NewIntDatum(42), + types.NewBytesDatum([]byte("hello world")), + types.NewUintDatum(999), + } + rd := rowcodec.Encoder{Enable: true} + restoredBytes, err := rd.Encode(time.UTC, colIDs, datums, nil, nil) + if err != nil { + b.Fatal(err) + } + + // Build version 0 value with restored data + int handle. + // restoredBytes starts with CodecVer (=RestoreDataFlag), so splitIndexValueForIndexValueVersion0 + // will correctly detect it as restored values. + var value []byte + value = append(value, 8) // tailLen = 8 + value = append(value, restoredBytes...) + var hBuf [8]byte + binary.BigEndian.PutUint64(hBuf[:], uint64(7)) + value = append(value, hBuf[:]...) + + uft := types.NewFieldType(mysql.TypeLonglong) + uft.AddFlag(mysql.UnsignedFlag) + columns := []rowcodec.ColInfo{ + {ID: 1, Ft: types.NewFieldType(mysql.TypeLonglong)}, + {ID: 2, Ft: types.NewFieldType(mysql.TypeVarchar)}, + {ID: 3, Ft: uft}, + {ID: 4, Ft: types.NewFieldType(mysql.TypeLonglong)}, // handle + } + + colValues := []types.Datum{ + types.NewIntDatum(42), + types.NewBytesDatum([]byte("hello world")), + types.NewUintDatum(999), + } + encodedCols, _ := codec.EncodeKey(time.UTC, nil, colValues...) + key := EncodeIndexSeekKey(1, 1, encodedCols) + + colsLen := 3 + + b.Run("Original", func(b *testing.B) { + preAlloc := make([][]byte, colsLen, colsLen+1) + b.ResetTimer() + for i := 0; i < b.N; i++ { + preAlloc = preAlloc[:colsLen:colsLen+1] + _, err := DecodeIndexKVEx(key, value, colsLen, HandleDefault, columns, nil, preAlloc) + if err != nil { + b.Fatal(err) + } + } + }) + + b.Run("WithIndexRestoredDecoder", func(b *testing.B) { + preAlloc := make([][]byte, colsLen, colsLen+1) + dec := NewIndexRestoredDecoder(columns[:colsLen]) + b.ResetTimer() + for i := 0; i < b.N; i++ { + preAlloc = preAlloc[:colsLen:colsLen+1] + _, err := DecodeIndexKVEx(key, value, colsLen, HandleDefault, columns, nil, preAlloc, dec) + if err != nil { + b.Fatal(err) + } + } + }) +} + func TestBenchDaily(t *testing.T) { benchdaily.Run( BenchmarkEncodeRowKeyWithHandle, diff --git a/pkg/tablecodec/tablecodec.go b/pkg/tablecodec/tablecodec.go index d6ef6d97e21fd..6ed620f3f9ae2 100644 --- a/pkg/tablecodec/tablecodec.go +++ b/pkg/tablecodec/tablecodec.go @@ -800,18 +800,6 @@ const ( HandleNotNeeded ) -// reEncodeHandle encodes the handle as a Datum so it can be properly decoded later. -// If it is common handle, it returns the encoded column values. -// If it is int handle, it is encoded as int Datum or uint Datum decided by the unsigned. -func reEncodeHandle(handle kv.Handle, unsigned bool) ([][]byte, error) { - handleColLen := 1 - if !handle.IsInt() { - handleColLen = handle.NumCols() - } - result := make([][]byte, 0, handleColLen) - return reEncodeHandleTo(handle, unsigned, nil, result) -} - func reEncodeHandleTo(handle kv.Handle, unsigned bool, buf []byte, result [][]byte) ([][]byte, error) { if !handle.IsInt() { handleColLen := handle.NumCols() @@ -862,6 +850,65 @@ func decodeRestoredValues(columns []rowcodec.ColInfo, restoredVal []byte) ([][]b return resultValues, nil } +// IndexRestoredDecoder caches the colIDs map and BytesDecoder across rows +// to avoid per-row allocations in decodeRestoredValues. +type IndexRestoredDecoder struct { + colIDs map[int64]int + rd *rowcodec.BytesDecoder + values [][]byte // pre-allocated, reused across rows + arena []byte // arena for encodeOldDatum allocations + + reuseArena bool +} + +// NewIndexRestoredDecoder creates a new IndexRestoredDecoder for the given columns. +// The decoder caches the column ID map and BytesDecoder so they are built only once +// and reused across all rows in a scan. +func NewIndexRestoredDecoder(columns []rowcodec.ColInfo) *IndexRestoredDecoder { + colIDs := make(map[int64]int, len(columns)) + for i, col := range columns { + colIDs[col.ID] = i + } + rd := rowcodec.NewByteDecoder(columns, []int64{-1}, nil, nil) + return &IndexRestoredDecoder{ + colIDs: colIDs, + rd: rd, + values: make([][]byte, len(columns)), + arena: make([]byte, 0, 256), + + reuseArena: true, + } +} + +// SetReuseArena controls whether the decoder reuses the internal arena across Decode calls. +// +// When reuse is enabled (default), the returned encoded bytes may alias the decoder's internal arena +// and are only valid until the next Decode call. +// When reuse is disabled, Decode allocates a fresh arena each call, so returned bytes are durable. +func (d *IndexRestoredDecoder) SetReuseArena(reuse bool) { + d.reuseArena = reuse +} + +// Decode decodes restored values using cached state. The returned slice header is reused +// across calls; callers must consume/copy it before calling Decode again. +func (d *IndexRestoredDecoder) Decode(restoredVal []byte) ([][]byte, error) { + arena := d.arena + if d.reuseArena { + arena = arena[:0] + } else { + arena = make([]byte, 0, 256) + } + values, arena, err := d.rd.DecodeToBytesNoHandleInto(d.colIDs, restoredVal, d.values, arena) + if d.reuseArena { + d.arena = arena + } + if err != nil { + return nil, errors.Trace(err) + } + d.values = values + return values, nil +} + // decodeRestoredValuesV5 decodes index values whose format is introduced in TiDB 5.0. // Unlike the format in TiDB 4.0, the new format is optimized for storage space: // 1. If the index is a composed index, only the non-binary string column's value need to write to value, not all. @@ -926,7 +973,7 @@ func buildRestoredColumn(allCols []rowcodec.ColInfo) []rowcodec.ColInfo { if collate.IsBinCollation(col.Ft.GetCollate()) { // Change the fieldType from string to uint since we store the number of the truncated spaces. // NOTE: the corresponding datum is generated as `types.NewUintDatum(paddingSize)`, and the raw data is - // encoded via `encodeUint`. Thus we should mark the field type as unsigened here so that the BytesDecoder + // encoded via `encodeUint`. Thus we should mark the field type as unsigned here so that the BytesDecoder // can decode it correctly later. Otherwise there might be issues like #47115. copyColInfo.Ft = types.NewFieldType(mysql.TypeLonglong) copyColInfo.Ft.AddFlag(mysql.UnsignedFlag) @@ -980,14 +1027,18 @@ func getIndexVersion(value []byte) int { } // DecodeIndexKVEx looks like DecodeIndexKV, the difference is that it tries to reduce allocations. -func DecodeIndexKVEx(key, value []byte, colsLen int, hdStatus HandleStatus, columns []rowcodec.ColInfo, buf []byte, preAlloc [][]byte) ([][]byte, error) { +func DecodeIndexKVEx(key, value []byte, colsLen int, hdStatus HandleStatus, columns []rowcodec.ColInfo, buf []byte, preAlloc [][]byte, restoredDec ...*IndexRestoredDecoder) ([][]byte, error) { if len(value) <= MaxOldEncodeValueLen { return decodeIndexKvOldCollation(key, value, hdStatus, buf, preAlloc) } if getIndexVersion(value) == 1 { - return decodeIndexKvForClusteredIndexVersion1(key, value, colsLen, hdStatus, columns) + return decodeIndexKvForClusteredIndexVersion1(key, value, colsLen, hdStatus, columns, buf, preAlloc) } - return decodeIndexKvGeneral(key, value, colsLen, hdStatus, columns) + var dec *IndexRestoredDecoder + if len(restoredDec) > 0 { + dec = restoredDec[0] + } + return decodeIndexKvGeneral(key, value, colsLen, hdStatus, columns, buf, preAlloc, dec) } // DecodeIndexKV uses to decode index key values. @@ -995,14 +1046,14 @@ func DecodeIndexKVEx(key, value []byte, colsLen int, hdStatus HandleStatus, colu // `colsLen` is expected to be index columns count. // `columns` is expected to be index columns + handle columns(if hdStatus is not HandleNotNeeded). func DecodeIndexKV(key, value []byte, colsLen int, hdStatus HandleStatus, columns []rowcodec.ColInfo) ([][]byte, error) { + preAlloc := make([][]byte, colsLen, colsLen+len(columns)) if len(value) <= MaxOldEncodeValueLen { - preAlloc := make([][]byte, colsLen, colsLen+len(columns)) return decodeIndexKvOldCollation(key, value, hdStatus, nil, preAlloc) } if getIndexVersion(value) == 1 { - return decodeIndexKvForClusteredIndexVersion1(key, value, colsLen, hdStatus, columns) + return decodeIndexKvForClusteredIndexVersion1(key, value, colsLen, hdStatus, columns, nil, preAlloc) } - return decodeIndexKvGeneral(key, value, colsLen, hdStatus, columns) + return decodeIndexKvGeneral(key, value, colsLen, hdStatus, columns, nil, preAlloc, nil) } // DecodeIndexHandle uses to decode the handle from index key/value. @@ -1903,21 +1954,25 @@ func splitIndexValueForClusteredIndexVersion1(value []byte) (segs IndexValueSegm return } -func decodeIndexKvForClusteredIndexVersion1(key, value []byte, colsLen int, hdStatus HandleStatus, columns []rowcodec.ColInfo) ([][]byte, error) { - var resultValues [][]byte +func decodeIndexKvForClusteredIndexVersion1(key, value []byte, colsLen int, hdStatus HandleStatus, columns []rowcodec.ColInfo, buf []byte, preAlloc [][]byte) ([][]byte, error) { var keySuffix []byte var handle kv.Handle var err error segs := splitIndexValueForClusteredIndexVersion1(value) - resultValues, keySuffix, err = CutIndexKeyNew(key, colsLen) + resultValues := preAlloc[:colsLen] + keySuffix, err = CutIndexKeyTo(key, resultValues) if err != nil { return nil, err } if segs.RestoredValues != nil { - resultValues, err = decodeRestoredValuesV5(columns[:colsLen], resultValues, segs.RestoredValues) + restored, err := decodeRestoredValuesV5(columns[:colsLen], resultValues, segs.RestoredValues) if err != nil { return nil, err } + if len(restored) != colsLen { + return nil, errors.Errorf("unexpected restored values length %d, expected %d", len(restored), colsLen) + } + copy(resultValues, restored) } if hdStatus == HandleNotNeeded { return resultValues, nil @@ -1953,21 +2008,30 @@ func decodeIndexKvForClusteredIndexVersion1(key, value []byte, colsLen int, hdSt } // decodeIndexKvGeneral decodes index key value pair of new layout in an extensible way. -func decodeIndexKvGeneral(key, value []byte, colsLen int, hdStatus HandleStatus, columns []rowcodec.ColInfo) ([][]byte, error) { - var resultValues [][]byte +func decodeIndexKvGeneral(key, value []byte, colsLen int, hdStatus HandleStatus, columns []rowcodec.ColInfo, buf []byte, preAlloc [][]byte, restoredDec *IndexRestoredDecoder) ([][]byte, error) { var keySuffix []byte var handle kv.Handle var err error segs := splitIndexValueForIndexValueVersion0(value) - resultValues, keySuffix, err = CutIndexKeyNew(key, colsLen) + resultValues := preAlloc[:colsLen] + keySuffix, err = CutIndexKeyTo(key, resultValues) if err != nil { return nil, err } if segs.RestoredValues != nil { // new collation - resultValues, err = decodeRestoredValues(columns[:colsLen], segs.RestoredValues) + var restored [][]byte + if restoredDec != nil { + restored, err = restoredDec.Decode(segs.RestoredValues) + } else { + restored, err = decodeRestoredValues(columns[:colsLen], segs.RestoredValues) + } if err != nil { return nil, err } + if len(restored) != colsLen { + return nil, errors.Errorf("unexpected restored values length %d, expected %d", len(restored), colsLen) + } + copy(resultValues, restored) } if hdStatus == HandleNotNeeded { return resultValues, nil @@ -1989,11 +2053,10 @@ func decodeIndexKvGeneral(key, value []byte, colsLen int, hdStatus HandleStatus, return nil, err } } - handleBytes, err := reEncodeHandle(handle, hdStatus == HandleIsUnsigned) + resultValues, err = reEncodeHandleTo(handle, hdStatus == HandleIsUnsigned, buf, resultValues) if err != nil { return nil, err } - resultValues = append(resultValues, handleBytes...) if segs.PartitionID != nil { _, pid, err := codec.DecodeInt(segs.PartitionID) if err != nil { diff --git a/pkg/tablecodec/tablecodec_test.go b/pkg/tablecodec/tablecodec_test.go index b63438a2fae48..1915df219286a 100644 --- a/pkg/tablecodec/tablecodec_test.go +++ b/pkg/tablecodec/tablecodec_test.go @@ -729,6 +729,370 @@ func TestTempIndexValueCodec(t *testing.T) { require.False(t, isUnique) } +func TestDecodeIndexKVExGeneral(t *testing.T) { + tableID := int64(1) + indexID := int64(1) + + // Helper to build a version-0 index value with int handle in tail. + // Format: [tailLen] [body...] [intHandle 8 bytes BE] + buildVersion0ValueIntHandle := func(handleVal int64) []byte { + var val []byte + val = append(val, 8) // tailLen = 8 (for int handle) + // Pad body to make total > MaxOldEncodeValueLen (9). + val = append(val, 0, 0) // padding in body + // Append int handle as 8 bytes big-endian. + var hBuf [8]byte + binary.BigEndian.PutUint64(hBuf[:], uint64(handleVal)) + val = append(val, hBuf[:]...) + return val + } + + // Helper to build a version-0 non-unique index value (handle in key suffix). + // Format: [tailLen=0] [padding...] to make len > 9 + buildVersion0ValueNonUnique := func() []byte { + var val []byte + val = append(val, 0) // tailLen = 0 + // Pad to > MaxOldEncodeValueLen. + val = append(val, make([]byte, 9)...) + return val + } + + t.Run("UniqueIntHandle", func(t *testing.T) { + // Build key: prefix + tableID + indexID + encodedCol1 + encodedCol2 + colValues := []types.Datum{types.NewIntDatum(42), types.NewIntDatum(100)} + encodedCols, err := codec.EncodeKey(time.UTC, nil, colValues...) + require.NoError(t, err) + key := EncodeIndexSeekKey(tableID, indexID, encodedCols) + + handleVal := int64(7) + value := buildVersion0ValueIntHandle(handleVal) + + colsLen := 2 + columns := []rowcodec.ColInfo{ + {ID: 1, Ft: types.NewFieldType(mysql.TypeLonglong)}, + {ID: 2, Ft: types.NewFieldType(mysql.TypeLonglong)}, + {ID: 3, Ft: types.NewFieldType(mysql.TypeLonglong)}, // handle column + } + preAlloc := make([][]byte, colsLen, colsLen+1) + result, err := DecodeIndexKVEx(key, value, colsLen, HandleDefault, columns, nil, preAlloc) + require.NoError(t, err) + require.Len(t, result, 3) + + // Verify index column values match. + for i, expected := range colValues { + got, err := DecodeColumnValue(result[i], columns[i].Ft, nil) + require.NoError(t, err) + require.Equal(t, expected.GetInt64(), got.GetInt64()) + } + // Verify handle. + handleDatum, err := DecodeColumnValue(result[2], columns[2].Ft, nil) + require.NoError(t, err) + require.Equal(t, handleVal, handleDatum.GetInt64()) + }) + + t.Run("UniqueUnsignedIntHandle", func(t *testing.T) { + colValues := []types.Datum{types.NewIntDatum(42)} + encodedCols, err := codec.EncodeKey(time.UTC, nil, colValues...) + require.NoError(t, err) + key := EncodeIndexSeekKey(tableID, indexID, encodedCols) + + handleVal := int64(999) + value := buildVersion0ValueIntHandle(handleVal) + + colsLen := 1 + ft := types.NewFieldType(mysql.TypeLonglong) + ft.AddFlag(mysql.UnsignedFlag) + columns := []rowcodec.ColInfo{ + {ID: 1, Ft: types.NewFieldType(mysql.TypeLonglong)}, + {ID: 2, Ft: ft}, // handle column unsigned + } + preAlloc := make([][]byte, colsLen, colsLen+1) + result, err := DecodeIndexKVEx(key, value, colsLen, HandleIsUnsigned, columns, nil, preAlloc) + require.NoError(t, err) + require.Len(t, result, 2) + + // Verify handle is unsigned. + handleDatum, err := DecodeColumnValue(result[1], columns[1].Ft, nil) + require.NoError(t, err) + require.Equal(t, uint64(handleVal), handleDatum.GetUint64()) + }) + + t.Run("NonUniqueIntHandle", func(t *testing.T) { + // For non-unique index, handle is in key suffix. + handleVal := int64(7) + colValues := []types.Datum{types.NewIntDatum(42), types.NewIntDatum(100)} + handleDatum := types.NewIntDatum(handleVal) + allDatums := append(colValues, handleDatum) + encodedAll, err := codec.EncodeKey(time.UTC, nil, allDatums...) + require.NoError(t, err) + key := EncodeIndexSeekKey(tableID, indexID, encodedAll) + + value := buildVersion0ValueNonUnique() + + colsLen := 2 + columns := []rowcodec.ColInfo{ + {ID: 1, Ft: types.NewFieldType(mysql.TypeLonglong)}, + {ID: 2, Ft: types.NewFieldType(mysql.TypeLonglong)}, + {ID: 3, Ft: types.NewFieldType(mysql.TypeLonglong)}, // handle column + } + preAlloc := make([][]byte, colsLen, colsLen+1) + result, err := DecodeIndexKVEx(key, value, colsLen, HandleDefault, columns, nil, preAlloc) + require.NoError(t, err) + require.Len(t, result, 3) + + // Verify index column values. + for i, expected := range colValues { + got, err := DecodeColumnValue(result[i], columns[i].Ft, nil) + require.NoError(t, err) + require.Equal(t, expected.GetInt64(), got.GetInt64()) + } + // Verify handle. + gotHandle, err := DecodeColumnValue(result[2], columns[2].Ft, nil) + require.NoError(t, err) + require.Equal(t, handleVal, gotHandle.GetInt64()) + }) + + t.Run("HandleNotNeeded", func(t *testing.T) { + colValues := []types.Datum{types.NewIntDatum(42)} + encodedCols, err := codec.EncodeKey(time.UTC, nil, colValues...) + require.NoError(t, err) + key := EncodeIndexSeekKey(tableID, indexID, encodedCols) + + value := buildVersion0ValueIntHandle(7) + + colsLen := 1 + columns := []rowcodec.ColInfo{ + {ID: 1, Ft: types.NewFieldType(mysql.TypeLonglong)}, + } + preAlloc := make([][]byte, colsLen, colsLen+1) + result, err := DecodeIndexKVEx(key, value, colsLen, HandleNotNeeded, columns, nil, preAlloc) + require.NoError(t, err) + require.Len(t, result, 1) + + got, err := DecodeColumnValue(result[0], columns[0].Ft, nil) + require.NoError(t, err) + require.Equal(t, int64(42), got.GetInt64()) + }) + + t.Run("UniqueCommonHandle", func(t *testing.T) { + // Build key with only index columns. + colValues := []types.Datum{types.NewIntDatum(42)} + encodedCols, err := codec.EncodeKey(time.UTC, nil, colValues...) + require.NoError(t, err) + key := EncodeIndexSeekKey(tableID, indexID, encodedCols) + + // Build common handle. + handleDatums := []types.Datum{types.NewIntDatum(10), types.NewIntDatum(20)} + encodedHandle, err := codec.EncodeKey(time.UTC, nil, handleDatums...) + require.NoError(t, err) + ch, err := kv.NewCommonHandle(encodedHandle) + require.NoError(t, err) + + // Build version 0 value with common handle. + var val []byte + val = append(val, 0) // tailLen placeholder + val = encodeCommonHandle(val, ch) + // Pad to >= 10 bytes if needed. + tailLen := 0 + if len(val) < 10 { + padding := 10 - len(val) + tailLen = padding + val = append(val, make([]byte, padding)...) + } + val[0] = byte(tailLen) + + colsLen := 1 + columns := []rowcodec.ColInfo{ + {ID: 1, Ft: types.NewFieldType(mysql.TypeLonglong)}, + {ID: 2, Ft: types.NewFieldType(mysql.TypeLonglong)}, // handle col 1 + {ID: 3, Ft: types.NewFieldType(mysql.TypeLonglong)}, // handle col 2 + } + preAlloc := make([][]byte, colsLen, colsLen+2) + result, err := DecodeIndexKVEx(key, val, colsLen, HandleDefault, columns, nil, preAlloc) + require.NoError(t, err) + require.Len(t, result, 3) // 1 index col + 2 handle cols + + // Verify index column. + got, err := DecodeColumnValue(result[0], columns[0].Ft, nil) + require.NoError(t, err) + require.Equal(t, int64(42), got.GetInt64()) + + // Verify handle columns. + for i, expected := range handleDatums { + got, err := DecodeColumnValue(result[1+i], columns[1+i].Ft, nil) + require.NoError(t, err) + require.Equal(t, expected.GetInt64(), got.GetInt64()) + } + }) + + // Verify that DecodeIndexKV (non-Ex) still works correctly through the general path. + t.Run("DecodeIndexKVNonEx", func(t *testing.T) { + colValues := []types.Datum{types.NewIntDatum(42)} + encodedCols, err := codec.EncodeKey(time.UTC, nil, colValues...) + require.NoError(t, err) + key := EncodeIndexSeekKey(tableID, indexID, encodedCols) + + handleVal := int64(7) + value := buildVersion0ValueIntHandle(handleVal) + + colsLen := 1 + columns := []rowcodec.ColInfo{ + {ID: 1, Ft: types.NewFieldType(mysql.TypeLonglong)}, + {ID: 2, Ft: types.NewFieldType(mysql.TypeLonglong)}, + } + result, err := DecodeIndexKV(key, value, colsLen, HandleDefault, columns) + require.NoError(t, err) + require.Len(t, result, 2) + + got, err := DecodeColumnValue(result[0], columns[0].Ft, nil) + require.NoError(t, err) + require.Equal(t, int64(42), got.GetInt64()) + + handleDatum, err := DecodeColumnValue(result[1], columns[1].Ft, nil) + require.NoError(t, err) + require.Equal(t, handleVal, handleDatum.GetInt64()) + }) +} + +func TestIndexRestoredDecoderCorrectness(t *testing.T) { + // Helper to build version-0 index value with restored data. + // Format: [tailLen=8] [padding] [restoredValues (rowcodec encoded)] [intHandle 8 bytes BE] + buildRestoredValue := func(handleVal int64, colIDs []int64, datums []types.Datum) []byte { + // Encode restored values using rowcodec. + rd := rowcodec.Encoder{Enable: true} + restoredBytes, err := rd.Encode(time.UTC, colIDs, datums, nil, nil) + require.NoError(t, err) + + // Build version 0 value: [tailLen=8] [restoredBytes (starts with CodecVer=RestoreDataFlag)] [intHandle 8 bytes] + var val []byte + val = append(val, 8) // tailLen = 8 (for int handle) + val = append(val, restoredBytes...) + var hBuf [8]byte + binary.BigEndian.PutUint64(hBuf[:], uint64(handleVal)) + val = append(val, hBuf[:]...) + return val + } + + t.Run("MatchesOriginal", func(t *testing.T) { + columns := []rowcodec.ColInfo{ + {ID: 1, Ft: types.NewFieldType(mysql.TypeLonglong)}, + {ID: 2, Ft: types.NewFieldType(mysql.TypeVarchar)}, + {ID: 3, Ft: types.NewFieldType(mysql.TypeLonglong)}, // handle column + } + colsLen := 2 + + colIDs := []int64{1, 2} + datums := []types.Datum{types.NewIntDatum(42), types.NewBytesDatum([]byte("hello"))} + + colValues := []types.Datum{types.NewIntDatum(42), types.NewBytesDatum([]byte("hello"))} + encodedCols, err := codec.EncodeKey(time.UTC, nil, colValues...) + require.NoError(t, err) + key := EncodeIndexSeekKey(1, 1, encodedCols) + + value := buildRestoredValue(7, colIDs, datums) + + // Decode with original path. + preAlloc1 := make([][]byte, colsLen, colsLen+1) + orig, err := DecodeIndexKVEx(key, value, colsLen, HandleDefault, columns, nil, preAlloc1) + require.NoError(t, err) + + // Decode with IndexRestoredDecoder. + dec := NewIndexRestoredDecoder(columns[:colsLen]) + preAlloc2 := make([][]byte, colsLen, colsLen+1) + opt, err := DecodeIndexKVEx(key, value, colsLen, HandleDefault, columns, nil, preAlloc2, dec) + require.NoError(t, err) + + require.Equal(t, len(orig), len(opt)) + for i := range orig { + require.Equal(t, orig[i], opt[i], "mismatch at index %d", i) + } + }) + + t.Run("MultiRowReuse", func(t *testing.T) { + // Verify no state leakage across multiple rows with the same decoder. + columns := []rowcodec.ColInfo{ + {ID: 1, Ft: types.NewFieldType(mysql.TypeLonglong)}, + {ID: 2, Ft: types.NewFieldType(mysql.TypeVarchar)}, + {ID: 3, Ft: types.NewFieldType(mysql.TypeLonglong)}, // handle + } + colsLen := 2 + + dec := NewIndexRestoredDecoder(columns[:colsLen]) + + rows := []struct { + intVal int64 + strVal string + handleVal int64 + }{ + {42, "hello", 1}, + {100, "world", 2}, + {0, "", 3}, + {-999, "long string with spaces ", 4}, + } + + for _, row := range rows { + colIDs := []int64{1, 2} + datums := []types.Datum{types.NewIntDatum(row.intVal), types.NewBytesDatum([]byte(row.strVal))} + + colValues := []types.Datum{types.NewIntDatum(row.intVal), types.NewBytesDatum([]byte(row.strVal))} + encodedCols, err := codec.EncodeKey(time.UTC, nil, colValues...) + require.NoError(t, err) + key := EncodeIndexSeekKey(1, 1, encodedCols) + value := buildRestoredValue(row.handleVal, colIDs, datums) + + // Decode with original path for reference. + preAlloc1 := make([][]byte, colsLen, colsLen+1) + orig, err := DecodeIndexKVEx(key, value, colsLen, HandleDefault, columns, nil, preAlloc1) + require.NoError(t, err) + + // Decode with reused IndexRestoredDecoder. + preAlloc2 := make([][]byte, colsLen, colsLen+1) + opt, err := DecodeIndexKVEx(key, value, colsLen, HandleDefault, columns, nil, preAlloc2, dec) + require.NoError(t, err) + + require.Equal(t, len(orig), len(opt)) + for i := range orig { + require.Equal(t, orig[i], opt[i], "mismatch at row handle=%d, index=%d", row.handleVal, i) + } + } + }) + + t.Run("UintAndNilColumns", func(t *testing.T) { + uft := types.NewFieldType(mysql.TypeLonglong) + uft.AddFlag(mysql.UnsignedFlag) + columns := []rowcodec.ColInfo{ + {ID: 1, Ft: uft}, + {ID: 2, Ft: types.NewFieldType(mysql.TypeLonglong)}, + {ID: 3, Ft: types.NewFieldType(mysql.TypeLonglong)}, // handle + } + colsLen := 2 + + // Only encode col 1, leave col 2 as nil in the rowcodec. + colIDs := []int64{1} + datums := []types.Datum{types.NewUintDatum(999)} + + colValues := []types.Datum{types.NewUintDatum(999), types.NewDatum(nil)} + encodedCols, err := codec.EncodeKey(time.UTC, nil, colValues...) + require.NoError(t, err) + key := EncodeIndexSeekKey(1, 1, encodedCols) + value := buildRestoredValue(5, colIDs, datums) + + preAlloc1 := make([][]byte, colsLen, colsLen+1) + orig, err := DecodeIndexKVEx(key, value, colsLen, HandleDefault, columns, nil, preAlloc1) + require.NoError(t, err) + + dec := NewIndexRestoredDecoder(columns[:colsLen]) + preAlloc2 := make([][]byte, colsLen, colsLen+1) + opt, err := DecodeIndexKVEx(key, value, colsLen, HandleDefault, columns, nil, preAlloc2, dec) + require.NoError(t, err) + + require.Equal(t, len(orig), len(opt)) + for i := range orig { + require.Equal(t, orig[i], opt[i], "mismatch at index %d", i) + } + }) +} + func TestV2TableCodec(t *testing.T) { const tableID int64 = 31415926 key := EncodeTablePrefix(tableID) diff --git a/pkg/types/field_type.go b/pkg/types/field_type.go index 95890c844ab6d..377642af22cf2 100644 --- a/pkg/types/field_type.go +++ b/pkg/types/field_type.go @@ -50,6 +50,18 @@ func NewFieldType(tp byte) *FieldType { BuildP() } +// InitUnspecifiedFieldType initializes a FieldType for TypeUnspecified +// without heap allocation. The caller should pass a stack-allocated FieldType. +func InitUnspecifiedFieldType(tp *FieldType) { + *tp = FieldType{} + tp.SetType(mysql.TypeUnspecified) + tp.SetFlag(0) + tp.SetFlen(UnspecifiedLength) + tp.SetDecimal(UnspecifiedLength) + tp.SetCharset(charset.CharsetBin) + tp.SetCollate(charset.CollationBin) +} + // NewFieldTypeWithCollation returns a FieldType, // with a type and other information about field type. func NewFieldTypeWithCollation(tp byte, collation string, length int) *FieldType { diff --git a/pkg/types/time.go b/pkg/types/time.go index 36693ef5a242e..8d97f91e33ccf 100644 --- a/pkg/types/time.go +++ b/pkg/types/time.go @@ -373,22 +373,9 @@ func (t *Time) ConvertTimeZone(from, to *gotime.Location) error { } func (t Time) String() string { - if t.Type() == mysql.TypeDate { - // We control the format, so no error would occur. - str, err := t.DateFormat("%Y-%m-%d") - terror.Log(errors.Trace(err)) - return str - } - - str, err := t.DateFormat("%Y-%m-%d %H:%i:%s") - terror.Log(errors.Trace(err)) - fsp := t.Fsp() - if fsp > 0 { - tmp := fmt.Sprintf(".%06d", t.Microsecond()) - str = str + tmp[:1+fsp] - } - - return str + buf := make([]byte, 0, 26) + buf = t.AppendString(buf) + return string(buf) } // IsZero returns a boolean indicating whether the time is equal to ZeroCoreTime. @@ -2905,6 +2892,58 @@ func (t Time) convertDateFormat(b rune, buf *bytes.Buffer) error { return nil } +// appendInt2 appends a 2-digit zero-padded integer to buf. +func appendInt2(buf []byte, v int) []byte { + return append(buf, byte('0'+v/10), byte('0'+v%10)) +} + +// appendInt4 appends a 4-digit zero-padded integer to buf. +func appendInt4(buf []byte, v int) []byte { + return append(buf, + byte('0'+v/1000), + byte('0'+(v/100)%10), + byte('0'+(v/10)%10), + byte('0'+v%10), + ) +} + +// AppendString appends the formatted time string to buf and returns the result. +// It avoids fmt/bytes.Buffer; allocations occur only if buf needs to grow. +func (t Time) AppendString(buf []byte) []byte { + buf = appendInt4(buf, t.Year()) + buf = append(buf, '-') + buf = appendInt2(buf, t.Month()) + buf = append(buf, '-') + buf = appendInt2(buf, t.Day()) + + if t.Type() == mysql.TypeDate { + return buf + } + + buf = append(buf, ' ') + buf = appendInt2(buf, t.Hour()) + buf = append(buf, ':') + buf = appendInt2(buf, t.Minute()) + buf = append(buf, ':') + buf = appendInt2(buf, t.Second()) + + fsp := t.Fsp() + if fsp > 0 { + buf = append(buf, '.') + micro := t.Microsecond() + digits := [6]byte{ + byte('0' + micro/100000), + byte('0' + (micro/10000)%10), + byte('0' + (micro/1000)%10), + byte('0' + (micro/100)%10), + byte('0' + (micro/10)%10), + byte('0' + micro%10), + } + buf = append(buf, digits[:fsp]...) + } + return buf +} + // FormatIntWidthN uses to format int with width. Insufficient digits are filled by 0. func FormatIntWidthN(num, n int) string { numString := strconv.FormatInt(int64(num), 10) diff --git a/pkg/types/time_test.go b/pkg/types/time_test.go index 00484a84a0225..4ec0cfdc33093 100644 --- a/pkg/types/time_test.go +++ b/pkg/types/time_test.go @@ -2338,3 +2338,82 @@ func BenchmarkStrToDate(b *testing.B) { benchmarkStrToDate(b, "strToDate %r ddMMyyyy", typeCtx, "04:13:56 AM 13/05/2019", "%r %d/%c/%Y") benchmarkStrToDate(b, "strToDate %T ddMMyyyy", typeCtx, " 4:13:56 13/05/2019", "%T %d/%c/%Y") } + +func TestTimeAppendString(t *testing.T) { + tests := []struct { + name string + tp byte + fsp int + year, month, day, hour, minute, second, micro int + expected string + }{ + {"Date", mysql.TypeDate, 0, 2024, 1, 2, 0, 0, 0, 0, "2024-01-02"}, + {"Datetime FSP=0", mysql.TypeDatetime, 0, 2024, 1, 2, 3, 4, 5, 0, "2024-01-02 03:04:05"}, + {"Timestamp FSP=3", mysql.TypeTimestamp, 3, 2024, 1, 2, 3, 4, 5, 123000, "2024-01-02 03:04:05.123"}, + {"Datetime FSP=6", mysql.TypeDatetime, 6, 2024, 1, 2, 3, 4, 5, 123456, "2024-01-02 03:04:05.123456"}, + {"Zero Date", mysql.TypeDate, 0, 0, 0, 0, 0, 0, 0, 0, "0000-00-00"}, + {"Zero Datetime", mysql.TypeDatetime, 0, 0, 0, 0, 0, 0, 0, 0, "0000-00-00 00:00:00"}, + {"Boundary month=1 day=1", mysql.TypeDatetime, 0, 2024, 1, 1, 0, 0, 0, 0, "2024-01-01 00:00:00"}, + {"Boundary month=12 day=31", mysql.TypeDatetime, 0, 2024, 12, 31, 23, 59, 59, 0, "2024-12-31 23:59:59"}, + {"FSP=1", mysql.TypeDatetime, 1, 2024, 6, 15, 12, 30, 45, 100000, "2024-06-15 12:30:45.1"}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + ct := types.FromDate(tt.year, tt.month, tt.day, tt.hour, tt.minute, tt.second, tt.micro) + v := types.NewTime(ct, tt.tp, tt.fsp) + + // Verify AppendString output matches expected. + buf := v.AppendString(nil) + require.Equal(t, tt.expected, string(buf)) + + // Verify AppendString and String produce the same result. + require.Equal(t, v.String(), string(buf)) + }) + } +} + +func BenchmarkTimeString(b *testing.B) { + b.Run("Date", func(b *testing.B) { + v := types.NewTime(types.FromDate(2024, 1, 2, 0, 0, 0, 0), mysql.TypeDate, 0) + for i := 0; i < b.N; i++ { + _ = v.String() + } + }) + b.Run("Datetime FSP=0", func(b *testing.B) { + v := types.NewTime(types.FromDate(2024, 1, 2, 3, 4, 5, 0), mysql.TypeDatetime, 0) + for i := 0; i < b.N; i++ { + _ = v.String() + } + }) + b.Run("Datetime FSP=6", func(b *testing.B) { + v := types.NewTime(types.FromDate(2024, 1, 2, 3, 4, 5, 123456), mysql.TypeDatetime, 6) + for i := 0; i < b.N; i++ { + _ = v.String() + } + }) +} + +func BenchmarkTimeAppendString(b *testing.B) { + b.Run("Date", func(b *testing.B) { + v := types.NewTime(types.FromDate(2024, 1, 2, 0, 0, 0, 0), mysql.TypeDate, 0) + buf := make([]byte, 0, 26) + for i := 0; i < b.N; i++ { + buf = v.AppendString(buf[:0]) + } + }) + b.Run("Datetime FSP=0", func(b *testing.B) { + v := types.NewTime(types.FromDate(2024, 1, 2, 3, 4, 5, 0), mysql.TypeDatetime, 0) + buf := make([]byte, 0, 26) + for i := 0; i < b.N; i++ { + buf = v.AppendString(buf[:0]) + } + }) + b.Run("Datetime FSP=6", func(b *testing.B) { + v := types.NewTime(types.FromDate(2024, 1, 2, 3, 4, 5, 123456), mysql.TypeDatetime, 6) + buf := make([]byte, 0, 26) + for i := 0; i < b.N; i++ { + buf = v.AppendString(buf[:0]) + } + }) +} diff --git a/pkg/util/chunk/chunk_test.go b/pkg/util/chunk/chunk_test.go index 14194699ee811..db9c8ad1e2e09 100644 --- a/pkg/util/chunk/chunk_test.go +++ b/pkg/util/chunk/chunk_test.go @@ -724,6 +724,77 @@ func TestToString(t *testing.T) { require.Equal(t, "1, 1, 1, 0000-00-00, 1\n2, 2, 2, 0000-00-00 00:00:00, 2\n", chk.ToString(fieldTypes)) } +func TestColumnAppendFixedValuesConsistency(t *testing.T) { + fieldTypes := []*types.FieldType{ + types.NewFieldType(mysql.TypeLonglong), + types.NewFieldType(mysql.TypeLonglong), // uint64 stored as fixed-len + types.NewFieldType(mysql.TypeFloat), + types.NewFieldType(mysql.TypeDouble), + types.NewFieldType(mysql.TypeDatetime), + types.NewFieldType(mysql.TypeNewDecimal), + } + fieldTypes[1].SetFlag(fieldTypes[1].GetFlag() | mysql.UnsignedFlag) + + chk := NewChunkWithCapacity(fieldTypes, 8) + + expectTimes := []types.Time{ + types.NewTime(types.FromDate(2024, 1, 2, 3, 4, 5, 0), mysql.TypeDatetime, 0), + types.NewTime(types.FromDate(2025, 2, 3, 4, 5, 6, 0), mysql.TypeDatetime, 0), + types.NewTime(types.FromDate(2026, 3, 4, 5, 6, 7, 0), mysql.TypeDatetime, 0), + } + expectDecs := []*types.MyDecimal{ + types.NewDecFromInt(42), + types.NewDecFromInt(-7), + types.NewDecFromUint(123456), + } + + for i := range 3 { + chk.AppendInt64(0, int64(-1*(i+1))) + chk.AppendUint64(1, uint64(100+i)) + chk.AppendFloat32(2, float32(i)+1.25) + chk.AppendFloat64(3, float64(i)*-2.5) + chk.AppendTime(4, expectTimes[i]) + chk.AppendMyDecimal(5, expectDecs[i]) + } + + require.Equal(t, 3, chk.NumRows()) + for i := range 3 { + row := chk.GetRow(i) + require.False(t, row.IsNull(0)) + require.Equal(t, int64(-1*(i+1)), row.GetInt64(0)) + require.False(t, row.IsNull(1)) + require.Equal(t, uint64(100+i), row.GetUint64(1)) + require.False(t, row.IsNull(2)) + require.Equal(t, float32(i)+1.25, row.GetFloat32(2)) + require.False(t, row.IsNull(3)) + require.Equal(t, float64(i)*-2.5, row.GetFloat64(3)) + require.False(t, row.IsNull(4)) + require.Equal(t, 0, row.GetTime(4).Compare(expectTimes[i])) + require.False(t, row.IsNull(5)) + require.Equal(t, 0, row.GetMyDecimal(5).Compare(expectDecs[i])) + } +} + +func TestColumnNullBitmapAfterAppend(t *testing.T) { + fieldTypes := []*types.FieldType{types.NewFieldType(mysql.TypeLonglong)} + chk := NewChunkWithCapacity(fieldTypes, 8) + + chk.AppendInt64(0, 1) + chk.AppendNull(0) + chk.AppendInt64(0, 2) + chk.AppendNull(0) + + require.Equal(t, 4, chk.NumRows()) + col := chk.Column(0) + require.False(t, col.IsNull(0)) + require.True(t, col.IsNull(1)) + require.False(t, col.IsNull(2)) + require.True(t, col.IsNull(3)) + + require.Equal(t, int64(1), chk.GetRow(0).GetInt64(0)) + require.Equal(t, int64(2), chk.GetRow(2).GetInt64(0)) +} + func BenchmarkAppendInt(b *testing.B) { b.ReportAllocs() chk := newChunk(8) @@ -768,6 +839,37 @@ func BenchmarkAppendRow(b *testing.B) { } } +func BenchmarkAppendMixedColumns(b *testing.B) { + b.ReportAllocs() + + const ( + rows = 1024 + bytesLen = 32 + ) + + fieldTypes := []*types.FieldType{ + types.NewFieldType(mysql.TypeLonglong), + types.NewFieldType(mysql.TypeVarchar), + types.NewFieldType(mysql.TypeTimestamp), + } + chk := NewChunkWithCapacity(fieldTypes, rows) + + bs := make([]byte, bytesLen) + for i := range bs { + bs[i] = byte('a' + (i % 26)) + } + tm := types.NewTime(types.FromDate(2024, 1, 2, 3, 4, 5, 0), mysql.TypeTimestamp, 0) + + for range b.N { + chk.Reset() + for i := 0; i < rows; i++ { + chk.AppendInt64(0, int64(i)) + chk.AppendBytes(1, bs) + chk.AppendTime(2, tm) + } + } +} + func appendRow(chk *Chunk, row Row) { chk.Reset() for range 1000 { diff --git a/pkg/util/chunk/column.go b/pkg/util/chunk/column.go index 6b180e5a4a0c7..a4c0c69f80545 100644 --- a/pkg/util/chunk/column.go +++ b/pkg/util/chunk/column.go @@ -39,8 +39,10 @@ func (c *Column) AppendDuration(dur types.Duration) { // AppendMyDecimal appends a MyDecimal value into this Column. func (c *Column) AppendMyDecimal(dec *types.MyDecimal) { - *(*types.MyDecimal)(unsafe.Pointer(&c.elemBuf[0])) = *dec - c.finishAppendFixed() + start := c.extendDataForAppendFixed(types.MyDecimalStructSize) + *(*types.MyDecimal)(unsafe.Pointer(&c.data[start])) = *dec + c.appendNullBitmap(true) + c.length++ } func (c *Column) appendNameValue(name string, val uint64) { @@ -355,13 +357,24 @@ func (c *Column) AppendNNulls(n int) { func (c *Column) AppendNull() { c.appendNullBitmap(false) if c.IsFixed() { - c.data = append(c.data, c.elemBuf...) + _ = c.extendDataForAppendFixed(len(c.elemBuf)) } else { c.offsets = append(c.offsets, c.offsets[c.length]) } c.length++ } +func (c *Column) extendDataForAppendFixed(typeSize int) int { + start := len(c.data) + newLen := start + typeSize + if cap(c.data) >= newLen { + c.data = c.data[:newLen] + return start + } + c.data = append(c.data, emptyBuf[:typeSize]...) + return start +} + func (c *Column) finishAppendFixed() { c.data = append(c.data, c.elemBuf...) c.appendNullBitmap(true) @@ -370,26 +383,34 @@ func (c *Column) finishAppendFixed() { // AppendInt64 appends an int64 value into this Column. func (c *Column) AppendInt64(i int64) { - *(*int64)(unsafe.Pointer(&c.elemBuf[0])) = i - c.finishAppendFixed() + start := c.extendDataForAppendFixed(sizeInt64) + *(*int64)(unsafe.Pointer(&c.data[start])) = i + c.appendNullBitmap(true) + c.length++ } // AppendUint64 appends a uint64 value into this Column. func (c *Column) AppendUint64(u uint64) { - *(*uint64)(unsafe.Pointer(&c.elemBuf[0])) = u - c.finishAppendFixed() + start := c.extendDataForAppendFixed(sizeUint64) + *(*uint64)(unsafe.Pointer(&c.data[start])) = u + c.appendNullBitmap(true) + c.length++ } // AppendFloat32 appends a float32 value into this Column. func (c *Column) AppendFloat32(f float32) { - *(*float32)(unsafe.Pointer(&c.elemBuf[0])) = f - c.finishAppendFixed() + start := c.extendDataForAppendFixed(sizeFloat32) + *(*float32)(unsafe.Pointer(&c.data[start])) = f + c.appendNullBitmap(true) + c.length++ } // AppendFloat64 appends a float64 value into this Column. func (c *Column) AppendFloat64(f float64) { - *(*float64)(unsafe.Pointer(&c.elemBuf[0])) = f - c.finishAppendFixed() + start := c.extendDataForAppendFixed(sizeFloat64) + *(*float64)(unsafe.Pointer(&c.data[start])) = f + c.appendNullBitmap(true) + c.length++ } func (c *Column) finishAppendVar() { @@ -412,8 +433,10 @@ func (c *Column) AppendBytes(b []byte) { // AppendTime appends a time value into this Column. func (c *Column) AppendTime(t types.Time) { - *(*types.Time)(unsafe.Pointer(&c.elemBuf[0])) = t - c.finishAppendFixed() + start := c.extendDataForAppendFixed(sizeTime) + *(*types.Time)(unsafe.Pointer(&c.data[start])) = t + c.appendNullBitmap(true) + c.length++ } // AppendEnum appends a Enum value into this Column. diff --git a/pkg/util/chunk/pool.go b/pkg/util/chunk/pool.go index 30f052968bcde..5484ba5a902c9 100644 --- a/pkg/util/chunk/pool.go +++ b/pkg/util/chunk/pool.go @@ -107,6 +107,16 @@ func (p *Pool) GetChunk(fields []*types.FieldType) *Chunk { func (p *Pool) PutChunk(fields []*types.FieldType, chk *Chunk) { for i, f := range fields { c := chk.columns[i] + // Some chunk columns are references to other columns. Avoid putting the + // same Column back into the pool multiple times. + if c == nil { + continue + } + for j := i + 1; j < len(fields); j++ { + if chk.columns[j] == c { + chk.columns[j] = nil + } + } c.reset() switch elemLen := getFixedLen(f); elemLen { case VarElemLen: diff --git a/pkg/util/chunk/pool_test.go b/pkg/util/chunk/pool_test.go index 105e06ba3ac19..96e555c2bd64b 100644 --- a/pkg/util/chunk/pool_test.go +++ b/pkg/util/chunk/pool_test.go @@ -87,6 +87,26 @@ func TestPoolPutChunk(t *testing.T) { require.Equal(t, 0, len(chk.columns)) } +func TestPoolPutChunkWithRefColumns(t *testing.T) { + initCap := 8 + pool := NewPool(initCap) + + fieldTypes := []*types.FieldType{ + types.NewFieldType(mysql.TypeLonglong), + types.NewFieldType(mysql.TypeLonglong), + } + + chk := pool.GetChunk(fieldTypes) + require.NotSame(t, chk.Column(0), chk.Column(1)) + + chk.MakeRef(0, 1) + require.Same(t, chk.Column(0), chk.Column(1)) + pool.PutChunk(fieldTypes, chk) + + chk2 := pool.GetChunk(fieldTypes) + require.NotSame(t, chk2.Column(0), chk2.Column(1)) +} + func BenchmarkPoolChunkOperation(b *testing.B) { pool := NewPool(1024) diff --git a/pkg/util/codec/codec.go b/pkg/util/codec/codec.go index 00d70f58bd051..1e2852edd2e03 100644 --- a/pkg/util/codec/codec.go +++ b/pkg/util/codec/codec.go @@ -200,7 +200,8 @@ func EncodeMySQLTime(loc *time.Location, t types.Time, tp byte, b []byte) (_ []b if tp == mysql.TypeUnspecified { tp = t.Type() } - if tp == mysql.TypeTimestamp && loc != time.UTC { + // loc can be nil in some callers; treat it as UTC (i.e., skip conversion). + if tp == mysql.TypeTimestamp && loc != nil && loc != time.UTC { err = t.ConvertTimeZone(loc, time.UTC) if err != nil { return nil, err diff --git a/pkg/util/codec/codec_test.go b/pkg/util/codec/codec_test.go index 88f5c6d677ecf..0be7e29e81d95 100644 --- a/pkg/util/codec/codec_test.go +++ b/pkg/util/codec/codec_test.go @@ -1338,3 +1338,17 @@ func TestDatumHashEquals(t *testing.T) { require.NotEqual(t, hasher1.Sum64(), hasher2.Sum64()) require.False(t, tests[len(tests)-1].d1.Equals(tests[len(tests)-1].d2)) } + +func TestEncodeMySQLTimeNilLocForTimestamp(t *testing.T) { + ts := types.NewTime(types.FromDate(2024, 1, 2, 3, 4, 5, 0), mysql.TypeTimestamp, types.DefaultFsp) + var gotNil, gotUTC []byte + require.NotPanics(t, func() { + var err error + gotNil, err = EncodeMySQLTime(nil, ts, mysql.TypeUnspecified, nil) + require.NoError(t, err) + }) + var err error + gotUTC, err = EncodeMySQLTime(time.UTC, ts, mysql.TypeUnspecified, nil) + require.NoError(t, err) + require.Equal(t, gotUTC, gotNil) +} diff --git a/pkg/util/execdetails/execdetails.go b/pkg/util/execdetails/execdetails.go index 4c08514717915..4792458e42c01 100644 --- a/pkg/util/execdetails/execdetails.go +++ b/pkg/util/execdetails/execdetails.go @@ -615,3 +615,38 @@ func (d *CopTasksDetails) ToZapFields() (fields []zap.Field) { fields = append(fields, zap.String("wait_max_addr", d.WaitTimeStats.MaxAddress)) return fields } + + +// ResultCacheRuntimeStats records result cache hit/miss info for EXPLAIN ANALYZE. +type ResultCacheRuntimeStats struct { + HitCache bool + CachedRows int64 +} + +// String implements the RuntimeStats interface. +func (e *ResultCacheRuntimeStats) String() string { + if e.HitCache { + return fmt.Sprintf("result_cache:hit, cached_rows:%d", e.CachedRows) + } + return "result_cache:miss" +} + +// Clone implements the RuntimeStats interface. +func (e *ResultCacheRuntimeStats) Clone() RuntimeStats { + return &ResultCacheRuntimeStats{HitCache: e.HitCache, CachedRows: e.CachedRows} +} + +// Merge implements the RuntimeStats interface. +func (e *ResultCacheRuntimeStats) Merge(other RuntimeStats) { + if tmp, ok := other.(*ResultCacheRuntimeStats); ok { + if tmp.HitCache { + e.HitCache = true + e.CachedRows = tmp.CachedRows + } + } +} + +// Tp implements the RuntimeStats interface. +func (*ResultCacheRuntimeStats) Tp() int { + return TpResultCacheRuntimeStats +} diff --git a/pkg/util/execdetails/execdetails_test.go b/pkg/util/execdetails/execdetails_test.go index aa59c1ed75b4c..b0ab7faa740f4 100644 --- a/pkg/util/execdetails/execdetails_test.go +++ b/pkg/util/execdetails/execdetails_test.go @@ -882,3 +882,24 @@ func TestRURuntimeStatsMergeKeepsExistingRUVersion(t *testing.T) { dst.Merge(src) require.Equal(t, rmclient.RUVersionV1, dst.RUVersion) } + +func TestResultCacheRuntimeStats(t *testing.T) { + // Test hit case. + hit := &ResultCacheRuntimeStats{HitCache: true, CachedRows: 100} + require.Equal(t, "result_cache:hit, cached_rows:100", hit.String()) + require.Equal(t, TpResultCacheRuntimeStats, hit.Tp()) + + // Test miss case. + miss := &ResultCacheRuntimeStats{HitCache: false} + require.Equal(t, "result_cache:miss", miss.String()) + + // Test Clone. + cloned := hit.Clone().(*ResultCacheRuntimeStats) + require.Equal(t, true, cloned.HitCache) + require.Equal(t, int64(100), cloned.CachedRows) + + // Test Merge: miss + hit = hit. + miss.Merge(hit) + require.True(t, miss.HitCache) + require.Equal(t, int64(100), miss.CachedRows) +} diff --git a/pkg/util/execdetails/runtime_stats.go b/pkg/util/execdetails/runtime_stats.go index ee5b2b1be377b..0407f10388c29 100644 --- a/pkg/util/execdetails/runtime_stats.go +++ b/pkg/util/execdetails/runtime_stats.go @@ -70,6 +70,8 @@ const ( TpFKCascadeRuntimeStats // TpRURuntimeStats is the tp for RURuntimeStats TpRURuntimeStats + // TpResultCacheRuntimeStats is the tp for ResultCacheRuntimeStats + TpResultCacheRuntimeStats ) // RuntimeStats is used to express the executor runtime information. diff --git a/pkg/util/hint/hint.go b/pkg/util/hint/hint.go index 5d4af787ac81d..9cfdebfe01cb3 100644 --- a/pkg/util/hint/hint.go +++ b/pkg/util/hint/hint.go @@ -111,6 +111,8 @@ const ( HintTimeRange = "time_range" // HintIgnorePlanCache is a hint to enforce ignoring plan cache HintIgnorePlanCache = "ignore_plan_cache" + // HintUsePlanCache is a hint to enforce enabling plan cache. + HintUsePlanCache = "use_plan_cache" // HintLimitToCop is a hint enforce pushing limit or topn to coprocessor. HintLimitToCop = "limit_to_cop" // HintMerge is a hint which can switch turning inline for the CTE. @@ -228,7 +230,9 @@ type StmtHints struct { ResourceGroup string // Do not store plan in either plan cache. IgnorePlanCache bool - WriteSlowLog bool + // Force statement to use plan cache. + UsePlanCache bool + WriteSlowLog bool // Hint flags HasAllowInSubqToJoinAndAggHint bool @@ -276,6 +280,7 @@ func (sh *StmtHints) Clone() *StmtHints { ForceNthPlan: sh.ForceNthPlan, ResourceGroup: sh.ResourceGroup, IgnorePlanCache: sh.IgnorePlanCache, + UsePlanCache: sh.UsePlanCache, WriteSlowLog: sh.WriteSlowLog, HasAllowInSubqToJoinAndAggHint: sh.HasAllowInSubqToJoinAndAggHint, HasMemQuotaHint: sh.HasMemQuotaHint, @@ -409,6 +414,8 @@ func ParseStmtHints(hints []*ast.TableOptimizerHint, setVarsOffs = append(setVarsOffs, i) case HintIgnorePlanCache: stmtHints.IgnorePlanCache = true + case HintUsePlanCache: + stmtHints.UsePlanCache = true case HintWriteSlowLog: stmtHints.WriteSlowLog = true } diff --git a/pkg/util/rowcodec/BUILD.bazel b/pkg/util/rowcodec/BUILD.bazel index 0a5b8025276b0..888cb568a9931 100644 --- a/pkg/util/rowcodec/BUILD.bazel +++ b/pkg/util/rowcodec/BUILD.bazel @@ -31,6 +31,7 @@ go_test( srcs = [ "bench_test.go", "common_test.go", + "decoder_test.go", "main_test.go", "rowcodec_test.go", ], diff --git a/pkg/util/rowcodec/bench_test.go b/pkg/util/rowcodec/bench_test.go index eea2c8373fbe4..cbbb576666561 100644 --- a/pkg/util/rowcodec/bench_test.go +++ b/pkg/util/rowcodec/bench_test.go @@ -28,6 +28,96 @@ import ( "github.com/pingcap/tidb/pkg/util/rowcodec" ) +func BenchmarkDecodeWideRowToChunk(b *testing.B) { + makeBytes := func(n int) []byte { + bs := make([]byte, n) + for i := range bs { + bs[i] = byte('a' + (i % 26)) + } + return bs + } + + const ( + intCols = 48 + bytesCols = 12 + timeCols = 4 + totalCols = intCols + bytesCols + timeCols + batchSize = 64 + timestampFsp = 0 + ) + + benchCases := []struct { + name string + bytesLen int + }{ + {name: "small_bytes_32", bytesLen: 32}, + {name: "big_bytes_1024", bytesLen: 1024}, + } + + for _, tc := range benchCases { + b.Run(tc.name, func(b *testing.B) { + b.ReportAllocs() + + colIDs := make([]int64, totalCols) + values := make([]types.Datum, totalCols) + fieldTypes := make([]*types.FieldType, totalCols) + cols := make([]rowcodec.ColInfo, totalCols) + + // Int columns. + for i := 0; i < intCols; i++ { + colIDs[i] = int64(i + 1) + values[i].SetInt64(int64(i)) + ft := types.NewFieldType(mysql.TypeLonglong) + fieldTypes[i] = ft + cols[i] = rowcodec.ColInfo{ID: colIDs[i], Ft: ft} + } + + // Bytes columns. + payload := makeBytes(tc.bytesLen) + for i := 0; i < bytesCols; i++ { + idx := intCols + i + colIDs[idx] = int64(idx + 1) + values[idx].SetBytes(payload) + ft := types.NewFieldType(mysql.TypeVarchar) + fieldTypes[idx] = ft + cols[idx] = rowcodec.ColInfo{ID: colIDs[idx], Ft: ft} + } + + // Timestamp columns (exercise FromPackedUint + optional TZ conversion). + baseCore := types.FromDate(2024, 1, 2, 3, 4, 5, 0) + baseTS := types.NewTime(baseCore, mysql.TypeTimestamp, timestampFsp) + for i := 0; i < timeCols; i++ { + idx := intCols + bytesCols + i + colIDs[idx] = int64(idx + 1) + values[idx].SetMysqlTime(baseTS) + ft := types.NewFieldType(mysql.TypeTimestamp) + ft.SetDecimal(timestampFsp) + fieldTypes[idx] = ft + cols[idx] = rowcodec.ColInfo{ID: colIDs[idx], Ft: ft} + } + + var enc rowcodec.Encoder + rowData, err := enc.Encode(time.Local, colIDs, values, nil, nil) + if err != nil { + b.Fatal(err) + } + + decoder := rowcodec.NewChunkDecoder(cols, []int64{-1}, nil, time.Local) + chk := chunk.NewChunkWithCapacity(fieldTypes, batchSize) + + b.ResetTimer() + for range b.N { + chk.Reset() + for r := 0; r < batchSize; r++ { + if err := decoder.DecodeToChunk(rowData, kv.IntHandle(r), chk); err != nil { + b.Fatal(err) + } + } + } + }) + } +} + func BenchmarkChecksum(b *testing.B) { b.ReportAllocs() datums := types.MakeDatums(1, "abc", 1.1) diff --git a/pkg/util/rowcodec/common.go b/pkg/util/rowcodec/common.go index 4e5273fe9da9d..40b9dfceb1262 100644 --- a/pkg/util/rowcodec/common.go +++ b/pkg/util/rowcodec/common.go @@ -38,6 +38,7 @@ var ( errInvalidCodecVer = errors.New("invalid codec version") errInvalidChecksumVer = errors.New("invalid checksum version") errInvalidChecksumTyp = errors.New("invalid type for checksum") + errInvalidChecksumKey = errors.New("invalid key or handle for checksum") ) // First byte in the encoded value which specifies the encoding type. diff --git a/pkg/util/rowcodec/decoder.go b/pkg/util/rowcodec/decoder.go index ab97a40a69c74..3305dd5418c8f 100644 --- a/pkg/util/rowcodec/decoder.go +++ b/pkg/util/rowcodec/decoder.go @@ -140,7 +140,7 @@ func (decoder *DatumMapDecoder) decodeColDatum(col *ColInfo, colData []byte) (ty if err != nil { return d, err } - if col.Ft.GetType() == mysql.TypeTimestamp && !t.IsZero() { + if col.Ft.GetType() == mysql.TypeTimestamp && decoder.loc != nil && !t.IsZero() { err = t.ConvertTimeZone(time.UTC, decoder.loc) if err != nil { return d, err @@ -189,6 +189,37 @@ func (decoder *DatumMapDecoder) decodeColDatum(col *ColInfo, colData []byte) (ty type ChunkDecoder struct { decoder defDatum func(i int, chk *chunk.Chunk) error + + compiledCols []compiledCol + compiledColsInited bool + + // colMapping caches the column index in not-null columns array for decoder.columns[i]. + // It can skip binary search for stable not-null columns layout. The cached index + // is validated on each row and falls back to findColID when layout changes. + // -1 means not cached and needs to call findColID for each row. + colMapping []int + mappingInited bool + mappingRowCols int +} + +type compiledColKind uint8 + +const ( + compiledColOther compiledColKind = iota + compiledColInt + compiledColUint + compiledColFloat32 + compiledColFloat64 + compiledColBytes + compiledColTime + compiledColDuration +) + +type compiledCol struct { + kind compiledColKind + tp byte + fsp int + needTZConvert bool } // NewChunkDecoder creates a NewChunkDecoder. @@ -203,6 +234,52 @@ func NewChunkDecoder(columns []ColInfo, handleColIDs []int64, defDatum func(i in } } +func (decoder *ChunkDecoder) initCompiledCols() { + if decoder.compiledColsInited { + return + } + if cap(decoder.compiledCols) < len(decoder.columns) { + decoder.compiledCols = make([]compiledCol, len(decoder.columns)) + } else { + decoder.compiledCols = decoder.compiledCols[:len(decoder.columns)] + } + for i := range decoder.columns { + col := &decoder.columns[i] + if col.Ft == nil { + decoder.compiledCols[i] = compiledCol{kind: compiledColOther} + continue + } + tp := col.Ft.GetType() + cc := compiledCol{kind: compiledColOther, tp: tp} + switch tp { + case mysql.TypeLonglong, mysql.TypeLong, mysql.TypeInt24, mysql.TypeShort, mysql.TypeTiny: + if mysql.HasUnsignedFlag(col.Ft.GetFlag()) { + cc.kind = compiledColUint + } else { + cc.kind = compiledColInt + } + case mysql.TypeYear: + cc.kind = compiledColInt + case mysql.TypeFloat: + cc.kind = compiledColFloat32 + case mysql.TypeDouble: + cc.kind = compiledColFloat64 + case mysql.TypeVarString, mysql.TypeVarchar, mysql.TypeString, + mysql.TypeBlob, mysql.TypeTinyBlob, mysql.TypeMediumBlob, mysql.TypeLongBlob: + cc.kind = compiledColBytes + case mysql.TypeDate, mysql.TypeDatetime, mysql.TypeTimestamp: + cc.kind = compiledColTime + cc.fsp = col.Ft.GetDecimal() + cc.needTZConvert = tp == mysql.TypeTimestamp && decoder.loc != nil + case mysql.TypeDuration: + cc.kind = compiledColDuration + cc.fsp = col.Ft.GetDecimal() + } + decoder.compiledCols[i] = cc + } + decoder.compiledColsInited = true +} + // DecodeToChunk decodes a row to chunk. func (decoder *ChunkDecoder) DecodeToChunk(rowData []byte, commitTS uint64, handle kv.Handle, chk *chunk.Chunk) error { err := decoder.fromBytes(rowData) @@ -210,6 +287,18 @@ func (decoder *ChunkDecoder) DecodeToChunk(rowData []byte, commitTS uint64, hand return err } + if !decoder.compiledColsInited { + decoder.initCompiledCols() + } + + // Build (or rebuild) column mapping cache for stable schema. + rowCols := int(decoder.row.numNotNullCols) + int(decoder.row.numNullCols) + if !decoder.mappingInited || decoder.mappingRowCols != rowCols { + decoder.buildColMapping() + decoder.mappingInited = true + decoder.mappingRowCols = rowCols + } + for colIdx := range decoder.columns { col := &decoder.columns[colIdx] if col.ID == model.ExtraCommitTSID { @@ -231,6 +320,16 @@ func (decoder *ChunkDecoder) DecodeToChunk(rowData []byte, commitTS uint64, hand continue } + mappedIdx := decoder.colMapping[colIdx] + if mappedIdx >= 0 && decoder.matchNotNullColID(mappedIdx, col.ID) { + colData := decoder.getData(mappedIdx) + err := decoder.decodeColToChunk(colIdx, col, colData, chk) + if err != nil { + return err + } + continue + } + idx, isNil, notFound := decoder.row.findColID(col.ID) if !notFound && !isNil { colData := decoder.getData(idx) @@ -266,6 +365,43 @@ func (decoder *ChunkDecoder) DecodeToChunk(rowData []byte, commitTS uint64, hand return nil } +func (decoder *ChunkDecoder) buildColMapping() { + if cap(decoder.colMapping) < len(decoder.columns) { + decoder.colMapping = make([]int, len(decoder.columns)) + } else { + decoder.colMapping = decoder.colMapping[:len(decoder.columns)] + } + for i := range decoder.colMapping { + decoder.colMapping[i] = -1 + } + + for i := range decoder.columns { + col := &decoder.columns[i] + if col.VirtualGenCol || col.ID == model.ExtraRowChecksumID { + continue + } + // Only attempt to cache columns declared NOT NULL. Nullable columns can move between + // not-null and null segments, and other columns' NULL changes may also shift indices. + if col.Ft == nil || !mysql.HasNotNullFlag(col.Ft.GetFlag()) { + continue + } + idx, isNil, notFound := decoder.row.findColID(col.ID) + if !notFound && !isNil { + decoder.colMapping[i] = idx + } + } +} + +func (decoder *ChunkDecoder) matchNotNullColID(idx int, colID int64) bool { + if idx < 0 || idx >= int(decoder.row.numNotNullCols) { + return false + } + if decoder.row.large() { + return int64(decoder.row.colIDs32[idx]) == colID + } + return int64(decoder.row.colIDs[idx]) == colID +} + func (decoder *ChunkDecoder) tryAppendHandleColumn(colIdx int, col *ColInfo, handle kv.Handle, chk *chunk.Chunk) bool { if handle == nil { return false @@ -288,64 +424,70 @@ func (decoder *ChunkDecoder) tryAppendHandleColumn(colIdx int, col *ColInfo, han } func (decoder *ChunkDecoder) decodeColToChunk(colIdx int, col *ColInfo, colData []byte, chk *chunk.Chunk) error { - switch col.Ft.GetType() { - case mysql.TypeLonglong, mysql.TypeLong, mysql.TypeInt24, mysql.TypeShort, mysql.TypeTiny: - if mysql.HasUnsignedFlag(col.Ft.GetFlag()) { - chk.AppendUint64(colIdx, decodeUint(colData)) - } else { - chk.AppendInt64(colIdx, decodeInt(colData)) - } - case mysql.TypeYear: + cc := decoder.compiledCols[colIdx] + switch cc.kind { + case compiledColInt: chk.AppendInt64(colIdx, decodeInt(colData)) - case mysql.TypeFloat: + return nil + case compiledColUint: + chk.AppendUint64(colIdx, decodeUint(colData)) + return nil + case compiledColFloat32: _, fVal, err := codec.DecodeFloat(colData) if err != nil { return err } chk.AppendFloat32(colIdx, float32(fVal)) - case mysql.TypeDouble: + return nil + case compiledColFloat64: _, fVal, err := codec.DecodeFloat(colData) if err != nil { return err } chk.AppendFloat64(colIdx, fVal) - case mysql.TypeVarString, mysql.TypeVarchar, mysql.TypeString, - mysql.TypeBlob, mysql.TypeTinyBlob, mysql.TypeMediumBlob, mysql.TypeLongBlob: + return nil + case compiledColBytes: chk.AppendBytes(colIdx, colData) - case mysql.TypeNewDecimal: - _, dec, _, frac, err := codec.DecodeDecimal(colData) - if err != nil { - return err - } - if col.Ft.GetDecimal() != types.UnspecifiedLength && frac > col.Ft.GetDecimal() { - to := new(types.MyDecimal) - err := dec.Round(to, col.Ft.GetDecimal(), types.ModeHalfUp) - if err != nil { - return errors.Trace(err) - } - dec = to - } - chk.AppendMyDecimal(colIdx, dec) - case mysql.TypeDate, mysql.TypeDatetime, mysql.TypeTimestamp: + return nil + case compiledColTime: var t types.Time - t.SetType(col.Ft.GetType()) - t.SetFsp(col.Ft.GetDecimal()) + t.SetType(cc.tp) + t.SetFsp(cc.fsp) err := t.FromPackedUint(decodeUint(colData)) if err != nil { return err } - if col.Ft.GetType() == mysql.TypeTimestamp && decoder.loc != nil && !t.IsZero() { + if cc.needTZConvert && !t.IsZero() { err = t.ConvertTimeZone(time.UTC, decoder.loc) if err != nil { return err } } chk.AppendTime(colIdx, t) - case mysql.TypeDuration: + return nil + case compiledColDuration: var dur types.Duration dur.Duration = time.Duration(decodeInt(colData)) - dur.Fsp = col.Ft.GetDecimal() + dur.Fsp = cc.fsp chk.AppendDuration(colIdx, dur) + return nil + } + + switch cc.tp { + case mysql.TypeNewDecimal: + _, dec, _, frac, err := codec.DecodeDecimal(colData) + if err != nil { + return err + } + if col.Ft.GetDecimal() != types.UnspecifiedLength && frac > col.Ft.GetDecimal() { + to := new(types.MyDecimal) + err := dec.Round(to, col.Ft.GetDecimal(), types.ModeHalfUp) + if err != nil { + return errors.Trace(err) + } + dec = to + } + chk.AppendMyDecimal(colIdx, dec) case mysql.TypeEnum: // ignore error deliberately, to read empty enum value. enum, err := types.ParseEnumValue(col.Ft.GetElems(), decodeUint(colData)) @@ -374,7 +516,7 @@ func (decoder *ChunkDecoder) decodeColToChunk(colIdx int, col *ColInfo, colData } chk.AppendVectorFloat32(colIdx, v) default: - return errors.Errorf("unknown type %d", col.Ft.GetType()) + return errors.Errorf("unknown type %d", cc.tp) } return nil } @@ -507,6 +649,75 @@ func (*BytesDecoder) encodeOldDatum(tp byte, val []byte) []byte { return buf } +// encodeOldDatumToArena is like encodeOldDatum but appends to a caller-provided arena +// instead of allocating a new slice. Returns the encoded sub-slice and the updated arena. +func encodeOldDatumToArena(tp byte, val []byte, arena []byte) (result []byte, newArena []byte) { + start := len(arena) + switch tp { + case BytesFlag: + arena = append(arena, CompactBytesFlag) + arena = codec.EncodeCompactBytes(arena, val) + case IntFlag: + arena = append(arena, VarintFlag) + arena = codec.EncodeVarint(arena, decodeInt(val)) + case UintFlag: + arena = append(arena, VaruintFlag) + arena = codec.EncodeUvarint(arena, decodeUint(val)) + default: + arena = append(arena, tp) + arena = append(arena, val...) + } + return arena[start:len(arena):len(arena)], arena +} + +// DecodeToBytesNoHandleInto is like DecodeToBytesNoHandle but writes into a caller-provided +// values slice and arena instead of allocating new ones. The arena is used for encodeOldDatum +// allocations; caller should reset arena length between rows (arena = arena[:0]). +// The values slice is cleared and reused. +func (decoder *BytesDecoder) DecodeToBytesNoHandleInto( + outputOffset map[int64]int, value []byte, values [][]byte, arena []byte, +) ([][]byte, []byte, error) { + var r row + err := r.fromBytes(value) + if err != nil { + return nil, arena, err + } + for i := range values { + values[i] = nil + } + for i := range decoder.columns { + col := &decoder.columns[i] + tp := fieldType2Flag(col.Ft.ArrayType().GetType(), col.Ft.GetFlag()&mysql.UnsignedFlag == 0) + colID := col.ID + offset := outputOffset[colID] + idx, isNil, notFound := r.findColID(colID) + if !notFound && !isNil { + val := r.getData(idx) + values[offset], arena = encodeOldDatumToArena(tp, val, arena) + continue + } + + if isNil { + values[offset] = []byte{NilFlag} + continue + } + + if decoder.defBytes != nil { + defVal, err := decoder.defBytes(i) + if err != nil { + return nil, arena, err + } + if len(defVal) > 0 { + values[offset] = defVal + continue + } + } + + values[offset] = []byte{NilFlag} + } + return values, arena, nil +} + // fieldType2Flag transforms field type into kv type flag. func fieldType2Flag(tp byte, signed bool) (flag byte) { switch tp { diff --git a/pkg/util/rowcodec/decoder_test.go b/pkg/util/rowcodec/decoder_test.go new file mode 100644 index 0000000000000..3ebefac642ff3 --- /dev/null +++ b/pkg/util/rowcodec/decoder_test.go @@ -0,0 +1,402 @@ +// Copyright 2026 PingCAP, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package rowcodec + +import ( + "encoding/binary" + "hash/crc32" + "testing" + "time" + + "github.com/pingcap/tidb/pkg/kv" + "github.com/pingcap/tidb/pkg/parser/mysql" + "github.com/pingcap/tidb/pkg/types" + "github.com/pingcap/tidb/pkg/util/chunk" + "github.com/stretchr/testify/require" +) + +func TestChunkDecoderColMapping(t *testing.T) { + ftNullable := types.NewFieldType(mysql.TypeLonglong) + ftNotNull1 := types.NewFieldType(mysql.TypeLonglong) + ftNotNull1.SetFlag(ftNotNull1.GetFlag() | mysql.NotNullFlag) + ftNotNull2 := types.NewFieldType(mysql.TypeLonglong) + ftNotNull2.SetFlag(ftNotNull2.GetFlag() | mysql.NotNullFlag) + + cols := []ColInfo{ + {ID: 1, Ft: ftNullable}, // nullable + {ID: 2, Ft: ftNotNull1}, // not-null + {ID: 3, Ft: ftNotNull2}, // not-null + } + fts := []*types.FieldType{ftNullable, ftNotNull1, ftNotNull2} + + var encoder Encoder + colIDs := []int64{1, 2, 3} + + // Row1: all columns are not null, build the mapping. + row1, err := encoder.Encode(time.UTC, colIDs, []types.Datum{ + types.NewIntDatum(10), + types.NewIntDatum(20), + types.NewIntDatum(30), + }, nil, nil) + require.NoError(t, err) + + // Row2: column 1 becomes NULL, shifting the not-null segment indices. + var nullDatum types.Datum + nullDatum.SetNull() + row2, err := encoder.Encode(time.UTC, colIDs, []types.Datum{ + nullDatum, + types.NewIntDatum(22), + types.NewIntDatum(33), + }, nil, nil) + require.NoError(t, err) + + // Row3: column 1 becomes not null again. + row3, err := encoder.Encode(time.UTC, colIDs, []types.Datum{ + types.NewIntDatum(11), + types.NewIntDatum(23), + types.NewIntDatum(34), + }, nil, nil) + require.NoError(t, err) + + decoder := NewChunkDecoder(cols, []int64{-1}, nil, time.UTC) + chk := chunk.New(fts, 0, 3) + + require.NoError(t, decoder.DecodeToChunk(row1, kv.IntHandle(-1), chk)) + require.True(t, decoder.mappingInited) + require.Equal(t, 3, decoder.mappingRowCols) + require.Len(t, decoder.colMapping, 3) + require.Equal(t, -1, decoder.colMapping[0]) // nullable column should not be cached + require.GreaterOrEqual(t, decoder.colMapping[1], 0) + require.GreaterOrEqual(t, decoder.colMapping[2], 0) + + mappingAfterRow1 := append([]int(nil), decoder.colMapping...) + + require.NoError(t, decoder.DecodeToChunk(row2, kv.IntHandle(-1), chk)) + require.Equal(t, mappingAfterRow1, decoder.colMapping) // no rebuild; rowCols unchanged + + require.NoError(t, decoder.DecodeToChunk(row3, kv.IntHandle(-1), chk)) + + require.Equal(t, 3, chk.NumRows()) + + r1 := chk.GetRow(0) + require.False(t, r1.IsNull(0)) + require.Equal(t, int64(10), r1.GetInt64(0)) + require.Equal(t, int64(20), r1.GetInt64(1)) + require.Equal(t, int64(30), r1.GetInt64(2)) + + r2 := chk.GetRow(1) + require.True(t, r2.IsNull(0)) + require.Equal(t, int64(22), r2.GetInt64(1)) + require.Equal(t, int64(33), r2.GetInt64(2)) + + r3 := chk.GetRow(2) + require.False(t, r3.IsNull(0)) + require.Equal(t, int64(11), r3.GetInt64(0)) + require.Equal(t, int64(23), r3.GetInt64(1)) + require.Equal(t, int64(34), r3.GetInt64(2)) +} + +func TestChunkDecoderColMappingSchemaChange(t *testing.T) { + ft1 := types.NewFieldType(mysql.TypeLonglong) + ft1.SetFlag(ft1.GetFlag() | mysql.NotNullFlag) + ft2 := types.NewFieldType(mysql.TypeLonglong) + ft2.SetFlag(ft2.GetFlag() | mysql.NotNullFlag) + ft3 := types.NewFieldType(mysql.TypeLonglong) + ft3.SetFlag(ft3.GetFlag() | mysql.NotNullFlag) + + cols := []ColInfo{ + {ID: 1, Ft: ft1}, + {ID: 2, Ft: ft2}, + {ID: 3, Ft: ft3}, + } + fts := []*types.FieldType{ft1, ft2, ft3} + + defDatum := func(i int, chk *chunk.Chunk) error { + // Default value for column 3 only. + if i == 2 { + chk.AppendInt64(i, 999) + return nil + } + chk.AppendNull(i) + return nil + } + + var encoder Encoder + // Row1: old schema, missing column 3. + row1, err := encoder.Encode(time.UTC, []int64{1, 2}, []types.Datum{ + types.NewIntDatum(10), + types.NewIntDatum(20), + }, nil, nil) + require.NoError(t, err) + + // Row2: new schema includes column 3. + row2, err := encoder.Encode(time.UTC, []int64{1, 2, 3}, []types.Datum{ + types.NewIntDatum(11), + types.NewIntDatum(22), + types.NewIntDatum(33), + }, nil, nil) + require.NoError(t, err) + + decoder := NewChunkDecoder(cols, []int64{-1}, defDatum, time.UTC) + chk := chunk.New(fts, 0, 2) + + require.NoError(t, decoder.DecodeToChunk(row1, kv.IntHandle(-1), chk)) + require.True(t, decoder.mappingInited) + require.Equal(t, 2, decoder.mappingRowCols) + require.Equal(t, -1, decoder.colMapping[2]) // col3 not found in old rows + + require.NoError(t, decoder.DecodeToChunk(row2, kv.IntHandle(-1), chk)) + require.Equal(t, 3, decoder.mappingRowCols) // mapping rebuilt + require.GreaterOrEqual(t, decoder.colMapping[2], 0) // col3 becomes cacheable + + require.Equal(t, 2, chk.NumRows()) + + r1 := chk.GetRow(0) + require.Equal(t, int64(10), r1.GetInt64(0)) + require.Equal(t, int64(20), r1.GetInt64(1)) + require.Equal(t, int64(999), r1.GetInt64(2)) // default + + r2 := chk.GetRow(1) + require.Equal(t, int64(11), r2.GetInt64(0)) + require.Equal(t, int64(22), r2.GetInt64(1)) + require.Equal(t, int64(33), r2.GetInt64(2)) +} + +func TestChunkDecoderCompiledColsCorrectness(t *testing.T) { + ftInt := types.NewFieldType(mysql.TypeLonglong) + ftUint := types.NewFieldType(mysql.TypeLonglong) + ftUint.SetFlag(ftUint.GetFlag() | mysql.UnsignedFlag) + ftBytes := types.NewFieldType(mysql.TypeVarchar) + ftDT := types.NewFieldType(mysql.TypeDatetime) + ftDT.SetDecimal(0) + ftTS := types.NewFieldType(mysql.TypeTimestamp) + ftTS.SetDecimal(0) + + cols := []ColInfo{ + {ID: 1, Ft: ftInt}, + {ID: 2, Ft: ftUint}, + {ID: 3, Ft: ftBytes}, + {ID: 4, Ft: ftDT}, + {ID: 5, Ft: ftTS}, + } + fts := []*types.FieldType{ftInt, ftUint, ftBytes, ftDT, ftTS} + colIDs := []int64{1, 2, 3, 4, 5} + + dt := types.NewTime(types.FromDate(2024, 1, 2, 3, 4, 5, 0), mysql.TypeDatetime, 0) + ts := types.NewTime(types.FromDate(2024, 2, 3, 4, 5, 6, 0), mysql.TypeTimestamp, 0) + + type testCase struct { + name string + encodeLoc *time.Location + decodeLoc *time.Location + expectTS types.Time + } + testCases := []testCase{ + { + name: "loc_nil", + encodeLoc: nil, + decodeLoc: nil, + expectTS: ts, + }, + { + name: "loc_local", + encodeLoc: time.Local, + decodeLoc: time.Local, + expectTS: ts, + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + var encoder Encoder + rowData, err := encoder.Encode(tc.encodeLoc, colIDs, []types.Datum{ + types.NewIntDatum(123), + types.NewUintDatum(456), + types.NewBytesDatum([]byte("abc")), + types.NewTimeDatum(dt), + types.NewTimeDatum(ts), + }, nil, nil) + require.NoError(t, err) + + decoder := NewChunkDecoder(cols, []int64{-1}, nil, tc.decodeLoc) + chk := chunk.NewChunkWithCapacity(fts, 1) + require.NoError(t, decoder.DecodeToChunk(rowData, kv.IntHandle(-1), chk)) + require.Equal(t, 1, chk.NumRows()) + + row := chk.GetRow(0) + require.False(t, row.IsNull(0)) + require.Equal(t, int64(123), row.GetInt64(0)) + require.False(t, row.IsNull(1)) + require.Equal(t, uint64(456), row.GetUint64(1)) + require.False(t, row.IsNull(2)) + require.Equal(t, []byte("abc"), row.GetBytes(2)) + require.False(t, row.IsNull(3)) + require.Equal(t, 0, row.GetTime(3).Compare(dt)) + require.False(t, row.IsNull(4)) + require.Equal(t, 0, row.GetTime(4).Compare(tc.expectTS)) + }) + } +} + +func TestEncodeOldDatumArena(t *testing.T) { + // Verify encodeOldDatumToArena produces identical results to encodeOldDatum. + dec := &BytesDecoder{} + + tests := []struct { + name string + tp byte + val []byte + }{ + {"Int", IntFlag, func() []byte { var b [8]byte; binary.BigEndian.PutUint64(b[:], uint64(42)); return b[:] }()}, + {"Uint", UintFlag, func() []byte { var b [8]byte; binary.BigEndian.PutUint64(b[:], 999); return b[:] }()}, + {"Bytes", BytesFlag, []byte("hello world")}, + {"Float", FloatFlag, func() []byte { var b [8]byte; binary.BigEndian.PutUint64(b[:], 0x4059000000000000); return b[:] }()}, + {"EmptyBytes", BytesFlag, []byte{}}, + {"NilFlag", NilFlag, nil}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + expected := dec.encodeOldDatum(tt.tp, tt.val) + arena := make([]byte, 0, 64) + result, newArena := encodeOldDatumToArena(tt.tp, tt.val, arena) + require.Equal(t, expected, []byte(result), "mismatch for %s", tt.name) + // Verify result is a sub-slice of arena. + require.Equal(t, len(result), len(newArena)-len(arena)) + }) + } + + // Verify multiple sequential calls share the same arena. + arena := make([]byte, 0, 256) + var results [][]byte + for _, tt := range tests { + var result []byte + result, arena = encodeOldDatumToArena(tt.tp, tt.val, arena) + results = append(results, result) + } + // Verify each result matches original. + for i, tt := range tests { + expected := dec.encodeOldDatum(tt.tp, tt.val) + require.Equal(t, expected, []byte(results[i]), "arena sequential mismatch for %s", tt.name) + } +} + +func TestRawChecksumRequiresHandle(t *testing.T) { + enc := Encoder{} + _, err := enc.Encode(time.UTC, []int64{1}, []types.Datum{types.NewIntDatum(1)}, RawChecksum{}, nil) + require.ErrorIs(t, err, errInvalidChecksumKey) +} + +func TestCalculateRawChecksumValidationAndCompatibility(t *testing.T) { + enc := Encoder{} + rowData, err := enc.Encode( + time.UTC, + []int64{1}, + []types.Datum{types.NewIntDatum(1)}, + RawChecksum{Handle: kv.IntHandle(1)}, + nil, + ) + require.NoError(t, err) + + var r row + require.NoError(t, r.fromBytes(rowData)) + + datum := types.NewIntDatum(1) + _, err = r.CalculateRawChecksum(time.UTC, []int64{1}, []*types.Datum{&datum}, nil, nil, nil) + require.ErrorIs(t, err, errInvalidChecksumKey) + + rawChecksum, err := r.CalculateRawChecksum(time.UTC, []int64{1}, []*types.Datum{&datum}, nil, kv.IntHandle(1), nil) + require.NoError(t, err) + + expected := r.toBytes(nil) + expected = append(expected, r.checksumHeader) + require.Equal(t, crc32.Update(crc32.Checksum(expected, crc32.IEEETable), crc32.IEEETable, kv.IntHandle(1).Encoded()), rawChecksum) + + r.checksumHeader &^= checksumMaskVersion + r.checksumHeader |= checksumVersionRawKey + rawChecksum, err = r.CalculateRawChecksum(time.UTC, []int64{1}, []*types.Datum{&datum}, kv.Key("k"), nil, nil) + require.NoError(t, err) + + expected = r.toBytes(nil) + expected = append(expected, r.checksumHeader) + require.Equal(t, crc32.Update(crc32.Checksum(expected, crc32.IEEETable), crc32.IEEETable, kv.Key("k")), rawChecksum) + + r.checksumHeader &^= checksumMaskVersion + r.checksumHeader |= checksumVersionColumn + _, err = r.CalculateRawChecksum(time.UTC, []int64{1}, []*types.Datum{&datum}, nil, kv.IntHandle(1), nil) + require.ErrorIs(t, err, errInvalidChecksumVer) +} + +func TestDecodeToBytesNoHandleInto(t *testing.T) { + // Encode a row with multiple column types. + var encoder Encoder + colIDs := []int64{1, 2, 3} + datums := []types.Datum{ + types.NewIntDatum(42), + types.NewBytesDatum([]byte("test")), + types.NewUintDatum(100), + } + rowData, err := encoder.Encode(time.UTC, colIDs, datums, nil, nil) + require.NoError(t, err) + + uft := types.NewFieldType(mysql.TypeLonglong) + uft.SetFlag(uft.GetFlag() | mysql.UnsignedFlag) + columns := []ColInfo{ + {ID: 1, Ft: types.NewFieldType(mysql.TypeLonglong)}, + {ID: 2, Ft: types.NewFieldType(mysql.TypeVarchar)}, + {ID: 3, Ft: uft}, + } + outputOffset := map[int64]int{1: 0, 2: 1, 3: 2} + + dec := NewByteDecoder(columns, []int64{-1}, nil, nil) + + // Original path. + origValues, err := dec.DecodeToBytesNoHandle(outputOffset, rowData) + require.NoError(t, err) + + // Into path. + values := make([][]byte, len(columns)) + arena := make([]byte, 0, 256) + intoValues, newArena, err := dec.DecodeToBytesNoHandleInto(outputOffset, rowData, values, arena) + require.NoError(t, err) + + require.Equal(t, len(origValues), len(intoValues)) + for i := range origValues { + require.Equal(t, origValues[i], intoValues[i], "mismatch at index %d", i) + } + + // Arena should have grown. + require.Greater(t, len(newArena), 0) + + // Second call should reuse the same values slice. + datums2 := []types.Datum{ + types.NewIntDatum(100), + types.NewBytesDatum([]byte("another")), + types.NewUintDatum(200), + } + rowData2, err := encoder.Encode(time.UTC, colIDs, datums2, nil, nil) + require.NoError(t, err) + + origValues2, err := dec.DecodeToBytesNoHandle(outputOffset, rowData2) + require.NoError(t, err) + + arena = arena[:0] + intoValues2, _, err := dec.DecodeToBytesNoHandleInto(outputOffset, rowData2, values, arena) + require.NoError(t, err) + + for i := range origValues2 { + require.Equal(t, origValues2[i], intoValues2[i], "second call mismatch at index %d", i) + } +} diff --git a/pkg/util/rowcodec/encoder.go b/pkg/util/rowcodec/encoder.go index 1730e0b3223fd..fa99bad99e9ba 100644 --- a/pkg/util/rowcodec/encoder.go +++ b/pkg/util/rowcodec/encoder.go @@ -42,7 +42,7 @@ type Encoder struct { // `buf` is not truncated before encoding. // This function may return both a valid encoded bytes and an error (actually `"pingcap/errors".ErrorGroup`). If the caller // expects to handle these errors according to `SQL_MODE` or other configuration, please refer to `pkg/errctx`. -// the caller needs to ensure the key is not nil if checksum is required. +// If row-level raw checksum is required, the caller must provide a non-nil handle. func (encoder *Encoder) Encode(loc *time.Location, colIDs []int64, values []types.Datum, checksum Checksum, buf []byte) ([]byte, error) { encoder.reset() encoder.appendColVals(colIDs, values) @@ -245,12 +245,15 @@ const checksumVersionRawKey byte = 1 // introduced since v8.4.0 const checksumVersionRawHandle byte = 2 -// RawChecksum indicates encode the raw bytes checksum and append it to the raw bytes. +// RawChecksum encodes a handle-based raw checksum (checksum version 2) into the row bytes. type RawChecksum struct { Handle kv.Handle } func (c RawChecksum) encode(encoder *Encoder, buf []byte) ([]byte, error) { + if c.Handle == nil { + return nil, errInvalidChecksumKey + } encoder.flags |= rowFlagChecksum encoder.checksumHeader &^= checksumFlagExtra // revert extra checksum flag encoder.checksumHeader &^= checksumMaskVersion // revert checksum version diff --git a/pkg/util/rowcodec/row.go b/pkg/util/rowcodec/row.go index 9183bb8d0f936..1972d871cc4ab 100644 --- a/pkg/util/rowcodec/row.go +++ b/pkg/util/rowcodec/row.go @@ -58,18 +58,19 @@ const ( // // Checksum // -// 0 1 2 3 4 5 6 7 8 -// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ -// | |E| VER | CHECKSUM | EXTRA_CHECKSUM(OPTIONAL) | -// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ -// HEADER +// 0 1 2 3 4 5 6 7 8 +// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ +// | |E| VER | CHECKSUM | EXTRA_CHECKSUM(OPTIONAL) | +// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ +// HEADER // -// - HEADER -// - VER: version -// - E: has extra checksum -// - CHECKSUM -// - little-endian CRC32(IEEE) when hdr.ver = 0 (old version, columns-level checksum) -// - little-endian CRC32(IEEE) when hdr.ver = 1 (default, bytes-level checksum) +// - HEADER +// - VER: version +// - E: has extra checksum +// - CHECKSUM +// - little-endian CRC32(IEEE) when hdr.ver = 0 (legacy columns-level checksum; decode-only) +// - little-endian CRC32(IEEE) over raw row bytes and the row key when hdr.ver = 1 +// - little-endian CRC32(IEEE) over raw row bytes and the row handle when hdr.ver = 2 type row struct { flags byte checksumHeader byte @@ -151,7 +152,7 @@ func (r *row) fromBytes(rowData []byte) error { if r.hasChecksum() { r.checksumHeader = rowData[cursor] checksumVersion := r.ChecksumVersion() - // make sure it can be read previous version checksum to support backward compatibility. + // Keep backward compatibility when decoding rows written by older checksum versions. switch checksumVersion { case 0, 1, 2: default: @@ -302,7 +303,7 @@ func (r *row) initOffsets32() { } } -// CalculateRawChecksum calculates the bytes-level checksum by using the given elements. +// CalculateRawChecksum calculates the raw bytes checksum by using the given elements. // this is mainly used by the TiCDC to implement E2E checksum functionality. func (r *row) CalculateRawChecksum( loc *time.Location, colIDs []int64, values []*types.Datum, key kv.Key, handle kv.Handle, buf []byte, @@ -323,11 +324,20 @@ func (r *row) CalculateRawChecksum( buf = r.toBytes(buf) buf = append(buf, r.checksumHeader) rawChecksum := crc32.Checksum(buf, crc32.IEEETable) - // keep backward compatibility to v8.3.0 - if r.ChecksumVersion() == int(checksumVersionRawKey) { + switch r.ChecksumVersion() { + // Keep backward compatibility to v8.3.0. + case int(checksumVersionRawKey): + if key == nil { + return 0, errInvalidChecksumKey + } rawChecksum = crc32.Update(rawChecksum, crc32.IEEETable, key) - } else { + case int(checksumVersionRawHandle): + if handle == nil { + return 0, errInvalidChecksumKey + } rawChecksum = crc32.Update(rawChecksum, crc32.IEEETable, handle.Encoded()) + default: + return 0, errInvalidChecksumVer } return rawChecksum, nil }