Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -167,6 +167,15 @@ default EngineResultStream fetchByRowIds(
throw new UnsupportedOperationException("fetchByRowIds not implemented for [" + name() + "]");
}

/**
* Cooperatively cancels in-flight backend work for {@code contextId} (e.g. fire the per-context
* cancellation token). Called from a task cancellation listener for the fetch path, which —
* unlike the query path's {@code SearchExecEngine} — returns an opaque {@link EngineResultStream}.
* Implementations must signal the native execution to unwind, not close the stream cross-thread
* (that races the in-flight pull). No-op for an unknown {@code contextId}; default no-op.
*/
default void cancelByContext(long contextId) {}

/**
* Converts a backend-specific exception into an appropriate OpenSearch exception type.
*
Expand Down
18 changes: 8 additions & 10 deletions sandbox/plugins/analytics-backend-datafusion/rust/src/api.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1826,11 +1826,10 @@ pub async unsafe fn execute_local_plan(
// drain this handle unchanged. Use the cancellable variant so the CPU
// task can be aborted mid-execution when cancel_query fires.
let cpu_exec = manager.cpu_executor();
let (cross_rt_stream, abort_handle, task_done) =
CrossRtStream::new_with_df_error_stream_cancellable(df_stream, cpu_exec.clone());
if let Some(h) = abort_handle {
query_tracker::set_abort_handle(context_id, h);
}
let (cross_rt_stream, _abort_handle, task_done) =
CrossRtStream::new_with_df_error_stream_cancellable(df_stream, cpu_exec.clone(), token.clone());
// Reduce path: cancel via the token only, do NOT register the abort handle — an abort() mid-send
// would skip the cross_rt drop+drain cleanup and leak the aggregate's in-flight GroupValues.
if let Some(rt) = cpu_exec.handle() {
query_tracker::set_cpu_runtime_handle(context_id, rt);
}
Expand Down Expand Up @@ -1866,6 +1865,7 @@ pub unsafe fn execute_local_prepared_plan(
// The token is held via the QueryStreamHandle's context and consulted by
// stream_next on each batch pull.
let query_context = QueryTrackingContext::new(context_id, session.memory_pool(), query_tracker::QueryType::Coordinator);
let token = query_tracker::get_cancellation_token(context_id);

// DataFusion's execute_stream is sync, but kicks off RepartitionExec /
// stream channels that require a Tokio reactor. Enter the IO runtime's
Expand All @@ -1874,11 +1874,9 @@ pub unsafe fn execute_local_prepared_plan(
let df_stream = session.execute_prepared()?;

let cpu_exec = manager.cpu_executor();
let (cross_rt_stream, abort_handle, task_done) =
CrossRtStream::new_with_df_error_stream_cancellable(df_stream, cpu_exec.clone());
if let Some(h) = abort_handle {
query_tracker::set_abort_handle(context_id, h);
}
let (cross_rt_stream, _abort_handle, task_done) =
CrossRtStream::new_with_df_error_stream_cancellable(df_stream, cpu_exec.clone(), token.clone());
// Prepared-reduce path: same as execute_local_plan — token-only cancel, no abort handle.
if let Some(rt) = cpu_exec.handle() {
query_tracker::set_cpu_runtime_handle(context_id, rt);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ use tokio::sync::mpsc::{channel, Sender};
use tokio::sync::oneshot;
use tokio::task::AbortHandle;
use tokio_stream::wrappers::ReceiverStream;
use tokio_util::sync::CancellationToken;

/// Fires its `oneshot` when dropped. Held inside the spawned task body so the signal is sent on
/// every exit path (drain, error, abort-unwind, panic).
Expand Down Expand Up @@ -72,16 +73,21 @@ impl CrossRtStream {
stream: SendableRecordBatchStream,
exec: DedicatedExecutor,
) -> Self {
let (cross_rt, _abort_handle, _done_rx) = Self::new_with_df_error_stream_cancellable(stream, exec);
let (cross_rt, _abort_handle, _done_rx) = Self::new_with_df_error_stream_cancellable(stream, exec, None);
cross_rt
}

/// Like [`new_with_df_error_stream`](Self::new_with_df_error_stream), but also returns an
/// [`AbortHandle`] and a `oneshot::Receiver` that fires once the spawned task has fully dropped
/// (completion or abort) — the barrier `stream_close` waits on before the allocator closes.
///
/// When `cancel_token` is supplied, cancel is cooperative: the producer breaks on the token and
/// runs the drop(stream)+drain cleanup, instead of an abort() mid-send that skips it. Callers
/// that pass `None` are unchanged.
pub fn new_with_df_error_stream_cancellable(
stream: SendableRecordBatchStream,
exec: DedicatedExecutor,
cancel_token: Option<CancellationToken>,
) -> (Self, Option<AbortHandle>, oneshot::Receiver<()>) {
let schema = stream.schema();
let (tx, rx) = channel(1);
Expand All @@ -90,12 +96,32 @@ impl CrossRtStream {

let fut = async move {
let _done = DoneGuard(Some(done_tx));
tokio::pin!(stream);
while let Some(res) = stream.next().await {
if tx_captured.send(res).await.is_err() {
return;
let mut stream = Box::pin(stream);
// Cooperative cancel: select each await against the token so a cancel breaks the loop
// and falls through to the drop(stream)+drain below, rather than an abort() mid-send.
loop {
let next = tokio::select! {
biased;
_ = async { match &cancel_token { Some(t) => t.cancelled().await, None => std::future::pending::<()>().await } } => break,
n = stream.next() => n,
};
let res = match next {
Some(r) => r,
None => break,
};
let sent = tokio::select! {
biased;
_ = async { match &cancel_token { Some(t) => t.cancelled().await, None => std::future::pending::<()>().await } } => break,
r = tx_captured.send(res) => r,
};
if sent.is_err() {
break;
}
}
// Drop the inner stream while this future is still polled — frees the aggregate's own
// GroupValues and schedules its child producer tasks for abort. The remaining deferred
// child drops are reaped by stream_close (it waits on task_done, then flush_cpu_runtime).
drop(stream);
};
Comment thread
mch2 marked this conversation as resolved.

let (abort_handle, join_fut) = exec.spawn_with_abort_handle(fut);
Expand Down Expand Up @@ -296,7 +322,7 @@ mod tests {
stream::iter(vec![Ok(test_batch(&[1, 2, 3]))]),
));

let (cross, _abort, done_rx) = CrossRtStream::new_with_df_error_stream_cancellable(inner, exec.clone());
let (cross, _abort, done_rx) = CrossRtStream::new_with_df_error_stream_cancellable(inner, exec.clone(), None);
let wrapped = RecordBatchStreamAdapter::new(cross.schema(), cross);
tokio::pin!(wrapped);
while wrapped.next().await.is_some() {}
Expand All @@ -315,7 +341,7 @@ mod tests {
stream::pending::<Result<RecordBatch, DataFusionError>>(),
));

let (cross, abort, done_rx) = CrossRtStream::new_with_df_error_stream_cancellable(inner, exec.clone());
let (cross, abort, done_rx) = CrossRtStream::new_with_df_error_stream_cancellable(inner, exec.clone(), None);
// Hold the stream so the abort, not a drop, is what ends the task.
let _wrapped = RecordBatchStreamAdapter::new(cross.schema(), cross);

Expand All @@ -325,4 +351,76 @@ mod tests {
assert!(fired.is_ok(), "done_rx must fire after the task is aborted");
exec.join_blocking();
}

// Firing the token breaks the loop and fires done_rx without an abort(), so the producer runs
// its drop+drain cleanup that frees the aggregate's GroupValues.
#[tokio::test]
async fn cancellation_token_breaks_loop_and_fires_done_rx() {
let exec = test_exec();
let schema = test_schema();
// Never-ending stream: the only way the task ends is the cooperative cancel.
let inner = Box::pin(RecordBatchStreamAdapter::new(
schema.clone(),
stream::pending::<Result<RecordBatch, DataFusionError>>(),
));

let token = CancellationToken::new();
let (cross, _abort, done_rx) =
CrossRtStream::new_with_df_error_stream_cancellable(inner, exec.clone(), Some(token.clone()));
// Hold the stream so a consumer-side drop is NOT what ends the task — the token is.
let _wrapped = RecordBatchStreamAdapter::new(cross.schema(), cross);

token.cancel();

let fired = tokio::time::timeout(std::time::Duration::from_secs(5), done_rx).await;
assert!(fired.is_ok(), "cancelling the token must break the loop and fire done_rx");
assert!(fired.unwrap().is_ok(), "done_rx must complete, not be dropped");
exec.join_blocking();
}

// A token that is never fired must not perturb the normal drain path: the stream completes and
// all its batches are delivered.
#[tokio::test]
async fn uncancelled_token_drains_normally() {
let exec = test_exec();
let schema = test_schema();
let batches = vec![Ok(test_batch(&[1, 2, 3])), Ok(test_batch(&[4, 5]))];
let inner = Box::pin(RecordBatchStreamAdapter::new(schema.clone(), stream::iter(batches)));

let token = CancellationToken::new(); // never cancelled
let (cross, _abort, done_rx) =
CrossRtStream::new_with_df_error_stream_cancellable(inner, exec.clone(), Some(token));
let wrapped = RecordBatchStreamAdapter::new(cross.schema(), cross);
tokio::pin!(wrapped);

let mut total_rows = 0;
while let Some(batch) = wrapped.next().await {
total_rows += batch.unwrap().num_rows();
}
assert_eq!(total_rows, 5, "all rows delivered when the token is never fired");
assert!(done_rx.await.is_ok(), "done_rx fires on normal drain even with a token present");
exec.join_blocking();
}

// Cancelling BEFORE the first poll still terminates cleanly (the biased select checks the token
// first), exercising the immediate-cancel race.
#[tokio::test]
async fn cancel_before_first_poll_terminates() {
let exec = test_exec();
let schema = test_schema();
let inner = Box::pin(RecordBatchStreamAdapter::new(
schema.clone(),
stream::pending::<Result<RecordBatch, DataFusionError>>(),
));

let token = CancellationToken::new();
token.cancel(); // already cancelled before the task runs
let (cross, _abort, done_rx) =
CrossRtStream::new_with_df_error_stream_cancellable(inner, exec.clone(), Some(token));
let _wrapped = RecordBatchStreamAdapter::new(cross.schema(), cross);

let fired = tokio::time::timeout(std::time::Duration::from_secs(5), done_rx).await;
assert!(fired.is_ok(), "a pre-cancelled token must still terminate the task");
exec.join_blocking();
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -745,7 +745,7 @@ async unsafe fn execute_indexed_with_context_inner(
let empty_exec = EmptyExec::new(Arc::clone(&plan_schema));
let df_stream = empty_exec.execute(0, handle.ctx.task_ctx())?;
let (cross_rt_stream, abort_handle, _task_done) =
CrossRtStream::new_with_df_error_stream_cancellable(df_stream, cpu_executor.clone());
CrossRtStream::new_with_df_error_stream_cancellable(df_stream, cpu_executor.clone(), None);
if let Some(h) = abort_handle {
crate::query_tracker::set_abort_handle(context_id_early, h);
}
Expand Down Expand Up @@ -1240,7 +1240,7 @@ async unsafe fn execute_indexed_with_context_inner(
.map_err(|e| DataFusionError::Execution(format!("execute_stream: {}", e)))?;

let (cross_rt_stream, abort_handle, _task_done) =
CrossRtStream::new_with_df_error_stream_cancellable(df_stream, cpu_executor.clone());
CrossRtStream::new_with_df_error_stream_cancellable(df_stream, cpu_executor.clone(), None);

if let Some(h) = abort_handle {
crate::query_tracker::set_abort_handle(context_id, h);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -195,7 +195,7 @@ pub async fn execute_query(

// Wrap in CrossRtStream — CPU work runs on DedicatedExecutor
let (cross_rt_stream, abort_handle, _task_done) =
CrossRtStream::new_with_df_error_stream_cancellable(df_stream, cpu_executor.clone());
CrossRtStream::new_with_df_error_stream_cancellable(df_stream, cpu_executor.clone(), None);

if let Some(h) = abort_handle {
crate::query_tracker::set_abort_handle(context_id, h);
Expand Down Expand Up @@ -352,7 +352,7 @@ pub async fn execute_with_context(
e
})?;
let (cross_rt_stream, abort_handle, _task_done) =
CrossRtStream::new_with_df_error_stream_cancellable(df_stream, cpu_executor.clone());
CrossRtStream::new_with_df_error_stream_cancellable(df_stream, cpu_executor.clone(), None);
if let Some(h) = abort_handle {
crate::query_tracker::set_abort_handle(context_id, h);
}
Expand Down Expand Up @@ -392,7 +392,7 @@ pub async fn execute_with_context(
})?;

let (cross_rt_stream, abort_handle, _task_done) =
CrossRtStream::new_with_df_error_stream_cancellable(df_stream, cpu_executor.clone());
CrossRtStream::new_with_df_error_stream_cancellable(df_stream, cpu_executor.clone(), None);
if let Some(h) = abort_handle {
crate::query_tracker::set_abort_handle(context_id, h);
}
Expand Down Expand Up @@ -421,7 +421,7 @@ pub async fn execute_with_context(
})?;

let (cross_rt_stream, abort_handle, _task_done) =
CrossRtStream::new_with_df_error_stream_cancellable(df_stream, cpu_executor.clone());
CrossRtStream::new_with_df_error_stream_cancellable(df_stream, cpu_executor.clone(), None);

if let Some(h) = abort_handle {
crate::query_tracker::set_abort_handle(context_id, h);
Expand Down Expand Up @@ -543,22 +543,35 @@ pub fn store_url_from_table_path(table_path: &ListingTableUrl) -> Result<datafus
}

/// Wrap a DataFusion stream in CrossRtStream and package as a QueryStreamHandle pointer.
///
/// Wires cancellation like the shard-query path so a `cancel_query` on the QTF fetch-by-rowid
/// stream can break/abort the cross_rt task instead of stranding its pool reservation.
pub fn wrap_stream_as_handle(
df_stream: datafusion::execution::SendableRecordBatchStream,
cpu_executor: DedicatedExecutor,
runtime: &DataFusionRuntime,
context_id: i64,
) -> i64 {
let cross_rt_stream = CrossRtStream::new_with_df_error_stream(df_stream, cpu_executor);
let wrapped = datafusion::physical_plan::stream::RecordBatchStreamAdapter::new(
cross_rt_stream.schema(),
cross_rt_stream,
);
// Create the tracking context first so its cancellation token is registered before the task starts.
let query_context = crate::query_tracker::QueryTrackingContext::new(
context_id,
runtime.runtime_env.memory_pool.clone(),
crate::query_tracker::QueryType::Shard,
);

let (cross_rt_stream, abort_handle, _task_done) =
CrossRtStream::new_with_df_error_stream_cancellable(df_stream, cpu_executor.clone(), None);
if let Some(h) = abort_handle {
crate::query_tracker::set_abort_handle(context_id, h);
}
if let Some(rt) = cpu_executor.handle() {
crate::query_tracker::set_cpu_runtime_handle(context_id, rt);
}

let wrapped = datafusion::physical_plan::stream::RecordBatchStreamAdapter::new(
cross_rt_stream.schema(),
cross_rt_stream,
);
let handle = crate::api::QueryStreamHandle::new(wrapped, query_context, None);
Box::into_raw(Box::new(handle)) as i64
}
Original file line number Diff line number Diff line change
Expand Up @@ -967,7 +967,9 @@ mod tests {
}

let global = make_global_pool(10_000);
let ctx_id = 70_001;
// Unique id: the QUERY_REGISTRY is process-wide and tests run in parallel, so this must not
// collide with any other test's id (70_001 collides with test_top_n_picks_highest_current_bytes).
let ctx_id = 80_001;
let ctx = QueryTrackingContext::new(ctx_id, global, QueryType::Shard);

// Build a dedicated executor with its own tokio runtime.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1022,6 +1022,15 @@ public EngineResultStream fetchByRowIds(
return new DatafusionResultStream(streamHandle, allocator);
}

@Override
public void cancelByContext(long contextId) {
// Fire the per-context cancellation token so the fetch stream's cross_rt task breaks
// cooperatively. No-op for an unknown contextId.
if (contextId != 0) {
NativeBridge.cancelQuery(contextId);
}
}

public Exception convertException(Exception original) {
return NativeErrorConverter.convert(original);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -275,7 +275,8 @@ private void drainFetchByRowIds(
AnalyticsShardTask task,
StreamingFragmentResponseHandler responseHandler
) {
if (task != null && task.isCancelled()) {
assert task != null : "fetch on " + shard.shardId() + " requires a non-null AnalyticsShardTask";
if (task.isCancelled()) {
responseHandler.onFailure(new TaskCancelledException("Fetch task cancelled before execution: " + task.getReasonCancelled()));
return;
}
Expand Down Expand Up @@ -354,6 +355,9 @@ private void drainFetchByRowIds(
responseHandler.onFailure(new RuntimeException("Failed to execute fetch-by-row-ids on " + shard.shardId(), e));
return;
}
// On cancel, release a fetch parked in the native pull via cooperative cancellation, not
// stream.close() (which would race the in-flight native pull).
task.setCancellationListener(() -> backend.cancelByContext(task.getId()));
try (FragmentResources ctx = resources) {
Iterator<EngineResultBatch> it = ctx.stream().iterator();
while (it.hasNext()) {
Expand All @@ -362,6 +366,8 @@ private void drainFetchByRowIds(
responseHandler.onComplete();
} catch (Exception e) {
responseHandler.onFailure(e);
} finally {
task.clearCancellationListener();
}
}

Expand Down
Loading