Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 13 additions & 1 deletion datafusion/common/src/scalar/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ mod struct_builder;

use std::borrow::Borrow;
use std::cmp::Ordering;
use std::collections::{HashSet, VecDeque};
use std::collections::{HashMap, HashSet, VecDeque};
use std::convert::Infallible;
use std::fmt;
use std::fmt::Write;
Expand Down Expand Up @@ -4753,6 +4753,18 @@ impl ScalarValue {
.sum::<usize>()
}

/// Estimates [size](Self::size) of [`HashMap`] keyed by [`ScalarValue`] in bytes.
///
/// Includes the size of the [`HashMap`] container itself. Heap payload of
/// `V` is not accounted for; callers storing heap-backed values should
/// supplement this estimate.
#[allow(clippy::allow_attributes, clippy::mutable_key_type)] // ScalarValue has interior mutability but is intentionally used as hash key
pub fn size_of_hashmap<V, S>(map: &HashMap<Self, V, S>) -> usize {
size_of_val(map)
+ ((size_of::<ScalarValue>() + size_of::<V>()) * map.capacity())
+ map.keys().map(|k| k.size() - size_of_val(k)).sum::<usize>()
}

/// Compacts the allocation referenced by `self` to the minimum, copying the data if
/// necessary.
///
Expand Down
195 changes: 185 additions & 10 deletions datafusion/functions-aggregate/src/array_agg.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
//! `ARRAY_AGG` aggregate implementation: [`ArrayAgg`]

use std::cmp::Ordering;
use std::collections::{HashSet, VecDeque};
use std::collections::{HashMap, VecDeque};
use std::mem::{size_of, size_of_val, take};
use std::sync::Arc;

Expand All @@ -34,7 +34,9 @@ use datafusion_common::cast::as_list_array;
use datafusion_common::utils::{
SingleRowListArrayBuilder, compare_rows, get_row_at_idx, take_function_args,
};
use datafusion_common::{Result, ScalarValue, assert_eq_or_internal_err, exec_err};
use datafusion_common::{
Result, ScalarValue, assert_eq_or_internal_err, exec_err, internal_err,
};
use datafusion_expr::function::{AccumulatorArgs, StateFieldsArgs};
use datafusion_expr::utils::format_state_name;
use datafusion_expr::{
Expand Down Expand Up @@ -814,7 +816,10 @@ impl GroupsAccumulator for ArrayAggGroupsAccumulator {

#[derive(Debug)]
pub struct DistinctArrayAggAccumulator {
values: HashSet<ScalarValue>,
// Value → live refcount. Multiset state lets `retract_batch` correctly
// drop a duplicate occurrence while keeping the key alive if other
// copies remain in the current window frame.
values: HashMap<ScalarValue, u64>,
datatype: DataType,
sort_options: Option<SortOptions>,
ignore_nulls: bool,
Expand All @@ -827,7 +832,7 @@ impl DistinctArrayAggAccumulator {
ignore_nulls: bool,
) -> Result<Self> {
Ok(Self {
values: HashSet::new(),
values: HashMap::new(),
datatype: datatype.clone(),
sort_options,
ignore_nulls,
Expand Down Expand Up @@ -856,8 +861,8 @@ impl Accumulator for DistinctArrayAggAccumulator {
if nulls.is_none_or(|nulls| nulls.null_count() < val.len()) {
for i in 0..val.len() {
if nulls.is_none_or(|nulls| nulls.is_valid(i)) {
self.values
.insert(ScalarValue::try_from_array(val, i)?.compacted());
let key = ScalarValue::try_from_array(val, i)?.compacted();
*self.values.entry(key).or_insert(0) += 1;
}
}
}
Expand All @@ -872,6 +877,12 @@ impl Accumulator for DistinctArrayAggAccumulator {

assert_eq_or_internal_err!(states.len(), 1, "expects single state");

// The DISTINCT state schema is `List<value>` — partial accumulators
// ship the set of values they saw, not multiplicities. Re-ingesting
// each element here makes the merged counts represent "partitions
// that emitted this value," which is fine because `evaluate` only
// reads keys. Refcount semantics for retract are only valid within
// a single accumulator instance (window execution).
states[0]
.as_list::<i32>()
.iter()
Expand All @@ -880,7 +891,7 @@ impl Accumulator for DistinctArrayAggAccumulator {
}

fn evaluate(&mut self) -> Result<ScalarValue> {
let mut values: Vec<ScalarValue> = self.values.iter().cloned().collect();
let mut values: Vec<ScalarValue> = self.values.keys().cloned().collect();
if values.is_empty() {
return Ok(ScalarValue::new_null_list(self.datatype.clone(), true, 1));
}
Expand Down Expand Up @@ -916,8 +927,50 @@ impl Accumulator for DistinctArrayAggAccumulator {
Ok(ScalarValue::List(arr))
}

fn retract_batch(&mut self, values: &[ArrayRef]) -> Result<()> {
if values.is_empty() {
return Ok(());
}

assert_eq_or_internal_err!(values.len(), 1, "expects single batch");

let val = &values[0];
let nulls = if self.ignore_nulls {
val.logical_nulls()
} else {
None
};
let nulls = nulls.as_ref();

for i in 0..val.len() {
if nulls.is_some_and(|nulls| !nulls.is_valid(i)) {
continue;
}
let key = ScalarValue::try_from_array(val, i)?.compacted();
match self.values.get_mut(&key) {
Some(count) => {
*count -= 1;
if *count == 0 {
self.values.remove(&key);
}
}
None => {
return internal_err!(
"DistinctArrayAggAccumulator::retract_batch: value not present in state"
);
}
}
}

Ok(())
}

fn supports_retract_batch(&self) -> bool {
true
}

fn size(&self) -> usize {
size_of_val(self) + ScalarValue::size_of_hashset(&self.values)
size_of_val(self) + ScalarValue::size_of_hashmap(&self.values)
- size_of_val(&self.values)
+ self.datatype.size()
- size_of_val(&self.datatype)
Expand Down Expand Up @@ -1494,8 +1547,8 @@ mod tests {
acc2.update_batch(&[string_list_data([vec!["e", "f", "g"]])])?;
acc1 = merge(acc1, acc2)?;

// without compaction, the size is 16660
assert_eq!(acc1.size(), 1660);
// without compaction, the size is 16684
assert_eq!(acc1.size(), 1684);

Ok(())
}
Expand Down Expand Up @@ -2415,4 +2468,126 @@ mod tests {

Ok(())
}

// ---- DistinctArrayAggAccumulator retract_batch tests ----

// Build a DISTINCT accumulator with ascending sort so evaluate output is
// deterministic regardless of HashMap iteration order.
fn distinct_acc(ignore_nulls: bool) -> Result<DistinctArrayAggAccumulator> {
DistinctArrayAggAccumulator::try_new(
&DataType::Utf8,
Some(SortOptions::default()),
ignore_nulls,
)
}

#[test]
fn distinct_retract_duplicate_remains() -> Result<()> {
// Canonical regression for the HashSet-can't-retract bug: a value
// that appears multiple times in-frame must survive retraction of
// a single occurrence.
let mut acc = distinct_acc(false)?;

// Feed [A, A, B] across two batches to exercise multi-batch state.
acc.update_batch(&[data(["A", "A"])])?;
acc.update_batch(&[data(["B"])])?;
assert_eq!(print_nulls(str_arr(acc.evaluate()?)?), vec!["A", "B"]);

// Retract a single A — the other A is still in the frame.
acc.retract_batch(&[data(["A"])])?;
assert_eq!(print_nulls(str_arr(acc.evaluate()?)?), vec!["A", "B"]);

// Retract the remaining A — only B left.
acc.retract_batch(&[data(["A"])])?;
assert_eq!(print_nulls(str_arr(acc.evaluate()?)?), vec!["B"]);

Ok(())
}

#[test]
fn distinct_retract_full_removal() -> Result<()> {
let mut acc = distinct_acc(false)?;

acc.update_batch(&[data(["A", "B"])])?;
acc.retract_batch(&[data(["A", "B"])])?;

let result = acc.evaluate()?;
assert!(
matches!(&result, ScalarValue::List(arr) if arr.is_null(0)),
"expected null list after full retract, got {result:?}"
);

Ok(())
}

#[test]
fn distinct_retract_ignore_nulls_skips() -> Result<()> {
// ignore_nulls=true: NULL never enters state on update, so retract
// must also skip NULL — otherwise we'd error on the missing key.
let mut acc = distinct_acc(true)?;

acc.update_batch(&[data([Some("A"), None, Some("B")])])?;
assert_eq!(print_nulls(str_arr(acc.evaluate()?)?), vec!["A", "B"]);

// Retract [A, NULL] — the NULL is skipped, only A is removed.
acc.retract_batch(&[data([Some("A"), None])])?;
assert_eq!(print_nulls(str_arr(acc.evaluate()?)?), vec!["B"]);

Ok(())
}

#[test]
fn distinct_retract_null_tracked() -> Result<()> {
// ignore_nulls=false: NULL enters state with a refcount and must
// retract symmetrically; the NULL key must be removed at zero
// (else evaluate still emits a NULL element).
let mut acc = distinct_acc(false)?;

acc.update_batch(&[data([Some("A"), None, None])])?;
// With nulls_first=true (SortOptions default), NULL sorts before A.
assert_eq!(print_nulls(str_arr(acc.evaluate()?)?), vec!["NULL", "A"]);

// Retract one NULL — count drops to 1, key still present.
acc.retract_batch(&[data::<Option<&str>, 1>([None])])?;
assert_eq!(print_nulls(str_arr(acc.evaluate()?)?), vec!["NULL", "A"]);

// Retract the remaining NULL — key is removed.
acc.retract_batch(&[data::<Option<&str>, 1>([None])])?;
assert_eq!(print_nulls(str_arr(acc.evaluate()?)?), vec!["A"]);

Ok(())
}

#[test]
fn distinct_supports_retract_batch() -> Result<()> {
let acc = distinct_acc(false)?;
assert!(acc.supports_retract_batch());

let acc_ignore = distinct_acc(true)?;
assert!(acc_ignore.supports_retract_batch());

Ok(())
}

#[test]
fn distinct_merge_then_evaluate_regression() -> Result<()> {
// Non-window path: state -> merge_batch -> evaluate must still
// produce the union of distinct values across partitions.
let mut acc1 = distinct_acc(false)?;
let mut acc2 = distinct_acc(false)?;

acc1.update_batch(&[data(["A", "A", "B"])])?;
acc2.update_batch(&[data(["A", "C"])])?;

let state = acc2.state()?;
let state_arrs: Vec<ArrayRef> = state
.into_iter()
.map(|sv| sv.to_array_of_size(1))
.collect::<Result<Vec<_>>>()?;
acc1.merge_batch(&state_arrs)?;

assert_eq!(print_nulls(str_arr(acc1.evaluate()?)?), vec!["A", "B", "C"]);

Ok(())
}
}
Loading
Loading