diff --git a/datafusion/common/src/scalar/mod.rs b/datafusion/common/src/scalar/mod.rs index 3e154b491eda7..c2f2c0e00e6a5 100644 --- a/datafusion/common/src/scalar/mod.rs +++ b/datafusion/common/src/scalar/mod.rs @@ -23,7 +23,7 @@ mod struct_builder; use std::borrow::Borrow; use std::cmp::Ordering; -use std::collections::{HashSet, VecDeque}; +use std::collections::{HashMap, HashSet, VecDeque}; use std::convert::Infallible; use std::fmt; use std::fmt::Write; @@ -4753,6 +4753,18 @@ impl ScalarValue { .sum::() } + /// Estimates [size](Self::size) of [`HashMap`] keyed by [`ScalarValue`] in bytes. + /// + /// Includes the size of the [`HashMap`] container itself. Heap payload of + /// `V` is not accounted for; callers storing heap-backed values should + /// supplement this estimate. + #[allow(clippy::allow_attributes, clippy::mutable_key_type)] // ScalarValue has interior mutability but is intentionally used as hash key + pub fn size_of_hashmap(map: &HashMap) -> usize { + size_of_val(map) + + ((size_of::() + size_of::()) * map.capacity()) + + map.keys().map(|k| k.size() - size_of_val(k)).sum::() + } + /// Compacts the allocation referenced by `self` to the minimum, copying the data if /// necessary. /// diff --git a/datafusion/functions-aggregate/src/array_agg.rs b/datafusion/functions-aggregate/src/array_agg.rs index 33c48f8bb725d..8ed3fbf8c3d26 100644 --- a/datafusion/functions-aggregate/src/array_agg.rs +++ b/datafusion/functions-aggregate/src/array_agg.rs @@ -18,7 +18,7 @@ //! `ARRAY_AGG` aggregate implementation: [`ArrayAgg`] use std::cmp::Ordering; -use std::collections::{HashSet, VecDeque}; +use std::collections::{HashMap, VecDeque}; use std::mem::{size_of, size_of_val, take}; use std::sync::Arc; @@ -34,7 +34,9 @@ use datafusion_common::cast::as_list_array; use datafusion_common::utils::{ SingleRowListArrayBuilder, compare_rows, get_row_at_idx, take_function_args, }; -use datafusion_common::{Result, ScalarValue, assert_eq_or_internal_err, exec_err}; +use datafusion_common::{ + Result, ScalarValue, assert_eq_or_internal_err, exec_err, internal_err, +}; use datafusion_expr::function::{AccumulatorArgs, StateFieldsArgs}; use datafusion_expr::utils::format_state_name; use datafusion_expr::{ @@ -814,7 +816,10 @@ impl GroupsAccumulator for ArrayAggGroupsAccumulator { #[derive(Debug)] pub struct DistinctArrayAggAccumulator { - values: HashSet, + // Value → live refcount. Multiset state lets `retract_batch` correctly + // drop a duplicate occurrence while keeping the key alive if other + // copies remain in the current window frame. + values: HashMap, datatype: DataType, sort_options: Option, ignore_nulls: bool, @@ -827,7 +832,7 @@ impl DistinctArrayAggAccumulator { ignore_nulls: bool, ) -> Result { Ok(Self { - values: HashSet::new(), + values: HashMap::new(), datatype: datatype.clone(), sort_options, ignore_nulls, @@ -856,8 +861,8 @@ impl Accumulator for DistinctArrayAggAccumulator { if nulls.is_none_or(|nulls| nulls.null_count() < val.len()) { for i in 0..val.len() { if nulls.is_none_or(|nulls| nulls.is_valid(i)) { - self.values - .insert(ScalarValue::try_from_array(val, i)?.compacted()); + let key = ScalarValue::try_from_array(val, i)?.compacted(); + *self.values.entry(key).or_insert(0) += 1; } } } @@ -872,6 +877,12 @@ impl Accumulator for DistinctArrayAggAccumulator { assert_eq_or_internal_err!(states.len(), 1, "expects single state"); + // The DISTINCT state schema is `List` — partial accumulators + // ship the set of values they saw, not multiplicities. Re-ingesting + // each element here makes the merged counts represent "partitions + // that emitted this value," which is fine because `evaluate` only + // reads keys. Refcount semantics for retract are only valid within + // a single accumulator instance (window execution). states[0] .as_list::() .iter() @@ -880,7 +891,7 @@ impl Accumulator for DistinctArrayAggAccumulator { } fn evaluate(&mut self) -> Result { - let mut values: Vec = self.values.iter().cloned().collect(); + let mut values: Vec = self.values.keys().cloned().collect(); if values.is_empty() { return Ok(ScalarValue::new_null_list(self.datatype.clone(), true, 1)); } @@ -916,8 +927,50 @@ impl Accumulator for DistinctArrayAggAccumulator { Ok(ScalarValue::List(arr)) } + fn retract_batch(&mut self, values: &[ArrayRef]) -> Result<()> { + if values.is_empty() { + return Ok(()); + } + + assert_eq_or_internal_err!(values.len(), 1, "expects single batch"); + + let val = &values[0]; + let nulls = if self.ignore_nulls { + val.logical_nulls() + } else { + None + }; + let nulls = nulls.as_ref(); + + for i in 0..val.len() { + if nulls.is_some_and(|nulls| !nulls.is_valid(i)) { + continue; + } + let key = ScalarValue::try_from_array(val, i)?; + match self.values.get_mut(&key) { + Some(count) => { + *count -= 1; + if *count == 0 { + self.values.remove(&key); + } + } + None => { + return internal_err!( + "DistinctArrayAggAccumulator::retract_batch: value not present in state" + ); + } + } + } + + Ok(()) + } + + fn supports_retract_batch(&self) -> bool { + true + } + fn size(&self) -> usize { - size_of_val(self) + ScalarValue::size_of_hashset(&self.values) + size_of_val(self) + ScalarValue::size_of_hashmap(&self.values) - size_of_val(&self.values) + self.datatype.size() - size_of_val(&self.datatype) @@ -1494,8 +1547,8 @@ mod tests { acc2.update_batch(&[string_list_data([vec!["e", "f", "g"]])])?; acc1 = merge(acc1, acc2)?; - // without compaction, the size is 16660 - assert_eq!(acc1.size(), 1660); + // without compaction, the size is 16684 + assert_eq!(acc1.size(), 1684); Ok(()) } @@ -2415,4 +2468,126 @@ mod tests { Ok(()) } + + // ---- DistinctArrayAggAccumulator retract_batch tests ---- + + // Build a DISTINCT accumulator with ascending sort so evaluate output is + // deterministic regardless of HashMap iteration order. + fn distinct_acc(ignore_nulls: bool) -> Result { + DistinctArrayAggAccumulator::try_new( + &DataType::Utf8, + Some(SortOptions::default()), + ignore_nulls, + ) + } + + #[test] + fn distinct_retract_duplicate_remains() -> Result<()> { + // Canonical regression for the HashSet-can't-retract bug: a value + // that appears multiple times in-frame must survive retraction of + // a single occurrence. + let mut acc = distinct_acc(false)?; + + // Feed [A, A, B] across two batches to exercise multi-batch state. + acc.update_batch(&[data(["A", "A"])])?; + acc.update_batch(&[data(["B"])])?; + assert_eq!(print_nulls(str_arr(acc.evaluate()?)?), vec!["A", "B"]); + + // Retract a single A — the other A is still in the frame. + acc.retract_batch(&[data(["A"])])?; + assert_eq!(print_nulls(str_arr(acc.evaluate()?)?), vec!["A", "B"]); + + // Retract the remaining A — only B left. + acc.retract_batch(&[data(["A"])])?; + assert_eq!(print_nulls(str_arr(acc.evaluate()?)?), vec!["B"]); + + Ok(()) + } + + #[test] + fn distinct_retract_full_removal() -> Result<()> { + let mut acc = distinct_acc(false)?; + + acc.update_batch(&[data(["A", "B"])])?; + acc.retract_batch(&[data(["A", "B"])])?; + + let result = acc.evaluate()?; + assert!( + matches!(&result, ScalarValue::List(arr) if arr.is_null(0)), + "expected null list after full retract, got {result:?}" + ); + + Ok(()) + } + + #[test] + fn distinct_retract_ignore_nulls_skips() -> Result<()> { + // ignore_nulls=true: NULL never enters state on update, so retract + // must also skip NULL — otherwise we'd error on the missing key. + let mut acc = distinct_acc(true)?; + + acc.update_batch(&[data([Some("A"), None, Some("B")])])?; + assert_eq!(print_nulls(str_arr(acc.evaluate()?)?), vec!["A", "B"]); + + // Retract [A, NULL] — the NULL is skipped, only A is removed. + acc.retract_batch(&[data([Some("A"), None])])?; + assert_eq!(print_nulls(str_arr(acc.evaluate()?)?), vec!["B"]); + + Ok(()) + } + + #[test] + fn distinct_retract_null_tracked() -> Result<()> { + // ignore_nulls=false: NULL enters state with a refcount and must + // retract symmetrically; the NULL key must be removed at zero + // (else evaluate still emits a NULL element). + let mut acc = distinct_acc(false)?; + + acc.update_batch(&[data([Some("A"), None, None])])?; + // With nulls_first=true (SortOptions default), NULL sorts before A. + assert_eq!(print_nulls(str_arr(acc.evaluate()?)?), vec!["NULL", "A"]); + + // Retract one NULL — count drops to 1, key still present. + acc.retract_batch(&[data::, 1>([None])])?; + assert_eq!(print_nulls(str_arr(acc.evaluate()?)?), vec!["NULL", "A"]); + + // Retract the remaining NULL — key is removed. + acc.retract_batch(&[data::, 1>([None])])?; + assert_eq!(print_nulls(str_arr(acc.evaluate()?)?), vec!["A"]); + + Ok(()) + } + + #[test] + fn distinct_supports_retract_batch() -> Result<()> { + let acc = distinct_acc(false)?; + assert!(acc.supports_retract_batch()); + + let acc_ignore = distinct_acc(true)?; + assert!(acc_ignore.supports_retract_batch()); + + Ok(()) + } + + #[test] + fn distinct_merge_then_evaluate_regression() -> Result<()> { + // Non-window path: state -> merge_batch -> evaluate must still + // produce the union of distinct values across partitions. + let mut acc1 = distinct_acc(false)?; + let mut acc2 = distinct_acc(false)?; + + acc1.update_batch(&[data(["A", "A", "B"])])?; + acc2.update_batch(&[data(["A", "C"])])?; + + let state = acc2.state()?; + let state_arrs: Vec = state + .into_iter() + .map(|sv| sv.to_array_of_size(1)) + .collect::>>()?; + acc1.merge_batch(&state_arrs)?; + + assert_eq!(print_nulls(str_arr(acc1.evaluate()?)?), vec!["A", "B", "C"]); + + Ok(()) + } } diff --git a/datafusion/sqllogictest/test_files/array_agg_sliding_window.slt b/datafusion/sqllogictest/test_files/array_agg_sliding_window.slt index 78d48513a6656..6f0712e2a6929 100644 --- a/datafusion/sqllogictest/test_files/array_agg_sliding_window.slt +++ b/datafusion/sqllogictest/test_files/array_agg_sliding_window.slt @@ -168,6 +168,233 @@ FROM t_nulls; [C] [C, E] +####### +# DISTINCT sliding window tests +# Validates retract_batch implementation on DistinctArrayAggAccumulator. +# DataFusion rejects `array_agg(... ORDER BY ...)` inside window functions, +# so we wrap with array_sort to make output deterministic +# (HashMap iteration order otherwise). +####### + +statement ok +CREATE TABLE t_dist(ts INT, val TEXT) AS VALUES + (1,'A'),(2,'A'),(3,'B'),(4,'C'),(5,'B'); + +# Duplicate stays in frame after partial retract. +# Frame contents per row (ts=1..5): +# [A] -> {A} +# [A,A] -> {A} (A appears twice, still distinct {A}) +# [A,A,B] -> {A,B} +# [A,B,C] -> {A,B,C} (one A retracted, one A remains) +# [B,C,B] -> {B,C} (last A retracted, B duplicate stays) +query ? +SELECT array_sort(array_agg(DISTINCT val) + OVER (ORDER BY ts ROWS BETWEEN 2 PRECEDING AND CURRENT ROW)) +FROM t_dist; +---- +[A] +[A] +[A, B] +[A, B, C] +[B, C] + +# Narrower ROWS frame +query ? +SELECT array_sort(array_agg(DISTINCT val) + OVER (ORDER BY ts ROWS BETWEEN 1 PRECEDING AND CURRENT ROW)) +FROM t_dist; +---- +[A] +[A] +[A, B] +[B, C] +[B, C] + +# DESC window ORDER BY: frame walks input in reverse temporal order, so +# update/retract are called against the reversed row stream. Validates +# retract still tracks duplicates correctly when rows arrive in DESC order. +# Output rows are emitted in ts DESC order (ts=5,4,3,2,1). +# ts=5 (B): frame [B] -> {B} +# ts=4 (C): frame [B,C] -> {B,C} (1 preceding in DESC = ts=5) +# ts=3 (B): frame [C,B] -> {B,C} +# ts=2 (A): frame [B,A] -> {A,B} +# ts=1 (A): frame [A,A] -> {A} (duplicate A in frame) +query ? +SELECT array_sort(array_agg(DISTINCT val) + OVER (ORDER BY ts DESC ROWS BETWEEN 1 PRECEDING AND CURRENT ROW)) +FROM t_dist; +---- +[B] +[B, C] +[B, C] +[A, B] +[A] + +# RANGE frame with value gaps -> multi-row retract on shift +statement ok +CREATE TABLE t_dist_range(ts INT, val TEXT) AS VALUES + (1,'A'),(2,'A'),(3,'B'),(10,'A'),(11,'C'); + +query ? +SELECT array_sort(array_agg(DISTINCT val) + OVER (ORDER BY ts RANGE BETWEEN 2 PRECEDING AND CURRENT ROW)) +FROM t_dist_range; +---- +[A] +[A] +[A, B] +[A] +[A, C] + +# DISTINCT + IGNORE NULLS in sliding frame: nulls never enter state. +statement ok +CREATE TABLE t_dist_nulls(ts INT, val TEXT) AS VALUES + (1,'A'),(2,NULL),(3,'A'),(4,NULL),(5,'B'); + +query ? +SELECT array_sort(array_agg(DISTINCT val) IGNORE NULLS + OVER (ORDER BY ts ROWS BETWEEN 1 PRECEDING AND CURRENT ROW)) +FROM t_dist_nulls; +---- +[A] +[A] +[A] +[A] +[B] + +# DISTINCT without IGNORE NULLS: NULL enters state with a refcount. +# Retract must remove the NULL key when its last occurrence leaves the frame. +# array_sort defaults to ASC NULLS FIRST, so a live NULL sorts ahead of A/B; +# rows with no live NULL have no NULL element. +# ts=1 (A): frame [A] -> {A} sorted [A] +# ts=2 (NULL): frame [A,NULL] -> {A,NULL} sorted [NULL, A] +# ts=3 (A): frame [NULL,A] -> {A,NULL} sorted [NULL, A] +# ts=4 (NULL): frame [A,NULL] -> {A,NULL} sorted [NULL, A] +# (the ts=2 NULL retracts but the ts=4 NULL is still present) +# ts=5 (B): frame [NULL,B] -> {B,NULL} sorted [NULL, B] +query ? +SELECT array_sort(array_agg(DISTINCT val) + OVER (ORDER BY ts ROWS BETWEEN 1 PRECEDING AND CURRENT ROW)) +FROM t_dist_nulls; +---- +[A] +[NULL, A] +[NULL, A] +[NULL, A] +[NULL, B] + +# GROUPS frame with duplicated sort keys: rows tied on the ORDER BY column +# are batched into the same group, so a single shift can update or retract +# multiple rows at once. +statement ok +CREATE TABLE t_dist_groups(ts INT, val TEXT) AS VALUES + (1,'A'),(1,'A'),(2,'B'),(2,'C'),(3,'A'),(3,'D'); + +query ? +SELECT array_sort(array_agg(DISTINCT val) + OVER (ORDER BY ts GROUPS BETWEEN 1 PRECEDING AND CURRENT ROW)) +FROM t_dist_groups; +---- +[A] +[A] +[A, B, C] +[A, B, C] +[A, B, C, D] +[A, B, C, D] + +# PARTITION BY: each partition retracts against its own state only. A leak of +# one partition's state into the next would surface as the next partition's +# first row carrying foreign values, or a retract hitting the +# `value not present in state` internal_err. 'A' lives only in grp 1, 'C' only +# in grp 2, 'B' in both — so leaked grp-1 state would make grp=2/ts=1 emit +# [A, B] instead of [B]. Rows emitted in (grp, ts) order. +# grp 1: ts=1 [A]->{A} ts=2 [A,A]->{A} ts=3 [A,B]->{A,B} +# grp 2: ts=1 [B]->{B} ts=2 [B,C]->{B,C} ts=3 [C,C]->{C} +statement ok +CREATE TABLE t_dist_parts(grp INT, ts INT, val TEXT) AS VALUES + (1,1,'A'),(1,2,'A'),(1,3,'B'), + (2,1,'B'),(2,2,'C'),(2,3,'C'); + +query ? +SELECT array_sort(array_agg(DISTINCT val) + OVER (PARTITION BY grp ORDER BY ts ROWS BETWEEN 1 PRECEDING AND CURRENT ROW)) +FROM t_dist_parts +ORDER BY grp, ts; +---- +[A] +[A] +[A, B] +[B] +[B, C] +[C] + +# Numeric element type: retract must hash and compare Int32 ScalarValues +# correctly (every sibling test uses Utf8). Mirrors the t_dist 2-PRECEDING walk. +# ts=1 [10] -> {10} +# ts=2 [10,10] -> {10} (duplicate, stays distinct {10}) +# ts=3 [10,10,20] -> {10,20} +# ts=4 [10,20,30] -> {10,20,30} (one 10 retracted, one 10 remains) +# ts=5 [20,30,20] -> {20,30} (last 10 retracted, 20 duplicate stays) +statement ok +CREATE TABLE t_dist_int(ts INT, val INT) AS VALUES + (1,10),(2,10),(3,20),(4,30),(5,20); + +query ? +SELECT array_sort(array_agg(DISTINCT val) + OVER (ORDER BY ts ROWS BETWEEN 2 PRECEDING AND CURRENT ROW)) +FROM t_dist_int; +---- +[10] +[10] +[10, 20] +[10, 20, 30] +[20, 30] + +# ORDER BY interaction — window context. +# DataFusion's planner rejects ANY aggregate-level ORDER BY inside a window +# function, so neither the valid (DISTINCT x ORDER BY x) nor the invalid +# (DISTINCT x ORDER BY y) form reaches the DISTINCT-arg-equality validator +# in window context. Both error at planning, but at the window-planner stage. +statement error Aggregate ORDER BY is not implemented for window functions +SELECT array_agg(DISTINCT val ORDER BY val) + OVER (ORDER BY ts ROWS BETWEEN 1 PRECEDING AND CURRENT ROW) +FROM t_dist; + +statement error Aggregate ORDER BY is not implemented for window functions +SELECT array_agg(DISTINCT val ORDER BY ts) + OVER (ORDER BY ts ROWS BETWEEN 1 PRECEDING AND CURRENT ROW) +FROM t_dist; + +# ORDER BY interaction — non-window context (regression for the storage swap). +# The DISTINCT-arg-equality validator must still accept the matching case +# and reject the mismatched case after we changed the underlying state. +query ? +SELECT array_agg(DISTINCT val ORDER BY val) FROM t_dist; +---- +[A, B, C] + +statement error In an aggregate with DISTINCT, ORDER BY expressions must appear in argument list +SELECT array_agg(DISTINCT val ORDER BY ts) FROM t_dist; + +# Result cardinality bounded by frame cardinality (live-key proxy for state growth). +# Set up 100 rows over 50 cycling distinct values, then run a 2-row sliding frame. +# Since `evaluate` returns the live key set verbatim, max(result_length) == 2 +# proves keys are dropped as their last occurrence leaves the frame. A leaky +# retract would let the result balloon toward 50 (all distinct values seen) or +# error at runtime via the `value not present in state` internal_err!. +statement ok +CREATE TABLE t_dist_growth AS + SELECT i AS ts, ('v' || (i % 50)::TEXT) AS val FROM generate_series(1, 100) t(i); + +query I +SELECT max(cardinality(distinct_arr)) FROM ( + SELECT array_agg(DISTINCT val) + OVER (ORDER BY ts ROWS BETWEEN 1 PRECEDING AND CURRENT ROW) AS distinct_arr + FROM t_dist_growth +); +---- +2 + # Cleanup statement ok DROP TABLE t; @@ -182,4 +409,25 @@ statement ok DROP TABLE t_int; statement ok -DROP TABLE t_groups; \ No newline at end of file +DROP TABLE t_groups; + +statement ok +DROP TABLE t_dist; + +statement ok +DROP TABLE t_dist_range; + +statement ok +DROP TABLE t_dist_nulls; + +statement ok +DROP TABLE t_dist_groups; + +statement ok +DROP TABLE t_dist_growth; + +statement ok +DROP TABLE t_dist_parts; + +statement ok +DROP TABLE t_dist_int; \ No newline at end of file