diff --git a/pkg/bindinfo/binding_auto_test.go b/pkg/bindinfo/binding_auto_test.go new file mode 100644 index 0000000000000..4fc616c268df9 --- /dev/null +++ b/pkg/bindinfo/binding_auto_test.go @@ -0,0 +1,350 @@ +// Copyright 2025 PingCAP, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package bindinfo_test + +import ( + "fmt" + "slices" + "strings" + "testing" + + "github.com/pingcap/tidb/pkg/bindinfo" + "github.com/pingcap/tidb/pkg/parser" + "github.com/pingcap/tidb/pkg/parser/auth" + "github.com/pingcap/tidb/pkg/sessionctx/vardef" + "github.com/pingcap/tidb/pkg/testkit" + "github.com/pingcap/tidb/pkg/testkit/testdata" + "github.com/stretchr/testify/require" +) + +func TestGenPlanWithSCtx(t *testing.T) { + store := testkit.CreateMockStore(t) + tk := testkit.NewTestKit(t, store) + tk.MustExec("use test") + tk.MustExec(`create table t1 (a int, b int, c int, key(a), key(b))`) + tk.MustExec(`create table t2 (a int, b int, c int, key(a), key(b))`) + + p := parser.New() + sctx := tk.Session() + sctx.GetSessionVars().CostModelVersion = 2 + check := func(sql, expectedHint, expectedPlan string) { + p.Reset() + stmt, err := p.ParseOneStmt(sql, "", "") + require.NoErrorf(t, err, "sql: %s", sql) + planDigest, planHint, planText, err := bindinfo.GenBriefPlanWithSCtx(sctx, stmt) + require.NoErrorf(t, err, "sql: %s", sql) + require.Greaterf(t, len(planDigest), 0, "sql: %s", sql) + require.Truef(t, strings.Contains(planHint, expectedHint), "sql: %s", sql) + planOperators := make([]string, 0, len(planText)) + for _, row := range planText { + planOperators = append(planOperators, row[0]) + } + require.Truef(t, strings.Contains(strings.Join(planOperators, ","), expectedPlan), "sql: %s", sql) + } + check("select count(1) from t1 where a=1", + "stream_agg", "StreamAgg") + + sctx.GetSessionVars().StreamAggCostFactor = 10000 + check("select count(1) from t1 where a=1", + "hash_agg", "HashAgg") + sctx.GetSessionVars().StreamAggCostFactor = 1 + + check("select * from t1, t2 where t1.a=t2.a and t2.b=1", + "inl_hash_join", "IndexHashJoin") + + sctx.GetSessionVars().IndexJoinCostFactor = 100000 + sctx.GetSessionVars().HashJoinCostFactor = 100000 + check("select * from t1, t2 where t1.a=t2.a and t2.b=1", + "merge_join", `MergeJoin`) +} + +func TestExplainExploreBasic(t *testing.T) { + store := testkit.CreateMockStore(t) + tk := testkit.NewTestKit(t, store) + tk.MustExec("use test") + + check := func(sql string, expectedRowCount int) { + rows := tk.MustQuery(sql).Rows() + require.Equalf(t, expectedRowCount, len(rows), "sql: %s", sql) + for _, row := range rows { + planDigest := row[3] + require.NotEmptyf(t, planDigest, "sql: %s", sql) + } + } + + tk.MustExec(`create table t (a int, b int, c varchar(10), key(a))`) + check(`explain explore select a from t where b=1`, 1) + tk.MustExec(`create global binding using select a from t where b=1`) + check(`explain explore select a from t where b=1`, 2) + check(`explain explore SELECT a FROM t WHERE b=1`, 2) + check(`explain explore SELECT a FROM t WHERE b= 1`, 2) + check(`explain explore SELECT a FROM test.t WHERE b= 1`, 2) + require.GreaterOrEqual(t, len(tk.MustQuery(`explain explore "23109784b802bcef5398dd81d3b1c5b79200c257c101a5b9f90758206f3d09ed"`).Rows()), 1) + + check(`explain explore select a from t where b in (1, 2, 3)`, 1) + tk.MustExec(`create global binding using select a from t where b in (1, 2, 3)`) + check(`explain explore select a from t where b in (1, 2, 3)`, 2) + check(`explain explore select a from t where b in (1, 2)`, 2) + check(`explain explore select a from t where b in (1)`, 2) + check(`explain explore SELECT a from t WHere b in (1)`, 2) + + check(`explain explore select a from t where c = ''`, 1) + tk.MustExec(`create global binding using select a from t where c = ''`) + check(`explain explore select a from t where c = ''`, 2) + check(`explain explore select a from t where c = '123'`, 2) + check(`explain explore select a from t where c = '\"'`, 2) + check(`explain explore select a from t where c = ' '`, 2) + check(`explain explore select a from t where c = ""`, 2) + check(`explain explore select a from t where c = "\'"`, 2) + + tk.MustExecToErr("explain explore 'xxx'", "") + tk.MustExecToErr("explain explore SELECT A FROM", "") +} + +func TestExplainExploreIndexHints(t *testing.T) { + store := testkit.CreateMockStore(t) + tk := testkit.NewTestKit(t, store) + tk.MustExec("use test") + tk.MustExec(`create table t (a int, b int, c int, key(a), key(b))`) + + rows := tk.MustQuery(`explain explore select * from t where a=1 and b=1`).Rows() + hasIndexA, hasIndexB := false, false + for _, row := range rows { + plan := row[2].(string) + if strings.Contains(plan, "index:a") { + hasIndexA = true + } + if strings.Contains(plan, "index:b") { + hasIndexB = true + } + } + require.True(t, hasIndexA, "expected index a plan in explain explore output") + require.True(t, hasIndexB, "expected index b plan in explain explore output") +} + +func TestExplainExploreIndexHintWithAlias(t *testing.T) { + store := testkit.CreateMockStore(t) + tk := testkit.NewTestKit(t, store) + tk.MustExec("use test") + tk.MustExec(`create table t (a int, b int, x varchar(10), key(a), key(b))`) + + rows := tk.MustQuery(`explain explore select 1 from t t_alias where a=1 and b=1 and x like "%xx%"`).Rows() + hasIndexA, hasIndexB := false, false + for _, row := range rows { + plan := row[2].(string) + if strings.Contains(plan, "index:a") { + hasIndexA = true + } + if strings.Contains(plan, "index:b") { + hasIndexB = true + } + } + require.True(t, hasIndexA, "expected index a plan in explain explore output") + require.True(t, hasIndexB, "expected index b plan in explain explore output") +} + +func TestExplainExploreNoDecorrelateHint(t *testing.T) { + store := testkit.CreateMockStore(t) + tk := testkit.NewTestKit(t, store) + tk.MustExec("use test") + tk.MustExec(`create table o (a int, b int, c int, d int, key(b))`) + tk.MustExec(`create table r (a int, b int, key(a), key(b))`) + tk.MustExec(`create table o1 (a int, key(a))`) + + rows := tk.MustQuery(`explain explore select o.* from o where exists (select 1 from r inner join o1 on o1.a=r.a where r.b=o.b)`).Rows() + hasNoDecorrelate := false + for _, row := range rows { + if strings.Contains(row[1].(string), "no_decorrelate") { + hasNoDecorrelate = true + break + } + } + require.True(t, hasNoDecorrelate, "expected no_decorrelate plan in explain explore output") +} + +func TestIsSimplePointPlan(t *testing.T) { + require.True(t, bindinfo.IsSimplePointPlan(` id task estRows operator info actRows execution info memory disk + Projection_4 root 1 plus(test.t.a, 1)->Column#3 0 time:173µs, open:24.9µs, close:8.92µs, loops:1, Concurrency:OFF 380 Bytes N/A + └─Point_Get_5 root 1 table:t, handle:2 0 time:143.2µs, open:1.71µs, close:5.92µs, loops:1, Get:{num_rpc:1, total_time:40µs} N/A N/A`)) + require.True(t, bindinfo.IsSimplePointPlan(` id task estRows operator info actRows execution info memory disk + Point_Get_5 root 1 table:t, handle:2 0 time:143.2µs, open:1.71µs, close:5.92µs, loops:1, Get:{num_rpc:1, total_time:40µs} N/A N/A`)) + require.True(t, bindinfo.IsSimplePointPlan(`Point_Get_5 root 1 table:t, handle:2 0 time:143.2µs, open:1.71µs, close:5.92µs, loops:1, Get:{num_rpc:1, total_time:40µs} N/A N/A`)) + require.True(t, bindinfo.IsSimplePointPlan(`id task estRows operator info actRows execution info memory disk + Projection_4 root 3.00 plus(test.t.a, 1)->Column#3 0 time:218.3µs, open:14.5µs, close:9.79µs, loops:1, Concurrency:OFF 145 Bytes N/A + └─Batch_Point_Get_5 root 3.00 table:t, handle:[1 2 3], keep order:false, desc:false 0 time:201.1µs, open:3.83µs, close:6.46µs, loops:1, BatchGet:{num_rpc:2, total_time:65.7µs}, rpc_errors:{epoch_not_match:1} N/A N/A `)) + require.True(t, bindinfo.IsSimplePointPlan(`id task estRows operator info actRows execution info memory disk + Batch_Point_Get_5 root 3.00 table:t, handle:[1 2 3], keep order:false, desc:false 0 time:201.1µs, open:3.83µs, close:6.46µs, loops:1, BatchGet:{num_rpc:2, total_time:65.7µs}, rpc_errors:{epoch_not_match:1} N/A N/A `)) + require.True(t, bindinfo.IsSimplePointPlan(`id task estRows operator info actRows execution info memory disk + Selection .... + └─Batch_Point_Get_5 root 3.00 table:t, handle:[1 2 3], keep order:false, desc:false 0 time:201.1µs, open:3.83µs, close:6.46µs, loops:1, BatchGet:{num_rpc:2, total_time:65.7µs}, rpc_errors:{epoch_not_match:1} N/A N/A `)) + + require.False(t, bindinfo.IsSimplePointPlan(` id task estRows operator info actRows execution info memory disk + TableReader_5 root 10000 data:TableFullScan_4 0 time:456.3µs, open:141µs, close:6.79µs, loops:1, cop_task: {num: 1, max: 241.3µs, proc_keys: 0, copr_cache_hit_ratio: 0.00, build_task_duration: 91.5µs, max_distsql_concurrency: 1}, rpc_info:{Cop:{num_rpc:1, total_time:203.9µs}} 182 Bytes N/A + └─TableFullScan_4 cop[tikv] 10000 table:t, keep order:false, stats:pseudo 0 tikv_task:{time:155.2µs, loops:0} N/A N/A `)) + require.False(t, bindinfo.IsSimplePointPlan(`id task estRows operator info actRows execution info memory disk + HashAgg root 3.00 plus(test.t.a, 1)->Column#3 0 time:218.3µs, open:14.5µs, close:9.79µs, loops:1, Concurrency:OFF 145 Bytes N/A + └─Batch_Point_Get_5 root 3.00 table:t, handle:[1 2 3], keep order:false, desc:false 0 time:201.1µs, open:3.83µs, close:6.46µs, loops:1, BatchGet:{num_rpc:2, total_time:65.7µs}, rpc_errors:{epoch_not_match:1} N/A N/A `)) + require.False(t, bindinfo.IsSimplePointPlan(` id task estRows operator info actRows execution info memory disk + HashJoin root 1 plus(test.t.a, 1)->Column#3 0 time:173µs, open:24.9µs, close:8.92µs, loops:1, Concurrency:OFF 380 Bytes N/A + └─Point_Get_5 root 1 table:t, handle:2 0 time:143.2µs, open:1.71µs, close:5.92µs, loops:1, Get:{num_rpc:1, total_time:40µs} N/A N/A`)) + require.False(t, bindinfo.IsSimplePointPlan(``)) + require.False(t, bindinfo.IsSimplePointPlan(` \n `)) +} + +func TestRelevantOptVarsAndFixes(t *testing.T) { + store := testkit.CreateMockStore(t) + tk := testkit.NewTestKit(t, store) + tk.MustExec("use test") + tk.MustExec(`create table t1 (a int, b int, c varchar(10), key(a), key(b))`) + tk.MustExec(`create table t2 (a int, b int, c varchar(10), key(a), key(b))`) + + var input []string + var output []struct { + Vars string + Fixes string + } + bindingAutoSuiteData.LoadTestCases(t, &input, &output) + p := parser.New() + for i, sql := range input { + p.Reset() + stmt, err := p.ParseOneStmt(sql, "", "") + require.NoErrorf(t, err, "sql: %s", sql) + vars, fixes, err := bindinfo.RecordRelevantOptVarsAndFixes(tk.Session(), stmt) + require.NoErrorf(t, err, "sql: %s", sql) + testdata.OnRecord(func() { + output[i].Vars = fmt.Sprintf("%v", vars) + output[i].Fixes = fmt.Sprintf("%v", fixes) + }) + require.Equalf(t, fmt.Sprintf("%v", vars), output[i].Vars, "sql: %s", sql) + require.Equalf(t, fmt.Sprintf("%v", fixes), output[i].Fixes, "sql: %s", sql) + } +} + +func TestRelevantOptVarsCorrelateSubquery(t *testing.T) { + store := testkit.CreateMockStore(t) + tk := testkit.NewTestKit(t, store) + tk.MustExec("use test") + tk.MustExec(`create table t1 (a int, b int, key(a))`) + tk.MustExec(`create table t2 (a int, b int, key(a))`) + + p := parser.New() + sql := "select * from t1 where a in (select a from t2)" + + // The alternative logical plans variable is recorded as relevant because the + // code path where it affects plan choice (correlate-to-Apply) was reached. + for _, enabled := range []string{"OFF", "ON"} { + tk.MustExec("set tidb_opt_enable_alternative_logical_plans = " + enabled) + p.Reset() + stmt, err := p.ParseOneStmt(sql, "", "") + require.NoError(t, err) + vars, _, err := bindinfo.RecordRelevantOptVarsAndFixes(tk.Session(), stmt) + require.NoError(t, err) + require.True(t, slices.Contains(vars, vardef.TiDBOptEnableAlternativeLogicalPlans), + "enabled=%s: expected %s in recorded vars %v", enabled, vardef.TiDBOptEnableAlternativeLogicalPlans, vars) + } +} + +func TestExplainExploreAnalyze(t *testing.T) { + store := testkit.CreateMockStore(t) + tk := testkit.NewTestKit(t, store) + tk.MustExec("use test") + tk.MustExec(`create table t (a int, b int, key(a))`) + tk.MustExec(`insert into t values (1, 2), (2, 3), (3, 4), (4, 5)`) + + checkExecInfo := func(sql string, hasExecInfo bool) { + rs := tk.MustQuery(sql).Rows() + for _, row := range rs { + latency := row[4].(string) + execTimes := row[5].(string) + retRows := row[7].(string) + if !hasExecInfo { + require.Equalf(t, "0", latency, "sql: %s", sql) + require.Equalf(t, "0", execTimes, "sql: %s", sql) + require.Equalf(t, "0", retRows, "sql: %s", sql) + } else { + require.NotEqualf(t, "0", latency, "sql: %s", sql) + require.NotEqualf(t, "0", execTimes, "sql: %s", sql) + require.NotEqualf(t, "0", retRows, "sql: %s", sql) + } + } + } + + checkExecInfo(`explain explore select * from t where a=1`, false) + checkExecInfo(`explain explore analyze select * from t where a=1`, true) + checkExecInfo(`explain explore select * from t where b<10`, false) + checkExecInfo(`explain explore analyze select * from t where b<10`, true) + checkExecInfo(`explain explore select count(1) from t where b<10`, false) + checkExecInfo(`explain explore analyze select count(1) from t where b<10`, true) +} + +func TestExplainExploreVerifyAndBind(t *testing.T) { + store := testkit.CreateMockStore(t) + tk := testkit.NewTestKit(t, store) + require.NoError(t, tk.Session().Auth(&auth.UserIdentity{Username: "root", Hostname: "%"}, nil, nil, nil)) + tk.MustExec("use test") + tk.MustExec(`create table t (a int, b int, key(a))`) + tk.MustExec(`insert into t values (1, 2), (2, 3), (3, 4), (4, 5)`) + + tk.MustQuery(`select * from t`) + tk.MustQuery(`select @@last_plan_from_binding`).Check(testkit.Rows("0")) + require.Equal(t, 0, len(tk.MustQuery(`show global bindings`).Rows())) // no binding + + rs := tk.MustQuery(`explain explore select * from t`).Rows() + runStmt := rs[0][12].(string) // "EXPLAIN ANALYZE " + bindingSQL := rs[0][13].(string) // "CREATE GLOBAL BINDING USING " + + require.True(t, strings.HasPrefix(runStmt, "EXPLAIN ANALYZE")) + require.True(t, strings.HasPrefix(bindingSQL, "CREATE GLOBAL BINDING USING")) + + rs = tk.MustQuery(runStmt).Rows() + require.True(t, strings.Contains(rs[0][0].(string), "TableReader")) // table scan and no error + + tk.MustExec(bindingSQL) + tk.MustQuery(`select * from t`) + tk.MustQuery(`select @@last_plan_from_binding`).Check(testkit.Rows("1")) + require.Equal(t, 1, len(tk.MustQuery(`show global bindings`).Rows())) +} + +func TestPlanGeneration(t *testing.T) { + store := testkit.CreateMockStore(t) + tk := testkit.NewTestKit(t, store) + tk.MustExec("use test") + tk.MustExec(`create table t (a int, b int, c int, key(a))`) + tk.MustExec(`create table t1 (a int, b int, c int, key(a), key(b))`) + tk.MustExec(`create table t2 (a int, b int, c int, key(a), key(b))`) + tk.MustExec(`create table t3 (a int, b int, c int, key(a), key(b))`) + + var input []string + var output []struct { + SQL string + Plan [][]string + } + bindingAutoSuiteData.LoadTestCases(t, &input, &output) + for i, sql := range input { + rows := tk.MustQuery(sql).Rows() + for rowID, row := range rows { + plan := strings.Split(strings.Replace(row[2].(string), "\t", " ", -1), "\n") + testdata.OnRecord(func() { + output[i].SQL = sql + if len(output[i].Plan) < rowID { + output[i].Plan[rowID] = plan + } else { + output[i].Plan = append(output[i].Plan, plan) + } + }) + require.Equalf(t, plan, output[i].Plan[rowID], "sql: %s", sql) + } + } +} diff --git a/pkg/bindinfo/binding_plan_generation.go b/pkg/bindinfo/binding_plan_generation.go new file mode 100644 index 0000000000000..c672f6f96ebd0 --- /dev/null +++ b/pkg/bindinfo/binding_plan_generation.go @@ -0,0 +1,988 @@ +// Copyright 2025 PingCAP, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package bindinfo + +import ( + "container/list" + "fmt" + "sort" + "strconv" + "strings" + + "github.com/pingcap/tidb/pkg/meta/model" + "github.com/pingcap/tidb/pkg/parser" + "github.com/pingcap/tidb/pkg/parser/ast" + "github.com/pingcap/tidb/pkg/planner/util/fixcontrol" + "github.com/pingcap/tidb/pkg/sessionctx" + "github.com/pingcap/tidb/pkg/sessionctx/vardef" + "github.com/pingcap/tidb/pkg/util" + "github.com/pingcap/tidb/pkg/util/hint" +) + +// PlanGenerator is used to generate new Plan Candidates for this specified query. +type PlanGenerator interface { + Generate(defaultSchema, sql, charset, collation string) (plans []*BindingPlanInfo, err error) +} + +// planGenerator implements PlanGenerator. +// It generates new plans via adjusting the optimizer variables and fixes. +type planGenerator struct { + sPool util.DestroyableSessionPool +} + +// Generate generates new plans for the given SQL statement. +func (g *planGenerator) Generate(defaultSchema, sql, charset, collation string) (plans []*BindingPlanInfo, err error) { + // TODO: only support SQL starting with SELECT for now, support other types of SQLs later. + // TODO: make this check more strict. + sql = strings.TrimSpace(sql) + prefix := "SELECT" + if len(sql) < len(prefix) || strings.ToUpper(sql[:len(prefix)]) != prefix { + return nil, nil // not a SELECT statement + } + + err = callWithSCtx(g.sPool, false, func(sctx sessionctx.Context) error { + genedPlans, err := generatePlanWithSCtx(sctx, defaultSchema, sql, charset, collation) + if err != nil { + return err + } + plans = make([]*BindingPlanInfo, 0, len(genedPlans)) + + for _, genedPlan := range genedPlans { + // TODO: construct bindingSQL in a more strict way. + bindingSQL := sql[:len(prefix)] + " /*+ " + genedPlan.planHints + " */ " + sql[len(prefix):] + binding := &Binding{ + OriginalSQL: sql, + BindSQL: bindingSQL, + Db: defaultSchema, + Source: "generated", + PlanDigest: genedPlan.planDigest, + } + if err := prepareHints(sctx, binding); err != nil { + return err + } + plan := &BindingPlanInfo{ + Binding: binding, + Plan: genedPlan.PlanText(), + } + plans = append(plans, plan) + } + return nil + }) + return +} + +type tableName struct { + schema string + name string + alias string +} + +func (t *tableName) HintName() string { + if t.alias != "" { + return t.alias + } + return t.name +} + +func (t *tableName) String() string { + return fmt.Sprintf("%s.%s", t.schema, t.HintName()) +} + +type indexHint struct { + table *tableName + index string +} + +func (h *indexHint) String() string { + return fmt.Sprintf("%s:%s", h.table.String(), h.index) +} + +// genedPlan represents a plan generated by planGenerator. +type genedPlan struct { + planDigest string // digest of this plan + planHints string // a set of hints to reproduce this plan + planText [][]string // human-readable plan text +} + +func (gp *genedPlan) PlanText() string { + sb := new(strings.Builder) + for i, row := range gp.planText { + if i > 0 { + sb.WriteString("\n") + } + for j, col := range row { + if j > 0 { + sb.WriteString("\t") + } + sb.WriteString(col) + } + } + return sb.String() +} + +// state represents a state of the optimizer variables and fixes. +type state struct { + leading2 [2]*tableName // leading-2 table names + indexHints []*indexHint // optional index hints per table + // noDecorrelateQB stores the query block name for NO_DECORRELATE() hint. + noDecorrelateQB ast.CIStr + varNames []string // relevant variables and their values to generate a certain plan + varValues []any + fixIDs []uint64 // relevant fixes and their values to generate a certain plan + fixValues []string +} + +// Encode encodes the state into a string. +func (s *state) Encode() string { + sb := new(strings.Builder) + for _, t := range s.leading2 { + if t == nil { + continue + } + if sb.Len() > 0 { + sb.WriteString(",") + } + sb.WriteString(t.String()) + } + for _, indexHint := range s.indexHints { + if indexHint == nil { + continue + } + if sb.Len() > 0 { + sb.WriteString(",") + } + sb.WriteString(indexHint.String()) + } + if s.noDecorrelateQB.L != "" { + if sb.Len() > 0 { + sb.WriteString(",") + } + sb.WriteString("no_decorrelate@") + sb.WriteString(s.noDecorrelateQB.L) + } + for _, v := range s.varValues { + if sb.Len() > 0 { + sb.WriteString(",") + } + if _, isFloat := v.(float64); isFloat { + // only consider 4 decimal digits, which should be enough for optimizer tuning. + fmt.Fprintf(sb, "%.4f", v) + continue + } + fmt.Fprintf(sb, "%v", v) + } + for _, v := range s.fixValues { + if sb.Len() > 0 { + sb.WriteString(",") + } + sb.WriteString(v) + } + return sb.String() +} + +func newStateWithLeading2(old *state, leading2 [2]*tableName) *state { + newState := &state{ + leading2: leading2, + noDecorrelateQB: old.noDecorrelateQB, + indexHints: append([]*indexHint(nil), old.indexHints...), + varNames: old.varNames, + varValues: old.varValues, + fixIDs: old.fixIDs, + fixValues: old.fixValues, + } + return newState +} + +func newStateWithIndexHint(old *state, tableIdx int, hint *indexHint) *state { + newHints := append([]*indexHint(nil), old.indexHints...) + if tableIdx >= 0 && tableIdx < len(newHints) { + newHints[tableIdx] = hint + } + return &state{ + leading2: old.leading2, + indexHints: newHints, + noDecorrelateQB: old.noDecorrelateQB, + varNames: old.varNames, + varValues: old.varValues, + fixIDs: old.fixIDs, + fixValues: old.fixValues, + } +} + +func newStateWithNoDecorrelateQB(old *state, qbName ast.CIStr) *state { + return &state{ + leading2: old.leading2, + noDecorrelateQB: qbName, + indexHints: append([]*indexHint(nil), old.indexHints...), + varNames: old.varNames, + varValues: old.varValues, + fixIDs: old.fixIDs, + fixValues: old.fixValues, + } +} + +func newStateWithNewVar(old *state, varName string, varVal any) *state { + newState := &state{ + leading2: old.leading2, + noDecorrelateQB: old.noDecorrelateQB, + indexHints: append([]*indexHint(nil), old.indexHints...), + varNames: old.varNames, + varValues: make([]any, len(old.varValues)), + fixIDs: old.fixIDs, + fixValues: old.fixValues, + } + copy(newState.varValues, old.varValues) + for i := range newState.varNames { + if newState.varNames[i] == varName { + newState.varValues[i] = varVal + break + } + } + return newState +} + +func newStateWithNewFix(old *state, fixID uint64, fixVal string) *state { + newState := &state{ + leading2: old.leading2, + noDecorrelateQB: old.noDecorrelateQB, + indexHints: append([]*indexHint(nil), old.indexHints...), + varNames: old.varNames, + varValues: old.varValues, + fixIDs: old.fixIDs, + fixValues: make([]string, len(old.fixValues)), + } + copy(newState.fixValues, old.fixValues) + for i := range newState.fixIDs { + if newState.fixIDs[i] == fixID { + newState.fixValues[i] = fixVal + break + } + } + return newState +} + +func generatePlanWithSCtx(sctx sessionctx.Context, defaultSchema, sql, charset, collation string) (plans []*genedPlan, err error) { + p := parser.New() + stmt, err := p.ParseOneStmt(sql, charset, collation) + if err != nil { + return nil, err + } + sctx.GetSessionVars().CurrentDB = defaultSchema + sctx.GetSessionVars().CostModelVersion = 2 // cost factor only works on cost-model v2 + vars, fixes, err := RecordRelevantOptVarsAndFixes(sctx, stmt) + if err != nil { + return nil, err + } + tableNames := extractSelectTableNames(defaultSchema, stmt) + possibleLeading2 := make([][2]*tableName, 0, 8) // enumerate all possible leading-2 table pairs + for i := range tableNames { + for j := range tableNames { + if i == j { + continue + } + possibleLeading2 = append(possibleLeading2, [2]*tableName{tableNames[i], tableNames[j]}) + } + } + indexHintOptions := extractSelectIndexHints(sctx, defaultSchema, stmt) + possibleNoDecorrelateQBs := extractNoDecorrelateQBs(stmt) + return breadthFirstPlanSearch(sctx, stmt, vars, fixes, possibleLeading2, indexHintOptions, possibleNoDecorrelateQBs) +} + +func breadthFirstPlanSearch(sctx sessionctx.Context, stmt ast.StmtNode, + vars []string, fixes []uint64, possibleLeading2 [][2]*tableName, indexHintOptions [][]*indexHint, possibleNoDecorrelateQBs []ast.CIStr) (plans []*genedPlan, err error) { + // init BFS structures + visitedStates := make(map[string]struct{}) // map[encodedState]struct{}, all visited states + visitedPlans := make(map[string]*genedPlan) // map[planDigest]plan, all visited plans + stateList := list.New() // states in queue to explore + + // init the start state and push it into the BFS list + // start state: no specified leading hint + default values of all variables and fix-controls + startState, err := getStartState(vars, fixes, len(indexHintOptions)) + if err != nil { + return nil, err + } + visitedStates[startState.Encode()] = struct{}{} + stateList.PushBack(startState) + + maxPlans, maxExploreState := 30, 10000 + for len(visitedPlans) < maxPlans && len(visitedStates) < maxExploreState && stateList.Len() > 0 { + currState := stateList.Remove(stateList.Front()).(*state) + plan, err := genPlanUnderState(sctx, stmt, currState) + if err != nil { + return nil, err + } + visitedPlans[plan.planDigest] = plan + + // in each step, adjust one variable or fix or join-order + for _, qbName := range possibleNoDecorrelateQBs { + newState := newStateWithNoDecorrelateQB(currState, qbName) + if _, ok := visitedStates[newState.Encode()]; !ok { + visitedStates[newState.Encode()] = struct{}{} + stateList.PushBack(newState) + } + } + for _, leading2 := range possibleLeading2 { + newState := newStateWithLeading2(currState, leading2) + if _, ok := visitedStates[newState.Encode()]; !ok { + visitedStates[newState.Encode()] = struct{}{} + stateList.PushBack(newState) + } + } + for tableIdx := range indexHintOptions { + for _, indexHint := range indexHintOptions[tableIdx] { + newState := newStateWithIndexHint(currState, tableIdx, indexHint) + if _, ok := visitedStates[newState.Encode()]; !ok { + visitedStates[newState.Encode()] = struct{}{} + stateList.PushBack(newState) + } + } + } + for i := range vars { + varName, varVal := vars[i], currState.varValues[i] + newVarVal, err := adjustVar(varName, varVal) + if err != nil { + return nil, err + } + newState := newStateWithNewVar(currState, varName, newVarVal) + if _, ok := visitedStates[newState.Encode()]; !ok { + visitedStates[newState.Encode()] = struct{}{} + stateList.PushBack(newState) + } + } + for i := range fixes { + fixID, fixVal := fixes[i], currState.fixValues[i] + newFixVal, err := adjustFix(fixID, fixVal) + if err != nil { + return nil, err + } + newState := newStateWithNewFix(currState, fixID, newFixVal) + if _, ok := visitedStates[newState.Encode()]; !ok { + visitedStates[newState.Encode()] = struct{}{} + stateList.PushBack(newState) + } + } + } + + plans = make([]*genedPlan, 0, len(visitedPlans)) + for _, plan := range visitedPlans { + plans = append(plans, plan) + } + sort.Slice(plans, func(i, j int) bool { // to make the result stable + return plans[i].planDigest < plans[j].planDigest + }) + return plans, nil +} + +// genPlanUnderState returns a plan generated under the given state (vars and fix-controls). +func genPlanUnderState(sctx sessionctx.Context, stmt ast.StmtNode, state *state) (plan *genedPlan, err error) { + for i, varName := range state.varNames { + switch varName { + case vardef.TiDBOptIndexScanCostFactor: + sctx.GetSessionVars().IndexScanCostFactor = state.varValues[i].(float64) + case vardef.TiDBOptIndexReaderCostFactor: + sctx.GetSessionVars().IndexReaderCostFactor = state.varValues[i].(float64) + case vardef.TiDBOptTableReaderCostFactor: + sctx.GetSessionVars().TableReaderCostFactor = state.varValues[i].(float64) + case vardef.TiDBOptTableFullScanCostFactor: + sctx.GetSessionVars().TableFullScanCostFactor = state.varValues[i].(float64) + case vardef.TiDBOptTableRangeScanCostFactor: + sctx.GetSessionVars().TableRangeScanCostFactor = state.varValues[i].(float64) + case vardef.TiDBOptTableRowIDScanCostFactor: + sctx.GetSessionVars().TableRowIDScanCostFactor = state.varValues[i].(float64) + case vardef.TiDBOptTableTiFlashScanCostFactor: + sctx.GetSessionVars().TableTiFlashScanCostFactor = state.varValues[i].(float64) + case vardef.TiDBOptIndexLookupCostFactor: + sctx.GetSessionVars().IndexLookupCostFactor = state.varValues[i].(float64) + case vardef.TiDBOptIndexMergeCostFactor: + sctx.GetSessionVars().IndexMergeCostFactor = state.varValues[i].(float64) + case vardef.TiDBOptSortCostFactor: + sctx.GetSessionVars().SortCostFactor = state.varValues[i].(float64) + case vardef.TiDBOptTopNCostFactor: + sctx.GetSessionVars().TopNCostFactor = state.varValues[i].(float64) + case vardef.TiDBOptLimitCostFactor: + sctx.GetSessionVars().LimitCostFactor = state.varValues[i].(float64) + case vardef.TiDBOptStreamAggCostFactor: + sctx.GetSessionVars().StreamAggCostFactor = state.varValues[i].(float64) + case vardef.TiDBOptHashAggCostFactor: + sctx.GetSessionVars().HashAggCostFactor = state.varValues[i].(float64) + case vardef.TiDBOptMergeJoinCostFactor: + sctx.GetSessionVars().MergeJoinCostFactor = state.varValues[i].(float64) + case vardef.TiDBOptHashJoinCostFactor: + sctx.GetSessionVars().HashJoinCostFactor = state.varValues[i].(float64) + case vardef.TiDBOptIndexJoinCostFactor: + sctx.GetSessionVars().IndexJoinCostFactor = state.varValues[i].(float64) + case vardef.TiDBOptOrderingIdxSelRatio: + sctx.GetSessionVars().OptOrderingIdxSelRatio = state.varValues[i].(float64) + case vardef.TiDBOptRiskEqSkewRatio: + sctx.GetSessionVars().RiskEqSkewRatio = state.varValues[i].(float64) + case vardef.TiDBOptRiskGroupNDVSkewRatio: + sctx.GetSessionVars().RiskGroupNDVSkewRatio = state.varValues[i].(float64) + case vardef.TiDBOptRiskRangeSkewRatio: + sctx.GetSessionVars().RiskRangeSkewRatio = state.varValues[i].(float64) + case vardef.TiDBOptPreferRangeScan: + sctx.GetSessionVars().SetAllowPreferRangeScan(state.varValues[i].(bool)) + case vardef.TiDBOptEnableNoDecorrelateInSelect: + sctx.GetSessionVars().EnableNoDecorrelateInSelect = state.varValues[i].(bool) + case vardef.TiDBOptEnableSemiJoinRewrite: + sctx.GetSessionVars().EnableSemiJoinRewrite = state.varValues[i].(bool) + case vardef.TiDBOptSelectivityFactor: + sctx.GetSessionVars().SelectivityFactor = state.varValues[i].(float64) + case vardef.TiDBOptEnableAlternativeLogicalPlans: + sctx.GetSessionVars().EnableAlternativeLogicalPlans = state.varValues[i].(bool) + default: + return nil, fmt.Errorf("unsupported variable %s in plan generation", varName) + } + } + + fixControlStrBuilder := strings.Builder{} + for i, fixID := range state.fixIDs { + if i > 0 { + fixControlStrBuilder.WriteString(",") + } + fmt.Fprintf(&fixControlStrBuilder, "%v:%v", fixID, state.fixValues[i]) + } + fixControlMap, _, err := fixcontrol.ParseToMap(fixControlStrBuilder.String()) + if err != nil { + return nil, err + } + sctx.GetSessionVars().OptimizerFixControl = fixControlMap + + if sel, isSel := stmt.(*ast.SelectStmt); isSel { + hasIndexHint := false + for _, indexHint := range state.indexHints { + if indexHint != nil { + hasIndexHint = true + break + } + } + if (state.leading2[0] != nil && state.leading2[1] != nil) || hasIndexHint { + originalHintsLen := len(sel.TableHints) + defer func() { + sel.TableHints = sel.TableHints[:originalHintsLen] + }() + if state.leading2[0] != nil && state.leading2[1] != nil { + leadingHint := &ast.TableOptimizerHint{ + HintName: ast.NewCIStr(hint.HintLeading), + Tables: []ast.HintTable{ + { + DBName: ast.NewCIStr(state.leading2[0].schema), + TableName: ast.NewCIStr(state.leading2[0].HintName()), + }, + { + DBName: ast.NewCIStr(state.leading2[1].schema), + TableName: ast.NewCIStr(state.leading2[1].HintName()), + }, + }, + } + sel.TableHints = append(sel.TableHints, leadingHint) + } + for _, indexHint := range state.indexHints { + if indexHint == nil { + continue + } + hintNode := &ast.TableOptimizerHint{ + HintName: ast.NewCIStr(hint.HintUseIndex), + Tables: []ast.HintTable{ + { + DBName: ast.NewCIStr(indexHint.table.schema), + TableName: ast.NewCIStr(indexHint.table.HintName()), + }, + }, + Indexes: []ast.CIStr{ast.NewCIStr(indexHint.index)}, + } + sel.TableHints = append(sel.TableHints, hintNode) + } + } + if state.noDecorrelateQB.L != "" { + noDecorrelateHint := &ast.TableOptimizerHint{ + HintName: ast.NewCIStr(hint.HintNoDecorrelate), + QBName: state.noDecorrelateQB, + } + sel.TableHints = append(sel.TableHints, noDecorrelateHint) + } + } + + planDigest, planHints, planText, err := GenBriefPlanWithSCtx(sctx, stmt) + if err != nil { + return nil, err + } + return &genedPlan{ + planDigest: planDigest, + planText: planText, + planHints: planHints, + }, nil +} + +// adjustVar returns the new value of the variable for plan generation. +func adjustVar(varName string, varVal any) (newVarVal any, err error) { + switch varName { + case vardef.TiDBOptIndexScanCostFactor, vardef.TiDBOptIndexReaderCostFactor, vardef.TiDBOptTableReaderCostFactor, + vardef.TiDBOptTableFullScanCostFactor, vardef.TiDBOptTableRangeScanCostFactor, vardef.TiDBOptTableRowIDScanCostFactor, + vardef.TiDBOptTableTiFlashScanCostFactor, vardef.TiDBOptIndexLookupCostFactor, vardef.TiDBOptIndexMergeCostFactor, + vardef.TiDBOptSortCostFactor, vardef.TiDBOptTopNCostFactor, vardef.TiDBOptLimitCostFactor, + vardef.TiDBOptStreamAggCostFactor, vardef.TiDBOptHashAggCostFactor, vardef.TiDBOptMergeJoinCostFactor, + vardef.TiDBOptHashJoinCostFactor, vardef.TiDBOptIndexJoinCostFactor: + // for cost factors, we add add some penalties (5 tims of its current cost) in each step. + v := varVal.(float64) + if v >= 1e6 { // avoid too large penalty. + return v, nil + } + return v * 5, nil + case vardef.TiDBOptOrderingIdxSelRatio, vardef.TiDBOptRiskEqSkewRatio, vardef.TiDBOptRiskRangeSkewRatio, vardef.TiDBOptRiskGroupNDVSkewRatio, vardef.TiDBOptSelectivityFactor: // range [0, 1], "<=0" means disable + v := varVal.(float64) + if v <= 0 { + return 0.1, nil + } else if v+0.1 > 1 { + return v, nil + } + // increase 0.1 each step + return v + 0.1, nil + case vardef.TiDBOptPreferRangeScan, vardef.TiDBOptEnableNoDecorrelateInSelect, vardef.TiDBOptAlwaysKeepJoinKey, vardef.TiDBOptEnableSemiJoinRewrite, vardef.TiDBOptEnableAlternativeLogicalPlans: // flip the switch + return !varVal.(bool), nil + } + return nil, fmt.Errorf("unsupported variable %s in plan generation", varName) +} + +// adjustFix returns the new value of the fix-control for plan generation. +func adjustFix(fixID uint64, fixVal string) (newFixVal string, err error) { + switch fixID { + case fixcontrol.Fix44855, fixcontrol.Fix52869: // flip the switch + fixVal = strings.ToUpper(strings.TrimSpace(fixVal)) + if fixVal == vardef.Off { + return vardef.On, nil + } + return vardef.Off, nil + case fixcontrol.Fix45132: + num, err := strconv.ParseInt(fixVal, 10, 64) + if err != nil { + return "", err + } + if num <= 10 { + return fixVal, nil + } + // each time become 50% more aggressive. + return fmt.Sprintf("%v", num/2), nil + default: + return "", fmt.Errorf("unsupported fix-control %d in plan generation", fixID) + } +} + +func getStartState(vars []string, fixes []uint64, indexHintCount int) (*state, error) { + // use the default values of these vars and fix-controls as the initial state. + s := &state{ + varNames: vars, + fixIDs: fixes, + indexHints: make([]*indexHint, indexHintCount), + } + for _, varName := range vars { + switch varName { + case vardef.TiDBOptIndexScanCostFactor: + s.varValues = append(s.varValues, vardef.DefOptIndexScanCostFactor) + case vardef.TiDBOptIndexReaderCostFactor: + s.varValues = append(s.varValues, vardef.DefOptIndexReaderCostFactor) + case vardef.TiDBOptTableReaderCostFactor: + s.varValues = append(s.varValues, vardef.DefOptTableReaderCostFactor) + case vardef.TiDBOptTableFullScanCostFactor: + s.varValues = append(s.varValues, vardef.DefOptTableFullScanCostFactor) + case vardef.TiDBOptTableRangeScanCostFactor: + s.varValues = append(s.varValues, vardef.DefOptTableRangeScanCostFactor) + case vardef.TiDBOptTableRowIDScanCostFactor: + s.varValues = append(s.varValues, vardef.DefOptTableRowIDScanCostFactor) + case vardef.TiDBOptTableTiFlashScanCostFactor: + s.varValues = append(s.varValues, vardef.DefOptTableTiFlashScanCostFactor) + case vardef.TiDBOptIndexLookupCostFactor: + s.varValues = append(s.varValues, vardef.DefOptIndexLookupCostFactor) + case vardef.TiDBOptIndexMergeCostFactor: + s.varValues = append(s.varValues, vardef.DefOptIndexMergeCostFactor) + case vardef.TiDBOptSortCostFactor: + s.varValues = append(s.varValues, vardef.DefOptSortCostFactor) + case vardef.TiDBOptTopNCostFactor: + s.varValues = append(s.varValues, vardef.DefOptTopNCostFactor) + case vardef.TiDBOptLimitCostFactor: + s.varValues = append(s.varValues, vardef.DefOptLimitCostFactor) + case vardef.TiDBOptStreamAggCostFactor: + s.varValues = append(s.varValues, vardef.DefOptStreamAggCostFactor) + case vardef.TiDBOptHashAggCostFactor: + s.varValues = append(s.varValues, vardef.DefOptHashAggCostFactor) + case vardef.TiDBOptMergeJoinCostFactor: + s.varValues = append(s.varValues, vardef.DefOptMergeJoinCostFactor) + case vardef.TiDBOptHashJoinCostFactor: + s.varValues = append(s.varValues, vardef.DefOptHashJoinCostFactor) + case vardef.TiDBOptIndexJoinCostFactor: + s.varValues = append(s.varValues, vardef.DefOptIndexJoinCostFactor) + case vardef.TiDBOptOrderingIdxSelRatio: + s.varValues = append(s.varValues, vardef.DefTiDBOptOrderingIdxSelRatio) + case vardef.TiDBOptRiskEqSkewRatio: + s.varValues = append(s.varValues, vardef.DefOptRiskEqSkewRatio) + case vardef.TiDBOptRiskRangeSkewRatio: + s.varValues = append(s.varValues, vardef.DefOptRiskRangeSkewRatio) + case vardef.TiDBOptRiskGroupNDVSkewRatio: + s.varValues = append(s.varValues, vardef.DefOptRiskGroupNDVSkewRatio) + case vardef.TiDBOptPreferRangeScan: + s.varValues = append(s.varValues, vardef.DefOptPreferRangeScan) + case vardef.TiDBOptEnableNoDecorrelateInSelect: + s.varValues = append(s.varValues, vardef.DefOptEnableNoDecorrelateInSelect) + case vardef.TiDBOptEnableSemiJoinRewrite: + s.varValues = append(s.varValues, vardef.DefOptEnableSemiJoinRewrite) + case vardef.TiDBOptAlwaysKeepJoinKey: + s.varValues = append(s.varValues, vardef.DefOptAlwaysKeepJoinKey) + case vardef.TiDBOptSelectivityFactor: + s.varValues = append(s.varValues, vardef.DefOptSelectivityFactor) + case vardef.TiDBOptCartesianJoinOrderThreshold: + s.varValues = append(s.varValues, vardef.DefOptCartesianJoinOrderThreshold) + case vardef.TiDBOptEnableAlternativeLogicalPlans: + s.varValues = append(s.varValues, vardef.DefOptEnableAlternativeLogicalPlans) + default: + return nil, fmt.Errorf("unsupported variable %s in plan generation", varName) + } + } + + for _, fixID := range fixes { + switch fixID { + case fixcontrol.Fix44855: + s.fixValues = append(s.fixValues, "OFF") + case fixcontrol.Fix45132: + s.fixValues = append(s.fixValues, "1000") + case fixcontrol.Fix52869: + s.fixValues = append(s.fixValues, "OFF") + default: + return nil, fmt.Errorf("unsupported fix-control %d in plan generation", fixID) + } + } + return s, nil +} + +type tableNameExtractor struct { + defaultSchema string + tableNames map[string]*tableName +} + +type selectOffsetAssigner struct { + offset int +} + +func (a *selectOffsetAssigner) Enter(in ast.Node) (node ast.Node, skipChildren bool) { + if sel, ok := in.(*ast.SelectStmt); ok { + a.offset++ + sel.QueryBlockOffset = a.offset + } + return in, false +} + +func (*selectOffsetAssigner) Leave(in ast.Node) (node ast.Node, ok bool) { + return in, true +} + +type subqueryOffsetExtractor struct { + offsets map[int]struct{} +} + +func (e *subqueryOffsetExtractor) Enter(in ast.Node) (node ast.Node, skipChildren bool) { + if subq, ok := in.(*ast.SubqueryExpr); ok { + collectSubqueryOffsets(subq.Query, e.offsets) + } + return in, false +} + +func (*subqueryOffsetExtractor) Leave(in ast.Node) (node ast.Node, ok bool) { + return in, true +} + +func collectSubqueryOffsets(node ast.ResultSetNode, offsets map[int]struct{}) { + if node == nil { + return + } + switch n := node.(type) { + case *ast.SelectStmt: + if n.QueryBlockOffset > 0 { + offsets[n.QueryBlockOffset] = struct{}{} + } + case *ast.SetOprStmt: + collectSubqueryOffsetsFromSelectList(n.SelectList, offsets) + } +} + +func collectSubqueryOffsetsFromSelectList(list *ast.SetOprSelectList, offsets map[int]struct{}) { + if list == nil { + return + } + for _, sel := range list.Selects { + switch n := sel.(type) { + case *ast.SelectStmt: + if n.QueryBlockOffset > 0 { + offsets[n.QueryBlockOffset] = struct{}{} + } + case *ast.SetOprStmt: + collectSubqueryOffsetsFromSelectList(n.SelectList, offsets) + } + } +} + +// Enter implements ast.Visitor interface. +func (e *tableNameExtractor) Enter(in ast.Node) (node ast.Node, skipChildren bool) { + if name, ok := in.(*ast.TableName); ok { + t := &tableName{ + schema: name.Schema.L, + name: name.Name.L, + } + if t.schema == "" { + t.schema = e.defaultSchema + } + if _, ok := e.tableNames[t.String()]; !ok { + e.tableNames[t.String()] = t + } + } + return in, false +} + +// Leave implements ast.Visitor interface. +func (*tableNameExtractor) Leave(in ast.Node) (node ast.Node, ok bool) { + return in, true +} + +// extractSelectTableNames returns the table names in the SELECT statement. +func extractSelectTableNames(defaultSchema string, node ast.StmtNode) []*tableName { + selStmt, isSel := node.(*ast.SelectStmt) + if !isSel { + return nil // only support SELECT statement for now + } + extractor := &tableNameExtractor{ + defaultSchema: defaultSchema, + tableNames: make(map[string]*tableName), + } + selStmt.Accept(extractor) + + names := make([]*tableName, 0, len(extractor.tableNames)) + for _, name := range extractor.tableNames { + names = append(names, name) + } + sort.Slice(names, func(i, j int) bool { + return names[i].String() < names[j].String() + }) + return names +} + +func extractNoDecorrelateQBs(node ast.StmtNode) []ast.CIStr { + selStmt, isSel := node.(*ast.SelectStmt) + if !isSel { + return nil // only support SELECT statement for now + } + + assigner := &selectOffsetAssigner{} + selStmt.Accept(assigner) + + extractor := &subqueryOffsetExtractor{offsets: make(map[int]struct{})} + selStmt.Accept(extractor) + + if len(extractor.offsets) == 0 { + return nil + } + + qbNames := make([]ast.CIStr, 0, len(extractor.offsets)) + topOffset := selStmt.QueryBlockOffset + for offset := range extractor.offsets { + if offset == topOffset { + continue + } + qbName, err := hint.GenerateQBName(hint.TypeSelect, offset) + if err != nil { + continue + } + qbNames = append(qbNames, qbName) + } + sort.Slice(qbNames, func(i, j int) bool { + return qbNames[i].L < qbNames[j].L + }) + return qbNames +} + +type predicateColumnExtractor struct { + table *tableName + columns map[string]struct{} + allowAnyTable bool +} + +func (e *predicateColumnExtractor) Enter(in ast.Node) (node ast.Node, skipChildren bool) { + switch n := in.(type) { + case *ast.SubqueryExpr: + // Only consider predicates in the current SELECT, skip inner queries to avoid mixing scopes. + return in, true + case *ast.ColumnNameExpr: + if n.Name == nil { + return in, false + } + if !e.allowAnyTable { + if n.Name.Table.L == "" { + return in, false + } + if !matchesColumnTable(e.table, n.Name) { + return in, false + } + } else if n.Name.Schema.L != "" && n.Name.Schema.L != e.table.schema { + return in, false + } + e.columns[n.Name.Name.L] = struct{}{} + } + return in, false +} + +func (*predicateColumnExtractor) Leave(in ast.Node) (node ast.Node, ok bool) { + return in, true +} + +func matchesColumnTable(target *tableName, name *ast.ColumnName) bool { + if target == nil || name == nil { + return false + } + if name.Schema.L != "" && name.Schema.L != target.schema { + return false + } + if name.Table.L == "" { + return false + } + if name.Table.L == target.name { + return true + } + return target.alias != "" && name.Table.L == target.alias +} + +func extractSelectTableNamesWithAlias(defaultSchema string, selStmt *ast.SelectStmt) []*tableName { + if selStmt.From == nil || selStmt.From.TableRefs == nil { + return nil + } + tables := make(map[string]*tableName) + var collectTable func(node ast.ResultSetNode) + collectTable = func(node ast.ResultSetNode) { + switch n := node.(type) { + case *ast.Join: + collectTable(n.Left) + collectTable(n.Right) + case *ast.TableSource: + alias := n.AsName.L + switch src := n.Source.(type) { + case *ast.TableName: + t := &tableName{ + schema: src.Schema.L, + name: src.Name.L, + alias: alias, + } + if t.schema == "" { + t.schema = defaultSchema + } + key := fmt.Sprintf("%s.%s:%s", t.schema, t.name, t.alias) + tables[key] = t + case *ast.Join: + collectTable(src) + } + } + } + collectTable(selStmt.From.TableRefs) + if len(tables) == 0 { + return nil + } + names := make([]*tableName, 0, len(tables)) + for _, name := range tables { + names = append(names, name) + } + sort.Slice(names, func(i, j int) bool { + return names[i].String() < names[j].String() + }) + return names +} + +func collectJoinPredicates(node ast.ResultSetNode, extractor *predicateColumnExtractor) { + switch n := node.(type) { + case *ast.Join: + if n.On != nil && n.On.Expr != nil { + n.On.Expr.Accept(extractor) + } + collectJoinPredicates(n.Left, extractor) + collectJoinPredicates(n.Right, extractor) + case *ast.TableSource: + if join, ok := n.Source.(*ast.Join); ok { + collectJoinPredicates(join, extractor) + } + } +} + +func extractSelectIndexHints(sctx sessionctx.Context, defaultSchema string, node ast.StmtNode) [][]*indexHint { + selStmt, isSel := node.(*ast.SelectStmt) + if !isSel { + return nil + } + tableNames := extractSelectTableNamesWithAlias(defaultSchema, selStmt) + if len(tableNames) == 0 { + return nil + } + allowAnyTable := len(tableNames) == 1 + hintOptions := make([][]*indexHint, 0, len(tableNames)) + for _, target := range tableNames { + options := make([]*indexHint, 0, 4) + options = append(options, nil) // empty option to avoid forcing index paths + extractor := &predicateColumnExtractor{ + table: target, + columns: make(map[string]struct{}), + allowAnyTable: allowAnyTable, + } + if selStmt.Where != nil { + selStmt.Where.Accept(extractor) + } + if selStmt.From != nil && selStmt.From.TableRefs != nil { + collectJoinPredicates(selStmt.From.TableRefs, extractor) + } + if len(extractor.columns) == 0 { + hintOptions = append(hintOptions, options) + continue + } + tblInfo, err := sctx.GetLatestInfoSchema().TableInfoByName(ast.NewCIStr(target.schema), ast.NewCIStr(target.name)) + if err != nil { + hintOptions = append(hintOptions, options) + continue + } + useInvisible := sctx.GetSessionVars().OptimizerUseInvisibleIndexes + seen := make(map[string]struct{}) + for _, index := range tblInfo.Indices { + if index.State != model.StatePublic { + continue + } + if !useInvisible && index.Invisible { + continue + } + if index.IsColumnarIndex() || index.InvertedInfo != nil { + continue + } + if tblInfo.IsCommonHandle && index.Primary { + continue + } + if len(index.Columns) == 0 { + continue + } + if _, ok := extractor.columns[index.Columns[0].Name.L]; !ok { + continue + } + hint := &indexHint{ + table: target, + index: index.Name.O, + } + if _, ok := seen[hint.index]; ok { + continue + } + seen[hint.index] = struct{}{} + options = append(options, hint) + } + hintOptions = append(hintOptions, options) + } + return hintOptions +} diff --git a/pkg/planner/core/BUILD.bazel b/pkg/planner/core/BUILD.bazel index 6dbe8bd8efca3..026fccb8cbde4 100644 --- a/pkg/planner/core/BUILD.bazel +++ b/pkg/planner/core/BUILD.bazel @@ -53,8 +53,12 @@ go_library( "rule_aggregation_elimination.go", "rule_aggregation_push_down.go", "rule_aggregation_skew_rewrite.go", +<<<<<<< HEAD "rule_collect_plan_stats.go", "rule_column_pruning.go", +======= + "rule_correlate.go", +>>>>>>> 7357a2e2f90 (planner: correlate subquery rule (#66206)) "rule_decorrelate.go", "rule_derive_topn_from_window.go", "rule_eliminate_projection.go", diff --git a/pkg/planner/core/casetest/rule/BUILD.bazel b/pkg/planner/core/casetest/rule/BUILD.bazel index 05152fed4cd66..294d2d7d5fdfc 100644 --- a/pkg/planner/core/casetest/rule/BUILD.bazel +++ b/pkg/planner/core/casetest/rule/BUILD.bazel @@ -5,6 +5,12 @@ go_test( timeout = "short", srcs = [ "main_test.go", +<<<<<<< HEAD +======= + "rule_cdc_join_reorder_test.go", + "rule_common_handle_ordering_test.go", + "rule_correlate_test.go", +>>>>>>> 7357a2e2f90 (planner: correlate subquery rule (#66206)) "rule_derive_topn_from_window_test.go", "rule_eliminate_projection_test.go", "rule_inject_extra_projection_test.go", @@ -15,7 +21,11 @@ go_test( ], data = glob(["testdata/**"]), flaky = True, +<<<<<<< HEAD shard_count = 12, +======= + shard_count = 23, +>>>>>>> 7357a2e2f90 (planner: correlate subquery rule (#66206)) deps = [ "//pkg/config", "//pkg/domain", diff --git a/pkg/planner/core/casetest/rule/main_test.go b/pkg/planner/core/casetest/rule/main_test.go index a2f34998b00c4..6cae27bc12bb3 100644 --- a/pkg/planner/core/casetest/rule/main_test.go +++ b/pkg/planner/core/casetest/rule/main_test.go @@ -29,11 +29,24 @@ var testDataMap = make(testdata.BookKeeper) func TestMain(m *testing.M) { testsetup.SetupForCommonTest() flag.Parse() +<<<<<<< HEAD testDataMap.LoadTestSuiteData("testdata", "outer2inner") testDataMap.LoadTestSuiteData("testdata", "derive_topn_from_window") testDataMap.LoadTestSuiteData("testdata", "join_reorder_suite") testDataMap.LoadTestSuiteData("testdata", "predicate_pushdown_suite") testDataMap.LoadTestSuiteData("testdata", "predicate_simplification") +======= + testDataMap.LoadTestSuiteData("testdata", "outer2inner", true) + testDataMap.LoadTestSuiteData("testdata", "derive_topn_from_window", true) + testDataMap.LoadTestSuiteData("testdata", "join_reorder_suite", true) + testDataMap.LoadTestSuiteData("testdata", "predicate_pushdown_suite", true) + testDataMap.LoadTestSuiteData("testdata", "predicate_simplification", true) + testDataMap.LoadTestSuiteData("testdata", "outer_to_semi_join_suite", true) + testDataMap.LoadTestSuiteData("testdata", "correlate_suite", true) + testDataMap.LoadTestSuiteData("testdata", "cdc_join_reorder_suite", true) + testDataMap.LoadTestSuiteData("testdata", "order_aware_join_reorder_suite", true) + +>>>>>>> 7357a2e2f90 (planner: correlate subquery rule (#66206)) opts := []goleak.Option{ goleak.IgnoreTopFunction("github.com/golang/glog.(*fileSink).flushDaemon"), goleak.IgnoreTopFunction("github.com/bazelbuild/rules_go/go/tools/bzltestutil.RegisterTimeoutHandler.func1"), @@ -71,3 +84,22 @@ func GetPredicatePushdownSuiteData() testdata.TestData { func GetPredicateSimplificationSuiteData() testdata.TestData { return testDataMap["predicate_simplification"] } +<<<<<<< HEAD +======= + +func GetOuterToSemiJoinSuiteData() testdata.TestData { + return testDataMap["outer_to_semi_join_suite"] +} + +func GetCorrelateSuiteData() testdata.TestData { + return testDataMap["correlate_suite"] +} + +func GetCDCJoinReorderSuiteData() testdata.TestData { + return testDataMap["cdc_join_reorder_suite"] +} + +func GetOrderAwareJoinReorderSuiteData() testdata.TestData { + return testDataMap["order_aware_join_reorder_suite"] +} +>>>>>>> 7357a2e2f90 (planner: correlate subquery rule (#66206)) diff --git a/pkg/planner/core/casetest/rule/rule_correlate_test.go b/pkg/planner/core/casetest/rule/rule_correlate_test.go new file mode 100644 index 0000000000000..0c3a1b61d2f0c --- /dev/null +++ b/pkg/planner/core/casetest/rule/rule_correlate_test.go @@ -0,0 +1,250 @@ +// Copyright 2026 PingCAP, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package rule + +import ( + "fmt" + "strings" + "testing" + + "github.com/pingcap/tidb/pkg/testkit" + "github.com/pingcap/tidb/pkg/testkit/testdata" + "github.com/stretchr/testify/require" +) + +// TestCorrelateNullSemantics verifies that CorrelateSolver does not break +// 3-valued NULL semantics for scalar IN (LeftOuterSemiJoin). +func TestCorrelateNullSemantics(t *testing.T) { + store := testkit.CreateMockStore(t) + tk := testkit.NewTestKit(t, store) + tk.MustExec("use test") + tk.MustExec("set tidb_opt_enable_alternative_logical_plans = ON") + + // Case 1: non-null outer, null inner → must return NULL (not 0). + tk.MustExec("drop table if exists tn, sn") + tk.MustExec("create table tn(a int)") + tk.MustExec("create table sn(a int, key(a))") + tk.MustExec("insert into tn values (1)") + tk.MustExec("insert into sn values (null)") + tk.MustQuery("select tn.a in (select sn.a from sn) as r from tn").Check(testkit.Rows("")) + + // Case 2: null outer, non-null inner → must return NULL (not 0). + tk.MustExec("truncate table tn") + tk.MustExec("truncate table sn") + tk.MustExec("insert into tn values (null)") + tk.MustExec("insert into sn values (1)") + tk.MustQuery("select tn.a in (select sn.a from sn) as r from tn").Check(testkit.Rows("")) + + // Case 3: both columns NOT NULL → correlate is safe; verify correct results. + tk.MustExec("drop table if exists tnn, snn") + tk.MustExec("create table tnn(a int not null)") + tk.MustExec("create table snn(a int not null, key(a))") + tk.MustExec("insert into tnn values (1), (2), (3)") + tk.MustExec("insert into snn values (1), (2)") + tk.MustQuery("select tnn.a in (select snn.a from snn) as r from tnn order by tnn.a").Check(testkit.Rows("1", "1", "0")) +} + +// TestCorrelateAlternativeChoosesApply verifies that the correlate alternative +// round produces an Apply plan that wins the cost comparison for a non-correlated +// IN subquery when an outer WHERE predicate reduces the estimated row count. +// Without alternative plans, the InnerJoin+Agg rewrite produces IndexJoin+StreamAgg. +// With alternative plans, the correlate round produces Apply+Limit which is cheaper +// (avoids the StreamAgg overhead and uses Limit 1 for early exit on the inner side). +func TestCorrelateAlternativeChoosesApply(t *testing.T) { + store := testkit.CreateMockStore(t) + tk := testkit.NewTestKit(t, store) + tk.MustExec("use test") + + tk.MustExec("drop table if exists t1, t2") + tk.MustExec("create table t1 (a int not null, b int, key(a))") + tk.MustExec("create table t2 (a int not null, b int, key(a))") + tk.MustExec("insert into t1 values (1,1),(2,2),(3,3)") + tk.MustExec("insert into t2 values (1,10),(2,20)") + + sql := "select * from t1 where b = 1 and a in (select a from t2)" + + // Without alternative plans: standard InnerJoin+Agg path produces IndexJoin. + tk.MustExec("set tidb_opt_enable_alternative_logical_plans = OFF") + rows := tk.MustQuery("explain format = 'brief' " + sql).Rows() + require.True(t, explainContains(rows, "IndexJoin"), + "without alternative plans, expected IndexJoin in plan:\n%s", joinExplainRows(rows)) + + // With alternative plans: correlate round produces Apply (cheaper than IndexJoin+StreamAgg). + tk.MustExec("set tidb_opt_enable_alternative_logical_plans = ON") + rows = tk.MustQuery("explain format = 'brief' " + sql).Rows() + require.True(t, explainContains(rows, "Apply"), + "with alternative plans, expected Apply in plan:\n%s", joinExplainRows(rows)) + + // Verify correct results in both modes. + tk.MustExec("set tidb_opt_enable_alternative_logical_plans = OFF") + tk.MustQuery(sql).Check(testkit.Rows("1 1")) + tk.MustExec("set tidb_opt_enable_alternative_logical_plans = ON") + tk.MustQuery(sql).Check(testkit.Rows("1 1")) +} + +func TestCorrelate(tt *testing.T) { + testkit.RunTestUnderCascades(tt, func(t *testing.T, tk *testkit.TestKit, cascades, caller string) { + tk.MustExec("use test") + tk.MustExec("drop table if exists t1, t2, t3") + tk.MustExec("create table t1 (a int, b int, key(a))") + tk.MustExec("create table t2 (a int, b int, key(a))") + tk.MustExec("create table t3 (a int, b int, key(a))") + tk.MustExec("insert into t1 values (1,1),(2,2),(3,3)") + tk.MustExec("insert into t2 values (1,10),(2,20)") + tk.MustExec("insert into t3 values (10,1),(20,2)") + + // Enable the correlate rule. + tk.MustExec("set tidb_opt_enable_alternative_logical_plans = ON") + + var input []string + var output []struct { + SQL string + Plan []string + Result []string + } + suite := GetCorrelateSuiteData() + suite.LoadTestCases(t, &input, &output, cascades, caller) + for i, sql := range input { + testdata.OnRecord(func() { + output[i].SQL = sql + output[i].Plan = testdata.ConvertRowsToStrings(tk.MustQuery("explain format = 'brief' " + sql).Rows()) + output[i].Result = testdata.ConvertRowsToStrings(tk.MustQuery(sql).Rows()) + }) + tk.MustQuery("explain format = 'brief' " + sql).Check(testkit.Rows(output[i].Plan...)) + tk.MustQuery(sql).Check(testkit.Rows(output[i].Result...)) + } + }) +} + +// explainContains scans all explain rows for a substring in the operator column. +func explainContains(rows [][]any, substr string) bool { + for _, row := range rows { + if strings.Contains(row[0].(string), substr) { + return true + } + } + return false +} + +// joinExplainRows formats explain rows into a single string for debug output. +func joinExplainRows(rows [][]any) string { + var sb strings.Builder + for _, row := range rows { + sb.WriteString(row[0].(string)) + sb.WriteByte('\n') + } + return sb.String() +} + +// TestCorrelateParallelApply verifies that when the correlate alternative round +// produces an Apply plan and tidb_enable_parallel_apply is ON, the Apply is +// executed with parallel concurrency. This tests the interaction between the +// correlate optimization (converting decorrelated semi-join back to Apply) and +// the parallel apply executor. +func TestCorrelateParallelApply(t *testing.T) { + store := testkit.CreateMockStore(t) + tk := testkit.NewTestKit(t, store) + tk.MustExec("use test") + + tk.MustExec("drop table if exists t1, t2") + tk.MustExec("create table t1 (a int not null, b int, key(a))") + tk.MustExec("create table t2 (a int not null, b int, key(a))") + tk.MustExec("insert into t1 values (1,1),(2,2),(3,3),(4,4),(5,5)") + tk.MustExec("insert into t2 values (1,10),(2,20),(3,30)") + + sql := "select * from t1 where b = 1 and a in (select a from t2)" + + // Enable correlate alternative + parallel apply. + tk.MustExec("set tidb_opt_enable_alternative_logical_plans = ON") + tk.MustExec("set tidb_enable_parallel_apply = ON") + tk.MustExec("set tidb_executor_concurrency = 5") + + // Verify the plan contains Apply (correlate alternative won). + rows := tk.MustQuery("explain format = 'brief' " + sql).Rows() + require.True(t, explainContains(rows, "Apply"), + "with correlate alternative + parallel apply, expected Apply in plan:\n%s", joinExplainRows(rows)) + + // Verify EXPLAIN ANALYZE reports Concurrency > 1 for the Apply. + analyzeRows := tk.MustQuery("explain analyze " + sql).Rows() + foundConcurrency := false + for _, row := range analyzeRows { + line := fmt.Sprintf("%v", row) + if strings.Contains(line, "Apply") && strings.Contains(line, "Concurrency:") { + idx := strings.Index(line, "Concurrency:") + if idx >= 0 { + rest := line[idx+len("Concurrency:"):] + var n int + if _, err := fmt.Sscanf(rest, "%d", &n); err == nil && n > 1 { + foundConcurrency = true + } + } + break + } + } + require.True(t, foundConcurrency, + "EXPLAIN ANALYZE must report Concurrency > 1 for Apply when parallel_apply is on") + + // Verify correctness: parallel + correlate must match serial + no correlate. + tk.MustExec("set tidb_enable_parallel_apply = OFF") + tk.MustExec("set tidb_opt_enable_alternative_logical_plans = OFF") + serialRows := tk.MustQuery(sql).Rows() + + tk.MustExec("set tidb_enable_parallel_apply = ON") + tk.MustExec("set tidb_opt_enable_alternative_logical_plans = ON") + parallelRows := tk.MustQuery(sql).Rows() + + require.Equal(t, serialRows, parallelRows, + "correlate alternative + parallel apply must produce the same result as standard path") +} + +// TestCorrelateWithCostFactors verifies that when hash/merge join cost factors +// are increased, the correlate alternative round wins and produces Apply-based +// plans with correlated index access for cases that normally choose HashJoin. +func TestCorrelateWithCostFactors(tt *testing.T) { + testkit.RunTestUnderCascades(tt, func(t *testing.T, tk *testkit.TestKit, cascades, caller string) { + tk.MustExec("use test") + tk.MustExec("drop table if exists t1, t2, t3") + tk.MustExec("create table t1 (a int, b int, key(a))") + tk.MustExec("create table t2 (a int, b int, key(a))") + tk.MustExec("create table t3 (a int, b int, key(a))") + tk.MustExec("insert into t1 values (1,1),(2,2),(3,3)") + tk.MustExec("insert into t2 values (1,10),(2,20)") + tk.MustExec("insert into t3 values (10,1),(20,2)") + + // Enable the correlate rule and penalize hash/merge joins so the + // correlate alternative (Apply with index lookup) wins the cost comparison. + tk.MustExec("set tidb_opt_enable_alternative_logical_plans = ON") + tk.MustExec("set tidb_opt_hash_join_cost_factor = 1000") + tk.MustExec("set tidb_opt_merge_join_cost_factor = 1000") + + var input []string + var output []struct { + SQL string + Plan []string + Result []string + } + suite := GetCorrelateSuiteData() + suite.LoadTestCases(t, &input, &output, cascades, caller) + for i, sql := range input { + testdata.OnRecord(func() { + output[i].SQL = sql + output[i].Plan = testdata.ConvertRowsToStrings(tk.MustQuery("explain format = 'brief' " + sql).Rows()) + output[i].Result = testdata.ConvertRowsToStrings(tk.MustQuery(sql).Rows()) + }) + tk.MustQuery("explain format = 'brief' " + sql).Check(testkit.Rows(output[i].Plan...)) + tk.MustQuery(sql).Check(testkit.Rows(output[i].Result...)) + } + }) +} diff --git a/pkg/planner/core/casetest/rule/testdata/correlate_suite_in.json b/pkg/planner/core/casetest/rule/testdata/correlate_suite_in.json new file mode 100644 index 0000000000000..b336f9a8cee93 --- /dev/null +++ b/pkg/planner/core/casetest/rule/testdata/correlate_suite_in.json @@ -0,0 +1,39 @@ +[ + { + "name": "TestCorrelate", + "cases": [ + "select * from t1 where exists (select 1 from t2 where t2.a = t1.a)", + "select * from t1 where not exists (select 1 from t2 where t2.a = t1.a)", + "select * from t1 where a in (select a from t2)", + "select * from t1 where exists (select 1 from t2)", + "select * from t1 where a not in (select a from t2)", + "select * from t1 where exists (select 1 from t2 where t2.a > t1.a)", + "select * from t1 where exists (select 1 from t2 where t2.a = t1.a and t2.b > t1.b)", + "select * from t1 where exists (select /*+ NO_DECORRELATE() */ 1 from t2 where t2.a = t1.a)", + "select * from t1 where a in (select t2.a from t2 inner join t3 on t3.a = t2.b where t3.b > 0)", + "select * from t1 where a in (select a from t2) order by a limit 10", + "select * from t1 where a in (select a from t2 where b > 1)", + "select * from t1 where a in (select a from t2 group by a)", + "select * from t1 where a in (select a from t2 where b > 1 group by a)", + "select * from t1 where a in (select a from t2 limit 10)", + "select * from t1 where a in (select a from t2 order by a limit 10)", + "select * from t1 where b = 1 and a in (select a from t2)", + "select * from t1 where b = 1 and exists (select 1 from t2 where t2.a = t1.a) limit 1", + "select * from t1 where b = 1 and a not in (select a from t2) limit 1", + "select * from t1 where b = 1 and a in (select a from t2 where t2.b > 0) limit 1" + ] + }, + { + "name": "TestCorrelateWithCostFactors", + "cases": [ + "select * from t1 where exists (select 1 from t2 where t2.a = t1.a)", + "select * from t1 where not exists (select 1 from t2 where t2.a = t1.a)", + "select * from t1 where a in (select a from t2)", + "select * from t1 where exists (select 1 from t2 where t2.a > t1.a)", + "select * from t1 where exists (select 1 from t2 where t2.a = t1.a and t2.b > t1.b)", + "select * from t1 where a in (select a from t2) order by a limit 10", + "select * from t1 where a in (select a from t2 where b > 1)", + "select * from t1 where a in (select a from t2 order by a limit 10)" + ] + } +] diff --git a/pkg/planner/core/casetest/rule/testdata/correlate_suite_out.json b/pkg/planner/core/casetest/rule/testdata/correlate_suite_out.json new file mode 100644 index 0000000000000..ee0fe604da9fb --- /dev/null +++ b/pkg/planner/core/casetest/rule/testdata/correlate_suite_out.json @@ -0,0 +1,466 @@ +[ + { + "Name": "TestCorrelate", + "Cases": [ + { + "SQL": "select * from t1 where exists (select 1 from t2 where t2.a = t1.a)", + "Plan": [ + "HashJoin 7992.00 root semi join, left side:TableReader, equal:[eq(test.t1.a, test.t2.a)]", + "├─IndexReader(Build) 9990.00 root index:IndexFullScan", + "│ └─IndexFullScan 9990.00 cop[tikv] table:t2, index:a(a) keep order:false, stats:pseudo", + "└─TableReader(Probe) 9990.00 root data:Selection", + " └─Selection 9990.00 cop[tikv] not(isnull(test.t1.a))", + " └─TableFullScan 10000.00 cop[tikv] table:t1 keep order:false, stats:pseudo" + ], + "Result": [ + "1 1", + "2 2" + ] + }, + { + "SQL": "select * from t1 where not exists (select 1 from t2 where t2.a = t1.a)", + "Plan": [ + "HashJoin 8000.00 root anti semi join, left side:TableReader, equal:[eq(test.t1.a, test.t2.a)]", + "├─IndexReader(Build) 10000.00 root index:IndexFullScan", + "│ └─IndexFullScan 10000.00 cop[tikv] table:t2, index:a(a) keep order:false, stats:pseudo", + "└─TableReader(Probe) 10000.00 root data:TableFullScan", + " └─TableFullScan 10000.00 cop[tikv] table:t1 keep order:false, stats:pseudo" + ], + "Result": [ + "3 3" + ] + }, + { + "SQL": "select * from t1 where a in (select a from t2)", + "Plan": [ + "HashJoin 9990.00 root inner join, equal:[eq(test.t1.a, test.t2.a)]", + "├─StreamAgg(Build) 7992.00 root group by:test.t2.a, funcs:firstrow(test.t2.a)->test.t2.a", + "│ └─IndexReader 7992.00 root index:StreamAgg", + "│ └─StreamAgg 7992.00 cop[tikv] group by:test.t2.a, ", + "│ └─IndexFullScan 9990.00 cop[tikv] table:t2, index:a(a) keep order:true, stats:pseudo", + "└─TableReader(Probe) 9990.00 root data:Selection", + " └─Selection 9990.00 cop[tikv] not(isnull(test.t1.a))", + " └─TableFullScan 10000.00 cop[tikv] table:t1 keep order:false, stats:pseudo" + ], + "Result": [ + "1 1", + "2 2" + ] + }, + { + "SQL": "select * from t1 where exists (select 1 from t2)", + "Plan": [ + "TableReader 10000.00 root data:TableFullScan", + "└─TableFullScan 10000.00 cop[tikv] table:t1 keep order:false, stats:pseudo", + "ScalarSubQuery N/A root Output: ScalarQueryCol#10, ScalarQueryCol#11, ScalarQueryCol#12, ScalarQueryCol#13", + "└─TableReader 10000.00 root data:TableFullScan", + " └─TableFullScan 10000.00 cop[tikv] table:t2 keep order:false, stats:pseudo" + ], + "Result": [ + "1 1", + "2 2", + "3 3" + ] + }, + { + "SQL": "select * from t1 where a not in (select a from t2)", + "Plan": [ + "HashJoin 8000.00 root Null-aware anti semi join, left side:TableReader, equal:[eq(test.t1.a, test.t2.a)]", + "├─IndexReader(Build) 10000.00 root index:IndexFullScan", + "│ └─IndexFullScan 10000.00 cop[tikv] table:t2, index:a(a) keep order:false, stats:pseudo", + "└─TableReader(Probe) 10000.00 root data:TableFullScan", + " └─TableFullScan 10000.00 cop[tikv] table:t1 keep order:false, stats:pseudo" + ], + "Result": [ + "3 3" + ] + }, + { + "SQL": "select * from t1 where exists (select 1 from t2 where t2.a > t1.a)", + "Plan": [ + "HashJoin 7992.00 root CARTESIAN semi join, left side:TableReader, other cond:gt(test.t2.a, test.t1.a)", + "├─IndexReader(Build) 9990.00 root index:IndexFullScan", + "│ └─IndexFullScan 9990.00 cop[tikv] table:t2, index:a(a) keep order:false, stats:pseudo", + "└─TableReader(Probe) 9990.00 root data:Selection", + " └─Selection 9990.00 cop[tikv] not(isnull(test.t1.a))", + " └─TableFullScan 10000.00 cop[tikv] table:t1 keep order:false, stats:pseudo" + ], + "Result": [ + "1 1" + ] + }, + { + "SQL": "select * from t1 where exists (select 1 from t2 where t2.a = t1.a and t2.b > t1.b)", + "Plan": [ + "HashJoin 7984.01 root semi join, left side:TableReader, equal:[eq(test.t1.a, test.t2.a)], other cond:gt(test.t2.b, test.t1.b)", + "├─TableReader(Build) 9980.01 root data:Selection", + "│ └─Selection 9980.01 cop[tikv] not(isnull(test.t2.a)), not(isnull(test.t2.b))", + "│ └─TableFullScan 10000.00 cop[tikv] table:t2 keep order:false, stats:pseudo", + "└─TableReader(Probe) 9980.01 root data:Selection", + " └─Selection 9980.01 cop[tikv] not(isnull(test.t1.a)), not(isnull(test.t1.b))", + " └─TableFullScan 10000.00 cop[tikv] table:t1 keep order:false, stats:pseudo" + ], + "Result": [ + "1 1", + "2 2" + ] + }, + { + "SQL": "select * from t1 where exists (select /*+ NO_DECORRELATE() */ 1 from t2 where t2.a = t1.a)", + "Plan": [ + "Apply 8000.00 root CARTESIAN semi join, left side:TableReader", + "├─TableReader(Build) 10000.00 root data:TableFullScan", + "│ └─TableFullScan 10000.00 cop[tikv] table:t1 keep order:false, stats:pseudo", + "└─Limit(Probe) 10000.00 root offset:0, count:1", + " └─IndexReader 10000.00 root index:Limit", + " └─Limit 10000.00 cop[tikv] offset:0, count:1", + " └─IndexRangeScan 10000.00 cop[tikv] table:t2, index:a(a) range: decided by [eq(test.t2.a, test.t1.a)], keep order:false, stats:pseudo" + ], + "Result": [ + "1 1", + "2 2" + ] + }, + { + "SQL": "select * from t1 where a in (select t2.a from t2 inner join t3 on t3.a = t2.b where t3.b > 0)", + "Plan": [ + "HashJoin 5203.12 root inner join, equal:[eq(test.t1.a, test.t2.a)]", + "├─HashAgg(Build) 4162.50 root group by:test.t2.a, funcs:firstrow(test.t2.a)->test.t2.a", + "│ └─HashJoin 4162.50 root inner join, equal:[eq(test.t3.a, test.t2.b)]", + "│ ├─TableReader(Build) 3330.00 root data:Selection", + "│ │ └─Selection 3330.00 cop[tikv] gt(test.t3.b, 0), not(isnull(test.t3.a))", + "│ │ └─TableFullScan 10000.00 cop[tikv] table:t3 keep order:false, stats:pseudo", + "│ └─TableReader(Probe) 9980.01 root data:Selection", + "│ └─Selection 9980.01 cop[tikv] not(isnull(test.t2.a)), not(isnull(test.t2.b))", + "│ └─TableFullScan 10000.00 cop[tikv] table:t2 keep order:false, stats:pseudo", + "└─TableReader(Probe) 9990.00 root data:Selection", + " └─Selection 9990.00 cop[tikv] not(isnull(test.t1.a))", + " └─TableFullScan 10000.00 cop[tikv] table:t1 keep order:false, stats:pseudo" + ], + "Result": [ + "1 1", + "2 2" + ] + }, + { + "SQL": "select * from t1 where a in (select a from t2) order by a limit 10", + "Plan": [ + "Limit 10.00 root offset:0, count:10", + "└─MergeJoin 10.00 root inner join, left key:test.t1.a, right key:test.t2.a", + " ├─StreamAgg(Build) 8.00 root group by:test.t2.a, funcs:firstrow(test.t2.a)->test.t2.a", + " │ └─IndexReader 8.00 root index:StreamAgg", + " │ └─StreamAgg 8.00 cop[tikv] group by:test.t2.a, ", + " │ └─IndexFullScan 10.00 cop[tikv] table:t2, index:a(a) keep order:true, stats:pseudo", + " └─Projection(Probe) 10.00 root test.t1.a, test.t1.b", + " └─IndexLookUp 10.00 root ", + " ├─IndexFullScan(Build) 10.00 cop[tikv] table:t1, index:a(a) keep order:true, stats:pseudo", + " └─TableRowIDScan(Probe) 10.00 cop[tikv] table:t1 keep order:false, stats:pseudo" + ], + "Result": [ + "1 1", + "2 2" + ] + }, + { + "SQL": "select * from t1 where a in (select a from t2 where b > 1)", + "Plan": [ + "HashJoin 3330.00 root inner join, equal:[eq(test.t2.a, test.t1.a)]", + "├─HashAgg(Build) 2664.00 root group by:test.t2.a, funcs:firstrow(test.t2.a)->test.t2.a", + "│ └─TableReader 2664.00 root data:HashAgg", + "│ └─HashAgg 2664.00 cop[tikv] group by:test.t2.a, ", + "│ └─Selection 3330.00 cop[tikv] gt(test.t2.b, 1), not(isnull(test.t2.a))", + "│ └─TableFullScan 10000.00 cop[tikv] table:t2 keep order:false, stats:pseudo", + "└─TableReader(Probe) 9990.00 root data:Selection", + " └─Selection 9990.00 cop[tikv] not(isnull(test.t1.a))", + " └─TableFullScan 10000.00 cop[tikv] table:t1 keep order:false, stats:pseudo" + ], + "Result": [ + "1 1", + "2 2" + ] + }, + { + "SQL": "select * from t1 where a in (select a from t2 group by a)", + "Plan": [ + "HashJoin 9990.00 root inner join, equal:[eq(test.t1.a, test.t2.a)]", + "├─StreamAgg(Build) 7992.00 root group by:test.t2.a, funcs:firstrow(test.t2.a)->test.t2.a", + "│ └─IndexReader 7992.00 root index:StreamAgg", + "│ └─StreamAgg 7992.00 cop[tikv] group by:test.t2.a, ", + "│ └─IndexFullScan 9990.00 cop[tikv] table:t2, index:a(a) keep order:true, stats:pseudo", + "└─TableReader(Probe) 9990.00 root data:Selection", + " └─Selection 9990.00 cop[tikv] not(isnull(test.t1.a))", + " └─TableFullScan 10000.00 cop[tikv] table:t1 keep order:false, stats:pseudo" + ], + "Result": [ + "1 1", + "2 2" + ] + }, + { + "SQL": "select * from t1 where a in (select a from t2 where b > 1 group by a)", + "Plan": [ + "HashJoin 3330.00 root inner join, equal:[eq(test.t2.a, test.t1.a)]", + "├─HashAgg(Build) 2664.00 root group by:test.t2.a, funcs:firstrow(test.t2.a)->test.t2.a", + "│ └─TableReader 2664.00 root data:HashAgg", + "│ └─HashAgg 2664.00 cop[tikv] group by:test.t2.a, ", + "│ └─Selection 3330.00 cop[tikv] gt(test.t2.b, 1), not(isnull(test.t2.a))", + "│ └─TableFullScan 10000.00 cop[tikv] table:t2 keep order:false, stats:pseudo", + "└─TableReader(Probe) 9990.00 root data:Selection", + " └─Selection 9990.00 cop[tikv] not(isnull(test.t1.a))", + " └─TableFullScan 10000.00 cop[tikv] table:t1 keep order:false, stats:pseudo" + ], + "Result": [ + "1 1", + "2 2" + ] + }, + { + "SQL": "select * from t1 where a in (select a from t2 limit 10)", + "Plan": [ + "IndexHashJoin 10.00 root inner join, inner:IndexLookUp, outer key:test.t2.a, inner key:test.t1.a, equal cond:eq(test.t2.a, test.t1.a)", + "├─HashAgg(Build) 8.00 root group by:test.t2.a, funcs:firstrow(test.t2.a)->test.t2.a", + "│ └─Selection 8.00 root not(isnull(test.t2.a))", + "│ └─Limit 10.00 root offset:0, count:10", + "│ └─IndexReader 10.00 root index:Limit", + "│ └─Limit 10.00 cop[tikv] offset:0, count:10", + "│ └─IndexFullScan 10.00 cop[tikv] table:t2, index:a(a) keep order:false, stats:pseudo", + "└─IndexLookUp(Probe) 10.00 root ", + " ├─Selection(Build) 10.00 cop[tikv] not(isnull(test.t1.a))", + " │ └─IndexRangeScan 10.01 cop[tikv] table:t1, index:a(a) range: decided by [eq(test.t1.a, test.t2.a)], keep order:false, stats:pseudo", + " └─TableRowIDScan(Probe) 10.00 cop[tikv] table:t1 keep order:false, stats:pseudo" + ], + "Result": [ + "1 1", + "2 2" + ] + }, + { + "SQL": "select * from t1 where a in (select a from t2 order by a limit 10)", + "Plan": [ + "IndexHashJoin 10.00 root inner join, inner:IndexLookUp, outer key:test.t2.a, inner key:test.t1.a, equal cond:eq(test.t2.a, test.t1.a)", + "├─StreamAgg(Build) 8.00 root group by:test.t2.a, funcs:firstrow(test.t2.a)->test.t2.a", + "│ └─Selection 8.00 root not(isnull(test.t2.a))", + "│ └─Limit 10.00 root offset:0, count:10", + "│ └─IndexReader 10.00 root index:Limit", + "│ └─Limit 10.00 cop[tikv] offset:0, count:10", + "│ └─IndexFullScan 10.00 cop[tikv] table:t2, index:a(a) keep order:true, stats:pseudo", + "└─IndexLookUp(Probe) 10.00 root ", + " ├─Selection(Build) 10.00 cop[tikv] not(isnull(test.t1.a))", + " │ └─IndexRangeScan 10.01 cop[tikv] table:t1, index:a(a) range: decided by [eq(test.t1.a, test.t2.a)], keep order:false, stats:pseudo", + " └─TableRowIDScan(Probe) 10.00 cop[tikv] table:t1 keep order:false, stats:pseudo" + ], + "Result": [ + "1 1", + "2 2" + ] + }, + { + "SQL": "select * from t1 where b = 1 and a in (select a from t2)", + "Plan": [ + "IndexJoin 9.99 root inner join, inner:StreamAgg, outer key:test.t1.a, inner key:test.t2.a, equal cond:eq(test.t1.a, test.t2.a)", + "├─TableReader(Build) 9.99 root data:Selection", + "│ └─Selection 9.99 cop[tikv] eq(test.t1.b, 1), not(isnull(test.t1.a))", + "│ └─TableFullScan 10000.00 cop[tikv] table:t1 keep order:false, stats:pseudo", + "└─StreamAgg(Probe) 9.99 root group by:test.t2.a, funcs:firstrow(test.t2.a)->test.t2.a", + " └─IndexReader 9.99 root index:Selection", + " └─Selection 9.99 cop[tikv] not(isnull(test.t2.a))", + " └─IndexRangeScan 10.00 cop[tikv] table:t2, index:a(a) range: decided by [eq(test.t2.a, test.t1.a)], keep order:true, stats:pseudo" + ], + "Result": [ + "1 1" + ] + }, + { + "SQL": "select * from t1 where b = 1 and exists (select 1 from t2 where t2.a = t1.a) limit 1", + "Plan": [ + "Limit 1.00 root offset:0, count:1", + "└─IndexHashJoin 1.00 root semi join, inner:IndexReader, left side:TableReader, outer key:test.t1.a, inner key:test.t2.a, equal cond:eq(test.t1.a, test.t2.a)", + " ├─TableReader(Build) 1.25 root data:Selection", + " │ └─Selection 1.25 cop[tikv] eq(test.t1.b, 1), not(isnull(test.t1.a))", + " │ └─TableFullScan 1251.25 cop[tikv] table:t1 keep order:false, stats:pseudo", + " └─IndexReader(Probe) 1.56 root index:Selection", + " └─Selection 1.56 cop[tikv] not(isnull(test.t2.a))", + " └─IndexRangeScan 1.56 cop[tikv] table:t2, index:a(a) range: decided by [eq(test.t2.a, test.t1.a)], keep order:false, stats:pseudo" + ], + "Result": [ + "1 1" + ] + }, + { + "SQL": "select * from t1 where b = 1 and a not in (select a from t2) limit 1", + "Plan": [ + "Limit 1.00 root offset:0, count:1", + "└─HashJoin 1.00 root Null-aware anti semi join, left side:TableReader, equal:[eq(test.t1.a, test.t2.a)]", + " ├─IndexReader(Build) 10000.00 root index:IndexFullScan", + " │ └─IndexFullScan 10000.00 cop[tikv] table:t2, index:a(a) keep order:false, stats:pseudo", + " └─TableReader(Probe) 1.25 root data:Selection", + " └─Selection 1.25 cop[tikv] eq(test.t1.b, 1)", + " └─TableFullScan 1250.00 cop[tikv] table:t1 keep order:false, stats:pseudo" + ], + "Result": null + }, + { + "SQL": "select * from t1 where b = 1 and a in (select a from t2 where t2.b > 0) limit 1", + "Plan": [ + "Limit 1.00 root offset:0, count:1", + "└─IndexJoin 1.00 root inner join, inner:StreamAgg, outer key:test.t1.a, inner key:test.t2.a, equal cond:eq(test.t1.a, test.t2.a)", + " ├─TableReader(Build) 1.00 root data:Selection", + " │ └─Selection 1.00 cop[tikv] eq(test.t1.b, 1), not(isnull(test.t1.a))", + " │ └─TableFullScan 1001.00 cop[tikv] table:t1 keep order:false, stats:pseudo", + " └─StreamAgg(Probe) 1.00 root group by:test.t2.a, funcs:firstrow(test.t2.a)->test.t2.a", + " └─Projection 1.00 root test.t2.a, test.t2.b", + " └─IndexLookUp 1.00 root ", + " ├─Selection(Build) 3.00 cop[tikv] not(isnull(test.t2.a))", + " │ └─IndexRangeScan 3.00 cop[tikv] table:t2, index:a(a) range: decided by [eq(test.t2.a, test.t1.a)], keep order:true, stats:pseudo", + " └─Selection(Probe) 1.00 cop[tikv] gt(test.t2.b, 0)", + " └─TableRowIDScan 3.00 cop[tikv] table:t2 keep order:false, stats:pseudo" + ], + "Result": [ + "1 1" + ] + } + ] + }, + { + "Name": "TestCorrelateWithCostFactors", + "Cases": [ + { + "SQL": "select * from t1 where exists (select 1 from t2 where t2.a = t1.a)", + "Plan": [ + "IndexHashJoin 7992.00 root semi join, inner:IndexReader, left side:TableReader, outer key:test.t1.a, inner key:test.t2.a, equal cond:eq(test.t1.a, test.t2.a)", + "├─TableReader(Build) 9990.00 root data:Selection", + "│ └─Selection 9990.00 cop[tikv] not(isnull(test.t1.a))", + "│ └─TableFullScan 10000.00 cop[tikv] table:t1 keep order:false, stats:pseudo", + "└─IndexReader(Probe) 12487.50 root index:Selection", + " └─Selection 12487.50 cop[tikv] not(isnull(test.t2.a))", + " └─IndexRangeScan 12500.00 cop[tikv] table:t2, index:a(a) range: decided by [eq(test.t2.a, test.t1.a)], keep order:false, stats:pseudo" + ], + "Result": [ + "1 1", + "2 2" + ] + }, + { + "SQL": "select * from t1 where not exists (select 1 from t2 where t2.a = t1.a)", + "Plan": [ + "IndexHashJoin 8000.00 root anti semi join, inner:IndexReader, left side:TableReader, outer key:test.t1.a, inner key:test.t2.a, equal cond:eq(test.t1.a, test.t2.a)", + "├─TableReader(Build) 10000.00 root data:TableFullScan", + "│ └─TableFullScan 10000.00 cop[tikv] table:t1 keep order:false, stats:pseudo", + "└─IndexReader(Probe) 12500.00 root index:IndexRangeScan", + " └─IndexRangeScan 12500.00 cop[tikv] table:t2, index:a(a) range: decided by [eq(test.t2.a, test.t1.a)], keep order:false, stats:pseudo" + ], + "Result": [ + "3 3" + ] + }, + { + "SQL": "select * from t1 where a in (select a from t2)", + "Plan": [ + "IndexHashJoin 9990.00 root inner join, inner:IndexLookUp, outer key:test.t2.a, inner key:test.t1.a, equal cond:eq(test.t2.a, test.t1.a)", + "├─StreamAgg(Build) 7992.00 root group by:test.t2.a, funcs:firstrow(test.t2.a)->test.t2.a", + "│ └─IndexReader 7992.00 root index:StreamAgg", + "│ └─StreamAgg 7992.00 cop[tikv] group by:test.t2.a, ", + "│ └─IndexFullScan 9990.00 cop[tikv] table:t2, index:a(a) keep order:true, stats:pseudo", + "└─IndexLookUp(Probe) 9990.00 root ", + " ├─Selection(Build) 9990.00 cop[tikv] not(isnull(test.t1.a))", + " │ └─IndexRangeScan 10000.00 cop[tikv] table:t1, index:a(a) range: decided by [eq(test.t1.a, test.t2.a)], keep order:false, stats:pseudo", + " └─TableRowIDScan(Probe) 9990.00 cop[tikv] table:t1 keep order:false, stats:pseudo" + ], + "Result": [ + "1 1", + "2 2" + ] + }, + { + "SQL": "select * from t1 where exists (select 1 from t2 where t2.a > t1.a)", + "Plan": [ + "HashJoin 7992.00 root CARTESIAN semi join, left side:TableReader, other cond:gt(test.t2.a, test.t1.a)", + "├─IndexReader(Build) 9990.00 root index:IndexFullScan", + "│ └─IndexFullScan 9990.00 cop[tikv] table:t2, index:a(a) keep order:false, stats:pseudo", + "└─TableReader(Probe) 9990.00 root data:Selection", + " └─Selection 9990.00 cop[tikv] not(isnull(test.t1.a))", + " └─TableFullScan 10000.00 cop[tikv] table:t1 keep order:false, stats:pseudo" + ], + "Result": [ + "1 1" + ] + }, + { + "SQL": "select * from t1 where exists (select 1 from t2 where t2.a = t1.a and t2.b > t1.b)", + "Plan": [ + "IndexHashJoin 7984.01 root semi join, inner:IndexLookUp, left side:TableReader, outer key:test.t1.a, inner key:test.t2.a, equal cond:eq(test.t1.a, test.t2.a), other cond:gt(test.t2.b, test.t1.b)", + "├─TableReader(Build) 9980.01 root data:Selection", + "│ └─Selection 9980.01 cop[tikv] not(isnull(test.t1.a)), not(isnull(test.t1.b))", + "│ └─TableFullScan 10000.00 cop[tikv] table:t1 keep order:false, stats:pseudo", + "└─IndexLookUp(Probe) 12475.01 root ", + " ├─Selection(Build) 12487.50 cop[tikv] not(isnull(test.t2.a))", + " │ └─IndexRangeScan 12500.00 cop[tikv] table:t2, index:a(a) range: decided by [eq(test.t2.a, test.t1.a)], keep order:false, stats:pseudo", + " └─Selection(Probe) 12475.01 cop[tikv] not(isnull(test.t2.b))", + " └─TableRowIDScan 12487.50 cop[tikv] table:t2 keep order:false, stats:pseudo" + ], + "Result": [ + "1 1", + "2 2" + ] + }, + { + "SQL": "select * from t1 where a in (select a from t2) order by a limit 10", + "Plan": [ + "Limit 10.00 root offset:0, count:10", + "└─IndexJoin 10.00 root inner join, inner:StreamAgg, outer key:test.t1.a, inner key:test.t2.a, equal cond:eq(test.t1.a, test.t2.a)", + " ├─Projection(Build) 10.00 root test.t1.a, test.t1.b", + " │ └─IndexLookUp 10.00 root ", + " │ ├─IndexFullScan(Build) 10.00 cop[tikv] table:t1, index:a(a) keep order:true, stats:pseudo", + " │ └─TableRowIDScan(Probe) 10.00 cop[tikv] table:t1 keep order:false, stats:pseudo", + " └─StreamAgg(Probe) 10.00 root group by:test.t2.a, funcs:firstrow(test.t2.a)->test.t2.a", + " └─IndexReader 10.00 root index:Selection", + " └─Selection 10.00 cop[tikv] not(isnull(test.t2.a))", + " └─IndexRangeScan 10.01 cop[tikv] table:t2, index:a(a) range: decided by [eq(test.t2.a, test.t1.a)], keep order:true, stats:pseudo" + ], + "Result": [ + "1 1", + "2 2" + ] + }, + { + "SQL": "select * from t1 where a in (select a from t2 where b > 1)", + "Plan": [ + "IndexHashJoin 3330.00 root inner join, inner:IndexLookUp, outer key:test.t2.a, inner key:test.t1.a, equal cond:eq(test.t2.a, test.t1.a)", + "├─HashAgg(Build) 2664.00 root group by:test.t2.a, funcs:firstrow(test.t2.a)->test.t2.a", + "│ └─TableReader 2664.00 root data:HashAgg", + "│ └─HashAgg 2664.00 cop[tikv] group by:test.t2.a, ", + "│ └─Selection 3330.00 cop[tikv] gt(test.t2.b, 1), not(isnull(test.t2.a))", + "│ └─TableFullScan 10000.00 cop[tikv] table:t2 keep order:false, stats:pseudo", + "└─IndexLookUp(Probe) 3330.00 root ", + " ├─Selection(Build) 3330.00 cop[tikv] not(isnull(test.t1.a))", + " │ └─IndexRangeScan 3333.33 cop[tikv] table:t1, index:a(a) range: decided by [eq(test.t1.a, test.t2.a)], keep order:false, stats:pseudo", + " └─TableRowIDScan(Probe) 3330.00 cop[tikv] table:t1 keep order:false, stats:pseudo" + ], + "Result": [ + "1 1", + "2 2" + ] + }, + { + "SQL": "select * from t1 where a in (select a from t2 order by a limit 10)", + "Plan": [ + "IndexHashJoin 10.00 root inner join, inner:IndexLookUp, outer key:test.t2.a, inner key:test.t1.a, equal cond:eq(test.t2.a, test.t1.a)", + "├─StreamAgg(Build) 8.00 root group by:test.t2.a, funcs:firstrow(test.t2.a)->test.t2.a", + "│ └─Selection 8.00 root not(isnull(test.t2.a))", + "│ └─Limit 10.00 root offset:0, count:10", + "│ └─IndexReader 10.00 root index:Limit", + "│ └─Limit 10.00 cop[tikv] offset:0, count:10", + "│ └─IndexFullScan 10.00 cop[tikv] table:t2, index:a(a) keep order:true, stats:pseudo", + "└─IndexLookUp(Probe) 10.00 root ", + " ├─Selection(Build) 10.00 cop[tikv] not(isnull(test.t1.a))", + " │ └─IndexRangeScan 10.01 cop[tikv] table:t1, index:a(a) range: decided by [eq(test.t1.a, test.t2.a)], keep order:false, stats:pseudo", + " └─TableRowIDScan(Probe) 10.00 cop[tikv] table:t1 keep order:false, stats:pseudo" + ], + "Result": [ + "1 1", + "2 2" + ] + } + ] + } +] diff --git a/pkg/planner/core/casetest/rule/testdata/correlate_suite_xut.json b/pkg/planner/core/casetest/rule/testdata/correlate_suite_xut.json new file mode 100644 index 0000000000000..ee0fe604da9fb --- /dev/null +++ b/pkg/planner/core/casetest/rule/testdata/correlate_suite_xut.json @@ -0,0 +1,466 @@ +[ + { + "Name": "TestCorrelate", + "Cases": [ + { + "SQL": "select * from t1 where exists (select 1 from t2 where t2.a = t1.a)", + "Plan": [ + "HashJoin 7992.00 root semi join, left side:TableReader, equal:[eq(test.t1.a, test.t2.a)]", + "├─IndexReader(Build) 9990.00 root index:IndexFullScan", + "│ └─IndexFullScan 9990.00 cop[tikv] table:t2, index:a(a) keep order:false, stats:pseudo", + "└─TableReader(Probe) 9990.00 root data:Selection", + " └─Selection 9990.00 cop[tikv] not(isnull(test.t1.a))", + " └─TableFullScan 10000.00 cop[tikv] table:t1 keep order:false, stats:pseudo" + ], + "Result": [ + "1 1", + "2 2" + ] + }, + { + "SQL": "select * from t1 where not exists (select 1 from t2 where t2.a = t1.a)", + "Plan": [ + "HashJoin 8000.00 root anti semi join, left side:TableReader, equal:[eq(test.t1.a, test.t2.a)]", + "├─IndexReader(Build) 10000.00 root index:IndexFullScan", + "│ └─IndexFullScan 10000.00 cop[tikv] table:t2, index:a(a) keep order:false, stats:pseudo", + "└─TableReader(Probe) 10000.00 root data:TableFullScan", + " └─TableFullScan 10000.00 cop[tikv] table:t1 keep order:false, stats:pseudo" + ], + "Result": [ + "3 3" + ] + }, + { + "SQL": "select * from t1 where a in (select a from t2)", + "Plan": [ + "HashJoin 9990.00 root inner join, equal:[eq(test.t1.a, test.t2.a)]", + "├─StreamAgg(Build) 7992.00 root group by:test.t2.a, funcs:firstrow(test.t2.a)->test.t2.a", + "│ └─IndexReader 7992.00 root index:StreamAgg", + "│ └─StreamAgg 7992.00 cop[tikv] group by:test.t2.a, ", + "│ └─IndexFullScan 9990.00 cop[tikv] table:t2, index:a(a) keep order:true, stats:pseudo", + "└─TableReader(Probe) 9990.00 root data:Selection", + " └─Selection 9990.00 cop[tikv] not(isnull(test.t1.a))", + " └─TableFullScan 10000.00 cop[tikv] table:t1 keep order:false, stats:pseudo" + ], + "Result": [ + "1 1", + "2 2" + ] + }, + { + "SQL": "select * from t1 where exists (select 1 from t2)", + "Plan": [ + "TableReader 10000.00 root data:TableFullScan", + "└─TableFullScan 10000.00 cop[tikv] table:t1 keep order:false, stats:pseudo", + "ScalarSubQuery N/A root Output: ScalarQueryCol#10, ScalarQueryCol#11, ScalarQueryCol#12, ScalarQueryCol#13", + "└─TableReader 10000.00 root data:TableFullScan", + " └─TableFullScan 10000.00 cop[tikv] table:t2 keep order:false, stats:pseudo" + ], + "Result": [ + "1 1", + "2 2", + "3 3" + ] + }, + { + "SQL": "select * from t1 where a not in (select a from t2)", + "Plan": [ + "HashJoin 8000.00 root Null-aware anti semi join, left side:TableReader, equal:[eq(test.t1.a, test.t2.a)]", + "├─IndexReader(Build) 10000.00 root index:IndexFullScan", + "│ └─IndexFullScan 10000.00 cop[tikv] table:t2, index:a(a) keep order:false, stats:pseudo", + "└─TableReader(Probe) 10000.00 root data:TableFullScan", + " └─TableFullScan 10000.00 cop[tikv] table:t1 keep order:false, stats:pseudo" + ], + "Result": [ + "3 3" + ] + }, + { + "SQL": "select * from t1 where exists (select 1 from t2 where t2.a > t1.a)", + "Plan": [ + "HashJoin 7992.00 root CARTESIAN semi join, left side:TableReader, other cond:gt(test.t2.a, test.t1.a)", + "├─IndexReader(Build) 9990.00 root index:IndexFullScan", + "│ └─IndexFullScan 9990.00 cop[tikv] table:t2, index:a(a) keep order:false, stats:pseudo", + "└─TableReader(Probe) 9990.00 root data:Selection", + " └─Selection 9990.00 cop[tikv] not(isnull(test.t1.a))", + " └─TableFullScan 10000.00 cop[tikv] table:t1 keep order:false, stats:pseudo" + ], + "Result": [ + "1 1" + ] + }, + { + "SQL": "select * from t1 where exists (select 1 from t2 where t2.a = t1.a and t2.b > t1.b)", + "Plan": [ + "HashJoin 7984.01 root semi join, left side:TableReader, equal:[eq(test.t1.a, test.t2.a)], other cond:gt(test.t2.b, test.t1.b)", + "├─TableReader(Build) 9980.01 root data:Selection", + "│ └─Selection 9980.01 cop[tikv] not(isnull(test.t2.a)), not(isnull(test.t2.b))", + "│ └─TableFullScan 10000.00 cop[tikv] table:t2 keep order:false, stats:pseudo", + "└─TableReader(Probe) 9980.01 root data:Selection", + " └─Selection 9980.01 cop[tikv] not(isnull(test.t1.a)), not(isnull(test.t1.b))", + " └─TableFullScan 10000.00 cop[tikv] table:t1 keep order:false, stats:pseudo" + ], + "Result": [ + "1 1", + "2 2" + ] + }, + { + "SQL": "select * from t1 where exists (select /*+ NO_DECORRELATE() */ 1 from t2 where t2.a = t1.a)", + "Plan": [ + "Apply 8000.00 root CARTESIAN semi join, left side:TableReader", + "├─TableReader(Build) 10000.00 root data:TableFullScan", + "│ └─TableFullScan 10000.00 cop[tikv] table:t1 keep order:false, stats:pseudo", + "└─Limit(Probe) 10000.00 root offset:0, count:1", + " └─IndexReader 10000.00 root index:Limit", + " └─Limit 10000.00 cop[tikv] offset:0, count:1", + " └─IndexRangeScan 10000.00 cop[tikv] table:t2, index:a(a) range: decided by [eq(test.t2.a, test.t1.a)], keep order:false, stats:pseudo" + ], + "Result": [ + "1 1", + "2 2" + ] + }, + { + "SQL": "select * from t1 where a in (select t2.a from t2 inner join t3 on t3.a = t2.b where t3.b > 0)", + "Plan": [ + "HashJoin 5203.12 root inner join, equal:[eq(test.t1.a, test.t2.a)]", + "├─HashAgg(Build) 4162.50 root group by:test.t2.a, funcs:firstrow(test.t2.a)->test.t2.a", + "│ └─HashJoin 4162.50 root inner join, equal:[eq(test.t3.a, test.t2.b)]", + "│ ├─TableReader(Build) 3330.00 root data:Selection", + "│ │ └─Selection 3330.00 cop[tikv] gt(test.t3.b, 0), not(isnull(test.t3.a))", + "│ │ └─TableFullScan 10000.00 cop[tikv] table:t3 keep order:false, stats:pseudo", + "│ └─TableReader(Probe) 9980.01 root data:Selection", + "│ └─Selection 9980.01 cop[tikv] not(isnull(test.t2.a)), not(isnull(test.t2.b))", + "│ └─TableFullScan 10000.00 cop[tikv] table:t2 keep order:false, stats:pseudo", + "└─TableReader(Probe) 9990.00 root data:Selection", + " └─Selection 9990.00 cop[tikv] not(isnull(test.t1.a))", + " └─TableFullScan 10000.00 cop[tikv] table:t1 keep order:false, stats:pseudo" + ], + "Result": [ + "1 1", + "2 2" + ] + }, + { + "SQL": "select * from t1 where a in (select a from t2) order by a limit 10", + "Plan": [ + "Limit 10.00 root offset:0, count:10", + "└─MergeJoin 10.00 root inner join, left key:test.t1.a, right key:test.t2.a", + " ├─StreamAgg(Build) 8.00 root group by:test.t2.a, funcs:firstrow(test.t2.a)->test.t2.a", + " │ └─IndexReader 8.00 root index:StreamAgg", + " │ └─StreamAgg 8.00 cop[tikv] group by:test.t2.a, ", + " │ └─IndexFullScan 10.00 cop[tikv] table:t2, index:a(a) keep order:true, stats:pseudo", + " └─Projection(Probe) 10.00 root test.t1.a, test.t1.b", + " └─IndexLookUp 10.00 root ", + " ├─IndexFullScan(Build) 10.00 cop[tikv] table:t1, index:a(a) keep order:true, stats:pseudo", + " └─TableRowIDScan(Probe) 10.00 cop[tikv] table:t1 keep order:false, stats:pseudo" + ], + "Result": [ + "1 1", + "2 2" + ] + }, + { + "SQL": "select * from t1 where a in (select a from t2 where b > 1)", + "Plan": [ + "HashJoin 3330.00 root inner join, equal:[eq(test.t2.a, test.t1.a)]", + "├─HashAgg(Build) 2664.00 root group by:test.t2.a, funcs:firstrow(test.t2.a)->test.t2.a", + "│ └─TableReader 2664.00 root data:HashAgg", + "│ └─HashAgg 2664.00 cop[tikv] group by:test.t2.a, ", + "│ └─Selection 3330.00 cop[tikv] gt(test.t2.b, 1), not(isnull(test.t2.a))", + "│ └─TableFullScan 10000.00 cop[tikv] table:t2 keep order:false, stats:pseudo", + "└─TableReader(Probe) 9990.00 root data:Selection", + " └─Selection 9990.00 cop[tikv] not(isnull(test.t1.a))", + " └─TableFullScan 10000.00 cop[tikv] table:t1 keep order:false, stats:pseudo" + ], + "Result": [ + "1 1", + "2 2" + ] + }, + { + "SQL": "select * from t1 where a in (select a from t2 group by a)", + "Plan": [ + "HashJoin 9990.00 root inner join, equal:[eq(test.t1.a, test.t2.a)]", + "├─StreamAgg(Build) 7992.00 root group by:test.t2.a, funcs:firstrow(test.t2.a)->test.t2.a", + "│ └─IndexReader 7992.00 root index:StreamAgg", + "│ └─StreamAgg 7992.00 cop[tikv] group by:test.t2.a, ", + "│ └─IndexFullScan 9990.00 cop[tikv] table:t2, index:a(a) keep order:true, stats:pseudo", + "└─TableReader(Probe) 9990.00 root data:Selection", + " └─Selection 9990.00 cop[tikv] not(isnull(test.t1.a))", + " └─TableFullScan 10000.00 cop[tikv] table:t1 keep order:false, stats:pseudo" + ], + "Result": [ + "1 1", + "2 2" + ] + }, + { + "SQL": "select * from t1 where a in (select a from t2 where b > 1 group by a)", + "Plan": [ + "HashJoin 3330.00 root inner join, equal:[eq(test.t2.a, test.t1.a)]", + "├─HashAgg(Build) 2664.00 root group by:test.t2.a, funcs:firstrow(test.t2.a)->test.t2.a", + "│ └─TableReader 2664.00 root data:HashAgg", + "│ └─HashAgg 2664.00 cop[tikv] group by:test.t2.a, ", + "│ └─Selection 3330.00 cop[tikv] gt(test.t2.b, 1), not(isnull(test.t2.a))", + "│ └─TableFullScan 10000.00 cop[tikv] table:t2 keep order:false, stats:pseudo", + "└─TableReader(Probe) 9990.00 root data:Selection", + " └─Selection 9990.00 cop[tikv] not(isnull(test.t1.a))", + " └─TableFullScan 10000.00 cop[tikv] table:t1 keep order:false, stats:pseudo" + ], + "Result": [ + "1 1", + "2 2" + ] + }, + { + "SQL": "select * from t1 where a in (select a from t2 limit 10)", + "Plan": [ + "IndexHashJoin 10.00 root inner join, inner:IndexLookUp, outer key:test.t2.a, inner key:test.t1.a, equal cond:eq(test.t2.a, test.t1.a)", + "├─HashAgg(Build) 8.00 root group by:test.t2.a, funcs:firstrow(test.t2.a)->test.t2.a", + "│ └─Selection 8.00 root not(isnull(test.t2.a))", + "│ └─Limit 10.00 root offset:0, count:10", + "│ └─IndexReader 10.00 root index:Limit", + "│ └─Limit 10.00 cop[tikv] offset:0, count:10", + "│ └─IndexFullScan 10.00 cop[tikv] table:t2, index:a(a) keep order:false, stats:pseudo", + "└─IndexLookUp(Probe) 10.00 root ", + " ├─Selection(Build) 10.00 cop[tikv] not(isnull(test.t1.a))", + " │ └─IndexRangeScan 10.01 cop[tikv] table:t1, index:a(a) range: decided by [eq(test.t1.a, test.t2.a)], keep order:false, stats:pseudo", + " └─TableRowIDScan(Probe) 10.00 cop[tikv] table:t1 keep order:false, stats:pseudo" + ], + "Result": [ + "1 1", + "2 2" + ] + }, + { + "SQL": "select * from t1 where a in (select a from t2 order by a limit 10)", + "Plan": [ + "IndexHashJoin 10.00 root inner join, inner:IndexLookUp, outer key:test.t2.a, inner key:test.t1.a, equal cond:eq(test.t2.a, test.t1.a)", + "├─StreamAgg(Build) 8.00 root group by:test.t2.a, funcs:firstrow(test.t2.a)->test.t2.a", + "│ └─Selection 8.00 root not(isnull(test.t2.a))", + "│ └─Limit 10.00 root offset:0, count:10", + "│ └─IndexReader 10.00 root index:Limit", + "│ └─Limit 10.00 cop[tikv] offset:0, count:10", + "│ └─IndexFullScan 10.00 cop[tikv] table:t2, index:a(a) keep order:true, stats:pseudo", + "└─IndexLookUp(Probe) 10.00 root ", + " ├─Selection(Build) 10.00 cop[tikv] not(isnull(test.t1.a))", + " │ └─IndexRangeScan 10.01 cop[tikv] table:t1, index:a(a) range: decided by [eq(test.t1.a, test.t2.a)], keep order:false, stats:pseudo", + " └─TableRowIDScan(Probe) 10.00 cop[tikv] table:t1 keep order:false, stats:pseudo" + ], + "Result": [ + "1 1", + "2 2" + ] + }, + { + "SQL": "select * from t1 where b = 1 and a in (select a from t2)", + "Plan": [ + "IndexJoin 9.99 root inner join, inner:StreamAgg, outer key:test.t1.a, inner key:test.t2.a, equal cond:eq(test.t1.a, test.t2.a)", + "├─TableReader(Build) 9.99 root data:Selection", + "│ └─Selection 9.99 cop[tikv] eq(test.t1.b, 1), not(isnull(test.t1.a))", + "│ └─TableFullScan 10000.00 cop[tikv] table:t1 keep order:false, stats:pseudo", + "└─StreamAgg(Probe) 9.99 root group by:test.t2.a, funcs:firstrow(test.t2.a)->test.t2.a", + " └─IndexReader 9.99 root index:Selection", + " └─Selection 9.99 cop[tikv] not(isnull(test.t2.a))", + " └─IndexRangeScan 10.00 cop[tikv] table:t2, index:a(a) range: decided by [eq(test.t2.a, test.t1.a)], keep order:true, stats:pseudo" + ], + "Result": [ + "1 1" + ] + }, + { + "SQL": "select * from t1 where b = 1 and exists (select 1 from t2 where t2.a = t1.a) limit 1", + "Plan": [ + "Limit 1.00 root offset:0, count:1", + "└─IndexHashJoin 1.00 root semi join, inner:IndexReader, left side:TableReader, outer key:test.t1.a, inner key:test.t2.a, equal cond:eq(test.t1.a, test.t2.a)", + " ├─TableReader(Build) 1.25 root data:Selection", + " │ └─Selection 1.25 cop[tikv] eq(test.t1.b, 1), not(isnull(test.t1.a))", + " │ └─TableFullScan 1251.25 cop[tikv] table:t1 keep order:false, stats:pseudo", + " └─IndexReader(Probe) 1.56 root index:Selection", + " └─Selection 1.56 cop[tikv] not(isnull(test.t2.a))", + " └─IndexRangeScan 1.56 cop[tikv] table:t2, index:a(a) range: decided by [eq(test.t2.a, test.t1.a)], keep order:false, stats:pseudo" + ], + "Result": [ + "1 1" + ] + }, + { + "SQL": "select * from t1 where b = 1 and a not in (select a from t2) limit 1", + "Plan": [ + "Limit 1.00 root offset:0, count:1", + "└─HashJoin 1.00 root Null-aware anti semi join, left side:TableReader, equal:[eq(test.t1.a, test.t2.a)]", + " ├─IndexReader(Build) 10000.00 root index:IndexFullScan", + " │ └─IndexFullScan 10000.00 cop[tikv] table:t2, index:a(a) keep order:false, stats:pseudo", + " └─TableReader(Probe) 1.25 root data:Selection", + " └─Selection 1.25 cop[tikv] eq(test.t1.b, 1)", + " └─TableFullScan 1250.00 cop[tikv] table:t1 keep order:false, stats:pseudo" + ], + "Result": null + }, + { + "SQL": "select * from t1 where b = 1 and a in (select a from t2 where t2.b > 0) limit 1", + "Plan": [ + "Limit 1.00 root offset:0, count:1", + "└─IndexJoin 1.00 root inner join, inner:StreamAgg, outer key:test.t1.a, inner key:test.t2.a, equal cond:eq(test.t1.a, test.t2.a)", + " ├─TableReader(Build) 1.00 root data:Selection", + " │ └─Selection 1.00 cop[tikv] eq(test.t1.b, 1), not(isnull(test.t1.a))", + " │ └─TableFullScan 1001.00 cop[tikv] table:t1 keep order:false, stats:pseudo", + " └─StreamAgg(Probe) 1.00 root group by:test.t2.a, funcs:firstrow(test.t2.a)->test.t2.a", + " └─Projection 1.00 root test.t2.a, test.t2.b", + " └─IndexLookUp 1.00 root ", + " ├─Selection(Build) 3.00 cop[tikv] not(isnull(test.t2.a))", + " │ └─IndexRangeScan 3.00 cop[tikv] table:t2, index:a(a) range: decided by [eq(test.t2.a, test.t1.a)], keep order:true, stats:pseudo", + " └─Selection(Probe) 1.00 cop[tikv] gt(test.t2.b, 0)", + " └─TableRowIDScan 3.00 cop[tikv] table:t2 keep order:false, stats:pseudo" + ], + "Result": [ + "1 1" + ] + } + ] + }, + { + "Name": "TestCorrelateWithCostFactors", + "Cases": [ + { + "SQL": "select * from t1 where exists (select 1 from t2 where t2.a = t1.a)", + "Plan": [ + "IndexHashJoin 7992.00 root semi join, inner:IndexReader, left side:TableReader, outer key:test.t1.a, inner key:test.t2.a, equal cond:eq(test.t1.a, test.t2.a)", + "├─TableReader(Build) 9990.00 root data:Selection", + "│ └─Selection 9990.00 cop[tikv] not(isnull(test.t1.a))", + "│ └─TableFullScan 10000.00 cop[tikv] table:t1 keep order:false, stats:pseudo", + "└─IndexReader(Probe) 12487.50 root index:Selection", + " └─Selection 12487.50 cop[tikv] not(isnull(test.t2.a))", + " └─IndexRangeScan 12500.00 cop[tikv] table:t2, index:a(a) range: decided by [eq(test.t2.a, test.t1.a)], keep order:false, stats:pseudo" + ], + "Result": [ + "1 1", + "2 2" + ] + }, + { + "SQL": "select * from t1 where not exists (select 1 from t2 where t2.a = t1.a)", + "Plan": [ + "IndexHashJoin 8000.00 root anti semi join, inner:IndexReader, left side:TableReader, outer key:test.t1.a, inner key:test.t2.a, equal cond:eq(test.t1.a, test.t2.a)", + "├─TableReader(Build) 10000.00 root data:TableFullScan", + "│ └─TableFullScan 10000.00 cop[tikv] table:t1 keep order:false, stats:pseudo", + "└─IndexReader(Probe) 12500.00 root index:IndexRangeScan", + " └─IndexRangeScan 12500.00 cop[tikv] table:t2, index:a(a) range: decided by [eq(test.t2.a, test.t1.a)], keep order:false, stats:pseudo" + ], + "Result": [ + "3 3" + ] + }, + { + "SQL": "select * from t1 where a in (select a from t2)", + "Plan": [ + "IndexHashJoin 9990.00 root inner join, inner:IndexLookUp, outer key:test.t2.a, inner key:test.t1.a, equal cond:eq(test.t2.a, test.t1.a)", + "├─StreamAgg(Build) 7992.00 root group by:test.t2.a, funcs:firstrow(test.t2.a)->test.t2.a", + "│ └─IndexReader 7992.00 root index:StreamAgg", + "│ └─StreamAgg 7992.00 cop[tikv] group by:test.t2.a, ", + "│ └─IndexFullScan 9990.00 cop[tikv] table:t2, index:a(a) keep order:true, stats:pseudo", + "└─IndexLookUp(Probe) 9990.00 root ", + " ├─Selection(Build) 9990.00 cop[tikv] not(isnull(test.t1.a))", + " │ └─IndexRangeScan 10000.00 cop[tikv] table:t1, index:a(a) range: decided by [eq(test.t1.a, test.t2.a)], keep order:false, stats:pseudo", + " └─TableRowIDScan(Probe) 9990.00 cop[tikv] table:t1 keep order:false, stats:pseudo" + ], + "Result": [ + "1 1", + "2 2" + ] + }, + { + "SQL": "select * from t1 where exists (select 1 from t2 where t2.a > t1.a)", + "Plan": [ + "HashJoin 7992.00 root CARTESIAN semi join, left side:TableReader, other cond:gt(test.t2.a, test.t1.a)", + "├─IndexReader(Build) 9990.00 root index:IndexFullScan", + "│ └─IndexFullScan 9990.00 cop[tikv] table:t2, index:a(a) keep order:false, stats:pseudo", + "└─TableReader(Probe) 9990.00 root data:Selection", + " └─Selection 9990.00 cop[tikv] not(isnull(test.t1.a))", + " └─TableFullScan 10000.00 cop[tikv] table:t1 keep order:false, stats:pseudo" + ], + "Result": [ + "1 1" + ] + }, + { + "SQL": "select * from t1 where exists (select 1 from t2 where t2.a = t1.a and t2.b > t1.b)", + "Plan": [ + "IndexHashJoin 7984.01 root semi join, inner:IndexLookUp, left side:TableReader, outer key:test.t1.a, inner key:test.t2.a, equal cond:eq(test.t1.a, test.t2.a), other cond:gt(test.t2.b, test.t1.b)", + "├─TableReader(Build) 9980.01 root data:Selection", + "│ └─Selection 9980.01 cop[tikv] not(isnull(test.t1.a)), not(isnull(test.t1.b))", + "│ └─TableFullScan 10000.00 cop[tikv] table:t1 keep order:false, stats:pseudo", + "└─IndexLookUp(Probe) 12475.01 root ", + " ├─Selection(Build) 12487.50 cop[tikv] not(isnull(test.t2.a))", + " │ └─IndexRangeScan 12500.00 cop[tikv] table:t2, index:a(a) range: decided by [eq(test.t2.a, test.t1.a)], keep order:false, stats:pseudo", + " └─Selection(Probe) 12475.01 cop[tikv] not(isnull(test.t2.b))", + " └─TableRowIDScan 12487.50 cop[tikv] table:t2 keep order:false, stats:pseudo" + ], + "Result": [ + "1 1", + "2 2" + ] + }, + { + "SQL": "select * from t1 where a in (select a from t2) order by a limit 10", + "Plan": [ + "Limit 10.00 root offset:0, count:10", + "└─IndexJoin 10.00 root inner join, inner:StreamAgg, outer key:test.t1.a, inner key:test.t2.a, equal cond:eq(test.t1.a, test.t2.a)", + " ├─Projection(Build) 10.00 root test.t1.a, test.t1.b", + " │ └─IndexLookUp 10.00 root ", + " │ ├─IndexFullScan(Build) 10.00 cop[tikv] table:t1, index:a(a) keep order:true, stats:pseudo", + " │ └─TableRowIDScan(Probe) 10.00 cop[tikv] table:t1 keep order:false, stats:pseudo", + " └─StreamAgg(Probe) 10.00 root group by:test.t2.a, funcs:firstrow(test.t2.a)->test.t2.a", + " └─IndexReader 10.00 root index:Selection", + " └─Selection 10.00 cop[tikv] not(isnull(test.t2.a))", + " └─IndexRangeScan 10.01 cop[tikv] table:t2, index:a(a) range: decided by [eq(test.t2.a, test.t1.a)], keep order:true, stats:pseudo" + ], + "Result": [ + "1 1", + "2 2" + ] + }, + { + "SQL": "select * from t1 where a in (select a from t2 where b > 1)", + "Plan": [ + "IndexHashJoin 3330.00 root inner join, inner:IndexLookUp, outer key:test.t2.a, inner key:test.t1.a, equal cond:eq(test.t2.a, test.t1.a)", + "├─HashAgg(Build) 2664.00 root group by:test.t2.a, funcs:firstrow(test.t2.a)->test.t2.a", + "│ └─TableReader 2664.00 root data:HashAgg", + "│ └─HashAgg 2664.00 cop[tikv] group by:test.t2.a, ", + "│ └─Selection 3330.00 cop[tikv] gt(test.t2.b, 1), not(isnull(test.t2.a))", + "│ └─TableFullScan 10000.00 cop[tikv] table:t2 keep order:false, stats:pseudo", + "└─IndexLookUp(Probe) 3330.00 root ", + " ├─Selection(Build) 3330.00 cop[tikv] not(isnull(test.t1.a))", + " │ └─IndexRangeScan 3333.33 cop[tikv] table:t1, index:a(a) range: decided by [eq(test.t1.a, test.t2.a)], keep order:false, stats:pseudo", + " └─TableRowIDScan(Probe) 3330.00 cop[tikv] table:t1 keep order:false, stats:pseudo" + ], + "Result": [ + "1 1", + "2 2" + ] + }, + { + "SQL": "select * from t1 where a in (select a from t2 order by a limit 10)", + "Plan": [ + "IndexHashJoin 10.00 root inner join, inner:IndexLookUp, outer key:test.t2.a, inner key:test.t1.a, equal cond:eq(test.t2.a, test.t1.a)", + "├─StreamAgg(Build) 8.00 root group by:test.t2.a, funcs:firstrow(test.t2.a)->test.t2.a", + "│ └─Selection 8.00 root not(isnull(test.t2.a))", + "│ └─Limit 10.00 root offset:0, count:10", + "│ └─IndexReader 10.00 root index:Limit", + "│ └─Limit 10.00 cop[tikv] offset:0, count:10", + "│ └─IndexFullScan 10.00 cop[tikv] table:t2, index:a(a) keep order:true, stats:pseudo", + "└─IndexLookUp(Probe) 10.00 root ", + " ├─Selection(Build) 10.00 cop[tikv] not(isnull(test.t1.a))", + " │ └─IndexRangeScan 10.01 cop[tikv] table:t1, index:a(a) range: decided by [eq(test.t1.a, test.t2.a)], keep order:false, stats:pseudo", + " └─TableRowIDScan(Probe) 10.00 cop[tikv] table:t1 keep order:false, stats:pseudo" + ], + "Result": [ + "1 1", + "2 2" + ] + } + ] + } +] diff --git a/pkg/planner/core/expression_rewriter.go b/pkg/planner/core/expression_rewriter.go index 1f0e4ba23604a..3c89056766e4f 100644 --- a/pkg/planner/core/expression_rewriter.go +++ b/pkg/planner/core/expression_rewriter.go @@ -1051,6 +1051,17 @@ func (er *expressionRewriter) handleExistSubquery(ctx context.Context, planCtx * // Add LIMIT 1 when noDecorrelate is true for EXISTS subqueries to enable early exit corCols := coreusage.ExtractCorColumnsBySchema4LogicalPlan(np, planCtx.plan.Schema()) noDecorrelate := isNoDecorrelate(planCtx, corCols, hintFlags, handlingExistsSubquery) + // When EnableCorrelateSubquery is ON (set by the correlate alternative round), + // prevent decorrelation of correlated subqueries so they stay as Apply with index lookups. + // Skip when SEMI_JOIN_REWRITE() hint is present, since that hint explicitly requires + // decorrelation and would be silently ineffective on LogicalApply nodes. + semiJoinRewriteHint := hintFlags&hint.HintFlagSemiJoinRewrite > 0 + if !noDecorrelate && len(corCols) > 0 && !semiJoinRewriteHint { + b.ctx.GetSessionVars().RecordRelevantOptVar(vardef.TiDBOptEnableAlternativeLogicalPlans) + if b.ctx.GetSessionVars().EnableCorrelateSubquery { + noDecorrelate = true + } + } if noDecorrelate { // Only add LIMIT 1 if the query doesn't already contain a LIMIT clause if !hasLimit(np) { @@ -1066,8 +1077,8 @@ func (er *expressionRewriter) handleExistSubquery(ctx context.Context, planCtx * } } np = er.popExistsSubPlan(planCtx, np) - semiJoinRewrite := hintFlags&hint.HintFlagSemiJoinRewrite > 0 - if semiJoinRewrite && noDecorrelate { + semiJoinRewrite := semiJoinRewriteHint + if semiJoinRewrite && hintFlags&hint.HintFlagNoDecorrelate > 0 { b.ctx.GetSessionVars().StmtCtx.SetHintWarning( "NO_DECORRELATE() and SEMI_JOIN_REWRITE() are in conflict. Both will be ineffective.") noDecorrelate = false @@ -1237,12 +1248,36 @@ func (er *expressionRewriter) handleInSubquery(ctx context.Context, planCtx *exp collFlag := collate.CompatibleCollate(lt.GetCollate(), rt.GetCollate()) corCols := coreusage.ExtractCorColumnsBySchema4LogicalPlan(np, planCtx.plan.Schema()) noDecorrelate := isNoDecorrelate(planCtx, corCols, hintFlags, handlingInSubquery) + // When EnableCorrelateSubquery is ON (set by the correlate alternative round), + // prevent decorrelation of correlated IN subqueries so they stay as Apply with index lookups. + if !noDecorrelate && len(corCols) > 0 && !v.Not { + planCtx.builder.ctx.GetSessionVars().RecordRelevantOptVar(vardef.TiDBOptEnableAlternativeLogicalPlans) + if planCtx.builder.ctx.GetSessionVars().EnableAlternativeLogicalPlans { + planCtx.builder.ctx.GetSessionVars().StmtCtx.MarkAlternativeLogicalPlanPreferCorrelate() + } + if planCtx.builder.ctx.GetSessionVars().EnableCorrelateSubquery { + noDecorrelate = true + } + } // If it's not the form of `not in (SUBQUERY)`, // and has no correlated column from the current level plan(if the correlated column is from upper level, // we can treat it as constant, because the upper LogicalApply cannot be eliminated since current node is a join node), // and don't need to append a scalar value, we can rewrite it to inner join. - if planCtx.builder.ctx.GetSessionVars().GetAllowInSubqToJoinAndAgg() && !v.Not && !asScalar && len(corCols) == 0 && collFlag { + // When EnableCorrelateSubquery is ON (set by the correlate alternative round), skip the + // InnerJoin+Agg rewrite so that a SemiJoin is built instead; the CorrelateSolver rule can + // then convert it to a correlated Apply with index lookups. + canRewriteToJoinAgg := planCtx.builder.ctx.GetSessionVars().GetAllowInSubqToJoinAndAgg() && !v.Not && !asScalar && len(corCols) == 0 && collFlag + if canRewriteToJoinAgg { + // Record that the alternative logical plans variable is relevant — toggling it + // changes whether we take the InnerJoin+Agg path or the SemiApply path. + planCtx.builder.ctx.GetSessionVars().RecordRelevantOptVar(vardef.TiDBOptEnableAlternativeLogicalPlans) + // Signal that a correlate alternative round is worth attempting. + if planCtx.builder.ctx.GetSessionVars().EnableAlternativeLogicalPlans { + planCtx.builder.ctx.GetSessionVars().StmtCtx.MarkAlternativeLogicalPlanPreferCorrelate() + } + } + if canRewriteToJoinAgg && !planCtx.builder.ctx.GetSessionVars().EnableCorrelateSubquery { // We need to try to eliminate the agg and the projection produced by this operation. planCtx.builder.optFlag |= rule.FlagEliminateAgg planCtx.builder.optFlag |= rule.FlagEliminateProjection @@ -1277,6 +1312,17 @@ func (er *expressionRewriter) handleInSubquery(ctx context.Context, planCtx *exp if er.err != nil { return v, true } + // When EnableCorrelateSubquery is ON (set by the correlate alternative round) + // and the subquery is non-correlated, mark the join so that CorrelateSolver + // converts it to a correlated Apply. + if len(corCols) == 0 && !v.Not { + planCtx.builder.ctx.GetSessionVars().RecordRelevantOptVar(vardef.TiDBOptEnableAlternativeLogicalPlans) + if planCtx.builder.ctx.GetSessionVars().EnableCorrelateSubquery { + if ap, ok := planCtx.plan.(*logicalop.LogicalApply); ok { + ap.PreferCorrelate = true + } + } + } } er.ctxStackPop(1) diff --git a/pkg/planner/core/operator/logicalop/logical_join.go b/pkg/planner/core/operator/logicalop/logical_join.go index 8741ffd9fdaad..ad731e8a5a9ed 100644 --- a/pkg/planner/core/operator/logicalop/logical_join.go +++ b/pkg/planner/core/operator/logicalop/logical_join.go @@ -144,6 +144,11 @@ type LogicalJoin struct { FullSchema *expression.Schema FullNames types.NameSlice + // PreferCorrelate is set to true when this SemiJoin originated from a non-correlated + // IN subquery during the correlate alternative round, indicating that the CorrelateSolver + // should convert it back to a correlated Apply with index lookups. + PreferCorrelate bool + // EqualCondOutCnt indicates the estimated count of joined rows after evaluating `EqualConditions`. EqualCondOutCnt float64 } diff --git a/pkg/planner/core/optimizer.go b/pkg/planner/core/optimizer.go index f2517c8d4b045..f43c3c1dcc7e4 100644 --- a/pkg/planner/core/optimizer.go +++ b/pkg/planner/core/optimizer.go @@ -98,7 +98,13 @@ var optRuleList = []base.LogicalOptRule{ &PushDownTopNOptimizer{}, &SyncWaitStatsLoadPoint{}, &JoinReOrderSolver{}, +<<<<<<< HEAD &ColumnPruner{}, // column pruning again at last, note it will mess up the results of buildKeySolver +======= + &rule.OuterJoinToSemiJoin{}, + &CorrelateSolver{}, + &rule.ColumnPruner{}, // column pruning again at last, note it will mess up the results of buildKeySolver +>>>>>>> 7357a2e2f90 (planner: correlate subquery rule (#66206)) &PushDownSequenceSolver{}, &ResolveExpand{}, } @@ -304,10 +310,6 @@ func doOptimize( } func adjustOptimizationFlags(flag uint64, logic base.LogicalPlan) uint64 { - // If there is something after flagPrunColumns, do FlagPruneColumnsAgain. - if flag&rule.FlagPruneColumns > 0 && flag-rule.FlagPruneColumns > rule.FlagPruneColumns { - flag |= rule.FlagPruneColumnsAgain - } if checkStableResultMode(logic.SCtx()) { flag |= rule.FlagStabilizeResults } @@ -323,6 +325,16 @@ func adjustOptimizationFlags(flag uint64, logic base.LogicalPlan) uint64 { if !logic.SCtx().GetSessionVars().StmtCtx.UseDynamicPruneMode { flag |= rule.FlagPartitionProcessor // apply partition pruning under static mode } + // FlagCorrelate is added by the correlate alternative round's flag adjuster, + // not here. EnableCorrelateSubquery is an internal flag toggled by the round. + // A second column-prune pass is worthwhile when any rule above column + // pruning is enabled. + if flag&rule.FlagPruneColumns != 0 { + const abovePruneColumns = ^(rule.FlagPruneColumns | (rule.FlagPruneColumns - 1)) + if flag&abovePruneColumns != 0 { + flag |= rule.FlagPruneColumnsAgain + } + } return flag } diff --git a/pkg/planner/core/optimizer_test.go b/pkg/planner/core/optimizer_test.go index 1e5cbf543e3ea..bd5a543aa91d9 100644 --- a/pkg/planner/core/optimizer_test.go +++ b/pkg/planner/core/optimizer_test.go @@ -16,6 +16,7 @@ package core import ( "math" + "math/bits" "reflect" "strings" "testing" @@ -29,6 +30,7 @@ import ( "github.com/pingcap/tidb/pkg/parser/mysql" "github.com/pingcap/tidb/pkg/planner/core/base" "github.com/pingcap/tidb/pkg/planner/core/operator/physicalop" + "github.com/pingcap/tidb/pkg/planner/core/rule" "github.com/pingcap/tidb/pkg/planner/property" "github.com/pingcap/tidb/pkg/statistics" "github.com/pingcap/tidb/pkg/types" @@ -558,3 +560,71 @@ func TestHandleFineGrainedShuffle(t *testing.T) { start(hashJoin, 0, 3, 0) require.NoError(t, failpoint.Disable(fpName2)) } +<<<<<<< HEAD +======= + +func TestCanTiFlashUseHashJoinV2(t *testing.T) { + sctx := coretestsdk.MockContext() + defer func() { + domain.GetDomain(sctx).StatsHandle().Close() + }() + col0 := &expression.Column{ + UniqueID: sctx.GetSessionVars().AllocPlanColumnID(), + RetType: types.NewFieldType(mysql.TypeLonglong), + } + cond, err := expression.NewFunction(sctx, ast.EQ, types.NewFieldType(mysql.TypeTiny), col0, col0) + require.True(t, err == nil) + sf, isSF := cond.(*expression.ScalarFunction) + require.True(t, isSF) + hashJoin := &physicalop.PhysicalHashJoin{} + hashJoin.EqualConditions = append(hashJoin.EqualConditions, sf) + hashJoin.LeftJoinKeys = append(hashJoin.LeftJoinKeys, col0) + + sctx.GetSessionVars().TiFlashHashJoinVersion = joinversion.HashJoinVersionLegacy + sctx.GetSessionVars().TiFlashMaxBytesBeforeExternalJoin = 0 + sctx.GetSessionVars().TiFlashMaxQueryMemoryPerNode = 0 + sctx.GetSessionVars().TiFlashQuerySpillRatio = 0 + require.False(t, hashJoin.CanTiFlashUseHashJoinV2(sctx)) + // can use hash join v2 + sctx.GetSessionVars().TiFlashHashJoinVersion = joinversion.HashJoinVersionOptimized + require.True(t, hashJoin.CanTiFlashUseHashJoinV2(sctx)) + // can not use hash join v2 due to enabling join spill + sctx.GetSessionVars().TiFlashMaxBytesBeforeExternalJoin = 1 + require.False(t, hashJoin.CanTiFlashUseHashJoinV2(sctx)) + // can use hash join v2 due to TiFlashMaxQueryMemoryPerNode * TiFlashQuerySpillRatio = 0 + sctx.GetSessionVars().TiFlashMaxBytesBeforeExternalJoin = 0 + sctx.GetSessionVars().TiFlashMaxQueryMemoryPerNode = 1 + require.True(t, hashJoin.CanTiFlashUseHashJoinV2(sctx)) + // can not use hash join v2 due to enabling join spill + sctx.GetSessionVars().TiFlashQuerySpillRatio = 0.7 + require.False(t, hashJoin.CanTiFlashUseHashJoinV2(sctx)) + + sctx.GetSessionVars().TiFlashMaxQueryMemoryPerNode = 0 + sctx.GetSessionVars().TiFlashQuerySpillRatio = 0 + hashJoin = &physicalop.PhysicalHashJoin{} + // can not use hash join v2 due to cross join + require.False(t, hashJoin.CanTiFlashUseHashJoinV2(sctx)) + + hashJoin = &physicalop.PhysicalHashJoin{} + hashJoin.EqualConditions = append(hashJoin.EqualConditions, sf) + hashJoin.LeftJoinKeys = append(hashJoin.LeftJoinKeys, col0) + hashJoin.IsNullEQ = append(hashJoin.IsNullEQ, true) + // can not use hash join v2 due to null eq + require.False(t, hashJoin.CanTiFlashUseHashJoinV2(sctx)) +} + +func TestOptRuleListFlagAlignment(t *testing.T) { + // Each position i in optRuleList is gated by the flag bit 1<>>>>>> 7357a2e2f90 (planner: correlate subquery rule (#66206)) diff --git a/pkg/planner/core/plan_clone_utils.go b/pkg/planner/core/plan_clone_utils.go new file mode 100644 index 0000000000000..2d97d11690bd8 --- /dev/null +++ b/pkg/planner/core/plan_clone_utils.go @@ -0,0 +1,286 @@ +// Copyright 2024 PingCAP, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package core + +import ( + "github.com/pingcap/tidb/pkg/expression" + "github.com/pingcap/tidb/pkg/expression/aggregation" + "github.com/pingcap/tidb/pkg/planner/core/base" + "github.com/pingcap/tidb/pkg/planner/core/operator/logicalop" + "github.com/pingcap/tidb/pkg/planner/core/operator/physicalop" + "github.com/pingcap/tidb/pkg/planner/property" + "github.com/pingcap/tidb/pkg/planner/util" + "github.com/pingcap/tidb/pkg/planner/util/utilfuncp" + "github.com/pingcap/tidb/pkg/types" +) + +// FastClonePointGetForPlanCache is a fast path to clone a PointGetPlan for plan cache. +func FastClonePointGetForPlanCache(newCtx base.PlanContext, src, dst *physicalop.PointGetPlan) *physicalop.PointGetPlan { + if dst == nil { + dst = new(physicalop.PointGetPlan) + } + dst.Plan = src.Plan + dst.Plan.SetSCtx(newCtx) + dst.ProbeParents = src.ProbeParents + dst.PartitionNames = src.PartitionNames + dst.DBName = src.DBName + dst.SetSchema(src.Schema()) + dst.TblInfo = src.TblInfo + dst.IndexInfo = src.IndexInfo + dst.PartitionIdx = nil // partition prune will be triggered during execution phase + dst.Handle = nil // handle will be set during rebuild phase + if src.HandleConstant == nil { + dst.HandleConstant = nil + } else { + if src.HandleConstant.SafeToShareAcrossSession() { + dst.HandleConstant = src.HandleConstant + } else { + dst.HandleConstant = src.HandleConstant.Clone().(*expression.Constant) + } + } + dst.HandleFieldType = src.HandleFieldType + dst.HandleColOffset = src.HandleColOffset + if len(dst.IndexValues) < len(src.IndexValues) { // actually set during rebuild phase + dst.IndexValues = make([]types.Datum, len(src.IndexValues)) + } else { + dst.IndexValues = dst.IndexValues[:len(src.IndexValues)] + } + dst.IndexConstants = utilfuncp.CloneConstantsForPlanCache(src.IndexConstants, dst.IndexConstants) + dst.ColsFieldType = src.ColsFieldType + dst.IdxCols = utilfuncp.CloneColumnsForPlanCache(src.IdxCols, dst.IdxCols) + dst.IdxColLens = src.IdxColLens + dst.AccessConditions = utilfuncp.CloneExpressionsForPlanCache(src.AccessConditions, dst.AccessConditions) + dst.UnsignedHandle = src.UnsignedHandle + dst.IsTableDual = src.IsTableDual + dst.Lock = src.Lock + dst.SetOutputNames(src.OutputNames()) + dst.LockWaitTime = src.LockWaitTime + dst.Columns = src.Columns + + // remaining fields are unnecessary to clone: + // cost, planCostInit, planCost, planCostVer2, accessCols + return dst +} + +// cloneLogicalSubtree creates a shallow clone of the logical plan subtree, +// ensuring each node has a fresh plan ID and independent mutable state (children, +// conditions, AllConds). Immutable data (table info, column info, etc.) is shared. +// This is used to build the Apply alternative's inner plan without modifying the +// Join's original inner subtree when PPD pushes correlated conditions down. +// Returns (clone, true) on success, or (nil, false) if an unhandled operator type +// is encountered. In the failure case, the caller must abort the correlate +// optimization to avoid corrupting the original subtree. +func cloneLogicalSubtree(p base.LogicalPlan) (base.LogicalPlan, bool) { + switch op := p.(type) { + case *logicalop.DataSource: + return cloneDataSource(op), true + case *logicalop.LogicalJoin: + return cloneJoin(op) + case *logicalop.LogicalSelection: + return cloneSelection(op) + case *logicalop.LogicalProjection: + return cloneProjection(op) + case *logicalop.LogicalAggregation: + return cloneAggregation(op) + case *logicalop.LogicalLimit: + return cloneLimit(op) + case *logicalop.LogicalSort: + return cloneSort(op) + case *logicalop.LogicalTopN: + return cloneTopN(op) + default: + // Unknown operator type — cannot safely clone. Return failure + // so the caller aborts the correlate optimization. + return nil, false + } +} + +func cloneWithChildren(p base.LogicalPlan) ([]base.LogicalPlan, bool) { + children := make([]base.LogicalPlan, len(p.Children())) + for i, child := range p.Children() { + cloned, ok := cloneLogicalSubtree(child) + if !ok { + return nil, false + } + children[i] = cloned + } + return children, true +} + +func cloneDataSource(ds *logicalop.DataSource) *logicalop.DataSource { + clone := *ds + clone.BaseLogicalPlan = logicalop.NewBaseLogicalPlan( + ds.SCtx(), ds.TP(), &clone, ds.QueryBlockOffset()) + clone.SetSchema(ds.Schema().Clone()) + // Independent slices that PPD replaces. + clone.AllConds = append([]expression.Expression(nil), ds.AllConds...) + clone.PushedDownConds = append([]expression.Expression(nil), ds.PushedDownConds...) + // Deep-clone AccessPaths so the Join and Apply alternatives have fully + // independent path objects. Stats derivation (fillIndexPath, etc.) mutates + // AccessPath fields in place; without deep cloning, costing one alternative + // can corrupt the other and destabilize CBO. + clone.AllPossibleAccessPaths = make([]*util.AccessPath, len(ds.AllPossibleAccessPaths)) + for i, ap := range ds.AllPossibleAccessPaths { + clone.AllPossibleAccessPaths[i] = ap.Clone() + } + clone.PossibleAccessPaths = make([]*util.AccessPath, len(ds.PossibleAccessPaths)) + for i, ap := range ds.PossibleAccessPaths { + clone.PossibleAccessPaths[i] = ap.Clone() + } + // Preserve original stats so DeriveStats returns early for DataSources + // that don't receive correlated conditions. Without this, DeriveStats + // re-runs fillIndexPath on all DataSources, which fails when conditions + // reference columns that column pruning removed from the schema. + if origStats := ds.StatsInfo(); origStats != nil { + clone.SetStats(origStats) + } + return &clone +} + +func cloneJoin(j *logicalop.LogicalJoin) (*logicalop.LogicalJoin, bool) { + children, ok := cloneWithChildren(j) + if !ok { + return nil, false + } + clone := *j + clone.BaseLogicalPlan = logicalop.NewBaseLogicalPlan( + j.SCtx(), j.TP(), &clone, j.QueryBlockOffset()) + clone.SetSchema(j.Schema().Clone()) + // Independent condition slices that PPD may modify. + clone.EqualConditions = append([]*expression.ScalarFunction(nil), j.EqualConditions...) + clone.LeftConditions = append(expression.CNFExprs(nil), j.LeftConditions...) + clone.RightConditions = append(expression.CNFExprs(nil), j.RightConditions...) + clone.OtherConditions = append(expression.CNFExprs(nil), j.OtherConditions...) + // Clear PreferCorrelate on cloned inner joins to prevent CorrelateSolver + // from processing nested semi-joins in the cloned subtree. + clone.PreferCorrelate = false + clone.SetChildren(children...) + return &clone, true +} + +func cloneSelection(s *logicalop.LogicalSelection) (*logicalop.LogicalSelection, bool) { + children, ok := cloneWithChildren(s) + if !ok { + return nil, false + } + clone := *s + clone.BaseLogicalPlan = logicalop.NewBaseLogicalPlan( + s.SCtx(), s.TP(), &clone, s.QueryBlockOffset()) + clone.Conditions = append(expression.CNFExprs(nil), s.Conditions...) + clone.SetChildren(children...) + return &clone, true +} + +func cloneProjection(proj *logicalop.LogicalProjection) (*logicalop.LogicalProjection, bool) { + children, ok := cloneWithChildren(proj) + if !ok { + return nil, false + } + clone := *proj + clone.BaseLogicalPlan = logicalop.NewBaseLogicalPlan( + proj.SCtx(), proj.TP(), &clone, proj.QueryBlockOffset()) + clone.SetSchema(proj.Schema().Clone()) + clone.Exprs = append([]expression.Expression(nil), proj.Exprs...) + clone.SetChildren(children...) + return &clone, true +} + +func cloneAggregation(agg *logicalop.LogicalAggregation) (*logicalop.LogicalAggregation, bool) { + children, ok := cloneWithChildren(agg) + if !ok { + return nil, false + } + clone := *agg + clone.BaseLogicalPlan = logicalop.NewBaseLogicalPlan( + agg.SCtx(), agg.TP(), &clone, agg.QueryBlockOffset()) + clone.SetSchema(agg.Schema().Clone()) + clone.AggFuncs = append([]*aggregation.AggFuncDesc(nil), agg.AggFuncs...) + clone.GroupByItems = append([]expression.Expression(nil), agg.GroupByItems...) + clone.SetChildren(children...) + return &clone, true +} + +func cloneLimit(lim *logicalop.LogicalLimit) (*logicalop.LogicalLimit, bool) { + children, ok := cloneWithChildren(lim) + if !ok { + return nil, false + } + clone := *lim + clone.BaseLogicalPlan = logicalop.NewBaseLogicalPlan( + lim.SCtx(), lim.TP(), &clone, lim.QueryBlockOffset()) + clone.SetSchema(lim.Schema().Clone()) + if len(lim.PartitionBy) > 0 { + clone.PartitionBy = append([]property.SortItem(nil), lim.PartitionBy...) + } + clone.SetChildren(children...) + return &clone, true +} + +func cloneSort(s *logicalop.LogicalSort) (*logicalop.LogicalSort, bool) { + children, ok := cloneWithChildren(s) + if !ok { + return nil, false + } + clone := *s + clone.BaseLogicalPlan = logicalop.NewBaseLogicalPlan( + s.SCtx(), s.TP(), &clone, s.QueryBlockOffset()) + // LogicalSort embeds BaseLogicalPlan (not LogicalSchemaProducer), + // so it inherits schema from its child — no SetSchema needed. + clone.ByItems = append([]*util.ByItems(nil), s.ByItems...) + clone.SetChildren(children...) + return &clone, true +} + +func cloneTopN(tn *logicalop.LogicalTopN) (*logicalop.LogicalTopN, bool) { + children, ok := cloneWithChildren(tn) + if !ok { + return nil, false + } + clone := *tn + clone.BaseLogicalPlan = logicalop.NewBaseLogicalPlan( + tn.SCtx(), tn.TP(), &clone, tn.QueryBlockOffset()) + clone.SetSchema(tn.Schema().Clone()) + clone.ByItems = append([]*util.ByItems(nil), tn.ByItems...) + if len(tn.PartitionBy) > 0 { + clone.PartitionBy = append([]property.SortItem(nil), tn.PartitionBy...) + } + clone.SetChildren(children...) + return &clone, true +} + +// freshAccessPath creates a new AccessPath with only the structural identity +// fields from the source path (Index, StoreType, handle flags, hint flags). +// Analysis fields (Ranges, AccessConds, IdxCols, etc.) are left at zero so +// that fillIndexPath / deriveTablePathStats start from a clean state. +// +// Index-merge fields (PartialIndexPaths, PartialAlternativeIndexPaths, etc.) +// are intentionally omitted: AllPossibleAccessPaths contains only individual +// index paths; index merge paths are synthesized later by generateIndexMergePath +// which runs as part of DeriveStats after fillIndexPath populates these fresh paths. +func freshAccessPath(src *util.AccessPath) *util.AccessPath { + return &util.AccessPath{ + Index: src.Index, + StoreType: src.StoreType, + IsIntHandlePath: src.IsIntHandlePath, + IsCommonHandlePath: src.IsCommonHandlePath, + Forced: src.Forced, + ForceKeepOrder: src.ForceKeepOrder, + ForceNoKeepOrder: src.ForceNoKeepOrder, + ForcePartialOrder: src.ForcePartialOrder, + IsUkShardIndexPath: src.IsUkShardIndexPath, + IndexLookUpPushDownBy: src.IndexLookUpPushDownBy, + NoncacheableReason: src.NoncacheableReason, + } +} diff --git a/pkg/planner/core/rule/logical_rules.go b/pkg/planner/core/rule/logical_rules.go index 0d84115b4a87d..8e28fb94bfedb 100644 --- a/pkg/planner/core/rule/logical_rules.go +++ b/pkg/planner/core/rule/logical_rules.go @@ -39,6 +39,11 @@ const ( FlagPushDownTopN FlagSyncWaitStatsLoadPoint FlagJoinReOrder +<<<<<<< HEAD +======= + FlagOuterJoinToSemiJoin + FlagCorrelate +>>>>>>> 7357a2e2f90 (planner: correlate subquery rule (#66206)) FlagPruneColumnsAgain FlagPushDownSequence FlagResolveExpand diff --git a/pkg/planner/core/rule_correlate.go b/pkg/planner/core/rule_correlate.go new file mode 100644 index 0000000000000..ac7d6e99785ad --- /dev/null +++ b/pkg/planner/core/rule_correlate.go @@ -0,0 +1,345 @@ +// Copyright 2026 PingCAP, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package core + +import ( + "context" + "fmt" + + "github.com/pingcap/tidb/pkg/expression" + "github.com/pingcap/tidb/pkg/parser/mysql" + "github.com/pingcap/tidb/pkg/planner/core/base" + "github.com/pingcap/tidb/pkg/planner/core/operator/logicalop" + "github.com/pingcap/tidb/pkg/planner/util" + "github.com/pingcap/tidb/pkg/types" + "github.com/pingcap/tidb/pkg/util/logutil" + "go.uber.org/zap" +) + +// CorrelateSolver tries to convert semi-join LogicalJoin back to correlated LogicalApply. +// This is the reverse of DecorrelateSolver and is useful when a correlated nested-loop +// (index lookup per outer row) might be more efficient than a hash semi-join. +type CorrelateSolver struct{} + +// Optimize implements base.LogicalOptRule.<0th> interface. +func (s *CorrelateSolver) Optimize(ctx context.Context, p base.LogicalPlan) (retPlan base.LogicalPlan, retChanged bool, retErr error) { + defer func() { + if r := recover(); r != nil { + logutil.BgLogger().Warn("CorrelateSolver panic", + zap.Any("recover", r), + zap.Stack("stack")) + retPlan = nil + retChanged = false + retErr = fmt.Errorf("CorrelateSolver panic: %v", r) + } + }() + return s.correlate(ctx, p) +} + +func (s *CorrelateSolver) correlate(ctx context.Context, p base.LogicalPlan) (base.LogicalPlan, bool, error) { + // CTE's logical optimization is independent. + if _, ok := p.(*logicalop.LogicalCTE); ok { + return p, false, nil + } + + // First recurse into children. + planChanged := false + newChildren := make([]base.LogicalPlan, 0, len(p.Children())) + for _, child := range p.Children() { + np, changed, err := s.correlate(ctx, child) + if err != nil { + return nil, false, err + } + planChanged = planChanged || changed + newChildren = append(newChildren, np) + } + p.SetChildren(newChildren...) + + // Check if this node is a LogicalApply — if so, skip (already correlated). + if _, isApply := p.(*logicalop.LogicalApply); isApply { + return p, planChanged, nil + } + + // Check if this node is a LogicalJoin with a semi-join type that was + // marked for re-correlation (from a non-correlated IN subquery). + join, isJoin := p.(*logicalop.LogicalJoin) + if !isJoin || !join.JoinType.IsSemiJoin() || !join.PreferCorrelate { + return p, planChanged, nil + } + + // Must have EqualConditions to correlate (skip if only NAEQConditions). + if len(join.EqualConditions) == 0 { + return p, planChanged, nil + } + + // For v1: skip null-aware conditions, LeftConditions, and OtherConditions. + if len(join.NAEQConditions) > 0 || len(join.LeftConditions) > 0 || len(join.OtherConditions) > 0 { + return p, planChanged, nil + } + + leftSchema := join.Children()[0].Schema() + rightSchema := join.Children()[1].Schema() + + // Left outer semi joins (scalar IN / NOT IN) require 3-valued NULL + // semantics: the joiner must distinguish "no match" (→ 0) from "unknown + // due to NULL" (→ NULL). It does this by evaluating the equality join + // condition and tracking whether any comparison returned NULL. + // + // When we push the equality into the inner side as a correlated filter + // (rightCol = CorCol(leftCol)), two problems arise: + // 1. If the inner column is nullable, NULL inner values are silently + // filtered out (NULL = X → NULL → filtered), so the joiner never + // sees them and returns 0 instead of NULL. + // 2. If the outer column is nullable and its value is NULL, the + // correlated filter becomes rightCol = NULL, which filters out all + // inner rows, and the joiner returns 0 instead of NULL. + // + // Skip unless ALL equality columns on both sides are proven NOT NULL. + if join.JoinType == base.LeftOuterSemiJoin || join.JoinType == base.AntiLeftOuterSemiJoin { + for _, eqCond := range join.EqualConditions { + col0, col1, ok := expression.IsColOpCol(eqCond) + if !ok { + return p, planChanged, nil + } + leftCol := leftSchema.RetrieveColumn(col0) + rightCol := rightSchema.RetrieveColumn(col1) + if leftCol == nil || rightCol == nil { + leftCol = leftSchema.RetrieveColumn(col1) + rightCol = rightSchema.RetrieveColumn(col0) + } + if leftCol == nil || rightCol == nil { + return p, planChanged, nil + } + if !mysql.HasNotNullFlag(leftCol.RetType.GetFlag()) || !mysql.HasNotNullFlag(rightCol.RetType.GetFlag()) { + return p, planChanged, nil + } + } + } + + selConds := make([]expression.Expression, 0, len(join.EqualConditions)+len(join.RightConditions)) + corCols := make([]*expression.CorrelatedColumn, 0, len(join.EqualConditions)) + + // Convert EqualConditions to correlated conditions. + for _, eqCond := range join.EqualConditions { + cond, corCol := s.buildCorrelatedCond(eqCond, leftSchema, rightSchema, join) + if cond == nil { + // Can't correlate this condition; abort. + return p, planChanged, nil + } + selConds = append(selConds, cond) + corCols = append(corCols, corCol) + } + + // Move RightConditions to the selection (they reference only the inner side). + selConds = append(selConds, join.RightConditions...) + + // Clone the inner subtree so PPD can modify the clone without affecting + // the Join's inner child (which must retain its original conditions). + // If the subtree contains an unhandled operator type, abort to avoid corruption. + clonedInner, ok := cloneLogicalSubtree(join.Children()[1]) + if !ok { + return p, planChanged, nil + } + + // Lift DataSource conditions back into Selection nodes. The original PPD + // pushed conditions all the way into DataSource.AllConds and cleared them + // from ancestor operators (e.g., Join.RightConditions). When we re-run PPD + // below, the Join re-collects conditions from its own fields (not from + // DataSource.AllConds), so conditions that were pushed past the Join would + // be lost. Wrapping each DataSource in a Selection restores the pre-PPD + // state so the re-run can properly redistribute all conditions. + clonedInner = liftDataSourceConds(clonedInner) + + sel := logicalop.LogicalSelection{Conditions: selConds}.Init(join.SCtx(), join.QueryBlockOffset()) + sel.SetChildren(clonedInner) + + // Run predicate push-down on the inner subtree so the new correlated + // predicates reach the DataSource (for index access path selection). + // PPD has already finished by the time this rule runs, so without this + // local pass the predicates would stay in the Selection and the inner + // side could only do full scans. + _, innerPlan, err := sel.PredicatePushDown(nil) + if err != nil { + // PPD failed (e.g., conditions reference columns pruned from the + // DataSource schema); abort the correlate optimization. + return p, planChanged, nil + } + + // Reset stats on DataSources that received correlated conditions so DeriveStats + // re-runs during physical optimization. This is necessary because the original + // DeriveStats ran before the correlate rule added correlated conditions, so the + // index access paths were built without them. + resetStatsForCorrelatedDS(innerPlan) + + // For semi-join semantics (EXISTS/IN and NOT EXISTS/NOT IN), add Limit 1 on + // the inner side. The Apply executor materializes all inner rows per outer + // key via fetchAllInners; a Limit 1 enables early exit since semi/anti-semi + // joins only need to know whether any matching row exists. + // This mirrors what expression_rewriter does for NO_DECORRELATE EXISTS. + if !hasLimit(innerPlan) { + limit := logicalop.LogicalLimit{Count: 1}.Init(join.SCtx(), join.QueryBlockOffset()) + limit.SetChildren(innerPlan) + innerPlan = limit + } + + // Build the LogicalApply. + ap := logicalop.LogicalApply{}.Init(join.SCtx(), join.QueryBlockOffset()) + ap.JoinType = join.JoinType + ap.CorCols = corCols + // Copy hint fields so hint behavior is preserved in the alternative. + ap.HintInfo = join.HintInfo + ap.PreferJoinType = join.PreferJoinType + ap.PreferJoinOrder = join.PreferJoinOrder + ap.LeftPreferJoinType = join.LeftPreferJoinType + ap.RightPreferJoinType = join.RightPreferJoinType + ap.SetChildren(join.Children()[0], innerPlan) + ap.SetSchema(join.Schema().Clone()) + ap.SetOutputNames(join.OutputNames()) + + // Replace the Join with the Apply. In the alternative logical plans framework, + // this round produces a complete plan; the top-level cost comparison across + // rounds selects the winner. + return ap, true, nil +} + +// buildCorrelatedCond converts an equal condition from the join into a correlated condition +// for the inner selection. It identifies which column comes from the left (outer) side and +// creates a CorrelatedColumn for it, then builds a new condition: rightCol CorCol(leftCol). +func (*CorrelateSolver) buildCorrelatedCond( + eqCond *expression.ScalarFunction, + leftSchema *expression.Schema, + rightSchema *expression.Schema, + join *logicalop.LogicalJoin, +) (expression.Expression, *expression.CorrelatedColumn) { + col0, col1, ok := expression.IsColOpCol(eqCond) + if !ok { + return nil, nil + } + + // Determine which column is from the left (outer) side and which from the right (inner). + leftCol := leftSchema.RetrieveColumn(col0) + rightCol := rightSchema.RetrieveColumn(col1) + if leftCol == nil || rightCol == nil { + // Try swapped order. + leftCol = leftSchema.RetrieveColumn(col1) + rightCol = rightSchema.RetrieveColumn(col0) + } + if leftCol == nil || rightCol == nil { + return nil, nil + } + + // Create a CorrelatedColumn for the outer (left) column. + // Data must be initialized (non-nil) to avoid panics during physical planning. + corCol := &expression.CorrelatedColumn{Column: *leftCol, Data: new(types.Datum)} + + // Create the correlated condition: rightCol CorCol(leftCol). + cond := expression.NewFunctionInternal( + join.SCtx().GetExprCtx(), + eqCond.FuncName.L, + types.NewFieldType(mysql.TypeTiny), + rightCol, corCol, + ) + + return cond, corCol +} + +// liftDataSourceConds walks the plan tree and for each DataSource with +// non-empty AllConds, wraps it in a Selection node containing those conditions. +// This "un-pushes" conditions that the original PPD pushed into DataSources, +// so that a subsequent PPD re-run (in correlate()) can properly redistribute +// all conditions — including those that would otherwise be silently dropped +// when DataSource.PredicatePushDown overwrites AllConds. +func liftDataSourceConds(p base.LogicalPlan) base.LogicalPlan { + // Recurse into children first, potentially replacing them. + for i, child := range p.Children() { + newChild := liftDataSourceConds(child) + if newChild != child { + p.Children()[i] = newChild + } + } + + // If this is a DataSource with AllConds, wrap it in a Selection. + if ds, ok := p.(*logicalop.DataSource); ok && len(ds.AllConds) > 0 { + sel := logicalop.LogicalSelection{ + Conditions: ds.AllConds, + }.Init(ds.SCtx(), ds.QueryBlockOffset()) + sel.SetChildren(ds) + + // Clear DataSource conditions; the PPD re-run will push them back. + ds.AllConds = nil + ds.PushedDownConds = nil + + return sel + } + + return p +} + +// resetStatsForCorrelatedDS walks the inner subtree and clears StatsInfo on +// DataSources that have correlated conditions in AllConds, plus all ancestor +// plan nodes up to the root. This forces DeriveStats to re-run during physical +// optimization so that index access paths are rebuilt with the correlated +// conditions. +// +// For correlated DataSources, fresh AccessPaths are created so fillIndexPath +// starts from a clean state with the new correlated conditions. Non-correlated +// DataSources retain their deep-cloned AccessPaths and stats (set during +// cloning) so DeriveStats returns early — this avoids failures when conditions +// reference columns that column pruning removed from the DataSource's schema. +func resetStatsForCorrelatedDS(p base.LogicalPlan) bool { + hasCorrelated := false + + // Check if this is a DataSource with correlated conditions. + if ds, ok := p.(*logicalop.DataSource); ok { + for _, cond := range ds.AllConds { + if len(expression.ExtractCorColumns(cond)) > 0 { + hasCorrelated = true + break + } + } + if hasCorrelated { + // Create fresh AccessPaths so fillIndexPath rebuilds them with the + // correlated conditions from a clean state. + origPaths := ds.AllPossibleAccessPaths + ds.AllPossibleAccessPaths = make([]*util.AccessPath, len(origPaths)) + for i, ap := range origPaths { + ds.AllPossibleAccessPaths[i] = freshAccessPath(ap) + } + ds.PossibleAccessPaths = append([]*util.AccessPath(nil), ds.AllPossibleAccessPaths...) + } + } + + // Recurse into children. + for _, child := range p.Children() { + if resetStatsForCorrelatedDS(child) { + hasCorrelated = true + } + } + + // Reset stats on this node if it or any descendant has correlated conditions. + // This ensures DeriveStats re-runs for the affected subtree path. + if hasCorrelated { + if blp, ok := p.GetBaseLogicalPlan().(*logicalop.BaseLogicalPlan); ok { + blp.SetStats(nil) + } + } + + return hasCorrelated +} + +// Name implements base.LogicalOptRule.<1st> interface. +func (*CorrelateSolver) Name() string { + return "correlate" +} diff --git a/pkg/planner/optimize.go b/pkg/planner/optimize.go index 8f19e77467d85..2886743e8ab38 100644 --- a/pkg/planner/optimize.go +++ b/pkg/planner/optimize.go @@ -469,6 +469,64 @@ var planBuilderPool = sync.Pool{ // optimizeCnt is a global variable only used for test. var optimizeCnt int +func shouldTryNonDecorrelationRound(sessVars *variable.SessionVars) bool { + return sessVars.EnableAlternativeLogicalPlans && + sessVars.StmtCtx.AlternativeLogicalPlanDecorrelatedApply && + !sessVars.StmtCtx.AlternativeLogicalPlanSameOrderIndexJoin +} + +func shouldTryOrderAwareReorderRound(sessVars *variable.SessionVars) bool { + return sessVars.EnableAlternativeLogicalPlans && + sessVars.StmtCtx.AlternativeLogicalPlanOrderAwareJoinReorder +} + +func shouldTryCorrelateRound(sessVars *variable.SessionVars) bool { + return sessVars.EnableAlternativeLogicalPlans && + sessVars.StmtCtx.AlternativeLogicalPlanPreferCorrelate +} + +// alternativeRound describes one alternative logical-plan round. +// adjustFlag adjusts the optimization flags for the round. +// enabled returns true when the round should be attempted. +// setup/cleanup optionally modify session state before/after plan building. +type alternativeRound struct { + name string + adjustFlag func(uint64) uint64 + enabled func(*variable.SessionVars) bool + setup func(*variable.SessionVars) + cleanup func(*variable.SessionVars) +} + +// savedEnableCorrelateSubquery holds the pre-round value of +// EnableCorrelateSubquery so setup/cleanup can share it without a closure +// wrapper. Safe because optimize is single-threaded per session. +var savedEnableCorrelateSubquery bool + +var alternativeRounds = [...]alternativeRound{ + { + name: "non-decorrelate", + adjustFlag: func(flag uint64) uint64 { return flag &^ rule.FlagDecorrelate }, + enabled: shouldTryNonDecorrelationRound, + }, + { + name: "order-aware-reorder", + adjustFlag: func(flag uint64) uint64 { return flag | rule.FlagOrderAwareJoinReorder }, + enabled: shouldTryOrderAwareReorderRound, + }, + { + name: "correlate", + adjustFlag: func(flag uint64) uint64 { return flag | rule.FlagCorrelate }, + enabled: shouldTryCorrelateRound, + setup: func(sv *variable.SessionVars) { + savedEnableCorrelateSubquery = sv.EnableCorrelateSubquery + sv.EnableCorrelateSubquery = true + }, + cleanup: func(sv *variable.SessionVars) { + sv.EnableCorrelateSubquery = savedEnableCorrelateSubquery + }, + }, +} + func optimize(ctx context.Context, sctx planctx.PlanContext, node *resolve.NodeW, is infoschema.InfoSchema) (base.Plan, types.NameSlice, float64, error) { failpoint.Inject("checkOptimizeCountOne", func(val failpoint.Value) { // only count the optimization for SQL with specified text @@ -536,6 +594,64 @@ func optimize(ctx context.Context, sctx planctx.PlanContext, node *resolve.NodeW return finalPlan, names, cost, err } + // Pre-compute which rounds are enabled based on the signals from the first + // (default) build. This prevents signal leakage: alternative rounds rebuild + // the plan and may set AlternativeLogicalPlan* signals as a side effect, + // which are not reset by restoreLogicalPlanBuildCtx. Evaluating enabled() + // upfront ensures each round's eligibility is determined solely by the + // original build's signals. + enabledRounds := make([]alternativeRound, 0, len(alternativeRounds)) + for _, round := range alternativeRounds { + if round.enabled(sessVars) { + enabledRounds = append(enabledRounds, round) + } + } + for _, round := range enabledRounds { + restoreLogicalPlanBuildCtx(sessVars, initialLogicalPlanCtx) + failpoint.Inject("failIfAlternativeLogicalPlanRoundTriggered", func(val failpoint.Value) { + if testSQL, ok := val.(string); ok && testSQL == node.Node.OriginalText() { + failpoint.Return(nil, nil, 0, errors.New("unexpected alternative logical plan round")) + } + }) + + // Use a closure so that defer-based cleanup runs at the end of each + // iteration, not at function exit. This ensures session state (e.g. + // EnableCorrelateSubquery) is restored even if the round panics. + func() { + if round.setup != nil { + round.setup(sessVars) + defer round.cleanup(sessVars) + } + p, names, nonLogical, err = buildAndOptimizeLogicalPlanRound( + ctx, + sctx, + node, + is, + hintProcessor, + &checked, + &optimizeStarted, + &beginOpt, + needRestoreLogicalPlanCtx, + &bestPlan, + &bestNames, + &bestCost, + &bestLogicalPlanCtx, + round.adjustFlag, + ) + }() + if err != nil { + // Alternative rounds are optional optimizations. If one fails, + // log and continue — the first round's plan is still valid. + logutil.BgLogger().Warn("alternative logical plan round failed", + zap.String("round", round.name), + zap.Error(err)) + continue + } + if nonLogical { + return p, names, 0, nil + } + } + beginOpt := time.Now() finalPlan, cost, err := core.DoOptimize(ctx, sctx, builder.GetOptFlag(), logic) // TODO: capture plan replayer here if it matches sql and plan digest diff --git a/pkg/sessionctx/stmtctx/stmtctx.go b/pkg/sessionctx/stmtctx/stmtctx.go index 7e3feeb3732f0..9c11f91e8a9b2 100644 --- a/pkg/sessionctx/stmtctx/stmtctx.go +++ b/pkg/sessionctx/stmtctx/stmtctx.go @@ -65,6 +65,26 @@ func AllocateTaskID() uint64 { // SQLWarn relates a sql warning and it's level. type SQLWarn = contextutil.SQLWarn +// LogicalPlanBuildState stores the statement-scoped planner state that is mutated while +// building a logical plan from AST. +type LogicalPlanBuildState struct { + warnings []SQLWarn + extraWarnings []SQLWarn + tables []TableEntry + tableStats map[int64]any + lockTableIDs map[int64]struct{} + tblInfo2UnionScan map[*model.TableInfo]bool + useDynamicPruneMode bool + viewDepth int32 + colRefFromUpdatePlan intset.FastIntSet + // plan cache related stuff + planCacheUseCache bool + planCacheType contextutil.PlanCacheType + planCacheUnqualified string + planCacheForce bool + planCacheAlwaysWarn bool +} + type jsonSQLWarn struct { Level string `json:"level"` SQLErr *terror.Error `json:"err,omitempty"` @@ -274,6 +294,7 @@ type StatementContext struct { // in stmtCtx IsStaleness bool InRestrictedSQL bool + ViewDepth int32 // mu struct holds variables that change during execution. mu *stmtCtxMu @@ -446,6 +467,20 @@ type StatementContext struct { UseDynamicPruneMode bool // ColRefFromPlan mark the column ref used by assignment in update statement. ColRefFromUpdatePlan intset.FastIntSet + // AlternativeLogicalPlanDecorrelatedApply indicates whether the current logical + // optimization round decorrelated at least one Apply into Join. + AlternativeLogicalPlanDecorrelatedApply bool + // AlternativeLogicalPlanSameOrderIndexJoin indicates whether the current first + // round already produced a same-order index join candidate for a decorrelated Apply. + AlternativeLogicalPlanSameOrderIndexJoin bool + // AlternativeLogicalPlanOrderAwareJoinReorder indicates whether at least one + // logical build round produced an order-aware join reorder candidate that is + // worth exploring in a dedicated alternative round. + AlternativeLogicalPlanOrderAwareJoinReorder bool + // AlternativeLogicalPlanPreferCorrelate indicates whether the current logical + // build round encountered a non-correlated IN subquery eligible for the + // correlate-to-Apply alternative. + AlternativeLogicalPlanPreferCorrelate bool // IsExplainAnalyzeDML is true if the statement is "explain analyze DML executors", before responding the explain // results to the client, the transaction should be committed first. See issue #37373 for more details. @@ -572,6 +607,78 @@ func (sc *StatementContext) Reset() bool { return true } +// SaveLogicalPlanBuildState captures the statement-scoped planner state before building +// another logical plan candidate from the same AST. +func (sc *StatementContext) SaveLogicalPlanBuildState() LogicalPlanBuildState { + planCacheUseCache, planCacheType, planCacheUnqualified, planCacheForce, planCacheAlwaysWarn := sc.PlanCacheTracker.Save() + return LogicalPlanBuildState{ + warnings: slices.Clone(sc.GetWarnings()), + extraWarnings: slices.Clone(sc.GetExtraWarnings()), + tables: slices.Clone(sc.Tables), + tableStats: maps.Clone(sc.TableStats), + lockTableIDs: maps.Clone(sc.LockTableIDs), + tblInfo2UnionScan: maps.Clone(sc.TblInfo2UnionScan), + useDynamicPruneMode: sc.UseDynamicPruneMode, + viewDepth: sc.ViewDepth, + colRefFromUpdatePlan: sc.ColRefFromUpdatePlan.Copy(), + planCacheUseCache: planCacheUseCache, + planCacheType: planCacheType, + planCacheUnqualified: planCacheUnqualified, + planCacheForce: planCacheForce, + planCacheAlwaysWarn: planCacheAlwaysWarn, + } +} + +// RestoreLogicalPlanBuildState restores the statement-scoped planner state after a +// discarded logical plan build attempt. +func (sc *StatementContext) RestoreLogicalPlanBuildState(state LogicalPlanBuildState) { + sc.SetWarnings(slices.Clone(state.warnings)) + sc.SetExtraWarnings(slices.Clone(state.extraWarnings)) + sc.Tables = slices.Clone(state.tables) + sc.TableStats = maps.Clone(state.tableStats) + sc.LockTableIDs = maps.Clone(state.lockTableIDs) + sc.TblInfo2UnionScan = maps.Clone(state.tblInfo2UnionScan) + sc.UseDynamicPruneMode = state.useDynamicPruneMode + sc.ViewDepth = state.viewDepth + sc.ColRefFromUpdatePlan.CopyFrom(state.colRefFromUpdatePlan) + sc.PlanCacheTracker.Restore(state.planCacheUseCache, state.planCacheType, state.planCacheUnqualified, state.planCacheForce, state.planCacheAlwaysWarn) + sc.RangeFallbackHandler = contextutil.NewRangeFallbackHandler(&sc.PlanCacheTracker, sc) +} + +// ResetAlternativeLogicalPlanSignals clears the statement-local signals used by the +// alternative logical plan feature. +func (sc *StatementContext) ResetAlternativeLogicalPlanSignals() { + sc.AlternativeLogicalPlanDecorrelatedApply = false + sc.AlternativeLogicalPlanSameOrderIndexJoin = false + sc.AlternativeLogicalPlanOrderAwareJoinReorder = false + sc.AlternativeLogicalPlanPreferCorrelate = false +} + +// MarkAlternativeLogicalPlanDecorrelatedApply records that at least one Apply has +// been decorrelated into a Join in the current round. +func (sc *StatementContext) MarkAlternativeLogicalPlanDecorrelatedApply() { + sc.AlternativeLogicalPlanDecorrelatedApply = true +} + +// MarkAlternativeLogicalPlanSameOrderIndexJoin records that the current first round +// has already produced a same-order index join candidate for a decorrelated Apply. +func (sc *StatementContext) MarkAlternativeLogicalPlanSameOrderIndexJoin() { + sc.AlternativeLogicalPlanSameOrderIndexJoin = true +} + +// MarkAlternativeLogicalPlanOrderAwareJoinReorder records that the current +// logical build round produced an order-aware join reorder candidate. +func (sc *StatementContext) MarkAlternativeLogicalPlanOrderAwareJoinReorder() { + sc.AlternativeLogicalPlanOrderAwareJoinReorder = true +} + +// MarkAlternativeLogicalPlanPreferCorrelate records that the current logical +// build round encountered a non-correlated IN subquery that is eligible for +// the correlate-to-Apply alternative. +func (sc *StatementContext) MarkAlternativeLogicalPlanPreferCorrelate() { + sc.AlternativeLogicalPlanPreferCorrelate = true +} + // CtxID returns the context id of the statement func (sc *StatementContext) CtxID() uint64 { return sc.ctxID diff --git a/pkg/sessionctx/variable/session.go b/pkg/sessionctx/variable/session.go index 94ed72e817207..659598264b236 100644 --- a/pkg/sessionctx/variable/session.go +++ b/pkg/sessionctx/variable/session.go @@ -1127,6 +1127,11 @@ type SessionVars struct { // EnableSemiJoinRewrite enables the SEMI_JOIN_REWRITE hint for subqueries in the where clause. EnableSemiJoinRewrite bool + // EnableCorrelateSubquery is an internal flag (not user-facing) toggled by the + // correlate alternative round to enable conversion of non-correlated semi-joins + // to correlated Apply during plan building. + EnableCorrelateSubquery bool + // AllowProjectionPushDown enables pushdown projection on TiKV. AllowProjectionPushDown bool diff --git a/pkg/util/context/plancache.go b/pkg/util/context/plancache.go index 9cac5ec81012e..d6c46cd4db81e 100644 --- a/pkg/util/context/plancache.go +++ b/pkg/util/context/plancache.go @@ -122,6 +122,26 @@ func (h *PlanCacheTracker) EnablePlanCache() { h.useCache = true } +// Save captures the mutable planning-time state of the tracker. +func (h *PlanCacheTracker) Save() (useCache bool, cacheType PlanCacheType, planCacheUnqualified string, forcePlanCache bool, alwaysWarnSkipCache bool) { + h.mu.Lock() + defer h.mu.Unlock() + + return h.useCache, h.cacheType, h.planCacheUnqualified, h.forcePlanCache, h.alwaysWarnSkipCache +} + +// Restore restores the mutable planning-time state of the tracker. +func (h *PlanCacheTracker) Restore(useCache bool, cacheType PlanCacheType, planCacheUnqualified string, forcePlanCache bool, alwaysWarnSkipCache bool) { + h.mu.Lock() + defer h.mu.Unlock() + + h.useCache = useCache + h.cacheType = cacheType + h.planCacheUnqualified = planCacheUnqualified + h.forcePlanCache = forcePlanCache + h.alwaysWarnSkipCache = alwaysWarnSkipCache +} + // UseCache returns whether to use plan cache. func (h *PlanCacheTracker) UseCache() bool { h.mu.Lock()