From 920cd75e6daff10f08d192ae29b7d6baf9f812e0 Mon Sep 17 00:00:00 2001 From: Michael Greenberg Date: Wed, 27 May 2026 16:08:38 -0400 Subject: [PATCH 1/2] parameterize MapFilterProject, MfpPlan, and SafeMfpPlan, using a trait to indicate necessary interface --- src/expr/src/explain.rs | 3 +- src/expr/src/explain/text.rs | 91 ++-- src/expr/src/interpret.rs | 4 +- src/expr/src/lib.rs | 4 +- src/expr/src/linear.rs | 394 +++++++++--------- src/expr/src/scalar.rs | 80 ++-- src/expr/src/scalar/columns.rs | 6 + src/expr/src/scalar/func/binary.rs | 3 +- src/expr/src/scalar/func/impls/array.rs | 3 +- .../src/scalar/func/impls/case_literal.rs | 3 +- src/expr/src/scalar/func/impls/list.rs | 3 +- src/expr/src/scalar/func/impls/map.rs | 3 +- src/expr/src/scalar/func/impls/range.rs | 3 +- src/expr/src/scalar/func/impls/record.rs | 3 +- src/expr/src/scalar/func/variadic.rs | 10 +- src/expr/src/scalar/optimizable.rs | 163 ++++++++ src/repr/src/explain.rs | 12 + 17 files changed, 487 insertions(+), 301 deletions(-) create mode 100644 src/expr/src/scalar/optimizable.rs diff --git a/src/expr/src/explain.rs b/src/expr/src/explain.rs index 1826eb1219c45..de9cf5e097999 100644 --- a/src/expr/src/explain.rs +++ b/src/expr/src/explain.rs @@ -30,7 +30,8 @@ use crate::{ }; pub use crate::explain::text::{ - HumanizedExplain, HumanizedExpr, HumanizedNotice, HumanizerMode, fmt_text_constant_rows, + HumanizeDisplay, HumanizedExplain, HumanizedExpr, HumanizedNotice, HumanizerMode, + fmt_text_constant_rows, }; mod json; diff --git a/src/expr/src/explain/text.rs b/src/expr/src/explain/text.rs index 8949be22f86fa..851c7d51ee32f 100644 --- a/src/expr/src/explain/text.rs +++ b/src/expr/src/explain/text.rs @@ -10,7 +10,7 @@ //! `EXPLAIN ... AS TEXT` support for structures defined in this crate. use std::collections::BTreeMap; -use std::fmt; +use std::fmt::{self, Display as _}; use std::sync::Arc; use itertools::Itertools; @@ -28,7 +28,7 @@ use mz_sql_parser::ast::Ident; use crate::explain::{ExplainMultiPlan, ExplainSinglePlan}; use crate::{ AccessStrategy, AggregateExpr, EvalError, Id, JoinImplementation, JoinInputCharacteristics, - LocalId, MapFilterProject, MirRelationExpr, MirScalarExpr, RowSetFinishing, + LocalId, MapFilterProject, MirRelationExpr, MirScalarExpr, OptimizableExpr, RowSetFinishing, }; impl<'a, T: 'a> DisplayText for ExplainSinglePlan<'a, T> @@ -223,10 +223,11 @@ where } } -impl<'a, C, M> DisplayText for HumanizedExpr<'a, MapFilterProject, M> +impl<'a, C, M, E> DisplayText for HumanizedExpr<'a, MapFilterProject, M> where C: AsMut, M: HumanizerMode, + E: OptimizableExpr + HumanizeDisplay, { fn fmt_text(&self, f: &mut fmt::Formatter<'_>, ctx: &mut C) -> fmt::Result { let (scalars, predicates, outputs, input_arity) = ( @@ -258,7 +259,9 @@ where } } -impl<'a, M: HumanizerMode> HumanizedExpr<'a, MapFilterProject, M> { +impl<'a, M: HumanizerMode, E: OptimizableExpr + HumanizeDisplay> + HumanizedExpr<'a, MapFilterProject, M> +{ /// Render an MFP using the default (concise) syntax. pub fn fmt_default_text( &self, @@ -1246,7 +1249,10 @@ where } } -impl<'a, M> ScalarOps for HumanizedExpr<'a, MirScalarExpr, M> { +impl<'a, T, M> ScalarOps for HumanizedExpr<'a, T, M> +where + T: ScalarOps, +{ fn match_col_ref(&self) -> Option { self.expr.match_col_ref() } @@ -1256,57 +1262,54 @@ impl<'a, M> ScalarOps for HumanizedExpr<'a, MirScalarExpr, M> { } } -impl<'a, M> ScalarOps for HumanizedExpr<'a, usize, M> { - fn match_col_ref(&self) -> Option { - Some(*self.expr) - } - - fn references(&self, col_ref: usize) -> bool { - col_ref == *self.expr - } +pub trait HumanizeDisplay: Sized { + fn humanize<'a, M: HumanizerMode>( + e: &HumanizedExpr<'a, Self, M>, + f: &mut fmt::Formatter<'_>, + ) -> fmt::Result; } -impl<'a, M> fmt::Display for HumanizedExpr<'a, MirScalarExpr, M> -where - M: HumanizerMode, -{ - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { +impl HumanizeDisplay for MirScalarExpr { + fn humanize<'a, M: HumanizerMode>( + e: &HumanizedExpr<'a, MirScalarExpr, M>, + f: &mut fmt::Formatter<'_>, + ) -> fmt::Result { use MirScalarExpr::*; - match self.expr { + match e.expr { Column(i, TreatAsEqual(None)) => { // Delegate to the `HumanizedExpr<'a, _>` implementation (plain column reference). - self.child(i).fmt(f) + e.child(i).fmt(f) } Column(i, TreatAsEqual(Some(name))) => { // Delegate to the `HumanizedExpr<'a, _>` implementation (with stored name information) - self.child(&(i, name)).fmt(f) + e.child(&(i, name)).fmt(f) } Literal(row, _) => { // Delegate to the `HumanizedExpr<'a, _>` implementation. - self.child(row).fmt(f) + e.child(row).fmt(f) } CallUnmaterializable(func) => write!(f, "{}()", func), CallUnary { func, expr } => { if let crate::UnaryFunc::Not(_) = *func { if let CallUnary { func, expr } = expr.as_ref() { if let Some(is) = func.is() { - let expr = self.child::(&*expr); + let expr = e.child::(&*expr); return write!(f, "({}) IS NOT {}", expr, is); } } } if let Some(is) = func.is() { - let expr = self.child::(&*expr); + let expr = e.child::(&*expr); write!(f, "({}) IS {}", expr, is) } else { - let expr = self.child::(&*expr); + let expr = e.child::(&*expr); write!(f, "{}({})", func, expr) } } CallBinary { func, expr1, expr2 } => { - let expr1 = self.child::(&*expr1); - let expr2 = self.child::(&*expr2); + let expr1 = e.child::(&*expr1); + let expr2 = e.child::(&*expr2); if func.is_infix_op() { write!(f, "({} {} {})", expr1, func, expr2) } else { @@ -1317,55 +1320,65 @@ where use crate::VariadicFunc::*; match func { CaseLiteral(cl) => { - let input = self.child::(&exprs[0]); + let input = e.child::(&exprs[0]); write!(f, "case_lookup {}", input)?; for entry in &cl.lookup { - let result = self.child::(&exprs[entry.expr_index]); + let result = e.child::(&exprs[entry.expr_index]); write!(f, " when ")?; - self.mode.humanize_datum(entry.literal.unpack_first(), f)?; + e.mode.humanize_datum(entry.literal.unpack_first(), f)?; write!(f, " then {}", result)?; } - let els = self.child::(exprs.last().unwrap()); + let els = e.child::(exprs.last().unwrap()); write!(f, " else {} end", els) } ArrayCreate(..) => { - let exprs = exprs.iter().map(|expr| self.child(expr)); + let exprs = exprs.iter().map(|expr| e.child(expr)); let exprs = separated(", ", exprs); write!(f, "array[{}]", exprs) } ListCreate(..) => { - let exprs = exprs.iter().map(|expr| self.child(expr)); + let exprs = exprs.iter().map(|expr| e.child(expr)); let exprs = separated(", ", exprs); write!(f, "list[{}]", exprs) } RecordCreate(..) => { - let exprs = exprs.iter().map(|expr| self.child(expr)); + let exprs = exprs.iter().map(|expr| e.child(expr)); let exprs = separated(", ", exprs); write!(f, "row({})", exprs) } func if func.is_infix_op() && exprs.len() > 1 => { - let exprs = exprs.iter().map(|expr| self.child(expr)); + let exprs = exprs.iter().map(|expr| e.child(expr)); let func = format!(" {} ", func); let exprs = separated(&func, exprs); write!(f, "({})", exprs) } func => { - let exprs = exprs.iter().map(|expr| self.child(expr)); + let exprs = exprs.iter().map(|expr| e.child(expr)); let exprs = separated(", ", exprs); write!(f, "{}({})", func, exprs) } } } If { cond, then, els } => { - let cond = self.child::(&*cond); - let then = self.child::(&*then); - let els = self.child::(&*els); + let cond = e.child::(&*cond); + let then = e.child::(&*then); + let els = e.child::(&*els); write!(f, "case when {} then {} else {} end", cond, then, els) } } } } +impl<'a, T, M> fmt::Display for HumanizedExpr<'a, T, M> +where + T: HumanizeDisplay, + M: HumanizerMode, +{ + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + T::humanize(self, f) + } +} + impl<'a, M> fmt::Display for HumanizedExpr<'a, AggregateExpr, M> where M: HumanizerMode, diff --git a/src/expr/src/interpret.rs b/src/expr/src/interpret.rs index 547973270dc0a..a994a026f2952 100644 --- a/src/expr/src/interpret.rs +++ b/src/expr/src/interpret.rs @@ -12,13 +12,11 @@ use std::fmt::Debug; use mz_repr::{Datum, ReprColumnType, ReprRelationType, ReprScalarType, Row, RowArena}; -use crate::Eval; use crate::scalar::func::variadic::And; use crate::{ - BinaryFunc, EvalError, MapFilterProject, MfpPlan, MirScalarExpr, UnaryFunc, + BinaryFunc, Eval, EvalError, MapFilterProject, MfpPlan, MirScalarExpr, UnaryFunc, UnmaterializableFunc, VariadicFunc, func, }; - /// An inclusive range of non-null datum values. #[derive(Clone, Eq, PartialEq, Debug)] enum Values<'a> { diff --git a/src/expr/src/lib.rs b/src/expr/src/lib.rs index 3f3f3c7bfba30..44d5a6987d7f8 100644 --- a/src/expr/src/lib.rs +++ b/src/expr/src/lib.rs @@ -48,8 +48,8 @@ pub use relation::{ }; pub use scalar::func::{self, BinaryFunc, UnaryFunc, UnmaterializableFunc, VariadicFunc}; pub use scalar::{ - Columns, Eval, EvalError, FilterCharacteristics, MirScalarExpr, ProtoDomainLimit, - ProtoEvalError, like_pattern, + Columns, Eval, EvalError, FilterCharacteristics, MirScalarExpr, OptimizableExpr, + ProtoDomainLimit, ProtoEvalError, like_pattern, }; /// A [`MirRelationExpr`] that claims to have been optimized, e.g., by an diff --git a/src/expr/src/linear.rs b/src/expr/src/linear.rs index 220bd02a55837..67ffdd92a3e04 100644 --- a/src/expr/src/linear.rs +++ b/src/expr/src/linear.rs @@ -12,7 +12,7 @@ use std::fmt::Display; use mz_repr::{Datum, Row}; use serde::{Deserialize, Serialize}; -use crate::scalar::columns::Columns; +use crate::scalar::optimizable::OptimizableExpr; use crate::visit::Visit; use crate::{MirRelationExpr, MirScalarExpr}; @@ -41,12 +41,13 @@ use crate::{MirRelationExpr, MirScalarExpr}; Ord, PartialOrd )] -pub struct MapFilterProject { +#[serde(bound(deserialize = "E: serde::de::DeserializeOwned"))] +pub struct MapFilterProject { /// A sequence of expressions that should be appended to the row. /// /// Many of these expressions may not be produced in the output, /// and may only be present as common subexpressions. - pub expressions: Vec, + pub expressions: Vec, /// Expressions that must evaluate to `Datum::True` for the output /// row to be produced. /// @@ -57,7 +58,7 @@ pub struct MapFilterProject { /// guarded evaluation of predicates. /// /// This list should be sorted by the first field. - pub predicates: Vec<(usize, MirScalarExpr)>, + pub predicates: Vec<(usize, E)>, /// A sequence of column identifiers whose data form the output row. pub projection: Vec, /// The expected number of input columns. @@ -67,7 +68,7 @@ pub struct MapFilterProject { pub input_arity: usize, } -impl Display for MapFilterProject { +impl Display for MapFilterProject { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { writeln!(f, "MapFilterProject(")?; writeln!(f, " expressions:")?; @@ -85,7 +86,7 @@ impl Display for MapFilterProject { } } -impl MapFilterProject { +impl MapFilterProject { /// Create a no-op operator for an input of a supplied arity. pub fn new(input_arity: usize) -> Self { Self { @@ -129,7 +130,7 @@ impl MapFilterProject { /// If fine manipulation is required, the predicates can be added manually. pub fn filter(mut self, predicates: I) -> Self where - I: IntoIterator, + I: IntoIterator, { for mut predicate in predicates { // Correct column references. @@ -157,14 +158,14 @@ impl MapFilterProject { // We put literal errors at the end as a stop-gap to avoid erroring // before we are able to evaluate any predicates that might prevent it. self.predicates - .sort_by_key(|(position, predicate)| (predicate.is_literal_err(), *position)); + .sort_by_key(|(position, predicate)| (E::is_literal_err(predicate), *position)); self } /// Append the result of evaluating expressions to each row. pub fn map(mut self, expressions: I) -> Self where - I: IntoIterator, + I: IntoIterator, { for mut expression in expressions { // Correct column references. @@ -188,7 +189,7 @@ impl MapFilterProject { } /// Like [`MapFilterProject::as_map_filter_project`], but consumes `self` rather than cloning. - pub fn into_map_filter_project(self) -> (Vec, Vec, Vec) { + pub fn into_map_filter_project(self) -> (Vec, Vec, Vec) { let predicates = self .predicates .into_iter() @@ -201,10 +202,13 @@ impl MapFilterProject { /// /// In principle, this operator can be implemented as a sequence of /// more elemental operators, likely less efficiently. - pub fn as_map_filter_project(&self) -> (Vec, Vec, Vec) { + pub fn as_map_filter_project(&self) -> (Vec, Vec, Vec) { self.clone().into_map_filter_project() } +} +// Methods that are specific to MirScalarExpr (use MirRelationExpr or as_literal). +impl MapFilterProject { /// Determines if a scalar expression must be equal to a literal datum. pub fn literal_constraint(&self, expr: &MirScalarExpr) -> Option> { for (_pos, predicate) in self.predicates.iter() { @@ -394,12 +398,14 @@ impl MapFilterProject { x => Self::new(x.arity()), } } +} +impl MapFilterProject { /// Returns `true` if any predicate in this MFP contains a temporal expression (`mz_now()`). pub fn has_temporal_predicates(&self) -> bool { self.predicates .iter() - .any(|(_, predicate)| predicate.contains_temporal()) + .any(|(_, predicate)| OptimizableExpr::contains_temporal(predicate)) } /// Extracts temporal predicates into their own `Self`. @@ -419,12 +425,17 @@ impl MapFilterProject { // Assert that we no longer have temporal expressions to evaluate. This should only // occur if the optimization above results with temporal expressions yielded in the // output, which is out of spec for how the type is meant to be used. - assert!(!self.expressions.iter().any(|e| e.contains_temporal())); + assert!( + !self + .expressions + .iter() + .any(|e| OptimizableExpr::contains_temporal(e)) + ); // Extract temporal predicates from `self.predicates`. let mut temporal_predicates = Vec::new(); self.predicates.retain(|(_position, predicate)| { - if predicate.contains_temporal() { + if OptimizableExpr::contains_temporal(predicate) { temporal_predicates.push(predicate.clone()); false } else { @@ -559,7 +570,7 @@ impl MapFilterProject { } } // As before, but easier: predicates in common to all mfps. - let common_preds: Vec = { + let common_preds: Vec = { let input_arity = result_mfp.projection.len(); let mut prev: BTreeSet<_> = mfps[0] .predicates @@ -629,12 +640,12 @@ impl MapFilterProject { /// /// The main behavior is extract temporal predicates, which cannot be evaluated /// using the standard machinery. - pub fn into_plan(self) -> Result { + pub fn into_plan(self) -> Result, String> { plan::MfpPlan::create_from(self) } } -impl MapFilterProject { +impl MapFilterProject { /// Partitions `self` into two instances, one of which can be eagerly applied. /// /// The `available` argument indicates which input columns are available (keys) @@ -741,8 +752,8 @@ impl MapFilterProject { // we are certainly making sub-optimal decisions by pushing down all available // work. // TODO(mcsherry): establish better principles about what work to push down. - let is_available = - expr.support().into_iter().all(|i| available_expr[i]) && !expr.is_literal(); + let is_available = expr.support().into_iter().all(|i| available_expr[i]) + && !OptimizableExpr::is_literal(&expr); if is_available { before_expr.push(expr); } else { @@ -916,7 +927,7 @@ impl MapFilterProject { } // Optimization routines. -impl MapFilterProject { +impl MapFilterProject { /// Optimize the internal expression evaluation order. /// /// This method performs several optimizations that are meant to streamline @@ -989,22 +1000,10 @@ impl MapFilterProject { for (index, expr) in self.expressions.iter_mut().enumerate() { // If `expr` matches a filter equating it to a column < index + input_arity, rewrite it for (_, predicate) in self.predicates.iter() { - if let MirScalarExpr::CallBinary { - func: crate::BinaryFunc::Eq(_), - expr1, - expr2, - } = predicate + if let Some(col) = + E::equality_column_alias(predicate, expr, index + self.input_arity) { - if let MirScalarExpr::Column(c, name) = &**expr1 { - if *c < index + self.input_arity && &**expr2 == expr { - *expr = MirScalarExpr::Column(*c, name.clone()); - } - } - if let MirScalarExpr::Column(c, name) = &**expr2 { - if *c < index + self.input_arity && &**expr1 == expr { - *expr = MirScalarExpr::Column(*c, name.clone()); - } - } + *expr = col; } } } @@ -1033,8 +1032,15 @@ impl MapFilterProject { /// Total expression sizes across all expressions. pub fn size(&self) -> usize { - self.expressions.iter().map(|e| e.size()).sum::() - + self.predicates.iter().map(|(_, e)| e.size()).sum::() + self.expressions + .iter() + .map(|e| OptimizableExpr::size(e)) + .sum::() + + self + .predicates + .iter() + .map(|(_, e)| OptimizableExpr::size(e)) + .sum::() } /// Place each certainly evaluated expression in its own column. @@ -1253,18 +1259,20 @@ impl MapFilterProject { let mut reference_count = vec![0; input_arity + self.expressions.len()]; // Increment reference counts for each use for expr in self.expressions.iter() { - expr.visit_pre(|e| { - if let MirScalarExpr::Column(i, _name) = e { - reference_count[*i] += 1; + expr.visit_pre(&mut |e| { + if let Some(i) = e.as_column() { + reference_count[i] += 1; } - }); + }) + .expect("visit_pre hit recursion limit"); } for (_, pred) in self.predicates.iter() { - pred.visit_pre(|e| { - if let MirScalarExpr::Column(i, _name) = e { - reference_count[*i] += 1; + pred.visit_pre(&mut |e| { + if let Some(i) = e.as_column() { + reference_count[i] += 1; } - }); + }) + .expect("visit_pre hit recursion limit"); } for proj in self.projection.iter() { reference_count[*proj] += 1; @@ -1275,7 +1283,8 @@ impl MapFilterProject { for expr in self.expressions.iter() { // An express may contain a temporal expression, or reference a column containing such. is_temporal.push( - expr.contains_temporal() || expr.support().into_iter().any(|col| is_temporal[col]), + OptimizableExpr::contains_temporal(expr) + || expr.support().into_iter().any(|col| is_temporal[col]), ); } @@ -1284,7 +1293,7 @@ impl MapFilterProject { // or 2c. reference temporal expressions (which cannot be evaluated). let mut should_inline = vec![false; reference_count.len()]; for i in (input_arity..reference_count.len()).rev() { - if let MirScalarExpr::Column(c, _) = self.expressions[i - input_arity] { + if let Some(c) = self.expressions[i - input_arity].as_column() { should_inline[i] = true; // The reference count of the referenced column should be // incremented with the number of references @@ -1301,7 +1310,7 @@ impl MapFilterProject { // We can only inline column references in `self.projection`, but we should. for proj in self.projection.iter_mut() { if *proj >= self.input_arity { - if let MirScalarExpr::Column(i, _) = self.expressions[*proj - self.input_arity] { + if let Some(i) = self.expressions[*proj - self.input_arity].as_column() { // TODO(mgree) !!! propagate name information to projection *proj = i; } @@ -1314,25 +1323,26 @@ impl MapFilterProject { pub fn perform_inlining(&mut self, should_inline: Vec) { for index in 0..self.expressions.len() { let (prior, expr) = self.expressions.split_at_mut(index); - #[allow(deprecated)] - expr[0].visit_mut_post_nolimit(&mut |e| { - if let MirScalarExpr::Column(i, _name) = e { - if should_inline[*i] { - *e = prior[*i - self.input_arity].clone(); + expr[0] + .visit_mut_post(&mut |e| { + if let Some(i) = e.as_column() { + if should_inline[i] { + *e = prior[i - self.input_arity].clone(); + } } - } - }); + }) + .expect("inlining hit recursion limit"); } for (_index, pred) in self.predicates.iter_mut() { let expressions = &self.expressions; - #[allow(deprecated)] - pred.visit_mut_post_nolimit(&mut |e| { - if let MirScalarExpr::Column(i, _name) = e { - if should_inline[*i] { - *e = expressions[*i - self.input_arity].clone(); + pred.visit_mut_post(&mut |e| { + if let Some(i) = e.as_column() { + if should_inline[i] { + *e = expressions[i - self.input_arity].clone(); } } - }); + }) + .expect("inlining hit recursion limit"); } } @@ -1435,77 +1445,44 @@ impl MapFilterProject { /// `(input_arity + pos)`, where `pos` is the position of the memoized part in /// `memoized_parts`, and `input_arity` is the arity of the input that `expr` /// refers to. -pub fn memoize_expr( - expr: &mut MirScalarExpr, - memoized_parts: &mut Vec, +pub fn memoize_expr( + expr: &mut E, + memoized_parts: &mut Vec, input_arity: usize, ) { #[allow(deprecated)] - expr.visit_mut_pre_post_nolimit( - &mut |e| { - // We should not eagerly memoize `if` branches that might not be taken. - // TODO: Memoize expressions in the intersection of `then` and `els`. - if let MirScalarExpr::If { cond, .. } = e { - return Some(vec![cond]); - } - - // We should not eagerly memoize `COALESCE` expressions after the first, - // as they are only meant to be evaluated if the preceding expressions - // evaluate to NULL. We could memoize any preceding by expressions that - // are certain not to error. - if let MirScalarExpr::CallVariadic { - func: crate::VariadicFunc::Coalesce(_), - exprs, - } = e - { - return Some(exprs.iter_mut().take(1).collect()); - } - - // We should not deconstruct temporal filters, because `MfpPlan::create_from` expects - // those to be in a specific form. However, we _should_ attend to the expression that is - // on the opposite side of mz_now(), because it might be a complex expression in itself, - // and is ok to deconstruct. - if let Some((_func, other_side)) = e.as_mut_temporal_filter().ok() { - return Some(vec![other_side]); - } - - None - }, - &mut |e| { - match e { - MirScalarExpr::Literal(_, _) => { - // Literals do not need to be memoized. - } - MirScalarExpr::Column(col, _) => { - // Column references do not need to be memoized, but may need to be - // updated if they reference a column reference themselves. - if *col > input_arity { - if let MirScalarExpr::Column(col2, _) = memoized_parts[*col - input_arity] { - // We do _not_ propagate column names, since mis-associating names and column - // references will be very confusing (and possibly bug-inducing). - *col = col2; - } - } - } - _ => { - // TODO: OOO (Optimizer Optimization Opportunity): - // we are quadratic in expression size because of this .iter().position - if let Some(position) = memoized_parts.iter().position(|e2| e2 == e) { - // Any complex expression that already exists as a prior column can - // be replaced by a reference to that column. - *e = MirScalarExpr::column(input_arity + position); - } else { - // A complex expression that does not exist should be memoized, and - // replaced by a reference to the column. - memoized_parts.push(std::mem::replace( - e, - MirScalarExpr::column(input_arity + memoized_parts.len()), - )); - } + expr.visit_mut_pre_post(&mut |e| e.eager_children(), &mut |e| { + if E::is_literal(e) { + // Literals do not need to be memoized. + return; + } + if let Some(col) = e.as_column_mut() { + // Column references do not need to be memoized, but may need to be + // updated if they reference a column reference themselves. + if *col > input_arity { + if let Some(col2) = memoized_parts[*col - input_arity].as_column() { + // Update the column index in place, preserving any name information. + *col = col2; } } - }, - ) + return; + } + // TODO: OOO (Optimizer Optimization Opportunity): + // we are quadratic in expression size because of this .iter().position + if let Some(position) = memoized_parts.iter().position(|e2| e2 == e) { + // Any complex expression that already exists as a prior column can + // be replaced by a reference to that column. + *e = E::column(input_arity + position); + } else { + // A complex expression that does not exist should be memoized, and + // replaced by a reference to the column. + memoized_parts.push(std::mem::replace( + e, + E::column(input_arity + memoized_parts.len()), + )); + } + }) + .expect("memoize_expr hit recursion limit"); } pub mod util { @@ -1611,15 +1588,27 @@ pub mod plan { use serde::{Deserialize, Serialize}; use crate::Eval; - use crate::{BinaryFunc, EvalError, MapFilterProject, MirScalarExpr, UnaryFunc, func}; + use crate::scalar::optimizable::OptimizableExpr; + use crate::{EvalError, MapFilterProject, MirScalarExpr}; /// A wrapper type which indicates it is safe to simply evaluate all expressions. #[derive(Clone, Debug, Serialize, Deserialize, Eq, PartialEq, Ord, PartialOrd)] - pub struct SafeMfpPlan { - pub(crate) mfp: MapFilterProject, + #[serde(bound(deserialize = "E: serde::de::DeserializeOwned"))] + pub struct SafeMfpPlan { + pub(crate) mfp: MapFilterProject, } - impl SafeMfpPlan { + impl SafeMfpPlan { + /// Wrap a `MapFilterProject` in a `SafeMfpPlan`. + pub fn from_mfp(mfp: MapFilterProject) -> Self { + Self { mfp } + } + + /// Unwrap the inner `MapFilterProject`. + pub fn into_mfp(self) -> MapFilterProject { + self.mfp + } + /// Remaps references to input columns according to `remap`. /// /// Leaves other column references, e.g. to newly mapped columns, unchanged. @@ -1629,6 +1618,14 @@ pub mod plan { { self.mfp.permute_fn(remap, new_arity); } + + /// Returns true when `Self` is the identity. + pub fn is_identity(&self) -> bool { + self.mfp.is_identity() + } + } + + impl SafeMfpPlan { /// Evaluates the linear operator on a supplied list of datums. /// /// The arguments are the initial datums associated with the row, @@ -1711,15 +1708,10 @@ pub mod plan { self.mfp.predicates.iter().any(|(_pos, e)| e.could_error()) || self.mfp.expressions.iter().any(|e| e.could_error()) } - - /// Returns true when `Self` is the identity. - pub fn is_identity(&self) -> bool { - self.mfp.is_identity() - } } - impl std::ops::Deref for SafeMfpPlan { - type Target = MapFilterProject; + impl std::ops::Deref for SafeMfpPlan { + type Target = MapFilterProject; fn deref(&self) -> &Self::Target { &self.mfp } @@ -1734,17 +1726,42 @@ pub mod plan { /// They must directly constrain `MzNow` from below or above, /// by expressions that do not themselves contain `MzNow`. /// Conjunctions of such constraints are also ok. - #[derive(Clone, Debug, PartialEq)] - pub struct MfpPlan { + #[derive(Clone, Debug, Eq, PartialEq, Ord, PartialOrd, Serialize, Deserialize)] + #[serde(bound(deserialize = "E: serde::de::DeserializeOwned"))] + pub struct MfpPlan { /// Normal predicates to evaluate on `&[Datum]` and expect `Ok(Datum::True)`. - pub(crate) mfp: SafeMfpPlan, - /// Expressions that when evaluated lower-bound `MzNow`. - pub(crate) lower_bounds: Vec, - /// Expressions that when evaluated upper-bound `MzNow`. - pub(crate) upper_bounds: Vec, + pub(crate) mfp: SafeMfpPlan, + /// Expressions that when evaluated lower-bound or equal (<=) `MzNow`. + pub(crate) lower_bounds: Vec, + /// Expressions that when evaluated strictly upper-bound `MzNow`. + pub(crate) upper_bounds: Vec, } - impl MfpPlan { + impl MfpPlan { + /// Construct an `MfpPlan` from its components. + pub fn from_parts(mfp: SafeMfpPlan, lower_bounds: Vec, upper_bounds: Vec) -> Self { + Self { + mfp, + lower_bounds, + upper_bounds, + } + } + + /// Deconstruct into components. + pub fn into_parts(self) -> (SafeMfpPlan, Vec, Vec) { + (self.mfp, self.lower_bounds, self.upper_bounds) + } + + /// Access the inner `SafeMfpPlan`. + pub fn safe_mfp(&self) -> &SafeMfpPlan { + &self.mfp + } + + /// Borrow all parts: `(safe_mfp, lower_bounds, upper_bounds)`. + pub fn as_parts(&self) -> (&SafeMfpPlan, &[E], &[E]) { + (&self.mfp, &self.lower_bounds, &self.upper_bounds) + } + /// Partitions `predicates` into non-temporal, and lower and upper temporal bounds. /// /// The first returned list is of predicates that do not contain `mz_now`. @@ -1758,17 +1775,12 @@ pub mod plan { /// /// If any unsupported expression is found, for example one that uses `mz_now` /// in an unsupported position, an error is returned. - pub fn create_from(mut mfp: MapFilterProject) -> Result { - let mut lower_bounds = Vec::new(); - let mut upper_bounds = Vec::new(); - - let mut temporal = Vec::new(); - - // Optimize, to ensure that temporal predicates are move in to `mfp.predicates`. + pub fn create_from(mut mfp: MapFilterProject) -> Result { mfp.optimize(); + let mut temporal = Vec::new(); mfp.predicates.retain(|(_position, predicate)| { - if predicate.contains_temporal() { + if OptimizableExpr::contains_temporal(predicate) { temporal.push(predicate.clone()); false } else { @@ -1776,39 +1788,7 @@ pub mod plan { } }); - for mut predicate in temporal.into_iter() { - let (func, expr2) = predicate.as_mut_temporal_filter()?; - let expr2 = expr2.clone(); - - // LogicalTimestamp for several supported operators. - match func { - BinaryFunc::Eq(_) => { - lower_bounds.push(expr2.clone()); - upper_bounds.push( - expr2.call_unary(UnaryFunc::StepMzTimestamp(func::StepMzTimestamp)), - ); - } - BinaryFunc::Lt(_) => { - upper_bounds.push(expr2.clone()); - } - BinaryFunc::Lte(_) => { - upper_bounds.push( - expr2.call_unary(UnaryFunc::StepMzTimestamp(func::StepMzTimestamp)), - ); - } - BinaryFunc::Gt(_) => { - lower_bounds.push( - expr2.call_unary(UnaryFunc::StepMzTimestamp(func::StepMzTimestamp)), - ); - } - BinaryFunc::Gte(_) => { - lower_bounds.push(expr2.clone()); - } - _ => { - return Err(format!("Unsupported binary temporal operation: {:?}", func)); - } - } - } + let (lower_bounds, upper_bounds) = E::extract_temporal_bounds(temporal)?; Ok(Self { mfp: SafeMfpPlan { mfp }, @@ -1824,6 +1804,12 @@ pub mod plan { && self.upper_bounds.is_empty() } + /// Returns `true` if the plan contains temporal bounds + /// (i.e., predicates involving `mz_now()`). + pub fn has_temporal_bounds(&self) -> bool { + !self.lower_bounds.is_empty() || !self.upper_bounds.is_empty() + } + /// Returns `self`, and leaves behind an identity operator that acts on its output. pub fn take(&mut self) -> Self { let mut identity = Self { @@ -1841,7 +1827,7 @@ pub mod plan { /// /// If that is not possible, the original instance is returned as an error. #[allow(clippy::result_large_err)] - pub fn into_nontemporal(self) -> Result { + pub fn into_nontemporal(self) -> Result, Self> { if self.lower_bounds.is_empty() && self.upper_bounds.is_empty() { Ok(self.mfp) } else { @@ -1853,7 +1839,7 @@ pub mod plan { /// scalar expressions in the plan. /// /// The order of iteration is unspecified. - pub fn iter_nontemporal_exprs(&mut self) -> impl Iterator { + pub fn iter_nontemporal_exprs(&mut self) -> impl Iterator { iter::empty() .chain(self.mfp.mfp.predicates.iter_mut().map(|(_, expr)| expr)) .chain(&mut self.mfp.mfp.expressions) @@ -1861,6 +1847,19 @@ pub mod plan { .chain(&mut self.upper_bounds) } + /// Indicates that `Self` ignores its input to the extent that it can be evaluated on `&[]`. + /// + /// At the moment, this is only true if it projects away all columns and applies no filters, + /// but it could be extended to plans that produce literals independent of the input. + pub fn ignores_input(&self) -> bool { + self.lower_bounds.is_empty() + && self.upper_bounds.is_empty() + && self.mfp.mfp.projection.is_empty() + && self.mfp.mfp.predicates.is_empty() + } + } + + impl MfpPlan { /// Evaluate the predicates, temporal and non-, and return times and differences for `data`. /// /// If `self` contains only non-temporal predicates, the result will either be `(time, diff)`, @@ -1869,7 +1868,7 @@ pub mod plan { /// /// The `row_builder` is not cleared first, but emptied if the function /// returns an iterator with any `Ok(_)` element. - pub fn evaluate<'b, 'a: 'b, E: From, V: Fn(&mz_repr::Timestamp) -> bool>( + pub fn evaluate<'b, 'a: 'b, Err: From, V: Fn(&mz_repr::Timestamp) -> bool>( &'a self, datums: &'b mut Vec>, arena: &'a RowArena, @@ -1878,8 +1877,8 @@ pub mod plan { valid_time: V, row_builder: &mut Row, ) -> impl Iterator< - Item = Result<(Row, mz_repr::Timestamp, Diff), (E, mz_repr::Timestamp, Diff)>, - > + use { + Item = Result<(Row, mz_repr::Timestamp, Diff), (Err, mz_repr::Timestamp, Diff)>, + > + use { match self.mfp.evaluate_inner(datums, arena) { Err(e) => { return Some(Err((e.into(), time, diff))).into_iter().chain(None); @@ -1983,16 +1982,5 @@ pub mod plan { || self.lower_bounds.iter().any(|e| e.could_error()) || self.upper_bounds.iter().any(|e| e.could_error()) } - - /// Indicates that `Self` ignores its input to the extent that it can be evaluated on `&[]`. - /// - /// At the moment, this is only true if it projects away all columns and applies no filters, - /// but it could be extended to plans that produce literals independent of the input. - pub fn ignores_input(&self) -> bool { - self.lower_bounds.is_empty() - && self.upper_bounds.is_empty() - && self.mfp.mfp.projection.is_empty() - && self.mfp.mfp.predicates.is_empty() - } } } diff --git a/src/expr/src/scalar.rs b/src/expr/src/scalar.rs index 62632ec9b9ce0..690cb5332f4a8 100644 --- a/src/expr/src/scalar.rs +++ b/src/expr/src/scalar.rs @@ -42,6 +42,7 @@ pub use crate::scalar::columns::Columns; pub use crate::scalar::eval::Eval; use crate::scalar::func::variadic::{And, Or}; use crate::scalar::func::{BinaryFunc, UnaryFunc, UnmaterializableFunc, VariadicFunc}; +pub use crate::scalar::optimizable::OptimizableExpr; use crate::scalar::proto_eval_error::proto_incompatible_array_dimensions::ProtoDims; use crate::visit::{Visit, VisitChildren}; @@ -49,6 +50,7 @@ pub mod columns; pub mod eval; pub mod func; pub mod like_pattern; +pub mod optimizable; mod reduce; include!(concat!(env!("OUT_DIR"), "/mz_expr.scalar.rs")); @@ -1166,39 +1168,6 @@ impl MirScalarExpr { } } -impl Columns for MirScalarExpr { - fn is_column(&self) -> bool { - matches!(self, MirScalarExpr::Column(_col, _name)) - } - - fn as_column(&self) -> Option { - if let MirScalarExpr::Column(c, _) = self { - Some(*c) - } else { - None - } - } - - fn support_into(&self, support: &mut BTreeSet) { - self.visit_pre(|e| { - if let MirScalarExpr::Column(i, _) = e { - support.insert(*i); - } - }); - } - - fn visit_columns(&mut self, mut action: F) - where - F: FnMut(&mut usize), - { - self.visit_pre_mut(|e| { - if let MirScalarExpr::Column(col, _) = e { - action(col); - } - }); - } -} - impl Eval for MirScalarExpr { fn eval<'a>( &'a self, @@ -1255,6 +1224,51 @@ impl Eval for MirScalarExpr { } } +impl Columns for MirScalarExpr { + fn column(c: usize) -> Self { + MirScalarExpr::column(c) + } + + fn is_column(&self) -> bool { + matches!(self, MirScalarExpr::Column(_col, _name)) + } + + fn as_column(&self) -> Option { + if let MirScalarExpr::Column(c, _) = self { + Some(*c) + } else { + None + } + } + + fn as_column_mut(&mut self) -> Option<&mut usize> { + if let MirScalarExpr::Column(c, _) = self { + Some(c) + } else { + None + } + } + + fn support_into(&self, support: &mut BTreeSet) { + self.visit_pre(|e| { + if let MirScalarExpr::Column(i, _) = e { + support.insert(*i); + } + }); + } + + fn visit_columns(&mut self, mut action: F) + where + F: FnMut(&mut usize), + { + self.visit_pre_mut(|e| { + if let MirScalarExpr::Column(col, _) = e { + action(col); + } + }); + } +} + impl VisitChildren for MirScalarExpr { fn visit_children(&self, mut f: F) where diff --git a/src/expr/src/scalar/columns.rs b/src/expr/src/scalar/columns.rs index 3b2e87a9fec0b..afb2ba5e62878 100644 --- a/src/expr/src/scalar/columns.rs +++ b/src/expr/src/scalar/columns.rs @@ -12,12 +12,18 @@ use std::collections::{BTreeMap, BTreeSet}; pub trait Columns: Sized { + /// Construct a column reference expression. + fn column(c: usize) -> Self; + /// True when the outermost structure is a column. fn is_column(&self) -> bool; /// If self is a column, return the column index, otherwise `None`. fn as_column(&self) -> Option; + /// If self is a column, return a mutable reference to the column index. + fn as_column_mut(&mut self) -> Option<&mut usize>; + /// The support of the given set, i.e., the columns that are actually used. /// /// You can use `BTreeSet::last()` to extract the maximum column. diff --git a/src/expr/src/scalar/func/binary.rs b/src/expr/src/scalar/func/binary.rs index 6eef91ccbbb5c..a15ee653c2e60 100644 --- a/src/expr/src/scalar/func/binary.rs +++ b/src/expr/src/scalar/func/binary.rs @@ -12,8 +12,7 @@ use mz_ore::assert_none; use mz_repr::{Datum, InputDatumType, OutputDatumType, ReprColumnType, RowArena, SqlColumnType}; -use crate::Eval; -use crate::EvalError; +use crate::{Eval, EvalError}; /// A description of an SQL binary function that has the ability to lazy evaluate its arguments // This trait will eventually be annotated with #[enum_dispatch] to autogenerate the UnaryFunc enum diff --git a/src/expr/src/scalar/func/impls/array.rs b/src/expr/src/scalar/func/impls/array.rs index 776a44fffb871..6bdf0d41416c4 100644 --- a/src/expr/src/scalar/func/impls/array.rs +++ b/src/expr/src/scalar/func/impls/array.rs @@ -15,9 +15,8 @@ use mz_repr::adt::array::{Array, ArrayDimension}; use mz_repr::{Datum, DatumList, Row, RowArena, RowPacker, SqlColumnType, SqlScalarType}; use serde::{Deserialize, Serialize}; -use crate::Eval; use crate::scalar::func::{LazyUnaryFunc, stringify_datum}; -use crate::{EvalError, MirScalarExpr}; +use crate::{Eval, EvalError, MirScalarExpr}; #[sqlfunc( sqlname = "arraytolist", diff --git a/src/expr/src/scalar/func/impls/case_literal.rs b/src/expr/src/scalar/func/impls/case_literal.rs index 910373342080c..3653e146a9320 100644 --- a/src/expr/src/scalar/func/impls/case_literal.rs +++ b/src/expr/src/scalar/func/impls/case_literal.rs @@ -25,9 +25,8 @@ use mz_lowertest::MzReflect; use mz_repr::{Datum, Row, RowArena, SqlColumnType}; use serde::{Deserialize, Serialize}; -use crate::Eval; -use crate::EvalError; use crate::scalar::func::variadic::LazyVariadicFunc; +use crate::{Eval, EvalError}; /// A single entry in a [`CaseLiteral`] lookup table: a literal `Row` value /// paired with the index of the corresponding result expression in `exprs`. diff --git a/src/expr/src/scalar/func/impls/list.rs b/src/expr/src/scalar/func/impls/list.rs index f18c7d38d0aa2..6a9eb7279a828 100644 --- a/src/expr/src/scalar/func/impls/list.rs +++ b/src/expr/src/scalar/func/impls/list.rs @@ -14,10 +14,9 @@ use mz_lowertest::MzReflect; use mz_repr::{AsColumnType, Datum, DatumList, Row, RowArena, SqlColumnType, SqlScalarType}; use serde::{Deserialize, Serialize}; -use crate::Eval; use crate::func::binary::EagerBinaryFunc; use crate::scalar::func::{LazyUnaryFunc, stringify_datum}; -use crate::{EvalError, MirScalarExpr}; +use crate::{Eval, EvalError, MirScalarExpr}; #[derive( Ord, diff --git a/src/expr/src/scalar/func/impls/map.rs b/src/expr/src/scalar/func/impls/map.rs index 69f9166fc446b..2664120bc2610 100644 --- a/src/expr/src/scalar/func/impls/map.rs +++ b/src/expr/src/scalar/func/impls/map.rs @@ -15,9 +15,8 @@ use mz_lowertest::MzReflect; use mz_repr::{Datum, DatumMap, RowArena, SqlColumnType, SqlScalarType}; use serde::{Deserialize, Serialize}; -use crate::Eval; -use crate::EvalError; use crate::scalar::func::{LazyUnaryFunc, stringify_datum}; +use crate::{Eval, EvalError}; #[derive( Ord, diff --git a/src/expr/src/scalar/func/impls/range.rs b/src/expr/src/scalar/func/impls/range.rs index 1c9f84c123d90..5e18fee7a0f2b 100644 --- a/src/expr/src/scalar/func/impls/range.rs +++ b/src/expr/src/scalar/func/impls/range.rs @@ -15,9 +15,8 @@ use mz_repr::adt::range::Range; use mz_repr::{Datum, RowArena, SqlColumnType, SqlScalarType}; use serde::{Deserialize, Serialize}; -use crate::Eval; -use crate::EvalError; use crate::scalar::func::{LazyUnaryFunc, stringify_datum}; +use crate::{Eval, EvalError}; #[derive( Ord, diff --git a/src/expr/src/scalar/func/impls/record.rs b/src/expr/src/scalar/func/impls/record.rs index aeb36da1ef161..2e246a7444b95 100644 --- a/src/expr/src/scalar/func/impls/record.rs +++ b/src/expr/src/scalar/func/impls/record.rs @@ -14,9 +14,8 @@ use mz_lowertest::MzReflect; use mz_repr::{Datum, RowArena, SqlColumnType, SqlScalarType}; use serde::{Deserialize, Serialize}; -use crate::Eval; use crate::scalar::func::{LazyUnaryFunc, stringify_datum}; -use crate::{EvalError, MirScalarExpr}; +use crate::{Eval, EvalError, MirScalarExpr}; #[derive( Ord, diff --git a/src/expr/src/scalar/func/variadic.rs b/src/expr/src/scalar/func/variadic.rs index 62a7f5eaac4c5..4fde21a9e5ed3 100644 --- a/src/expr/src/scalar/func/variadic.rs +++ b/src/expr/src/scalar/func/variadic.rs @@ -40,14 +40,12 @@ use mz_repr::{ }; use serde::{Deserialize, Serialize}; -use crate::Eval; -use crate::func::CaseLiteral; use crate::func::{ - MAX_STRING_FUNC_RESULT_BYTES, array_create_scalar, build_regex, date_bin, parse_timezone, - regexp_match_static, regexp_replace_parse_flags, regexp_split_to_array_re, stringify_datum, - timezone_time, + CaseLiteral, MAX_STRING_FUNC_RESULT_BYTES, array_create_scalar, build_regex, date_bin, + parse_timezone, regexp_match_static, regexp_replace_parse_flags, regexp_split_to_array_re, + stringify_datum, timezone_time, }; -use crate::{EvalError, MirScalarExpr}; +use crate::{Eval, EvalError, MirScalarExpr}; use mz_repr::adt::date::Date; use mz_repr::adt::interval::Interval; use mz_repr::adt::jsonb::JsonbRef; diff --git a/src/expr/src/scalar/optimizable.rs b/src/expr/src/scalar/optimizable.rs new file mode 100644 index 0000000000000..cf805b07518fd --- /dev/null +++ b/src/expr/src/scalar/optimizable.rs @@ -0,0 +1,163 @@ +// Copyright Materialize, Inc. and contributors. All rights reserved. +// +// Use of this software is governed by the Business Source License +// included in the LICENSE file. +// +// As of the Change Date specified in that file, in accordance with +// the Business Source License, use of this software will be governed +// by the Apache License, Version 2.0. + +//! A trait for scalar expressions that can be optimized inside a `MapFilterProject`. +//! +//! This trait is implemented by both `MirScalarExpr` and `LirScalarExpr`, +//! allowing `MapFilterProject` to be parameterized over either. + +use std::fmt::Debug; +use std::hash::Hash; + +use serde::Serialize; + +use crate::scalar::columns::Columns; +use crate::scalar::func::{BinaryFunc, UnaryFunc, VariadicFunc}; +use crate::visit::VisitChildren; +use crate::{MirScalarExpr, func}; + +/// A scalar expression type that can be optimized inside a `MapFilterProject`. +/// +/// Implemented by `MirScalarExpr` and `LirScalarExpr`. +pub trait OptimizableExpr: + Columns + VisitChildren + Clone + Eq + Ord + Hash + Debug + Sized + Serialize +{ + /// True if this expression is a literal. + fn is_literal(&self) -> bool; + + /// True if this expression is a literal error. + fn is_literal_err(&self) -> bool; + + /// True if this expression contains a temporal reference (`mz_now()`). + fn contains_temporal(&self) -> bool; + + /// Count of AST nodes in the expression tree. + fn size(&self) -> usize; + + /// For memoization: which children should be eagerly memoized? + /// + /// Returns `None` to visit all children (the common case). + /// Returns `Some(children)` for selective descent — e.g., for `If`, only the + /// condition should be eagerly memoized (branches may not be taken). + fn eager_children(&mut self) -> Option>; + + /// If `predicate` is `col = expr` (or `expr = col`) where `col` is a column + /// with index < `threshold`, return a clone of that column expression. + /// + /// Used by `optimize()` to detect equality-derived column aliases. + fn equality_column_alias(predicate: &Self, expr: &Self, threshold: usize) -> Option; + + /// Extract temporal bounds from a list of temporal predicates. + /// + /// Returns `(lower_bounds, upper_bounds)` for use in `MfpPlan`. + fn extract_temporal_bounds(temporal: Vec) -> Result<(Vec, Vec), String>; +} + +impl OptimizableExpr for MirScalarExpr { + fn is_literal(&self) -> bool { + self.is_literal() + } + + fn is_literal_err(&self) -> bool { + self.is_literal_err() + } + + fn contains_temporal(&self) -> bool { + self.contains_temporal() + } + + fn size(&self) -> usize { + self.size() + } + + fn eager_children(&mut self) -> Option> { + // Do not eagerly memoize `if` branches that might not be taken. + if let MirScalarExpr::If { cond, .. } = self { + return Some(vec![cond]); + } + + // Do not eagerly memoize `COALESCE` expressions after the first, + // as they are only meant to be evaluated if the preceding expressions + // evaluate to NULL. + if let MirScalarExpr::CallVariadic { + func: VariadicFunc::Coalesce(_), + exprs, + } = self + { + return Some(exprs.iter_mut().take(1).collect()); + } + + // Do not deconstruct temporal filters, because `MfpPlan::create_from` expects + // those to be in a specific form. However, attend to the expression on the + // opposite side of mz_now(). + if let Ok((_func, other_side)) = self.as_mut_temporal_filter() { + return Some(vec![other_side]); + } + + None + } + + fn equality_column_alias(predicate: &Self, expr: &Self, threshold: usize) -> Option { + if let MirScalarExpr::CallBinary { + func: BinaryFunc::Eq(_), + expr1, + expr2, + } = predicate + { + if let MirScalarExpr::Column(c, name) = &**expr1 { + if *c < threshold && &**expr2 == expr { + return Some(MirScalarExpr::Column(*c, name.clone())); + } + } + if let MirScalarExpr::Column(c, name) = &**expr2 { + if *c < threshold && &**expr1 == expr { + return Some(MirScalarExpr::Column(*c, name.clone())); + } + } + } + None + } + + fn extract_temporal_bounds(temporal: Vec) -> Result<(Vec, Vec), String> { + let mut lower_bounds = Vec::new(); + let mut upper_bounds = Vec::new(); + + for mut predicate in temporal.into_iter() { + let (func, expr2) = predicate.as_mut_temporal_filter()?; + let expr2 = expr2.clone(); + + match func { + BinaryFunc::Eq(_) => { + lower_bounds.push(expr2.clone()); + upper_bounds + .push(expr2.call_unary(UnaryFunc::StepMzTimestamp(func::StepMzTimestamp))); + } + BinaryFunc::Lt(_) => { + upper_bounds.push(expr2.clone()); + } + BinaryFunc::Lte(_) => { + upper_bounds + .push(expr2.call_unary(UnaryFunc::StepMzTimestamp(func::StepMzTimestamp))); + } + BinaryFunc::Gt(_) => { + lower_bounds + .push(expr2.call_unary(UnaryFunc::StepMzTimestamp(func::StepMzTimestamp))); + } + BinaryFunc::Gte(_) => { + lower_bounds.push(expr2.clone()); + } + _ => { + return Err(format!("Unsupported binary temporal operation: {:?}", func)); + } + } + } + + Ok((lower_bounds, upper_bounds)) + } +} diff --git a/src/repr/src/explain.rs b/src/repr/src/explain.rs index 931e06511417d..71db6e1f393ba 100644 --- a/src/repr/src/explain.rs +++ b/src/repr/src/explain.rs @@ -649,11 +649,23 @@ where I: Iterator + Clone; pub trait ScalarOps { + /// If this expression is a column-reference, return the column referenced. fn match_col_ref(&self) -> Option; + /// Returns true if this expression is a reference to the given column. fn references(&self, col_ref: usize) -> bool; } +impl ScalarOps for usize { + fn match_col_ref(&self) -> Option { + Some(*self) + } + + fn references(&self, col_ref: usize) -> bool { + *self == col_ref + } +} + /// A somewhat ad-hoc way to keep carry a plan with a set /// of analyses derived for each node in that plan. #[allow(missing_debug_implementations)] From 14f93698dab41b1dcf7f63bf4c2645e037bfe1b4 Mon Sep 17 00:00:00 2001 From: Michael Greenberg Date: Mon, 1 Jun 2026 17:08:37 -0400 Subject: [PATCH 2/2] override visit_pre to avoid stack unsafety for MIR --- src/expr/src/scalar/optimizable.rs | 19 ++++++++++++++++++- 1 file changed, 18 insertions(+), 1 deletion(-) diff --git a/src/expr/src/scalar/optimizable.rs b/src/expr/src/scalar/optimizable.rs index cf805b07518fd..84e220ef04edb 100644 --- a/src/expr/src/scalar/optimizable.rs +++ b/src/expr/src/scalar/optimizable.rs @@ -15,11 +15,12 @@ use std::fmt::Debug; use std::hash::Hash; +use mz_ore::stack::RecursionLimitError; use serde::Serialize; use crate::scalar::columns::Columns; use crate::scalar::func::{BinaryFunc, UnaryFunc, VariadicFunc}; -use crate::visit::VisitChildren; +use crate::visit::{Visit, VisitChildren}; use crate::{MirScalarExpr, func}; /// A scalar expression type that can be optimized inside a `MapFilterProject`. @@ -57,6 +58,14 @@ pub trait OptimizableExpr: /// /// Returns `(lower_bounds, upper_bounds)` for use in `MfpPlan`. fn extract_temporal_bounds(temporal: Vec) -> Result<(Vec, Vec), String>; + + /// Visit in a pre-traversal. Defaults to the `Visit` implementation, but overridable. + fn visit_pre(&self, f: &mut F) -> Result<(), RecursionLimitError> + where + F: FnMut(&Self), + { + Visit::visit_pre(self, f) + } } impl OptimizableExpr for MirScalarExpr { @@ -160,4 +169,12 @@ impl OptimizableExpr for MirScalarExpr { Ok((lower_bounds, upper_bounds)) } + + fn visit_pre(&self, f: &mut F) -> Result<(), RecursionLimitError> + where + F: FnMut(&Self), + { + self.visit_pre(f); + Ok(()) + } }