Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -560,6 +560,13 @@ class DatabricksSubmitRunOperator(BaseOperator):

.. seealso::
https://docs.databricks.com/dev-tools/api/latest/jobs.html#operation/JobsRunsSubmit
:param openlineage_inject_parent_job_info: If True, injects OpenLineage parent job information
into the ``new_cluster`` ``spark_conf`` so the Spark job emits a ``parentRunFacet`` linking
back to the Airflow task. Defaults to the
``openlineage.spark_inject_parent_job_info`` config value.
:param openlineage_inject_transport_info: If True, injects OpenLineage transport configuration
into the ``new_cluster`` ``spark_conf`` so the Spark job sends OL events to the same backend
as Airflow. Defaults to the ``openlineage.spark_inject_transport_info`` config value.

.. note::
If the operator's ``params`` dict is non-empty, it is automatically forwarded into the
Expand Down Expand Up @@ -606,6 +613,12 @@ def __init__(
wait_for_termination: bool = True,
git_source: dict[str, str] | None = None,
deferrable: bool = conf.getboolean("operators", "default_deferrable", fallback=False),
openlineage_inject_parent_job_info: bool = conf.getboolean(
"openlineage", "spark_inject_parent_job_info", fallback=False
),
openlineage_inject_transport_info: bool = conf.getboolean(
"openlineage", "spark_inject_transport_info", fallback=False
),
**kwargs,
) -> None:
"""Create a new ``DatabricksSubmitRunOperator``."""
Expand All @@ -618,6 +631,8 @@ def __init__(
self.databricks_retry_args = databricks_retry_args
self.wait_for_termination = wait_for_termination
self.deferrable = deferrable
self.openlineage_inject_parent_job_info = openlineage_inject_parent_job_info
self.openlineage_inject_transport_info = openlineage_inject_transport_info
if tasks is not None:
self.json["tasks"] = tasks
if spark_jar_task is not None:
Expand Down Expand Up @@ -694,13 +709,36 @@ def execute(self, context: Context):
else:
_inject_airflow_params_into_task(self.json, params_dump)

if self.openlineage_inject_parent_job_info or self.openlineage_inject_transport_info:
self.log.info("Automatic injection of OpenLineage information into Spark properties is enabled.")
self._inject_openlineage_properties_into_databricks_job(context)

json_normalised = normalise_json_content(self.json)
self.run_id = self._hook.submit_run(json_normalised)
if self.deferrable:
_handle_deferrable_databricks_operator_execution(self, self._hook, self.log, context)
else:
_handle_databricks_operator_execution(self, self._hook, self.log, context)

def _inject_openlineage_properties_into_databricks_job(self, context: Context) -> None:
try:
from airflow.providers.databricks.utils.openlineage import (
inject_openlineage_properties_into_databricks_job,
)

self.json = inject_openlineage_properties_into_databricks_job(
job=self.json,
context=context,
inject_parent_job_info=self.openlineage_inject_parent_job_info,
inject_transport_info=self.openlineage_inject_transport_info,
)
except Exception as e:
self.log.warning(
"An error occurred while trying to inject OpenLineage information. "
"Databricks job has not been modified by OpenLineage.",
exc_info=e,
)

def on_kill(self):
if self.run_id:
self._hook.cancel_run(self.run_id)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
# under the License.
from __future__ import annotations

import copy
import datetime
import json
import logging
Expand All @@ -24,12 +25,17 @@
import requests

from airflow.providers.common.compat.openlineage.check import require_openlineage_version
from airflow.providers.common.compat.openlineage.utils.spark import (
inject_parent_job_information_into_spark_properties,
inject_transport_information_into_spark_properties,
)
from airflow.providers.common.compat.sdk import timezone

if TYPE_CHECKING:
from openlineage.client.event_v2 import RunEvent
from openlineage.client.facet_v2 import JobFacet

from airflow.providers.common.compat.sdk import Context
from airflow.providers.databricks.hooks.databricks import DatabricksHook
from airflow.providers.databricks.hooks.databricks_sql import DatabricksSqlHook

Expand Down Expand Up @@ -320,3 +326,128 @@ def emit_openlineage_events_for_databricks_queries(

log.info("OpenLineage has successfully finished processing information about Databricks queries.")
return


def _is_openlineage_provider_accessible() -> bool:
"""
Check if the OpenLineage provider is accessible.

This function attempts to import the necessary OpenLineage modules and checks if the provider
is enabled and the listener is available.

Returns:
True if the OpenLineage provider is accessible, False otherwise.
"""
try:
from airflow.providers.openlineage.conf import is_disabled
from airflow.providers.openlineage.plugins.listener import get_openlineage_listener
except ImportError:
log.debug("OpenLineage provider could not be imported.")
return False

if is_disabled():
log.debug("OpenLineage provider is disabled.")
return False

if not get_openlineage_listener():
log.debug("OpenLineage listener could not be found.")
return False

return True


def _extract_new_clusters_from_databricks_job(job: dict) -> list[dict]:
"""
Collect every ``new_cluster`` definition that can carry Spark properties in a Databricks job.

A ``runs/submit`` payload can define a ``new_cluster`` in three places: at the top level
(single-task form), inline on each task (``tasks[].new_cluster``), or as a shared job cluster
(``job_clusters[].new_cluster``) referenced by tasks through ``job_cluster_key``. Tasks running on
an ``existing_cluster_id`` have no ``new_cluster`` to mutate and are skipped.

Args:
job: The Databricks ``runs/submit`` job definition.

Returns:
The list of ``new_cluster`` dicts found in the job definition.
"""
new_clusters = []
if isinstance(job.get("new_cluster"), dict):
new_clusters.append(job["new_cluster"])
for key in ("tasks", "job_clusters"):
if isinstance(job.get(key), list):
new_clusters.extend(
item["new_cluster"] for item in job[key] if isinstance(item.get("new_cluster"), dict)
)
return new_clusters


def inject_openlineage_properties_into_databricks_job(
job: dict, context: Context, inject_parent_job_info: bool, inject_transport_info: bool
) -> dict:
"""
Inject OpenLineage properties into a Databricks job definition.

This function does not remove existing configurations or modify the job definition in any way,
except to add the required OpenLineage properties if they are not already present.

The entire properties injection process will be skipped if any condition is met:
- The OpenLineage provider is not accessible.
- The job has no ``new_cluster`` definition to inject Spark properties into (e.g. it only uses
an ``existing_cluster_id``, whose Spark configuration is fixed at cluster creation time).
- Both `inject_parent_job_info` and `inject_transport_info` are set to False.

Additionally, specific information will not be injected if relevant OpenLineage properties already
exist.

Parent job information will not be injected if:
- Any property prefixed with `spark.openlineage.parent` exists.
- `inject_parent_job_info` is False.
Transport information will not be injected if:
- Any property prefixed with `spark.openlineage.transport` exists.
- `inject_transport_info` is False.

Args:
job: The original Databricks ``runs/submit`` job definition.
context: The Airflow context in which the job is running.
inject_parent_job_info: Flag indicating whether to inject parent job information.
inject_transport_info: Flag indicating whether to inject transport information.

Returns:
The modified job definition with OpenLineage properties injected, if applicable.
"""
if not inject_parent_job_info and not inject_transport_info:
log.debug("Automatic injection of OpenLineage information is disabled.")
return job

if not _is_openlineage_provider_accessible():
log.warning(
"Could not access OpenLineage provider for automatic OpenLineage "
"properties injection. No action will be performed."
)
return job

job = copy.deepcopy(job)
new_clusters = _extract_new_clusters_from_databricks_job(job)
if not new_clusters:
log.debug(
"Could not find a Databricks `new_cluster` definition for automatic OpenLineage "
"properties injection. No action will be performed."
)
return job

for new_cluster in new_clusters:
properties = new_cluster.get("spark_conf", {})
if inject_parent_job_info:
log.debug("Injecting OpenLineage parent job information into Spark properties.")
properties = inject_parent_job_information_into_spark_properties(
properties=properties, context=context
)
if inject_transport_info:
log.debug("Injecting OpenLineage transport information into Spark properties.")
properties = inject_transport_information_into_spark_properties(
properties=properties, context=context
)
new_cluster["spark_conf"] = properties

return job
Original file line number Diff line number Diff line change
Expand Up @@ -1344,6 +1344,139 @@ def test_submit_run_does_not_override_existing_task_parameters(self, db_mock_cla
assert actual["notebook_task"]["base_parameters"] == {"explicit": "value"}


class TestDatabricksSubmitRunOperatorOpenLineageInjection:
"""Tests for OpenLineage parent job info and transport info injection in DatabricksSubmitRunOperator."""

@mock.patch(
"airflow.providers.databricks.utils.openlineage.inject_openlineage_properties_into_databricks_job"
)
@mock.patch("airflow.providers.databricks.operators.databricks.DatabricksHook")
def test_inject_parent_job_info_called_when_enabled(self, db_mock_class, mock_inject):
mock_inject.side_effect = lambda job, context, inject_parent_job_info, inject_transport_info: {
**job,
"new_cluster": {
**job["new_cluster"],
"spark_conf": {"spark.openlineage.parentJobNamespace": "ns"},
},
}
op = DatabricksSubmitRunOperator(
task_id=TASK_ID,
new_cluster=NEW_CLUSTER,
notebook_task=NOTEBOOK_TASK,
openlineage_inject_parent_job_info=True,
)
db_mock = db_mock_class.return_value
db_mock.submit_run.return_value = RUN_ID
db_mock.get_run = make_run_with_state_mock("TERMINATED", "SUCCESS")

op.execute(None)

mock_inject.assert_called_once()
submitted = db_mock.submit_run.call_args.args[0]
assert submitted["new_cluster"]["spark_conf"]["spark.openlineage.parentJobNamespace"] == "ns"

@mock.patch(
"airflow.providers.databricks.utils.openlineage.inject_openlineage_properties_into_databricks_job"
)
@mock.patch("airflow.providers.databricks.operators.databricks.DatabricksHook")
def test_inject_parent_job_info_not_called_when_disabled(self, db_mock_class, mock_inject):
op = DatabricksSubmitRunOperator(
task_id=TASK_ID,
new_cluster=NEW_CLUSTER,
notebook_task=NOTEBOOK_TASK,
openlineage_inject_parent_job_info=False,
)
db_mock = db_mock_class.return_value
db_mock.submit_run.return_value = RUN_ID
db_mock.get_run = make_run_with_state_mock("TERMINATED", "SUCCESS")

op.execute(None)

mock_inject.assert_not_called()

@mock.patch(
"airflow.providers.databricks.utils.openlineage.inject_openlineage_properties_into_databricks_job"
)
@mock.patch("airflow.providers.databricks.operators.databricks.DatabricksHook")
def test_inject_transport_info_called_when_enabled(self, db_mock_class, mock_inject):
mock_inject.side_effect = lambda job, context, inject_parent_job_info, inject_transport_info: {
**job,
"new_cluster": {**job["new_cluster"], "spark_conf": {"spark.openlineage.transport.type": "http"}},
}
op = DatabricksSubmitRunOperator(
task_id=TASK_ID,
new_cluster=NEW_CLUSTER,
notebook_task=NOTEBOOK_TASK,
openlineage_inject_transport_info=True,
)
db_mock = db_mock_class.return_value
db_mock.submit_run.return_value = RUN_ID
db_mock.get_run = make_run_with_state_mock("TERMINATED", "SUCCESS")

op.execute(None)

mock_inject.assert_called_once()
submitted = db_mock.submit_run.call_args.args[0]
assert submitted["new_cluster"]["spark_conf"]["spark.openlineage.transport.type"] == "http"

@mock.patch(
"airflow.providers.databricks.utils.openlineage.inject_openlineage_properties_into_databricks_job"
)
@mock.patch("airflow.providers.databricks.operators.databricks.DatabricksHook")
def test_inject_both_parent_and_transport_info(self, db_mock_class, mock_inject):
mock_inject.side_effect = lambda job, context, inject_parent_job_info, inject_transport_info: job
op = DatabricksSubmitRunOperator(
task_id=TASK_ID,
new_cluster=NEW_CLUSTER,
notebook_task=NOTEBOOK_TASK,
openlineage_inject_parent_job_info=True,
openlineage_inject_transport_info=True,
)
db_mock = db_mock_class.return_value
db_mock.submit_run.return_value = RUN_ID
db_mock.get_run = make_run_with_state_mock("TERMINATED", "SUCCESS")

op.execute(None)

mock_inject.assert_called_once()
_, call_kwargs = mock_inject.call_args
assert call_kwargs["inject_parent_job_info"] is True
assert call_kwargs["inject_transport_info"] is True

@mock.patch(
"airflow.providers.databricks.utils.openlineage.inject_openlineage_properties_into_databricks_job"
)
@mock.patch("airflow.providers.databricks.operators.databricks.DatabricksHook")
def test_inject_parent_job_info_preserves_existing_config(self, db_mock_class, mock_inject):
"""Existing ``spark_conf`` entries are preserved alongside the injected OpenLineage properties."""
new_cluster = {**NEW_CLUSTER, "spark_conf": {"spark.executor.memory": "8g"}}
mock_inject.side_effect = lambda job, context, inject_parent_job_info, inject_transport_info: {
**job,
"new_cluster": {
**job["new_cluster"],
"spark_conf": {
**job["new_cluster"]["spark_conf"],
"spark.openlineage.parentJobNamespace": "ns",
},
},
}
op = DatabricksSubmitRunOperator(
task_id=TASK_ID,
new_cluster=new_cluster,
notebook_task=NOTEBOOK_TASK,
openlineage_inject_parent_job_info=True,
)
db_mock = db_mock_class.return_value
db_mock.submit_run.return_value = RUN_ID
db_mock.get_run = make_run_with_state_mock("TERMINATED", "SUCCESS")

op.execute(None)

submitted = db_mock.submit_run.call_args.args[0]
assert submitted["new_cluster"]["spark_conf"]["spark.executor.memory"] == "8g"
assert submitted["new_cluster"]["spark_conf"]["spark.openlineage.parentJobNamespace"] == "ns"


class TestDatabricksRunNowOperator:
def test_init_with_named_parameters(self):
"""
Expand Down
Loading
Loading