diff --git a/pkg/planner/core/operator/logicalop/logical_aggregation.go b/pkg/planner/core/operator/logicalop/logical_aggregation.go index 310588601b20c..a016db449ab58 100644 --- a/pkg/planner/core/operator/logicalop/logical_aggregation.go +++ b/pkg/planner/core/operator/logicalop/logical_aggregation.go @@ -396,7 +396,7 @@ func (la *LogicalAggregation) ExtractFD() *fd.FDSet { determinants.Insert(int(one.UniqueID)) groupByColsOutputCols.Insert(int(one.UniqueID)) } - notnull := util.IsNullRejected(la.SCtx(), la.Schema(), x, true) + notnull := util.IsNullRejected(la.SCtx(), la.Schema(), x) if notnull || determinants.SubsetOf(fds.NotNullCols) { notnullColsUniqueIDs.Insert(scalarUniqueID) } diff --git a/pkg/planner/core/operator/logicalop/logical_join.go b/pkg/planner/core/operator/logicalop/logical_join.go index cd38a0fda18f3..2dd33439093dd 100644 --- a/pkg/planner/core/operator/logicalop/logical_join.go +++ b/pkg/planner/core/operator/logicalop/logical_join.go @@ -319,7 +319,7 @@ func simplifyOuterJoin(p *LogicalJoin, predicates []expression.Expression) { if expression.ExprFromSchema(expr, outerTable.Schema()) { continue } - isOk := util.IsNullRejected(p.SCtx(), innerTable.Schema(), expr, true) + isOk := util.IsNullRejected(p.SCtx(), innerTable.Schema(), expr) if isOk { canBeSimplified = true break @@ -727,7 +727,7 @@ func (p *LogicalJoin) ConvertOuterToInnerJoin(predicates []expression.Expression if p.JoinType == base.LeftOuterJoin || p.JoinType == base.RightOuterJoin { canBeSimplified := false for _, expr := range predicates { - isOk := util.IsNullRejected(p.SCtx(), innerTable.Schema(), expr, true) + isOk := util.IsNullRejected(p.SCtx(), innerTable.Schema(), expr) if isOk { canBeSimplified = true break @@ -1471,13 +1471,13 @@ func (p *LogicalJoin) ExtractOnCondition( } if leftCol != nil && rightCol != nil { if deriveLeft { - if util.IsNullRejected(ctx, leftSchema, expr, true) && !mysql.HasNotNullFlag(leftCol.RetType.GetFlag()) { + if util.IsNullRejected(ctx, leftSchema, expr) && !mysql.HasNotNullFlag(leftCol.RetType.GetFlag()) { notNullExpr := expression.BuildNotNullExpr(ctx.GetExprCtx(), leftCol) leftCond = append(leftCond, notNullExpr) } } if deriveRight { - if util.IsNullRejected(ctx, rightSchema, expr, true) && !mysql.HasNotNullFlag(rightCol.RetType.GetFlag()) { + if util.IsNullRejected(ctx, rightSchema, expr) && !mysql.HasNotNullFlag(rightCol.RetType.GetFlag()) { notNullExpr := expression.BuildNotNullExpr(ctx.GetExprCtx(), rightCol) rightCond = append(rightCond, notNullExpr) } @@ -2219,7 +2219,7 @@ func deriveNotNullExpr(ctx base.PlanContext, expr expression.Expression, schema if childCol == nil { childCol = schema.RetrieveColumn(arg1) } - if util.IsNullRejected(ctx, schema, expr, true) && !mysql.HasNotNullFlag(childCol.RetType.GetFlag()) { + if util.IsNullRejected(ctx, schema, expr) && !mysql.HasNotNullFlag(childCol.RetType.GetFlag()) { return expression.BuildNotNullExpr(ctx.GetExprCtx(), childCol) } return nil diff --git a/pkg/planner/core/operator/logicalop/logical_projection.go b/pkg/planner/core/operator/logicalop/logical_projection.go index 8ca9a7ced1e82..5b50002d05ad9 100644 --- a/pkg/planner/core/operator/logicalop/logical_projection.go +++ b/pkg/planner/core/operator/logicalop/logical_projection.go @@ -459,7 +459,7 @@ func (p *LogicalProjection) ExtractFD() *fd.FDSet { // the dependent columns in scalar function should be also considered as output columns as well. outputColsUniqueIDs.Insert(int(one.UniqueID)) } - notnull := util.IsNullRejected(p.SCtx(), p.Schema(), x, true) + notnull := util.IsNullRejected(p.SCtx(), p.Schema(), x) if notnull || determinants.SubsetOf(fds.NotNullCols) { notnullColsUniqueIDs.Insert(scalarUniqueID) } diff --git a/pkg/planner/util/funcdep_misc.go b/pkg/planner/util/funcdep_misc.go index fa724284f651f..7aec43d199bcc 100644 --- a/pkg/planner/util/funcdep_misc.go +++ b/pkg/planner/util/funcdep_misc.go @@ -41,7 +41,7 @@ func ExtractNotNullFromConds(conditions []expression.Expression, p base.LogicalP if len(cols) == 0 { continue } - if IsNullRejected(p.SCtx(), p.Schema(), condition, false) { + if IsNullRejected(p.SCtx(), p.Schema(), condition) { for _, col := range cols { notnullColsUniqueIDs.Insert(int(col.UniqueID)) } diff --git a/pkg/planner/util/null_misc.go b/pkg/planner/util/null_misc.go index 0f492d6401035..6735ded890b7a 100644 --- a/pkg/planner/util/null_misc.go +++ b/pkg/planner/util/null_misc.go @@ -62,7 +62,9 @@ import ( // classify that exact value. This recovers cases such as COALESCE/IF/IFNULL // that may hide NULL but still collapse after nullification. The bridge stays // conservative for plan-cache-sensitive expressions by refusing to treat -// ParamMarker/DeferredExpr values as static fold results. +// ParamMarker/DeferredExpr values as static fold results. DeferredExpr can +// still be inspected symbolically, but its runtime value must not be folded or +// classified as a compile-time constant. // nullRejectProof holds the two proof results for a sub-expression. // See the file-level comment above for the full model. @@ -71,7 +73,8 @@ type nullRejectProof struct { mustNull bool } -// allConstants checks whether the expression tree consists entirely of constants. +// allConstants checks whether the expression tree can be attempted as a static +// constant tree without lazy constants. func allConstants(ctx expression.BuildContext, expr expression.Expression) bool { if expression.MaybeOverOptimized4PlanCache(ctx, expr) { return false @@ -85,18 +88,16 @@ func allConstants(ctx expression.BuildContext, expr expression.Expression) bool } return true case *expression.Constant: - return true + return v.ParamMarker == nil && v.DeferredExpr == nil } return false } // IsNullRejected proves whether `predicate` can be TRUE after every column in // `innerSchema` is replaced with SQL NULL. -func IsNullRejected(ctx base.PlanContext, innerSchema *expression.Schema, predicate expression.Expression, - skipPlanCacheCheck bool) bool { - _ = skipPlanCacheCheck // kept for API compatibility; the new proof does not use EvaluateExprWithNull +func IsNullRejected(ctx base.PlanContext, innerSchema *expression.Schema, predicate expression.Expression) bool { predicate = expression.PushDownNot(ctx.GetNullRejectCheckExprCtx(), predicate) - return proveNullRejected(ctx, innerSchema, predicate).nonTrue + return proveNullRejected(ctx, innerSchema, predicate, true).nonTrue } // proveNullRejected recursively proves the two proof bits for one expression. @@ -119,13 +120,20 @@ func IsNullRejected(ctx base.PlanContext, innerSchema *expression.Schema, predic // 2 > 2 // // so the predicate is nonTrue. +// +// allowNullifiedFold is false when proving a Constant.DeferredExpr. In that +// mode the proof remains purely symbolic so execution-time dependent values are +// not folded during optimization. func proveNullRejected( ctx base.PlanContext, innerSchema *expression.Schema, expr expression.Expression, + allowNullifiedFold bool, ) nullRejectProof { - if cons, ok := tryFoldNullifiedConstant(ctx, innerSchema, expr); ok { - return proofFromConstant(ctx, cons) + if allowNullifiedFold { + if cons, ok := tryFoldNullifiedConstant(ctx, innerSchema, expr); ok { + return proofFromConstant(ctx, cons) + } } switch x := expr.(type) { @@ -141,9 +149,12 @@ func proveNullRejected( return nullRejectProof{nonTrue: true, mustNull: true} } case *expression.Constant: + if x.ParamMarker == nil && x.DeferredExpr != nil { + return proveNullRejected(ctx, innerSchema, x.DeferredExpr, false) + } return proofFromConstant(ctx, x) case *expression.ScalarFunction: - return proveNullRejectedScalarFunc(ctx, innerSchema, x) + return proveNullRejectedScalarFunc(ctx, innerSchema, x, allowNullifiedFold) } return nullRejectProof{} } @@ -162,18 +173,19 @@ func proveNullRejectedScalarFunc( ctx base.PlanContext, innerSchema *expression.Schema, expr *expression.ScalarFunction, + allowNullifiedFold bool, ) nullRejectProof { switch expr.FuncName.L { case ast.LogicAnd: - lhs := proveNullRejected(ctx, innerSchema, expr.GetArgs()[0]) - rhs := proveNullRejected(ctx, innerSchema, expr.GetArgs()[1]) + lhs := proveNullRejected(ctx, innerSchema, expr.GetArgs()[0], allowNullifiedFold) + rhs := proveNullRejected(ctx, innerSchema, expr.GetArgs()[1], allowNullifiedFold) return nullRejectProof{ nonTrue: lhs.nonTrue || rhs.nonTrue, mustNull: lhs.mustNull && rhs.mustNull, } case ast.LogicOr: - lhs := proveNullRejected(ctx, innerSchema, expr.GetArgs()[0]) - rhs := proveNullRejected(ctx, innerSchema, expr.GetArgs()[1]) + lhs := proveNullRejected(ctx, innerSchema, expr.GetArgs()[0], allowNullifiedFold) + rhs := proveNullRejected(ctx, innerSchema, expr.GetArgs()[1], allowNullifiedFold) return nullRejectProof{ nonTrue: lhs.nonTrue && rhs.nonTrue, mustNull: lhs.mustNull && rhs.mustNull, @@ -190,25 +202,25 @@ func proveNullRejectedScalarFunc( // NOT(TRUE) = FALSE, so it is null-rejected. if child, ok := expr.GetArgs()[0].(*expression.ScalarFunction); ok && child.FuncName.L == ast.IsNull { return nullRejectProof{ - nonTrue: proveNullRejected(ctx, innerSchema, child.GetArgs()[0]).mustNull, + nonTrue: proveNullRejected(ctx, innerSchema, child.GetArgs()[0], allowNullifiedFold).mustNull, } } // General NOT: NOT(NULL) = NULL (nonTrue), but NOT(FALSE) = TRUE // (not nonTrue). So nonTrue requires child.mustNull, not just // child.nonTrue. - child := proveNullRejected(ctx, innerSchema, expr.GetArgs()[0]) + child := proveNullRejected(ctx, innerSchema, expr.GetArgs()[0], allowNullifiedFold) return nullRejectProof{ nonTrue: child.mustNull, mustNull: child.mustNull, } case ast.In: - return proveNullRejectedIn(ctx, innerSchema, expr) + return proveNullRejectedIn(ctx, innerSchema, expr, allowNullifiedFold) case ast.IsNull: return nullRejectProof{} } if mode, ok := nullRejectRejectNullTests[expr.FuncName.L]; ok { - child := proveNullRejected(ctx, innerSchema, expr.GetArgs()[0]) + child := proveNullRejected(ctx, innerSchema, expr.GetArgs()[0], allowNullifiedFold) return nullRejectProof{ nonTrue: child.mustNull, mustNull: child.mustNull && mode == nullRejectTestKeepsNull, @@ -217,7 +229,7 @@ func proveNullRejectedScalarFunc( if _, ok := nullRejectNullPreservingFunctions[expr.FuncName.L]; ok { for _, arg := range expr.GetArgs() { - if proveNullRejected(ctx, innerSchema, arg).mustNull { + if proveNullRejected(ctx, innerSchema, arg, allowNullifiedFold).mustNull { return nullRejectProof{nonTrue: true, mustNull: true} } } @@ -232,18 +244,19 @@ func proveNullRejectedIn( ctx base.PlanContext, innerSchema *expression.Schema, expr *expression.ScalarFunction, + allowNullifiedFold bool, ) nullRejectProof { args := expr.GetArgs() if len(args) == 0 { return nullRejectProof{} } - valueProof := proveNullRejected(ctx, innerSchema, args[0]) + valueProof := proveNullRejected(ctx, innerSchema, args[0], allowNullifiedFold) if valueProof.mustNull { return nullRejectProof{nonTrue: true, mustNull: true} } allListMustNull := true for _, arg := range args[1:] { - if !proveNullRejected(ctx, innerSchema, arg).mustNull { + if !proveNullRejected(ctx, innerSchema, arg, allowNullifiedFold).mustNull { allListMustNull = false break } diff --git a/pkg/planner/util/null_misc_test.go b/pkg/planner/util/null_misc_test.go index 6cb83ef4854f2..3461068806664 100644 --- a/pkg/planner/util/null_misc_test.go +++ b/pkg/planner/util/null_misc_test.go @@ -181,6 +181,11 @@ func TestIsNullRejectedProofModes(t *testing.T) { newNullRejectStringConst("abc"), innerS, ) + deferredInnerGTZero := newNullRejectDeferredConst(exprCtx, gtInnerAZero) + deferredCoalesceInnerATwoGTTwo := newNullRejectDeferredConst(exprCtx, + newNullRejectFunc(t, exprCtx, ast.GT, types.NewFieldType(mysql.TypeTiny), coalesceInnerATwo, newNullRejectIntConst(2)), + ) + deferredOneWithNullPlaceholder := newNullRejectDeferredConst(exprCtx, expression.NewOne()) cases := []struct { name string @@ -337,11 +342,26 @@ func TestIsNullRejectedProofModes(t *testing.T) { expr: newNullRejectNotNull(t, exprCtx, jsonSearchNullableEscape), expected: false, }, + { + name: "deferred_expr_uses_symbolic_null_reject_proof", + expr: deferredInnerGTZero, + expected: true, + }, + { + name: "deferred_expr_skips_nullified_fold", + expr: deferredCoalesceInnerATwoGTTwo, + expected: false, + }, + { + name: "deferred_expr_does_not_classify_placeholder_null", + expr: deferredOneWithNullPlaceholder, + expected: false, + }, } for _, tt := range cases { t.Run(tt.name, func(t *testing.T) { - require.Equal(t, tt.expected, IsNullRejected(sctx, innerSchema, tt.expr, true)) + require.Equal(t, tt.expected, IsNullRejected(sctx, innerSchema, tt.expr)) }) } } @@ -394,6 +414,15 @@ func newNullRejectUintConst(value uint64) *expression.Constant { } } +// newNullRejectDeferredConst builds a deferred constant with a NULL placeholder value. +func newNullRejectDeferredConst(ctx expression.BuildContext, deferred expression.Expression) *expression.Constant { + return &expression.Constant{ + Value: types.NewDatum(nil), + RetType: deferred.GetType(ctx.GetEvalCtx()), + DeferredExpr: deferred, + } +} + func newNullRejectUintFieldType(tp byte) *types.FieldType { fieldType := types.NewFieldType(tp) fieldType.AddFlag(mysql.UnsignedFlag)