refactor: Change all Visitors to be iterative, child-based#36852
Conversation
4947ffe to
9b39e43
Compare
9b39e43 to
9ab62e9
Compare
73bffaa to
a21f03a
Compare
antiguru
left a comment
There was a problem hiding this comment.
Seems fine! Left some comments inline.
| // The `T`-typed children of this element. | ||
| fn children(&self) -> Vec<&T>; | ||
|
|
||
| // The `&mut T`-typed children of this element. | ||
| fn children_mut(&mut self) -> Vec<&mut T>; |
There was a problem hiding this comment.
Could we return impl DoubleEndedIterator<Item=&(mut) T> here? This would avoid the otherwise required allocation.
There was a problem hiding this comment.
Maybe? When we're working with HIR, this includes the direct subqueries in MSEs... I ran into issues with just impl Iterator<...>. I can try!
| /// NB that any trait with post-traversal uses unsafe code---any shenanigans | ||
| /// in the visitor could result in unsoundness. Ordinary rust code in the visitor |
There was a problem hiding this comment.
Make this tighter: Any safe code is safe, only unsafe code can cause unsoundness.
| enum VisitAction<'a, T> { | ||
| Enter(&'a T), | ||
| Leave(&'a T), | ||
| } |
There was a problem hiding this comment.
Please document what the variants mean, just with a brief comment.
|
I prototyped changing Notes from the exercise:
Review comments on the PR itself:
DoubleEndedIterator diff (apply on top of this branch)diff --git i/src/expr/src/relation.rs w/src/expr/src/relation.rs
index 62021926ec..87ba88cc6b 100644
--- i/src/expr/src/relation.rs
+++ w/src/expr/src/relation.rs
@@ -2401,12 +2401,18 @@ impl VisitChildren<Self> for MirRelationExpr {
Ok(())
}
- fn children(&self) -> Vec<&MirRelationExpr> {
- self.children().collect()
+ fn children<'a>(&'a self) -> impl DoubleEndedIterator<Item = &'a MirRelationExpr>
+ where
+ MirRelationExpr: 'a,
+ {
+ self.children()
}
- fn children_mut(&mut self) -> Vec<&mut MirRelationExpr> {
- self.children_mut().collect()
+ fn children_mut<'a>(&'a mut self) -> impl DoubleEndedIterator<Item = &'a mut MirRelationExpr>
+ where
+ MirRelationExpr: 'a,
+ {
+ self.children_mut()
}
}
diff --git i/src/expr/src/scalar.rs w/src/expr/src/scalar.rs
index 3ff3d859bc..df923281d8 100644
--- i/src/expr/src/scalar.rs
+++ w/src/expr/src/scalar.rs
@@ -1366,12 +1366,18 @@ impl VisitChildren<Self> for MirScalarExpr {
Ok(())
}
- fn children(&self) -> Vec<&Self> {
- self.children().collect()
+ fn children<'a>(&'a self) -> impl DoubleEndedIterator<Item = &'a Self>
+ where
+ Self: 'a,
+ {
+ self.children()
}
- fn children_mut(&mut self) -> Vec<&mut Self> {
- self.children_mut().collect()
+ fn children_mut<'a>(&'a mut self) -> impl DoubleEndedIterator<Item = &'a mut Self>
+ where
+ Self: 'a,
+ {
+ self.children_mut()
}
}
diff --git i/src/expr/src/visit.rs w/src/expr/src/visit.rs
index 61c6a3e79a..ab6635da10 100644
--- i/src/expr/src/visit.rs
+++ w/src/expr/src/visit.rs
@@ -59,7 +59,7 @@ pub trait VisitChildren<T> {
where
F: FnMut(&T),
{
- self.children().into_iter().for_each(f);
+ self.children().for_each(f);
}
/// Apply an infallible mutable function `f` to each direct child.
@@ -67,7 +67,7 @@ pub trait VisitChildren<T> {
where
F: FnMut(&mut T),
{
- self.children_mut().into_iter().for_each(f);
+ self.children_mut().for_each(f);
}
/// Apply a fallible immutable function `f` to each direct child.
@@ -105,13 +105,17 @@ pub trait VisitChildren<T> {
}
/// The `T`-typed children of this element.
- fn children(&self) -> Vec<&T>;
+ fn children<'a>(&'a self) -> impl DoubleEndedIterator<Item = &'a T>
+ where
+ T: 'a;
/// The `&mut T`-typed children of this element.
///
/// It is critical for the safety of mutable post-order traversals that this
/// function be written using safe code.
- fn children_mut(&mut self) -> Vec<&mut T>;
+ fn children_mut<'a>(&'a mut self) -> impl DoubleEndedIterator<Item = &'a mut T>
+ where
+ T: 'a;
}
/// A trait for types that can recursively visit their children of the
@@ -328,7 +332,7 @@ impl<T: VisitChildren<T>> Visit for T {
Enter(elt) => {
stack.push(Leave(elt));
// Push children in reverse so they pop (and are visited) left-to-right.
- stack.extend(elt.children().into_iter().rev().map(Enter));
+ stack.extend(elt.children().rev().map(Enter));
}
Leave(elt) => f(elt),
}
@@ -365,12 +369,7 @@ impl<T: VisitChildren<T>> Visit for T {
stack.push(Leave(ptr));
let elt = unsafe { &mut *ptr };
// Push children in reverse so they pop (and are visited) left-to-right.
- stack.extend(
- elt.children_mut()
- .into_iter()
- .rev()
- .map(|child| Enter(child as *mut T)),
- );
+ stack.extend(elt.children_mut().rev().map(|child| Enter(child as *mut T)));
}
Leave(elt) => f(unsafe { &mut *elt }),
}
@@ -398,7 +397,7 @@ impl<T: VisitChildren<T>> Visit for T {
Enter(elt) => {
stack.push(Leave(elt));
// Push children in reverse so they pop (and are visited) left-to-right.
- stack.extend(elt.children().into_iter().rev().map(Enter));
+ stack.extend(elt.children().rev().map(Enter));
}
Leave(elt) => f(elt)?,
}
@@ -428,12 +427,7 @@ impl<T: VisitChildren<T>> Visit for T {
stack.push(Leave(ptr));
let elt = unsafe { &mut *ptr };
// Push children in reverse so they pop (and are visited) left-to-right.
- stack.extend(
- elt.children_mut()
- .into_iter()
- .rev()
- .map(|child| Enter(child as *mut T)),
- );
+ stack.extend(elt.children_mut().rev().map(|child| Enter(child as *mut T)));
}
Leave(ptr) => f(unsafe { &mut *ptr })?,
}
@@ -450,7 +444,7 @@ impl<T: VisitChildren<T>> Visit for T {
while let Some(elt) = stack.pop() {
f(elt);
// Push children in reverse so they pop (and are visited) left-to-right.
- stack.extend(elt.children().into_iter().rev());
+ stack.extend(elt.children().rev());
}
Ok(())
@@ -472,12 +466,7 @@ impl<T: VisitChildren<T>> Visit for T {
visitor(&ctx, elt);
let ctx = acc_fun(ctx, elt);
// Push children in reverse so they pop (and are visited) left-to-right.
- stack.extend(
- elt.children()
- .into_iter()
- .rev()
- .map(|child| (child, ctx.clone())),
- );
+ stack.extend(elt.children().rev().map(|child| (child, ctx.clone())));
}
Ok(())
@@ -498,7 +487,7 @@ impl<T: VisitChildren<T>> Visit for T {
while let Some(elt) = stack.pop() {
f(elt);
// Push children in reverse so they pop (and are visited) left-to-right.
- stack.extend(elt.children_mut().into_iter().rev());
+ stack.extend(elt.children_mut().rev());
}
Ok(())
@@ -519,7 +508,7 @@ impl<T: VisitChildren<T>> Visit for T {
while let Some(elt) = stack.pop() {
f(elt)?;
// Push children in reverse so they pop (and are visited) left-to-right.
- stack.extend(elt.children().into_iter().rev());
+ stack.extend(elt.children().rev());
}
Ok(())
@@ -533,7 +522,7 @@ impl<T: VisitChildren<T>> Visit for T {
while let Some(elt) = stack.pop() {
f(elt)?;
// Push children in reverse so they pop (and are visited) left-to-right.
- stack.extend(elt.children_mut().into_iter().rev());
+ stack.extend(elt.children_mut().rev());
}
Ok(())
@@ -548,10 +537,12 @@ impl<T: VisitChildren<T>> Visit for T {
while let Some(action) = stack.pop() {
match action {
Enter(elt) => {
- let children = pre(elt).unwrap_or_else(|| elt.children());
+ let children = pre(elt);
stack.push(Leave(elt));
- for child in children.into_iter().rev() {
- stack.push(Enter(child));
+ // Push children in reverse so they pop (and are visited) left-to-right.
+ match children {
+ Some(children) => stack.extend(children.into_iter().rev().map(Enter)),
+ None => stack.extend(elt.children().rev().map(Enter)),
}
}
Leave(elt) => {
@@ -596,16 +587,25 @@ impl<T: VisitChildren<T>> Visit for T {
match action {
Enter(ptr) => {
let elt = unsafe { &mut *ptr };
- let children: Vec<&mut T> = match pre(elt) {
- Some(explicit) => explicit,
- None => {
- let elt = unsafe { &mut *ptr };
- elt.children_mut()
- }
- };
+ let explicit = pre(elt);
stack.push(Leave(ptr));
- for child in children.into_iter().rev() {
- stack.push(Enter(child));
+ // Push children in reverse so they pop (and are visited) left-to-right.
+ match explicit {
+ Some(children) => {
+ stack.extend(
+ children
+ .into_iter()
+ .rev()
+ .map(|child| Enter(child as *mut T)),
+ );
+ }
+ None => {
+ // Retake the pointer: `pre` may have replaced the node wholesale.
+ let elt = unsafe { &mut *ptr };
+ stack.extend(
+ elt.children_mut().rev().map(|child| Enter(child as *mut T)),
+ );
+ }
}
}
Leave(ptr) => {
@@ -909,7 +909,10 @@ mod tests {
Ok(())
}
- fn children(&self) -> Vec<&A> {
+ fn children<'a>(&'a self) -> impl DoubleEndedIterator<Item = &'a A>
+ where
+ A: 'a,
+ {
let mut v: Vec<&A> = vec![];
match self {
@@ -923,10 +926,13 @@ mod tests {
}
}
- v
+ v.into_iter()
}
- fn children_mut(&mut self) -> Vec<&mut A> {
+ fn children_mut<'a>(&'a mut self) -> impl DoubleEndedIterator<Item = &'a mut A>
+ where
+ A: 'a,
+ {
let mut v: Vec<&mut A> = vec![];
match self {
@@ -940,7 +946,7 @@ mod tests {
}
}
- v
+ v.into_iter()
}
}
@@ -989,18 +995,26 @@ mod tests {
}
}
- fn children(&self) -> Vec<&B> {
+ fn children<'a>(&'a self) -> impl DoubleEndedIterator<Item = &'a B>
+ where
+ B: 'a,
+ {
match self {
A::Add(_, _) | A::Lit(_) => vec![],
- A::FrB(b) => vec![&*b],
+ A::FrB(b) => vec![&**b],
}
+ .into_iter()
}
- fn children_mut(&mut self) -> Vec<&mut B> {
+ fn children_mut<'a>(&'a mut self) -> impl DoubleEndedIterator<Item = &'a mut B>
+ where
+ B: 'a,
+ {
match self {
A::Add(_, _) | A::Lit(_) => vec![],
A::FrB(b) => vec![&mut **b],
}
+ .into_iter()
}
}
@@ -1093,7 +1107,10 @@ mod tests {
Ok(())
}
- fn children(&self) -> Vec<&B> {
+ fn children<'a>(&'a self) -> impl DoubleEndedIterator<Item = &'a B>
+ where
+ B: 'a,
+ {
let mut v: Vec<&B> = vec![];
match self {
B::Mul(lhs, rhs) => {
@@ -1103,10 +1120,13 @@ mod tests {
B::Lit(_) => (),
B::FrA(a) => v.append(&mut a.direct_sub_b()),
}
- v
+ v.into_iter()
}
- fn children_mut(&mut self) -> Vec<&mut B> {
+ fn children_mut<'a>(&'a mut self) -> impl DoubleEndedIterator<Item = &'a mut B>
+ where
+ B: 'a,
+ {
let mut v: Vec<&mut B> = vec![];
match self {
B::Mul(lhs, rhs) => {
@@ -1116,7 +1136,7 @@ mod tests {
B::Lit(_) => (),
B::FrA(a) => v.append(&mut a.direct_sub_b_mut()),
}
- v
+ v.into_iter()
}
}
@@ -1165,18 +1185,26 @@ mod tests {
}
}
- fn children(&self) -> Vec<&A> {
+ fn children<'a>(&'a self) -> impl DoubleEndedIterator<Item = &'a A>
+ where
+ A: 'a,
+ {
match self {
B::Mul(_, _) | B::Lit(_) => vec![],
- B::FrA(a) => vec![&*a],
+ B::FrA(a) => vec![&**a],
}
+ .into_iter()
}
- fn children_mut(&mut self) -> Vec<&mut A> {
+ fn children_mut<'a>(&'a mut self) -> impl DoubleEndedIterator<Item = &'a mut A>
+ where
+ A: 'a,
+ {
match self {
B::Mul(_, _) | B::Lit(_) => vec![],
B::FrA(a) => vec![&mut **a],
}
+ .into_iter()
}
}
diff --git i/src/sql/src/plan/hir.rs w/src/sql/src/plan/hir.rs
index 0f95512d53..4655bfa8f2 100644
--- i/src/sql/src/plan/hir.rs
+++ w/src/sql/src/plan/hir.rs
@@ -382,28 +382,34 @@ impl VisitChildren<HirScalarExpr> for WindowExpr {
Ok(())
}
- fn children(&self) -> Vec<&HirScalarExpr> {
- let mut v = self.func.children();
+ fn children<'a>(&'a self) -> impl DoubleEndedIterator<Item = &'a HirScalarExpr>
+ where
+ HirScalarExpr: 'a,
+ {
+ let mut v: Vec<&HirScalarExpr> = self.func.children().collect();
for c in self.partition_by.iter() {
- v.append(&mut c.children())
+ v.extend(VisitChildren::<HirScalarExpr>::children(c));
}
for c in self.order_by.iter() {
- v.append(&mut c.children());
+ v.extend(VisitChildren::<HirScalarExpr>::children(c));
}
- v
+ v.into_iter()
}
- fn children_mut(&mut self) -> Vec<&mut HirScalarExpr> {
- let mut v = self.func.children_mut();
+ fn children_mut<'a>(&'a mut self) -> impl DoubleEndedIterator<Item = &'a mut HirScalarExpr>
+ where
+ HirScalarExpr: 'a,
+ {
+ let mut v: Vec<&mut HirScalarExpr> = self.func.children_mut().collect();
for c in self.partition_by.iter_mut() {
- v.append(&mut c.children_mut())
+ v.extend(VisitChildren::<HirScalarExpr>::children_mut(c));
}
for c in self.order_by.iter_mut() {
- v.append(&mut c.children_mut());
+ v.extend(VisitChildren::<HirScalarExpr>::children_mut(c));
}
- v
+ v.into_iter()
}
}
@@ -528,20 +534,28 @@ impl VisitChildren<HirScalarExpr> for WindowExprType {
}
}
- fn children(&self) -> Vec<&HirScalarExpr> {
+ fn children<'a>(&'a self) -> impl DoubleEndedIterator<Item = &'a HirScalarExpr>
+ where
+ HirScalarExpr: 'a,
+ {
match self {
Self::Scalar(_) => vec![],
- Self::Value(expr) => expr.children(),
- Self::Aggregate(expr) => expr.children(),
+ Self::Value(expr) => expr.children().collect(),
+ Self::Aggregate(expr) => expr.children().collect(),
}
+ .into_iter()
}
- fn children_mut(&mut self) -> Vec<&mut HirScalarExpr> {
+ fn children_mut<'a>(&'a mut self) -> impl DoubleEndedIterator<Item = &'a mut HirScalarExpr>
+ where
+ HirScalarExpr: 'a,
+ {
match self {
Self::Scalar(_) => vec![],
- Self::Value(expr) => expr.children_mut(),
- Self::Aggregate(expr) => expr.children_mut(),
+ Self::Value(expr) => expr.children_mut().collect(),
+ Self::Aggregate(expr) => expr.children_mut().collect(),
}
+ .into_iter()
}
}
@@ -755,12 +769,18 @@ impl VisitChildren<HirScalarExpr> for ValueWindowExpr {
f(&mut self.args)
}
- fn children(&self) -> Vec<&HirScalarExpr> {
- self.args.children()
+ fn children<'a>(&'a self) -> impl DoubleEndedIterator<Item = &'a HirScalarExpr>
+ where
+ HirScalarExpr: 'a,
+ {
+ VisitChildren::<HirScalarExpr>::children(&*self.args)
}
- fn children_mut(&mut self) -> Vec<&mut HirScalarExpr> {
- self.args.children_mut()
+ fn children_mut<'a>(&'a mut self) -> impl DoubleEndedIterator<Item = &'a mut HirScalarExpr>
+ where
+ HirScalarExpr: 'a,
+ {
+ VisitChildren::<HirScalarExpr>::children_mut(&mut *self.args)
}
}
@@ -949,12 +969,18 @@ impl VisitChildren<HirScalarExpr> for AggregateWindowExpr {
f(&mut self.aggregate_expr.expr)
}
- fn children(&self) -> Vec<&HirScalarExpr> {
- self.aggregate_expr.expr.children()
+ fn children<'a>(&'a self) -> impl DoubleEndedIterator<Item = &'a HirScalarExpr>
+ where
+ HirScalarExpr: 'a,
+ {
+ VisitChildren::<HirScalarExpr>::children(&*self.aggregate_expr.expr)
}
- fn children_mut(&mut self) -> Vec<&mut HirScalarExpr> {
- self.aggregate_expr.expr.children_mut()
+ fn children_mut<'a>(&'a mut self) -> impl DoubleEndedIterator<Item = &'a mut HirScalarExpr>
+ where
+ HirScalarExpr: 'a,
+ {
+ VisitChildren::<HirScalarExpr>::children_mut(&mut *self.aggregate_expr.expr)
}
}
@@ -2914,7 +2940,10 @@ impl VisitChildren<Self> for HirRelationExpr {
Ok(())
}
- fn children(&self) -> Vec<&Self> {
+ fn children<'a>(&'a self) -> impl DoubleEndedIterator<Item = &'a Self>
+ where
+ Self: 'a,
+ {
// we visit subqueries _first_, then the input
let mut v: Vec<&HirRelationExpr> = vec![];
use HirRelationExpr::*;
@@ -2998,12 +3027,15 @@ impl VisitChildren<Self> for HirRelationExpr {
}
}
- v
+ v.into_iter()
}
- fn children_mut(&mut self) -> Vec<&mut Self> {
+ fn children_mut<'a>(&'a mut self) -> impl DoubleEndedIterator<Item = &'a mut Self>
+ where
+ Self: 'a,
+ {
// we visit subqueries _first_, then the input
- let mut v = vec![];
+ let mut v: Vec<&mut HirRelationExpr> = vec![];
use HirRelationExpr::*;
match self {
Constant { rows: _, typ: _ } | Get { id: _, typ: _ } => (),
@@ -3085,7 +3117,7 @@ impl VisitChildren<Self> for HirRelationExpr {
}
}
- v
+ v.into_iter()
}
}
@@ -3403,7 +3435,10 @@ impl VisitChildren<HirScalarExpr> for HirRelationExpr {
Ok(())
}
- fn children(&self) -> Vec<&HirScalarExpr> {
+ fn children<'a>(&'a self) -> impl DoubleEndedIterator<Item = &'a HirScalarExpr>
+ where
+ HirScalarExpr: 'a,
+ {
use HirRelationExpr::*;
match self {
Constant { rows: _, typ: _ }
@@ -3455,11 +3490,18 @@ impl VisitChildren<HirScalarExpr> for HirRelationExpr {
limit,
offset,
expected_group_size: _,
- } => limit.iter().chain(std::iter::once(offset)).collect(),
+ } => limit
+ .iter()
+ .chain(std::iter::once(offset))
+ .collect::<Vec<_>>(),
}
+ .into_iter()
}
- fn children_mut(&mut self) -> Vec<&mut HirScalarExpr> {
+ fn children_mut<'a>(&'a mut self) -> impl DoubleEndedIterator<Item = &'a mut HirScalarExpr>
+ where
+ HirScalarExpr: 'a,
+ {
use HirRelationExpr::*;
match self {
Constant { rows: _, typ: _ }
@@ -3511,8 +3553,12 @@ impl VisitChildren<HirScalarExpr> for HirRelationExpr {
limit,
offset,
expected_group_size: _,
- } => limit.iter_mut().chain(std::iter::once(offset)).collect(),
+ } => limit
+ .iter_mut()
+ .chain(std::iter::once(offset))
+ .collect::<Vec<_>>(),
}
+ .into_iter()
}
}
@@ -4535,9 +4581,12 @@ impl VisitChildren<Self> for HirScalarExpr {
Ok(())
}
- fn children(&self) -> Vec<&Self> {
+ fn children<'a>(&'a self) -> impl DoubleEndedIterator<Item = &'a Self>
+ where
+ Self: 'a,
+ {
use HirScalarExpr::*;
- match self {
+ let v: Vec<&Self> = match self {
Column(..) | Parameter(..) | Literal(..) | CallUnmaterializable(..) => vec![],
CallUnary { expr, .. } => vec![&*expr],
CallBinary { expr1, expr2, .. } => {
@@ -4553,11 +4602,15 @@ impl VisitChildren<Self> for HirScalarExpr {
vec![&*cond, &*then, &*els]
}
Exists(..) | Select(..) => vec![],
- Windowing(expr, _name) => expr.children(),
- }
+ Windowing(expr, _name) => expr.children().collect(),
+ };
+ v.into_iter()
}
- fn children_mut(&mut self) -> Vec<&mut Self> {
+ fn children_mut<'a>(&'a mut self) -> impl DoubleEndedIterator<Item = &'a mut Self>
+ where
+ Self: 'a,
+ {
use HirScalarExpr::*;
match self {
Column(..) | Parameter(..) | Literal(..) | CallUnmaterializable(..) => vec![],
@@ -4575,8 +4628,9 @@ impl VisitChildren<Self> for HirScalarExpr {
vec![&mut **cond, &mut **then, &mut **els]
}
Exists(..) | Select(..) => vec![],
- Windowing(expr, _name) => expr.children_mut(),
+ Windowing(expr, _name) => expr.children_mut().collect(),
}
+ .into_iter()
}
}
@@ -4661,9 +4715,12 @@ impl VisitChildren<HirRelationExpr> for HirScalarExpr {
Ok(())
}
- fn children(&self) -> Vec<&HirRelationExpr> {
+ fn children<'a>(&'a self) -> impl DoubleEndedIterator<Item = &'a HirRelationExpr>
+ where
+ HirRelationExpr: 'a,
+ {
use HirScalarExpr::*;
- match self {
+ let v: Vec<&HirRelationExpr> = match self {
Column(..)
| Parameter(..)
| Literal(..)
@@ -4674,10 +4731,14 @@ impl VisitChildren<HirRelationExpr> for HirScalarExpr {
| If { .. }
| Windowing(..) => vec![],
Exists(expr, _name) | Select(expr, _name) => vec![&*expr],
- }
+ };
+ v.into_iter()
}
- fn children_mut(&mut self) -> Vec<&mut HirRelationExpr> {
+ fn children_mut<'a>(&'a mut self) -> impl DoubleEndedIterator<Item = &'a mut HirRelationExpr>
+ where
+ HirRelationExpr: 'a,
+ {
use HirScalarExpr::*;
match self {
Column(..)
@@ -4691,6 +4752,7 @@ impl VisitChildren<HirRelationExpr> for HirScalarExpr {
| Windowing(..) => vec![],
Exists(expr, _name) | Select(expr, _name) => vec![&mut **expr],
}
+ .into_iter()
}
}
|
Motivation
#36759 (comment)
https://github.com/MaterializeInc/database-issues/issues/3516
https://github.com/MaterializeInc/database-issues/issues/3733
https://github.com/MaterializeInc/database-issues/issues/9996
Description
Changes
VisitChildrento havechildren()andchildren_mut()functions.Changes
Visitorto use these for iterative traversals.NB that mutable post-traversal requires unsafety: we need an
&mutfor the children and an&mutfor the parent when we're done. This is sound---Rust allows it when your stack is the call stack, but not your own data structure. So: this is somewhat unsavory code.Verification
Green CI.
NB maintaing safety in
VisitChildrenmeans slightly changing visit order: if we're going to work viachildren_mut(), we can no longer yield all subqueries before the term itself. This seems to have affected precisely one SLT, which I've rewritten.