diff --git a/pkg/planner/core/logical_plan_builder.go b/pkg/planner/core/logical_plan_builder.go index a59915cad5582..92fb175a7b670 100644 --- a/pkg/planner/core/logical_plan_builder.go +++ b/pkg/planner/core/logical_plan_builder.go @@ -495,8 +495,10 @@ func (b *PlanBuilder) buildResultSetNode(ctx context.Context, node ast.ResultSet return nil, err } - // Clone output names before modifying to avoid mutating shared structs if x.AsName.L != "" { + // Clone output names before modifying to avoid mutating shared structs. + // This is critical for CTEs whose output names are shared across multiple + // references — in-place mutation would corrupt other consumers. clonedNames := make([]*types.FieldName, len(p.OutputNames())) for i, name := range p.OutputNames() { if name.Hidden { @@ -674,7 +676,12 @@ func findJoinFullSchema(p base.LogicalPlan) (*expression.Schema, types.NameSlice func containsLateralTableSource(node ast.ResultSetNode) bool { switch n := node.(type) { case *ast.TableSource: - return n.Lateral + if n.Lateral { + return true + } + // Descend into the inner source (derived table / set-op) so nested + // LATERAL inside a subquery or set-op used as a table source is detected. + return containsLateralTableSource(n.Source) case *ast.Join: // For parenthesized single table refs, the parser creates Join{Left: TableSource, Right: nil} if n.Right == nil { @@ -682,6 +689,22 @@ func containsLateralTableSource(node ast.ResultSetNode) bool { } // Check both sides for nested LATERAL return containsLateralTableSource(n.Left) || containsLateralTableSource(n.Right) + case *ast.SelectStmt: + // Descend into the FROM clause of a derived subquery. + if n.From != nil { + return containsLateralTableSource(n.From.TableRefs) + } + return false + case *ast.SetOprStmt: + // Check each operand in the UNION/INTERSECT/EXCEPT list. + if n.SelectList != nil { + for _, sel := range n.SelectList.Selects { + if rs, ok := sel.(ast.ResultSetNode); ok && containsLateralTableSource(rs) { + return true + } + } + } + return false default: return false } @@ -982,7 +1005,6 @@ func (b *PlanBuilder) buildLateralJoin(ctx context.Context, leftPlan, rightPlan ap.SetChildren(leftPlan, rightPlan) ap.SetSchema(expression.MergeSchema(leftPlan.Schema(), rightPlan.Schema())) - setIsInApplyForCTE(rightPlan, ap.Schema()) // Note: nullability adjustment is not needed for InnerJoin (the only type supported currently). // When LEFT/RIGHT JOIN support is added, ResetNotNullFlag must be called here. @@ -1040,6 +1062,10 @@ func (b *PlanBuilder) buildLateralJoin(ctx context.Context, leftPlan, rightPlan ap.FullNames = append(ap.FullNames, &name) } + // Mark inner CTEs against FullSchema so correlations via USING/NATURAL + // merged columns are detected and the CTE storage is reset per outer row. + setIsInApplyForCTE(rightPlan, ap.FullSchema) + // Handle ON conditions if present if joinNode.On != nil { b.curClause = onClause diff --git a/pkg/planner/core/operator/logicalop/logical_cte.go b/pkg/planner/core/operator/logicalop/logical_cte.go index 0e73911e2ab7d..0eed46336b99e 100644 --- a/pkg/planner/core/operator/logicalop/logical_cte.go +++ b/pkg/planner/core/operator/logicalop/logical_cte.go @@ -218,8 +218,8 @@ func (p *LogicalCTE) DeriveStats(_ []*property.StatsInfo, selfSchema *expression vars := p.SCtx().GetSessionVars() savedParallelApply := vars.EnableParallelApply vars.EnableParallelApply = false + defer func() { vars.EnableParallelApply = savedParallelApply }() _, p.Cte.RecursivePartPhysicalPlan, _, err = utilfuncp.DoOptimize(context.TODO(), p.SCtx(), p.Cte.OptFlag, p.Cte.RecursivePartLogicalPlan) - vars.EnableParallelApply = savedParallelApply if err != nil { return nil, false, err } diff --git a/pkg/planner/core/rule_decorrelate.go b/pkg/planner/core/rule_decorrelate.go index c2e1fae7ec0bd..9d4fa95c27029 100644 --- a/pkg/planner/core/rule_decorrelate.go +++ b/pkg/planner/core/rule_decorrelate.go @@ -237,8 +237,10 @@ func (s *DecorrelateSolver) optimize(ctx context.Context, p base.LogicalPlan, gr // to find the underlying LogicalJoin, matching the schema used for name // resolution in LATERAL subqueries (see logical_plan_builder.go buildJoin). outerSchema := outerPlan.Schema() - if fullSchema, _ := findJoinFullSchema(outerPlan); fullSchema != nil { - outerSchema = fullSchema + if apply.IsLateral { + if fullSchema, _ := findJoinFullSchema(outerPlan); fullSchema != nil { + outerSchema = fullSchema + } } apply.CorCols = coreusage.ExtractCorColumnsBySchema4LogicalPlan(innerPlan, outerSchema) if len(apply.CorCols) == 0 { @@ -258,7 +260,7 @@ func (s *DecorrelateSolver) optimize(ctx context.Context, p base.LogicalPlan, gr // Notice that no matter what kind of join is, it's always right. newConds := make([]expression.Expression, 0, len(sel.Conditions)) for _, cond := range sel.Conditions { - newConds = append(newConds, cond.Decorrelate(outerPlan.Schema())) + newConds = append(newConds, cond.Decorrelate(outerSchema)) } apply.AttachOnConds(newConds) innerPlan = sel.Children()[0] @@ -296,9 +298,9 @@ func (s *DecorrelateSolver) optimize(ctx context.Context, p base.LogicalPlan, gr } // step2: when it can be substituted all, we then just do the de-correlation (apply conditions included). for i, expr := range proj.Exprs { - proj.Exprs[i] = expr.Decorrelate(outerPlan.Schema()) + proj.Exprs[i] = expr.Decorrelate(outerSchema) } - apply.Decorrelate(outerPlan.Schema()) + apply.Decorrelate(outerSchema) innerPlan = proj.Children()[0] apply.SetChildren(outerPlan, innerPlan)