diff --git a/pkg/expression/schema.go b/pkg/expression/schema.go index 3e2e737024e88..6864672337a0a 100644 --- a/pkg/expression/schema.go +++ b/pkg/expression/schema.go @@ -83,6 +83,22 @@ func (s *Schema) Clone() *Schema { return schema } +// Equal checks if two schemas are equal. +func (s *Schema) Equal(other *Schema) bool { + if s == nil || other == nil { + return s == other + } + if len(s.Columns) != len(other.Columns) { + return false + } + for i, col := range s.Columns { + if !col.EqualColumn(other.Columns[i]) { + return false + } + } + return true +} + // ExprReferenceSchema checks if any column of this expression are from the schema. func ExprReferenceSchema(expr Expression, schema *Schema) bool { switch v := expr.(type) { diff --git a/pkg/planner/core/BUILD.bazel b/pkg/planner/core/BUILD.bazel index c8deedacbdec9..7e688e75d2a98 100644 --- a/pkg/planner/core/BUILD.bazel +++ b/pkg/planner/core/BUILD.bazel @@ -124,6 +124,7 @@ go_library( "//pkg/planner/cascades/base", "//pkg/planner/core/base", "//pkg/planner/core/cost", + "//pkg/planner/core/joinorder", "//pkg/planner/core/metrics", "//pkg/planner/core/operator/baseimpl", "//pkg/planner/core/operator/logicalop", diff --git a/pkg/planner/core/joinorder/BUILD.bazel b/pkg/planner/core/joinorder/BUILD.bazel new file mode 100644 index 0000000000000..af9d0ae09f67f --- /dev/null +++ b/pkg/planner/core/joinorder/BUILD.bazel @@ -0,0 +1,16 @@ +load("@io_bazel_rules_go//go:def.bzl", "go_library") + +go_library( + name = "joinorder", + srcs = ["util.go"], + importpath = "github.com/pingcap/tidb/pkg/planner/core/joinorder", + visibility = ["//visibility:public"], + deps = [ + "//pkg/parser/ast", + "//pkg/planner/core/base", + "//pkg/planner/core/operator/logicalop", + "//pkg/planner/util", + "//pkg/util/hint", + "//pkg/util/intest", + ], +) diff --git a/pkg/planner/core/joinorder/util.go b/pkg/planner/core/joinorder/util.go new file mode 100644 index 0000000000000..a39dae31f36c8 --- /dev/null +++ b/pkg/planner/core/joinorder/util.go @@ -0,0 +1,333 @@ +// Copyright 2026 PingCAP, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package joinorder + +import ( + "strconv" + "strings" + + "github.com/pingcap/tidb/pkg/parser/ast" + "github.com/pingcap/tidb/pkg/planner/core/base" + "github.com/pingcap/tidb/pkg/planner/core/operator/logicalop" + "github.com/pingcap/tidb/pkg/planner/util" + "github.com/pingcap/tidb/pkg/util/hint" + "github.com/pingcap/tidb/pkg/util/intest" +) + +// JoinMethodHint records the join method hint for a vertex. +type JoinMethodHint struct { + PreferJoinMethod uint + HintInfo *hint.PlanHints +} + +// CheckAndGenerateLeadingHint used to check and generate the valid leading hint. +// We are allowed to use at most one leading hint in a join group. When more than one, +// all leading hints in the current join group will be invalid. +// For example: select /*+ leading(t3) */ * from (select /*+ leading(t1) */ t2.b from t1 join t2 on t1.a=t2.a) t4 join t3 on t4.b=t3.b +// The Join Group {t1, t2, t3} contains two leading hints includes leading(t3) and leading(t1). +// Although they are in different query blocks, they are conflicting. +// In addition, the table alias 't4' cannot be recognized because of the join group. +func CheckAndGenerateLeadingHint(hintInfo []*hint.PlanHints) (*hint.PlanHints, bool) { + leadingHintNum := len(hintInfo) + var leadingHintInfo *hint.PlanHints + hasDiffLeadingHint := false + if leadingHintNum > 0 { + leadingHintInfo = hintInfo[0] + // One join group has one leading hint at most. Check whether there are different join order hints. + for i := 1; i < leadingHintNum; i++ { + if hintInfo[i] != hintInfo[i-1] { + hasDiffLeadingHint = true + break + } + } + if hasDiffLeadingHint { + leadingHintInfo = nil + } + } + return leadingHintInfo, hasDiffLeadingHint +} + +// LeadingTreeFinder finds a node by hint and removes it from the available slice. +type LeadingTreeFinder[T any] func(available []T, hint *ast.HintTable) (T, []T, bool) + +// LeadingTreeJoiner joins two nodes in the leading tree. +type LeadingTreeJoiner[T any] func(left, right T) (T, bool, error) + +// BuildLeadingTreeFromList recursively constructs a LEADING join order tree. +// the `leadingList` argument is derived from a LEADING hint in SQL, e.g.: +// +// /*+ LEADING(t1, (t2, t3), (t4, (t5, t6, t7))) */ +// +// and it is parsed into a nested structure of *ast.LeadingList and *ast.HintTable: +// leadingList.Items = [ +// +// *ast.HintTable{name: "t1"}, +// *ast.LeadingList{ // corresponds to (t2, t3) +// Items: [ +// *ast.HintTable{name: "t2"}, +// *ast.HintTable{name: "t3"}, +// ], +// }, +// *ast.LeadingList{ // corresponds to (t4, (t5, t6, t7)) +// Items: [ +// *ast.HintTable{name: "t4"}, +// *ast.LeadingList{ +// Items: [ +// *ast.HintTable{name: "t5"}, +// *ast.HintTable{name: "t6"}, +// *ast.HintTable{name: "t7"}, +// ], +// }, +// ], +// }, +// +// ] +func BuildLeadingTreeFromList[T any]( + leadingList *ast.LeadingList, + availableGroups []T, + findAndRemoveByHint LeadingTreeFinder[T], + checkAndJoin LeadingTreeJoiner[T], + warn func(), +) (T, []T, bool, error) { + var zero T + if leadingList == nil || len(leadingList.Items) == 0 { + return zero, availableGroups, false, nil + } + + var ( + currentJoin T + err error + ok bool + remainingGroups = availableGroups + ) + + for i, item := range leadingList.Items { + switch element := item.(type) { + case *ast.HintTable: + var tableNode T + tableNode, remainingGroups, ok = findAndRemoveByHint(remainingGroups, element) + if !ok { + return zero, availableGroups, false, nil + } + + if i == 0 { + currentJoin = tableNode + } else { + currentJoin, ok, err = checkAndJoin(currentJoin, tableNode) + if err != nil { + return zero, availableGroups, false, err + } + if !ok { + return zero, availableGroups, false, nil + } + } + case *ast.LeadingList: + var nestedJoin T + nestedJoin, remainingGroups, ok, err = BuildLeadingTreeFromList(element, remainingGroups, findAndRemoveByHint, checkAndJoin, warn) + if err != nil { + return zero, availableGroups, false, err + } + if !ok { + return zero, availableGroups, false, nil + } + + if i == 0 { + currentJoin = nestedJoin + } else { + currentJoin, ok, err = checkAndJoin(currentJoin, nestedJoin) + if err != nil { + return zero, availableGroups, false, err + } + if !ok { + return zero, availableGroups, false, nil + } + } + default: + if warn != nil { + warn() + } + return zero, availableGroups, false, nil + } + } + + return currentJoin, remainingGroups, true, nil +} + +// FindAndRemovePlanByAstHint find the plan in `plans` that matches `ast.HintTable` and remove that plan, returning the new slice. +// Matching rules: +// 1. Match by regular table name (db/table/*) +// 2. Match by query-block alias (subquery name, e.g., tx) +// 3. If multiple join groups belong to the same block alias, mark as ambiguous and skip (consistent with old logic) +// +// NOTE: T is usually be *Node or base.LogicalPlan, we use generics because we want to reuse this function in both the old and new join order code. +func FindAndRemovePlanByAstHint[T any]( + ctx base.PlanContext, + plans []T, + astTbl *ast.HintTable, + getPlan func(T) base.LogicalPlan, +) (T, []T, bool) { + var zero T + var queryBlockNames []ast.HintTable + if p := ctx.GetSessionVars().PlannerSelectBlockAsName.Load(); p != nil { + queryBlockNames = *p + } + + // Step 1: Direct match by table name + for i, joinGroup := range plans { + plan := getPlan(joinGroup) + tableAlias := util.ExtractTableAlias(plan, plan.QueryBlockOffset()) + if tableAlias != nil { + // Match db/table (supports astTbl.DBName == "*") + dbMatch := astTbl.DBName.L == "" || astTbl.DBName.L == tableAlias.DBName.L || astTbl.DBName.L == "*" + tableMatch := astTbl.TableName.L == tableAlias.TblName.L + + // Match query block names + // Use SelectOffset to match query blocks + qbMatch := true + if astTbl.QBName.L != "" { + expectedOffset := extractSelectOffset(astTbl.QBName.L) + if expectedOffset > 0 { + qbMatch = tableAlias.SelectOffset == expectedOffset + } else { + // If QBName cannot be parsed, ignore the QB match. + qbMatch = true + } + } + if dbMatch && tableMatch && qbMatch { + newPlans := append(plans[:i], plans[i+1:]...) + return joinGroup, newPlans, true + } + } + } + + // Step 2: Match by query-block alias (subquery name) + // Only execute this step if no direct table name match was found + matchIdx := -1 + for i, joinGroup := range plans { + plan := getPlan(joinGroup) + blockOffset := plan.QueryBlockOffset() + if blockOffset > 1 && blockOffset < len(queryBlockNames) { + blockName := queryBlockNames[blockOffset] + dbMatch := astTbl.DBName.L == "" || astTbl.DBName.L == blockName.DBName.L + tableMatch := astTbl.TableName.L == blockName.TableName.L + if dbMatch && tableMatch { + if matchIdx != -1 { + intest.Assert(false, "leading subquery alias matches multiple join groups") + return zero, plans, false + } + matchIdx = i + } + } + } + if matchIdx != -1 { + // take the matched plan before slice manipulation. `append(plans[:matchIdx], ...)` + // may overwrite `plans[matchIdx]` due to shared backing arrays. + matched := plans[matchIdx] + newPlans := append(plans[:matchIdx], plans[matchIdx+1:]...) + return matched, newPlans, true + } + + return zero, plans, false +} + +// extract the number x from 'sel_x' +func extractSelectOffset(qbName string) int { + if strings.HasPrefix(qbName, "sel_") { + if offset, err := strconv.Atoi(qbName[4:]); err == nil { + return offset + } + } + return -1 +} + +// IsDerivedTableInLeadingHint checks if a plan node represents a derived table (subquery) +// that is explicitly referenced in the LEADING hint. +func IsDerivedTableInLeadingHint(p base.LogicalPlan, leadingHint *hint.PlanHints) bool { + if leadingHint == nil || leadingHint.LeadingList == nil { + return false + } + + // Get the query block names mapping to find derived table aliases + var queryBlockNames []ast.HintTable + names := p.SCtx().GetSessionVars().PlannerSelectBlockAsName.Load() + if names == nil { + return false + } + queryBlockNames = *names + + // Get the block offset of this plan node + blockOffset := p.QueryBlockOffset() + + // Only blockOffset values in [2, len(queryBlockNames)-1] can represent + // subqueries / derived tables. Offsets 0 and 1 are typically main query + // or CTE, and offsets beyond the end of queryBlockNames are invalid. + if blockOffset <= 1 || blockOffset >= len(queryBlockNames) { + return false + } + + // Get the alias name of this derived table + derivedTableAlias := queryBlockNames[blockOffset].TableName.L + if derivedTableAlias == "" { + return false + } + derivedDBName := queryBlockNames[blockOffset].DBName.L + + // Check if this alias appears in the LEADING hint + return containsTableInLeadingList(leadingHint.LeadingList, derivedDBName, derivedTableAlias) +} + +// containsTableInLeadingList recursively searches for a table name in the LEADING hint structure +func containsTableInLeadingList(leadingList *ast.LeadingList, dbName, tableName string) bool { + if leadingList == nil { + return false + } + + for _, item := range leadingList.Items { + switch element := item.(type) { + case *ast.HintTable: + // Direct table reference in LEADING hint + dbMatch := element.DBName.L == "" || element.DBName.L == dbName || element.DBName.L == "*" + tableMatch := element.TableName.L == tableName + if dbMatch && tableMatch { + return true + } + case *ast.LeadingList: + // Nested structure, recursively check + if containsTableInLeadingList(element, dbName, tableName) { + return true + } + } + } + + return false +} + +// SetNewJoinWithHint sets the join method hint for the join node. +func SetNewJoinWithHint(newJoin *logicalop.LogicalJoin, vertexHints map[int]*JoinMethodHint) { + if newJoin == nil { + return + } + lChild := newJoin.Children()[0] + rChild := newJoin.Children()[1] + if joinMethodHint, ok := vertexHints[lChild.ID()]; ok { + newJoin.LeftPreferJoinType = joinMethodHint.PreferJoinMethod + newJoin.HintInfo = joinMethodHint.HintInfo + } + if joinMethodHint, ok := vertexHints[rChild.ID()]; ok { + newJoin.RightPreferJoinType = joinMethodHint.PreferJoinMethod + newJoin.HintInfo = joinMethodHint.HintInfo + } + newJoin.SetPreferredJoinType() +} diff --git a/pkg/planner/core/operator/logicalop/logical_projection.go b/pkg/planner/core/operator/logicalop/logical_projection.go index f2f930f2df7fe..10d418b6748df 100644 --- a/pkg/planner/core/operator/logicalop/logical_projection.go +++ b/pkg/planner/core/operator/logicalop/logical_projection.go @@ -673,3 +673,14 @@ func canProjectionBeEliminatedLoose(p *LogicalProjection) bool { } return true } + +// InjectExpr injects the expr into a projection above p, and returns the new projection and the new column. +func InjectExpr(p base.LogicalPlan, expr expression.Expression) (base.LogicalPlan, *expression.Column) { + proj, ok := p.(*LogicalProjection) + if !ok { + proj = LogicalProjection{Exprs: expression.Column2Exprs(p.Schema().Columns)}.Init(p.SCtx(), p.QueryBlockOffset()) + proj.SetSchema(p.Schema().Clone()) + proj.SetChildren(p) + } + return proj, proj.AppendExpr(expr) +} diff --git a/pkg/planner/core/plan_cost_ver2.go b/pkg/planner/core/plan_cost_ver2.go index cc8e8552951bf..37c54ba7133e0 100644 --- a/pkg/planner/core/plan_cost_ver2.go +++ b/pkg/planner/core/plan_cost_ver2.go @@ -578,8 +578,8 @@ func (p *PhysicalMergeJoin) GetPlanCostVer2(taskType property.TaskType, option * filterCost := costusage.SumCostVer2(filterCostVer2(option, leftRows, p.LeftConditions, cpuFactor), filterCostVer2(option, rightRows, p.RightConditions, cpuFactor), filterCostVer2(option, leftRows+rightRows, p.OtherConditions, cpuFactor)) // OtherConditions are applied to both sides - groupCost := costusage.SumCostVer2(groupCostVer2(option, leftRows, cols2Exprs(p.LeftJoinKeys), cpuFactor), - groupCostVer2(option, rightRows, cols2Exprs(p.RightJoinKeys), cpuFactor)) + groupCost := costusage.SumCostVer2(groupCostVer2(option, leftRows, expression.Column2Exprs(p.LeftJoinKeys), cpuFactor), + groupCostVer2(option, rightRows, expression.Column2Exprs(p.RightJoinKeys), cpuFactor)) leftChildCost, err := p.Children()[0].GetPlanCostVer2(taskType, option) if err != nil { @@ -1206,11 +1206,3 @@ func getTableInfo(p base.PhysicalPlan) *model.TableInfo { return getTableInfo(x.Children()[0]) } } - -func cols2Exprs(cols []*expression.Column) []expression.Expression { - exprs := make([]expression.Expression, 0, len(cols)) - for _, c := range cols { - exprs = append(exprs, c) - } - return exprs -} diff --git a/pkg/planner/core/rule_join_reorder.go b/pkg/planner/core/rule_join_reorder.go index 64ac4954ac053..3c5e8a4204001 100644 --- a/pkg/planner/core/rule_join_reorder.go +++ b/pkg/planner/core/rule_join_reorder.go @@ -23,6 +23,7 @@ import ( "github.com/pingcap/tidb/pkg/expression" "github.com/pingcap/tidb/pkg/parser/ast" "github.com/pingcap/tidb/pkg/planner/core/base" + "github.com/pingcap/tidb/pkg/planner/core/joinorder" "github.com/pingcap/tidb/pkg/planner/core/operator/logicalop" "github.com/pingcap/tidb/pkg/planner/util" "github.com/pingcap/tidb/pkg/planner/util/optimizetrace" @@ -49,7 +50,7 @@ func extractJoinGroup(p base.LogicalPlan) *joinGroupResult { // The derived output columns are tracked in colExprMap and substituted back into // join predicates before the join graph is built. func extractJoinGroupImpl(p base.LogicalPlan) *joinGroupResult { - joinMethodHintInfo := make(map[int]*joinMethodHint) + joinMethodHintInfo := make(map[int]*joinorder.JoinMethodHint) var ( group []base.LogicalPlan joinOrderHintInfo []*h.PlanHints @@ -140,11 +141,11 @@ func extractJoinGroupImpl(p base.LogicalPlan) *joinGroupResult { if isJoin && p.SCtx().GetSessionVars().EnableAdvancedJoinHint && join.PreferJoinType > uint(0) { // If the current join node has the join method hint, we should store the hint information and restore it when we have finished the join reorder process. if join.LeftPreferJoinType > uint(0) { - joinMethodHintInfo[join.Children()[0].ID()] = &joinMethodHint{join.LeftPreferJoinType, join.HintInfo} + joinMethodHintInfo[join.Children()[0].ID()] = &joinorder.JoinMethodHint{PreferJoinMethod: join.LeftPreferJoinType, HintInfo: join.HintInfo} leftHasHint = true } if join.RightPreferJoinType > uint(0) { - joinMethodHintInfo[join.Children()[1].ID()] = &joinMethodHint{join.RightPreferJoinType, join.HintInfo} + joinMethodHintInfo[join.Children()[1].ID()] = &joinorder.JoinMethodHint{PreferJoinMethod: join.RightPreferJoinType, HintInfo: join.HintInfo} rightHasHint = true } } @@ -338,7 +339,7 @@ func (s *JoinReOrderSolver) optimizeRecursive(ctx base.PlanContext, p base.Logic joinGroupNum := len(curJoinGroup) useGreedy := joinGroupNum > ctx.GetSessionVars().TiDBOptJoinReorderThreshold || !isSupportDP - leadingHintInfo, hasDiffLeadingHint := checkAndGenerateLeadingHint(joinOrderHintInfo) + leadingHintInfo, hasDiffLeadingHint := joinorder.CheckAndGenerateLeadingHint(joinOrderHintInfo) if hasDiffLeadingHint { ctx.GetSessionVars().StmtCtx.SetHintWarning( "We can only use one leading hint at most, when multiple leading hints are used, all leading hints will be invalid") @@ -472,11 +473,6 @@ func checkAndGenerateLeadingHint(hintInfo []*h.PlanHints) (*h.PlanHints, bool) { return leadingHintInfo, hasDiffLeadingHint } -type joinMethodHint struct { - preferredJoinMethod uint - joinMethodHintInfo *h.PlanHints -} - // basicJoinGroupInfo represents basic information for a join group in the join reorder process. type basicJoinGroupInfo struct { eqEdges []*expression.ScalarFunction @@ -488,7 +484,7 @@ type basicJoinGroupInfo struct { // `joinMethodHintInfo` is used to map the sub-plan's ID to the join method hint. // The sub-plan will join the join reorder process to build the new plan. // So after we have finished the join reorder process, we can reset the join method hint based on the sub-plan's ID. - joinMethodHintInfo map[int]*joinMethodHint + joinMethodHintInfo map[int]*joinorder.JoinMethodHint } type joinGroupResult struct { @@ -821,7 +817,7 @@ func (s *baseSingleGroupJoinOrderSolver) newCartesianJoin(lChild, rChild base.Lo }.Init(s.ctx, offset) join.SetSchema(expression.MergeSchema(lChild.Schema(), rChild.Schema())) join.SetChildren(lChild, rChild) - s.setNewJoinWithHint(join) + joinorder.SetNewJoinWithHint(join, s.joinMethodHintInfo) return join } @@ -848,23 +844,6 @@ func (s *baseSingleGroupJoinOrderSolver) newJoinWithEdges(lChild, rChild base.Lo return newJoin } -// setNewJoinWithHint sets the join method hint for the join node. -// Before the join reorder process, we split the join node and collect the join method hint. -// And we record the join method hint and reset the hint after we have finished the join reorder process. -func (s *baseSingleGroupJoinOrderSolver) setNewJoinWithHint(newJoin *logicalop.LogicalJoin) { - lChild := newJoin.Children()[0] - rChild := newJoin.Children()[1] - if joinMethodHint, ok := s.joinMethodHintInfo[lChild.ID()]; ok { - newJoin.LeftPreferJoinType = joinMethodHint.preferredJoinMethod - newJoin.HintInfo = joinMethodHint.joinMethodHintInfo - } - if joinMethodHint, ok := s.joinMethodHintInfo[rChild.ID()]; ok { - newJoin.RightPreferJoinType = joinMethodHint.preferredJoinMethod - newJoin.HintInfo = joinMethodHint.joinMethodHintInfo - } - newJoin.SetPreferredJoinType() -} - // calcJoinCumCost calculates the cumulative cost of the join node. func (*baseSingleGroupJoinOrderSolver) calcJoinCumCost(join base.LogicalPlan, lNode, rNode *jrNode) float64 { return join.StatsInfo().RowCount + lNode.cumCost + rNode.cumCost