Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
50 changes: 33 additions & 17 deletions task-sdk/src/airflow/sdk/execution_time/task_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -1859,6 +1859,38 @@ def _send_error_email_notification(
log.exception("Failed to send email notification")


@detail_span("task.execute")
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The ctx passed in here is snapshotted back in _execute_task (ctx = contextvars.copy_context(), line 1917), which runs before this task.execute span is entered. So inside ctx.run(execute, context=context) the current OTel span in that snapshot is still _execute_task, not task.execute.

Doesn't that mean spans the operator emits during execute() parent to _execute_task and land as siblings of task.execute rather than nesting under it? task.execute would record the right duration but stay empty of the operator's own child spans, which seems opposite to wrapping "the actual task execution".

Was the intent for operator child spans to nest under task.execute? If so, would re-snapshotting the context inside this function (after the span is current) be the fix? One wrinkle: ctx also carries ExecutorSafeguard.tracker.set(task) from line 1919, so a plain copy_context() here would drop that and it would need re-applying.

def _run_execute_callable(
context: Context,
ctx: contextvars.Context,
execute: Callable[..., Any] | functools.partial[Any],
task: BaseOperator,
) -> Any:
"""Run the task's execute callable, applying the execution timeout if one is set."""
if task.execution_timeout:
from airflow.sdk.execution_time.timeout import timeout

# TODO: handle timeout in case of deferral
timeout_seconds = task.execution_timeout.total_seconds()
try:
# It's possible we're already timed out, so fast-fail if true
if timeout_seconds <= 0:
raise AirflowTaskTimeout()
# Run task in timeout wrapper
with timeout(timeout_seconds):
result = ctx.run(execute, context=context)
except AirflowTaskTimeout:
# AirflowTaskTimeout inherits from BaseException, so OpenTelemetry's
# start_as_current_span won't mark the span as errored on its own
# (it only does so for Exception subclasses). Set it explicitly.
trace.get_current_span().set_status(Status(StatusCode.ERROR, "AirflowTaskTimeout"))
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

At detail level 1, task.execute isn't created (the decorator returns INVALID_SPAN), and _execute_task is also a detail_span, so here trace.get_current_span() resolves to the worker.<task_id> span and this marks that span ERROR on timeout. A regular exception at level 1 sets no status explicitly, so the worker span is left untouched unless the exception propagates through its own with block.

Is it intended that a timeout flips the worker span to ERROR at level 1 while an ordinary failure doesn't? Just flagging the asymmetry, not blocking.

task.on_kill()
raise
else:
result = ctx.run(execute, context=context)
return result


@detail_span("_execute_task")
def _execute_task(context: Context, ti: RuntimeTaskInstance, log: Logger):
"""Execute Task (optionally with a Timeout) and push Xcom results."""
Expand Down Expand Up @@ -1901,23 +1933,7 @@ def _execute_task(context: Context, ti: RuntimeTaskInstance, log: Logger):

log.info("::endgroup::")

if task.execution_timeout:
from airflow.sdk.execution_time.timeout import timeout

# TODO: handle timeout in case of deferral
timeout_seconds = task.execution_timeout.total_seconds()
try:
# It's possible we're already timed out, so fast-fail if true
if timeout_seconds <= 0:
raise AirflowTaskTimeout()
# Run task in timeout wrapper
with timeout(timeout_seconds):
result = ctx.run(execute, context=context)
except AirflowTaskTimeout:
task.on_kill()
raise
else:
result = ctx.run(execute, context=context)
result = _run_execute_callable(context, ctx, execute, task)

if (post_execute_hook := task._post_execute_hook) is not None:
create_executable_runner(post_execute_hook, outlet_events, logger=log).run(context, result)
Expand Down
152 changes: 152 additions & 0 deletions task-sdk/tests/task_sdk/execution_time/test_task_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
from __future__ import annotations

import contextlib
import contextvars
import functools
import json
import os
Expand Down Expand Up @@ -162,6 +163,7 @@
_execute_task,
_make_task_span,
_push_xcom_if_needed,
_run_execute_callable,
_serialize_outlet_events,
_xcom_push,
detail_span,
Expand Down Expand Up @@ -5268,6 +5270,156 @@ def test_exception_in_context_manager_propagates(self):
raise ValueError("boom")


class TestRunExecuteCallable:
"""Tests for ``_run_execute_callable``.

It runs the task's execute callable inside the task's contextvars context,
applies the execution timeout when one is configured, and wraps the call in
a ``task.execute`` detail span.
"""

@staticmethod
def _make_task(execution_timeout=None):
task = mock.MagicMock(spec=BaseOperator)
task.execution_timeout = execution_timeout
return task

def test_returns_result_and_uses_contextvars_context(self):
"""The callable runs inside the provided contextvars context and its return value is passed back."""
var = contextvars.ContextVar("marker")
var.set("outer")
ctx = contextvars.copy_context()

def execute(context):
var.set("inner")
return context["value"] * 2

task = self._make_task()
result = _run_execute_callable(context={"value": 21}, ctx=ctx, execute=execute, task=task)

assert result == 42
# The mutation happened inside the copied context, not the current one.
assert var.get() == "outer"
assert ctx[var] == "inner"
task.on_kill.assert_not_called()

def test_applies_execution_timeout(self):
"""When a timeout is set and the callable overruns, AirflowTaskTimeout is raised and on_kill is called."""
ctx = contextvars.copy_context()
task = self._make_task(execution_timeout=timedelta(milliseconds=10))

def execute(context):
time.sleep(2)

with pytest.raises(AirflowTaskTimeout):
_run_execute_callable(context={}, ctx=ctx, execute=execute, task=task)

task.on_kill.assert_called_once()

def test_fast_fails_when_timeout_already_elapsed(self):
"""A non-positive timeout fast-fails before running the callable and still calls on_kill."""
ctx = contextvars.copy_context()
task = self._make_task(execution_timeout=timedelta(seconds=-1))
execute = mock.MagicMock()

with pytest.raises(AirflowTaskTimeout):
_run_execute_callable(context={}, ctx=ctx, execute=execute, task=task)

execute.assert_not_called()
task.on_kill.assert_called_once()

def test_emits_task_execute_span_at_detail_level_2(self):
"""At detail level 2, running the callable produces a recorded ``task.execute`` span."""
exporter = InMemorySpanExporter()
provider = TracerProvider()
provider.add_span_processor(SimpleSpanProcessor(exporter))
t = provider.get_tracer("test")
carrier = new_dagrun_trace_carrier(task_span_detail_level=2)
parent_ctx = TraceContextTextMapPropagator().extract(carrier)

ctx = contextvars.copy_context()
task = self._make_task()

with mock.patch("airflow.sdk.execution_time.task_runner.tracer", t):
with t.start_as_current_span("parent", context=parent_ctx):
result = _run_execute_callable(context={}, ctx=ctx, execute=lambda context: "ok", task=task)

assert result == "ok"
names = [s.name for s in exporter.get_finished_spans()]
assert "task.execute" in names
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

These tests assert that task.execute exists and check its status, but none of them assert its parent/child relationship. Given the nesting question I raised on _run_execute_callable, would it be worth asserting that the operator's child spans actually nest under task.execute (e.g. checking span.parent.span_id)? As written, would this suite stay green even if task.execute ends up a sibling of the operator's spans rather than their parent?


def test_no_task_execute_span_at_detail_level_1(self):
"""At detail level 1, no ``task.execute`` span is recorded but the callable still runs."""
exporter = InMemorySpanExporter()
provider = TracerProvider()
provider.add_span_processor(SimpleSpanProcessor(exporter))
t = provider.get_tracer("test")
carrier = new_dagrun_trace_carrier(task_span_detail_level=1)
parent_ctx = TraceContextTextMapPropagator().extract(carrier)

ctx = contextvars.copy_context()
task = self._make_task()

with mock.patch("airflow.sdk.execution_time.task_runner.tracer", t):
with t.start_as_current_span("parent", context=parent_ctx):
result = _run_execute_callable(context={}, ctx=ctx, execute=lambda context: "ok", task=task)

assert result == "ok"
names = [s.name for s in exporter.get_finished_spans()]
assert "task.execute" not in names

def test_task_execute_span_marked_error_on_regular_exception(self):
"""A regular Exception from the callable marks the ``task.execute`` span as ERROR."""
exporter = InMemorySpanExporter()
provider = TracerProvider()
provider.add_span_processor(SimpleSpanProcessor(exporter))
t = provider.get_tracer("test")
carrier = new_dagrun_trace_carrier(task_span_detail_level=2)
parent_ctx = TraceContextTextMapPropagator().extract(carrier)

ctx = contextvars.copy_context()
task = self._make_task()

def execute(context):
raise ValueError("boom")

with mock.patch("airflow.sdk.execution_time.task_runner.tracer", t):
with t.start_as_current_span("parent", context=parent_ctx):
with pytest.raises(ValueError, match="boom"):
_run_execute_callable(context={}, ctx=ctx, execute=execute, task=task)

spans = {s.name: s for s in exporter.get_finished_spans()}
assert spans["task.execute"].status.status_code == trace.StatusCode.ERROR

def test_task_execute_span_marked_error_on_timeout(self):
"""A timeout (AirflowTaskTimeout, a BaseException) is explicitly marked ERROR on the span.

OpenTelemetry only auto-sets ERROR status for Exception subclasses, so the timeout handler
sets it explicitly; this guards against that being lost.
"""
exporter = InMemorySpanExporter()
provider = TracerProvider()
provider.add_span_processor(SimpleSpanProcessor(exporter))
t = provider.get_tracer("test")
carrier = new_dagrun_trace_carrier(task_span_detail_level=2)
parent_ctx = TraceContextTextMapPropagator().extract(carrier)

ctx = contextvars.copy_context()
task = self._make_task(execution_timeout=timedelta(milliseconds=10))

def execute(context):
time.sleep(2)

with mock.patch("airflow.sdk.execution_time.task_runner.tracer", t):
with t.start_as_current_span("parent", context=parent_ctx):
with pytest.raises(AirflowTaskTimeout):
_run_execute_callable(context={}, ctx=ctx, execute=execute, task=task)

spans = {s.name: s for s in exporter.get_finished_spans()}
assert spans["task.execute"].status.status_code == trace.StatusCode.ERROR
task.on_kill.assert_called_once()


def test_dag_add_result(create_runtime_ti, mock_supervisor_comms):
with DAG(dag_id="test_dag_add_result") as dag:
task = PythonOperator(task_id="t", python_callable=lambda: 123)
Expand Down
Loading