-
Notifications
You must be signed in to change notification settings - Fork 17.2k
Add task.execute detail span around task execute callable #67877
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -1859,6 +1859,38 @@ def _send_error_email_notification( | |
| log.exception("Failed to send email notification") | ||
|
|
||
|
|
||
| @detail_span("task.execute") | ||
| def _run_execute_callable( | ||
| context: Context, | ||
| ctx: contextvars.Context, | ||
| execute: Callable[..., Any] | functools.partial[Any], | ||
| task: BaseOperator, | ||
| ) -> Any: | ||
| """Run the task's execute callable, applying the execution timeout if one is set.""" | ||
| if task.execution_timeout: | ||
| from airflow.sdk.execution_time.timeout import timeout | ||
|
|
||
| # TODO: handle timeout in case of deferral | ||
| timeout_seconds = task.execution_timeout.total_seconds() | ||
| try: | ||
| # It's possible we're already timed out, so fast-fail if true | ||
| if timeout_seconds <= 0: | ||
| raise AirflowTaskTimeout() | ||
| # Run task in timeout wrapper | ||
| with timeout(timeout_seconds): | ||
| result = ctx.run(execute, context=context) | ||
| except AirflowTaskTimeout: | ||
| # AirflowTaskTimeout inherits from BaseException, so OpenTelemetry's | ||
| # start_as_current_span won't mark the span as errored on its own | ||
| # (it only does so for Exception subclasses). Set it explicitly. | ||
| trace.get_current_span().set_status(Status(StatusCode.ERROR, "AirflowTaskTimeout")) | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. At detail level 1, Is it intended that a timeout flips the worker span to ERROR at level 1 while an ordinary failure doesn't? Just flagging the asymmetry, not blocking. |
||
| task.on_kill() | ||
| raise | ||
| else: | ||
| result = ctx.run(execute, context=context) | ||
| return result | ||
|
|
||
|
|
||
| @detail_span("_execute_task") | ||
| def _execute_task(context: Context, ti: RuntimeTaskInstance, log: Logger): | ||
| """Execute Task (optionally with a Timeout) and push Xcom results.""" | ||
|
|
@@ -1901,23 +1933,7 @@ def _execute_task(context: Context, ti: RuntimeTaskInstance, log: Logger): | |
|
|
||
| log.info("::endgroup::") | ||
|
|
||
| if task.execution_timeout: | ||
| from airflow.sdk.execution_time.timeout import timeout | ||
|
|
||
| # TODO: handle timeout in case of deferral | ||
| timeout_seconds = task.execution_timeout.total_seconds() | ||
| try: | ||
| # It's possible we're already timed out, so fast-fail if true | ||
| if timeout_seconds <= 0: | ||
| raise AirflowTaskTimeout() | ||
| # Run task in timeout wrapper | ||
| with timeout(timeout_seconds): | ||
| result = ctx.run(execute, context=context) | ||
| except AirflowTaskTimeout: | ||
| task.on_kill() | ||
| raise | ||
| else: | ||
| result = ctx.run(execute, context=context) | ||
| result = _run_execute_callable(context, ctx, execute, task) | ||
|
|
||
| if (post_execute_hook := task._post_execute_hook) is not None: | ||
| create_executable_runner(post_execute_hook, outlet_events, logger=log).run(context, result) | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -18,6 +18,7 @@ | |
| from __future__ import annotations | ||
|
|
||
| import contextlib | ||
| import contextvars | ||
| import functools | ||
| import json | ||
| import os | ||
|
|
@@ -162,6 +163,7 @@ | |
| _execute_task, | ||
| _make_task_span, | ||
| _push_xcom_if_needed, | ||
| _run_execute_callable, | ||
| _serialize_outlet_events, | ||
| _xcom_push, | ||
| detail_span, | ||
|
|
@@ -5268,6 +5270,156 @@ def test_exception_in_context_manager_propagates(self): | |
| raise ValueError("boom") | ||
|
|
||
|
|
||
| class TestRunExecuteCallable: | ||
| """Tests for ``_run_execute_callable``. | ||
|
|
||
| It runs the task's execute callable inside the task's contextvars context, | ||
| applies the execution timeout when one is configured, and wraps the call in | ||
| a ``task.execute`` detail span. | ||
| """ | ||
|
|
||
| @staticmethod | ||
| def _make_task(execution_timeout=None): | ||
| task = mock.MagicMock(spec=BaseOperator) | ||
| task.execution_timeout = execution_timeout | ||
| return task | ||
|
|
||
| def test_returns_result_and_uses_contextvars_context(self): | ||
| """The callable runs inside the provided contextvars context and its return value is passed back.""" | ||
| var = contextvars.ContextVar("marker") | ||
| var.set("outer") | ||
| ctx = contextvars.copy_context() | ||
|
|
||
| def execute(context): | ||
| var.set("inner") | ||
| return context["value"] * 2 | ||
|
|
||
| task = self._make_task() | ||
| result = _run_execute_callable(context={"value": 21}, ctx=ctx, execute=execute, task=task) | ||
|
|
||
| assert result == 42 | ||
| # The mutation happened inside the copied context, not the current one. | ||
| assert var.get() == "outer" | ||
| assert ctx[var] == "inner" | ||
| task.on_kill.assert_not_called() | ||
|
|
||
| def test_applies_execution_timeout(self): | ||
| """When a timeout is set and the callable overruns, AirflowTaskTimeout is raised and on_kill is called.""" | ||
| ctx = contextvars.copy_context() | ||
| task = self._make_task(execution_timeout=timedelta(milliseconds=10)) | ||
|
|
||
| def execute(context): | ||
| time.sleep(2) | ||
|
|
||
| with pytest.raises(AirflowTaskTimeout): | ||
| _run_execute_callable(context={}, ctx=ctx, execute=execute, task=task) | ||
|
|
||
| task.on_kill.assert_called_once() | ||
|
|
||
| def test_fast_fails_when_timeout_already_elapsed(self): | ||
| """A non-positive timeout fast-fails before running the callable and still calls on_kill.""" | ||
| ctx = contextvars.copy_context() | ||
| task = self._make_task(execution_timeout=timedelta(seconds=-1)) | ||
| execute = mock.MagicMock() | ||
|
|
||
| with pytest.raises(AirflowTaskTimeout): | ||
| _run_execute_callable(context={}, ctx=ctx, execute=execute, task=task) | ||
|
|
||
| execute.assert_not_called() | ||
| task.on_kill.assert_called_once() | ||
|
|
||
| def test_emits_task_execute_span_at_detail_level_2(self): | ||
| """At detail level 2, running the callable produces a recorded ``task.execute`` span.""" | ||
| exporter = InMemorySpanExporter() | ||
| provider = TracerProvider() | ||
| provider.add_span_processor(SimpleSpanProcessor(exporter)) | ||
| t = provider.get_tracer("test") | ||
| carrier = new_dagrun_trace_carrier(task_span_detail_level=2) | ||
| parent_ctx = TraceContextTextMapPropagator().extract(carrier) | ||
|
|
||
| ctx = contextvars.copy_context() | ||
| task = self._make_task() | ||
|
|
||
| with mock.patch("airflow.sdk.execution_time.task_runner.tracer", t): | ||
| with t.start_as_current_span("parent", context=parent_ctx): | ||
| result = _run_execute_callable(context={}, ctx=ctx, execute=lambda context: "ok", task=task) | ||
|
|
||
| assert result == "ok" | ||
| names = [s.name for s in exporter.get_finished_spans()] | ||
| assert "task.execute" in names | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. These tests assert that |
||
|
|
||
| def test_no_task_execute_span_at_detail_level_1(self): | ||
| """At detail level 1, no ``task.execute`` span is recorded but the callable still runs.""" | ||
| exporter = InMemorySpanExporter() | ||
| provider = TracerProvider() | ||
| provider.add_span_processor(SimpleSpanProcessor(exporter)) | ||
| t = provider.get_tracer("test") | ||
| carrier = new_dagrun_trace_carrier(task_span_detail_level=1) | ||
| parent_ctx = TraceContextTextMapPropagator().extract(carrier) | ||
|
|
||
| ctx = contextvars.copy_context() | ||
| task = self._make_task() | ||
|
|
||
| with mock.patch("airflow.sdk.execution_time.task_runner.tracer", t): | ||
| with t.start_as_current_span("parent", context=parent_ctx): | ||
| result = _run_execute_callable(context={}, ctx=ctx, execute=lambda context: "ok", task=task) | ||
|
|
||
| assert result == "ok" | ||
| names = [s.name for s in exporter.get_finished_spans()] | ||
| assert "task.execute" not in names | ||
|
|
||
| def test_task_execute_span_marked_error_on_regular_exception(self): | ||
| """A regular Exception from the callable marks the ``task.execute`` span as ERROR.""" | ||
| exporter = InMemorySpanExporter() | ||
| provider = TracerProvider() | ||
| provider.add_span_processor(SimpleSpanProcessor(exporter)) | ||
| t = provider.get_tracer("test") | ||
| carrier = new_dagrun_trace_carrier(task_span_detail_level=2) | ||
| parent_ctx = TraceContextTextMapPropagator().extract(carrier) | ||
|
|
||
| ctx = contextvars.copy_context() | ||
| task = self._make_task() | ||
|
|
||
| def execute(context): | ||
| raise ValueError("boom") | ||
|
|
||
| with mock.patch("airflow.sdk.execution_time.task_runner.tracer", t): | ||
| with t.start_as_current_span("parent", context=parent_ctx): | ||
| with pytest.raises(ValueError, match="boom"): | ||
| _run_execute_callable(context={}, ctx=ctx, execute=execute, task=task) | ||
|
|
||
| spans = {s.name: s for s in exporter.get_finished_spans()} | ||
| assert spans["task.execute"].status.status_code == trace.StatusCode.ERROR | ||
|
|
||
| def test_task_execute_span_marked_error_on_timeout(self): | ||
| """A timeout (AirflowTaskTimeout, a BaseException) is explicitly marked ERROR on the span. | ||
|
|
||
| OpenTelemetry only auto-sets ERROR status for Exception subclasses, so the timeout handler | ||
| sets it explicitly; this guards against that being lost. | ||
| """ | ||
| exporter = InMemorySpanExporter() | ||
| provider = TracerProvider() | ||
| provider.add_span_processor(SimpleSpanProcessor(exporter)) | ||
| t = provider.get_tracer("test") | ||
| carrier = new_dagrun_trace_carrier(task_span_detail_level=2) | ||
| parent_ctx = TraceContextTextMapPropagator().extract(carrier) | ||
|
|
||
| ctx = contextvars.copy_context() | ||
| task = self._make_task(execution_timeout=timedelta(milliseconds=10)) | ||
|
|
||
| def execute(context): | ||
| time.sleep(2) | ||
|
|
||
| with mock.patch("airflow.sdk.execution_time.task_runner.tracer", t): | ||
| with t.start_as_current_span("parent", context=parent_ctx): | ||
| with pytest.raises(AirflowTaskTimeout): | ||
| _run_execute_callable(context={}, ctx=ctx, execute=execute, task=task) | ||
|
|
||
| spans = {s.name: s for s in exporter.get_finished_spans()} | ||
| assert spans["task.execute"].status.status_code == trace.StatusCode.ERROR | ||
| task.on_kill.assert_called_once() | ||
|
|
||
|
|
||
| def test_dag_add_result(create_runtime_ti, mock_supervisor_comms): | ||
| with DAG(dag_id="test_dag_add_result") as dag: | ||
| task = PythonOperator(task_id="t", python_callable=lambda: 123) | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The
ctxpassed in here is snapshotted back in_execute_task(ctx = contextvars.copy_context(), line 1917), which runs before thistask.executespan is entered. So insidectx.run(execute, context=context)the current OTel span in that snapshot is still_execute_task, nottask.execute.Doesn't that mean spans the operator emits during
execute()parent to_execute_taskand land as siblings oftask.executerather than nesting under it?task.executewould record the right duration but stay empty of the operator's own child spans, which seems opposite to wrapping "the actual task execution".Was the intent for operator child spans to nest under
task.execute? If so, would re-snapshotting the context inside this function (after the span is current) be the fix? One wrinkle:ctxalso carriesExecutorSafeguard.tracker.set(task)from line 1919, so a plaincopy_context()here would drop that and it would need re-applying.